Insert KASAN shadow memory checks before memory load and store
operations in JIT-compiled BPF programs. This helps detect memory safety
bugs such as use-after-free and out-of-bounds accesses at runtime.

The main instructions being targeted are BPF_ST, BPF_STX and BPF_LDX,
but not all of them are being instrumented:
- if the load/store instruction is in fact accessing the program stack,
  emit_kasan_check silently skips the instrumentation, as we already
  have page guards to monitor stack accesses.
- if the load/store instruction is a BPF_PROBE_MEM or a BPF_PROBE_ATOMIC
  instruction, we do not instrument it, as the passed address can fault
  (hence the custom fault management with BPF_PROBE_XXX instructions),
  and so the corresponding kasan check could fault as well.

Signed-off-by: Alexis LothorĂ© (eBPF Foundation) <[email protected]>
---
Changes in v3:
- fix LLVM23 build failure

Changes in v2:
- support BPF_ATOMICS
- support BPF_ST
- make sure to systematically pass correct instruction to kasan check
---
 arch/x86/net/bpf_jit_comp.c | 72 +++++++++++++++++++++++++++++++++++++--------
 1 file changed, 60 insertions(+), 12 deletions(-)

diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
index b70cecfec179..a383ffc8f289 100644
--- a/arch/x86/net/bpf_jit_comp.c
+++ b/arch/x86/net/bpf_jit_comp.c
@@ -1576,17 +1576,31 @@ static int emit_atomic_rmw_index(u8 **pprog, u32 
atomic_op, u32 size,
        return 0;
 }
 
