The test installs a kprobe on __sys_connect and checks that
bpf_probe_write_user() can modify the syscall argument. However, any
concurrent thread in any other test that calls connect() will also
trigger the kprobe and have its sockaddr silently overwritten, causing
flaky failures in unrelated tests.

Fix this by introducing a pid_map (BPF_MAP_TYPE_ARRAY) that stores the
test process's PID. The kprobe handler returns early if the current PID
does not match, confining the hook to the test process only.

Also remove the TODO comment that tracked this issue.

Signed-off-by: Sun Jian <[email protected]>
---
 .../selftests/bpf/prog_tests/probe_user.c     | 13 ++++++++++-
 .../selftests/bpf/progs/test_probe_user.c     | 22 +++++++++++++++++++
 2 files changed, 34 insertions(+), 1 deletion(-)

diff --git a/tools/testing/selftests/bpf/prog_tests/probe_user.c 
b/tools/testing/selftests/bpf/prog_tests/probe_user.c
index 8721671321de..f50ee4066759 100644
--- a/tools/testing/selftests/bpf/prog_tests/probe_user.c
+++ b/tools/testing/selftests/bpf/prog_tests/probe_user.c
@@ -13,13 +13,15 @@ void serial_test_probe_user(void)
        enum { prog_count = ARRAY_SIZE(prog_names) };
        const char *obj_file = "./test_probe_user.bpf.o";
        DECLARE_LIBBPF_OPTS(bpf_object_open_opts, opts, );
-       int err, results_map_fd, sock_fd, duration = 0;
+       int err, results_map_fd, pid_map_fd, sock_fd, duration = 0;
        struct sockaddr curr, orig, tmp;
        struct sockaddr_in *in = (struct sockaddr_in *)&curr;
        struct bpf_link *kprobe_links[prog_count] = {};
        struct bpf_program *kprobe_progs[prog_count];
        struct bpf_object *obj;
        static const int zero = 0;
+       __u32 key = 0;
+       __u32 pid;
        size_t i;
 
        obj = bpf_object__open_file(obj_file, &opts);
@@ -38,6 +40,15 @@ void serial_test_probe_user(void)
        if (CHECK(err, "obj_load", "err %d\n", err))
                goto cleanup;
 
+       pid_map_fd = bpf_find_map(__func__, obj, "pid_map");
+       if (CHECK(pid_map_fd < 0, "find_pid_map", "err %d\n", pid_map_fd))
+               goto cleanup;
+
+       pid = getpid();
+       err = bpf_map_update_elem(pid_map_fd, &key, &pid, BPF_ANY);
+       if (CHECK(err, "update_pid_map", "err %d\n", err))
+               goto cleanup;
+
        results_map_fd = bpf_find_map(__func__, obj, "test_pro.bss");
        if (CHECK(results_map_fd < 0, "find_bss_map",
                  "err %d\n", results_map_fd))
diff --git a/tools/testing/selftests/bpf/progs/test_probe_user.c 
b/tools/testing/selftests/bpf/progs/test_probe_user.c
index a8e501af9604..20a13b984fb0 100644
--- a/tools/testing/selftests/bpf/progs/test_probe_user.c
+++ b/tools/testing/selftests/bpf/progs/test_probe_user.c
@@ -5,6 +5,24 @@
 #include <bpf/bpf_core_read.h>
 #include "bpf_misc.h"
 
+struct {
+       __uint(type, BPF_MAP_TYPE_ARRAY);
+       __uint(max_entries, 1);
+       __type(key, __u32);
+       __type(value, __u32);
+} pid_map SEC(".maps");
+
+static __always_inline int pid_ok(void)
+{
+       __u32 key = 0, *expected_pid;
+
+       expected_pid =  bpf_map_lookup_elem(&pid_map, &key);
+       if (!expected_pid ||
+           *expected_pid != (bpf_get_current_pid_tgid() >> 32))
+               return 0;
+       return 1;
+}
+
 static struct sockaddr_in old;
 
 static int handle_sys_connect_common(struct sockaddr_in *uservaddr)
@@ -22,6 +40,8 @@ SEC("ksyscall/connect")
 int BPF_KSYSCALL(handle_sys_connect, int fd, struct sockaddr_in *uservaddr,
                 int addrlen)
 {
+       if (!pid_ok())
+               return 0;
        return handle_sys_connect_common(uservaddr);
 }
 
@@ -36,6 +56,8 @@ int BPF_KSYSCALL(handle_sys_socketcall, int call, unsigned 
long *args)
        if (call == SYS_CONNECT) {
                struct sockaddr_in *uservaddr;
 
+               if (!pid_ok())
+                       return 0;
                bpf_probe_read_user(&uservaddr, sizeof(uservaddr), &args[1]);
                return handle_sys_connect_common(uservaddr);
        }
-- 
2.43.0


Reply via email to