New sockets cloned from listening sockets that are in a sockmap must not
inherit the psock that has the link to the sockmap. Otherwise child sockets
unintentionally share the sockmap entry with the listening socket, which
leads to double-free on socket close.

Prevent it by overloading the accept callback. In it we restore the
protocol and write buffer callbacks and clear the pointer to psock.

Signed-off-by: Jakub Sitnicki <ja...@cloudflare.com>
---
 net/ipv4/tcp_bpf.c | 30 ++++++++++++++++++++++++++++++
 1 file changed, 30 insertions(+)

diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c
index 8a56e09cfb0e..5838aaba4ce0 100644
--- a/net/ipv4/tcp_bpf.c
+++ b/net/ipv4/tcp_bpf.c
@@ -582,6 +582,35 @@ static void tcp_bpf_close(struct sock *sk, long timeout)
        saved_close(sk, timeout);
 }
 
+static struct sock *tcp_bpf_accept(struct sock *sk, int flags, int *err,
+                                  bool kern)
+{
+       void (*saved_write_space)(struct sock *sk);
+       struct proto *saved_proto;
+       struct sk_psock *psock;
+       struct sock *child;
+
+       rcu_read_lock();
+       psock = sk_psock(sk);
+       if (unlikely(!psock)) {
+               rcu_read_unlock();
+               return sk->sk_prot->accept(sk, flags, err, kern);
+       }
+       saved_proto = psock->sk_proto;
+       saved_write_space = psock->saved_write_space;
+       rcu_read_unlock();
+
+       child = saved_proto->accept(sk, flags, err, kern);
+       if (!child)
+               return NULL;
+
+       /* Child must not inherit psock or its ops. */
+       rcu_assign_sk_user_data(child, NULL);
+       child->sk_prot = saved_proto;
+       child->sk_write_space = saved_write_space;
+       return child;
+}
+
 enum {
        TCP_BPF_IPV4,
        TCP_BPF_IPV6,
@@ -606,6 +635,7 @@ static void tcp_bpf_rebuild_protos(struct proto 
prot[TCP_BPF_NUM_CFGS],
        prot[TCP_BPF_BASE].close                = tcp_bpf_close;
        prot[TCP_BPF_BASE].recvmsg              = tcp_bpf_recvmsg;
        prot[TCP_BPF_BASE].stream_memory_read   = tcp_bpf_stream_read;
+       prot[TCP_BPF_BASE].accept               = tcp_bpf_accept;
 
        prot[TCP_BPF_TX]                        = prot[TCP_BPF_BASE];
        prot[TCP_BPF_TX].sendmsg                = tcp_bpf_sendmsg;
-- 
2.20.1

Reply via email to