-static int emit_atomic_ld_st(u8 **pprog, u32 atomic_op, u32 dst_reg,
-                            u32 src_reg, s16 off, u8 bpf_size)
+static int emit_atomic_ld_st(struct bpf_verifier_env *env, u8 **pprog,
+                            struct bpf_insn *insn, u8 *ip, u32 dst_reg,
+                            u32 src_reg, bool accesses_stack_only)
 {
+       u32 atomic_op = insn->imm;
+       int err;
+
        switch (atomic_op) {
        case BPF_LOAD_ACQ:
+               err = emit_kasan_check(env, pprog, src_reg, insn, ip, false,
+                                      accesses_stack_only);
+               if (err)
+                       return err;
                /* dst_reg = smp_load_acquire(src_reg + off16) */
-               emit_ldx(pprog, bpf_size, dst_reg, src_reg, off);
+               emit_ldx(pprog, BPF_SIZE(insn->code), dst_reg, src_reg,
+                        insn->off);
                break;
        case BPF_STORE_REL:
+               err = emit_kasan_check(env, pprog, dst_reg, insn, ip, true,
+                                      accesses_stack_only);
+               if (err)
+                       return err;
                /* smp_store_release(dst_reg + off16, src_reg) */
-               emit_stx(pprog, bpf_size, dst_reg, src_reg, off);
+               emit_stx(pprog, BPF_SIZE(insn->code), dst_reg, src_reg,
+                        insn->off);
                break;
        default:
                pr_err("bpf_jit: unknown atomic load/store opcode %02x\n",
@@ -1964,10 +1978,12 @@ static int do_jit(struct bpf_verifier_env *env, struct 
bpf_prog *bpf_prog, int *
                const s32 imm32 = insn->imm;
                u32 dst_reg = insn->dst_reg;
                u32 src_reg = insn->src_reg;
+               bool accesses_stack_only;
                u8 b2 = 0, b3 = 0;
                u8 *start_of_ldx;
                s64 jmp_offset;
                s32 insn_off;
+               int insn_idx;
                u8 jmp_cond;
                u8 *func;
                int nops;
@@ -1984,6 +2000,10 @@ static int do_jit(struct bpf_verifier_env *env, struct 
bpf_prog *bpf_prog, int *
                        EMIT_ENDBR();
 
                ip = image + addrs[i - 1] + (prog - temp);
+               insn_idx = i - 1 + bpf_prog->aux->subprog_start;
+               accesses_stack_only =
+                       env ? !env->insn_aux_data[insn_idx].non_stack_access :
+                             false;
 
                switch (insn->code) {
                        /* ALU */
@@ -2364,6 +2384,11 @@ static int do_jit(struct bpf_verifier_env *env, struct 
bpf_prog *bpf_prog, int *
                case BPF_ST | BPF_MEM | BPF_H:
                case BPF_ST | BPF_MEM | BPF_W:
                case BPF_ST | BPF_MEM | BPF_DW:
+                       err = emit_kasan_check(env, &prog, dst_reg, insn, ip,
+                                              true, accesses_stack_only);
+                       if (err)
+                               return err;
+
                        emit_st(&prog, insn, dst_reg, outgoing_arg_base,
                                outgoing_rsp);
                        break;
@@ -2383,6 +2408,10 @@ static int do_jit(struct bpf_verifier_env *env, struct 
bpf_prog *bpf_prog, int *
                                insn_off = outgoing_arg_base - outgoing_rsp - 
insn_off - 16;
                                dst_reg = BPF_REG_FP;
                        }
+                       err = emit_kasan_check(env, &prog, dst_reg, insn, ip,
+                                              true, accesses_stack_only);
+                       if (err)
+                               return err;
                        emit_stx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, 
insn_off);
                        break;
 
@@ -2544,6 +2573,12 @@ static int do_jit(struct bpf_verifier_env *env, struct 
bpf_prog *bpf_prog, int *
                                /* populate jmp_offset for JAE above to jump to 
start_of_ldx */
                                start_of_ldx = prog;
                                end_of_jmp[-1] = start_of_ldx - end_of_jmp;
+                       } else {
+                               err = emit_kasan_check(env, &prog, src_reg,
+                                                      insn, ip, false,
+                                                      accesses_stack_only);
+                               if (err)
+                                       return err;
                        }
                        if (BPF_MODE(insn->code) == BPF_PROBE_MEMSX ||
                            BPF_MODE(insn->code) == BPF_MEMSX)
@@ -2605,14 +2640,14 @@ static int do_jit(struct bpf_verifier_env *env, struct 
bpf_prog *bpf_prog, int *
                        }
                        fallthrough;
                case BPF_STX | BPF_ATOMIC | BPF_W:
-               case BPF_STX | BPF_ATOMIC | BPF_DW:
+               case BPF_STX | BPF_ATOMIC | BPF_DW: {
+                       bool is64 = BPF_SIZE(insn->code) == BPF_DW;
+                       u32 real_src_reg = src_reg;
+                       u32 real_dst_reg = dst_reg;
+                       u8 *branch_target;
                        if (insn->imm == (BPF_AND | BPF_FETCH) ||
                            insn->imm == (BPF_OR | BPF_FETCH) ||
                            insn->imm == (BPF_XOR | BPF_FETCH)) {
-                               bool is64 = BPF_SIZE(insn->code) == BPF_DW;
-                               u32 real_src_reg = src_reg;
-                               u32 real_dst_reg = dst_reg;
-                               u8 *branch_target;
 
                                /*
                                 * Can't be implemented with a single x86 insn.
@@ -2626,7 +2661,19 @@ static int do_jit(struct bpf_verifier_env *env, struct 
bpf_prog *bpf_prog, int *
                                if (dst_reg == BPF_REG_0)
                                        real_dst_reg = BPF_REG_AX;
 
+                               ip += 3;
+                       }
+                       if (!bpf_atomic_is_load_store(insn)) {
+                               err = emit_kasan_check(env, &prog, real_dst_reg,
+                                                      insn, ip, false,
+                                                      accesses_stack_only);
+                               if (err)
+                                       return err;
                                branch_target = prog;
+                       }
+                       if (insn->imm == (BPF_AND | BPF_FETCH) ||
+                           insn->imm == (BPF_OR | BPF_FETCH) ||
+                           insn->imm == (BPF_XOR | BPF_FETCH)) {
                                /* Load old value */
                                emit_ldx(&prog, BPF_SIZE(insn->code),
                                         BPF_REG_0, real_dst_reg, insn->off);
@@ -2658,15 +2705,16 @@ static int do_jit(struct bpf_verifier_env *env, struct 
bpf_prog *bpf_prog, int *
                        }
 
                        if (bpf_atomic_is_load_store(insn))
-                               err = emit_atomic_ld_st(&prog, insn->imm, 
dst_reg, src_reg,
-                                                       insn->off, 
BPF_SIZE(insn->code));
+                               err = emit_atomic_ld_st(env, &prog, insn, ip,
+                                                       dst_reg, src_reg,
+                                                       accesses_stack_only);
                        else
                                err = emit_atomic_rmw(&prog, insn->imm, 
dst_reg, src_reg,
                                                      insn->off, 
BPF_SIZE(insn->code));
                        if (err)
                                return err;
                        break;
-
+               }
                case BPF_STX | BPF_PROBE_ATOMIC | BPF_B:
                case BPF_STX | BPF_PROBE_ATOMIC | BPF_H:
                        if (!bpf_atomic_is_load_store(insn)) {

-- 
2.54.0


Reply via email to