Allow looking up a sock_common. This gives eBPF programs
access to timewait and request sockets.

Signed-off-by: Lorenz Bauer <l...@cloudflare.com>
---
 include/uapi/linux/bpf.h |  20 ++++++-
 kernel/bpf/verifier.c    |   3 +-
 net/core/filter.c        | 113 +++++++++++++++++++++++++++++++++++----
 3 files changed, 124 insertions(+), 12 deletions(-)

diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h
index 983b25cb608d..8e4f8276942a 100644
--- a/include/uapi/linux/bpf.h
+++ b/include/uapi/linux/bpf.h
@@ -2374,6 +2374,23 @@ union bpf_attr {
  *     Return
  *             A **struct bpf_sock** pointer on success, or NULL in
  *             case of failure.
+ *
+ * struct bpf_sock *bpf_skc_lookup_tcp(void *ctx, struct bpf_sock_tuple 
*tuple, u32 tuple_size, u64 netns, u64 flags)
+ *     Description
+ *             Look for TCP socket matching *tuple*, optionally in a child
+ *             network namespace *netns*. The return value must be checked,
+ *             and if non-**NULL**, released via **bpf_sk_release**\ ().
+ *
+ *             This function is identical to bpf_sk_lookup_tcp, except that it
+ *             also returns timewait or request sockets. Use bpf_sk_fullsock
+ *             or bpf_tcp_socket to access the full structure.
+ *
+ *             This helper is available only if the kernel was compiled with
+ *             **CONFIG_NET** configuration option.
+ *     Return
+ *             Pointer to **struct bpf_sock**, or **NULL** in case of failure.
+ *             For sockets with reuseport option, the **struct bpf_sock**
+ *             result is from **reuse->socks**\ [] using the hash of the tuple.
  */
 #define __BPF_FUNC_MAPPER(FN)          \
        FN(unspec),                     \
@@ -2474,7 +2491,8 @@ union bpf_attr {
        FN(sk_fullsock),                \
        FN(tcp_sock),                   \
        FN(skb_ecn_set_ce),             \
-       FN(get_listener_sock),
+       FN(get_listener_sock),          \
+       FN(skc_lookup_tcp),
 
 /* integer value in 'imm' field of BPF_CALL instruction selects which helper
  * function eBPF program intends to call
diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index f60d9df4e00a..94420942af32 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -369,7 +369,8 @@ static bool is_release_function(enum bpf_func_id func_id)
 static bool is_acquire_function(enum bpf_func_id func_id)
 {
        return func_id == BPF_FUNC_sk_lookup_tcp ||
-               func_id == BPF_FUNC_sk_lookup_udp;
+               func_id == BPF_FUNC_sk_lookup_udp ||
+               func_id == BPF_FUNC_skc_lookup_tcp;
 }
 
 static bool is_ptr_cast_function(enum bpf_func_id func_id)
diff --git a/net/core/filter.c b/net/core/filter.c
index f879791ea53f..f5210773cfd8 100644
--- a/net/core/filter.c
+++ b/net/core/filter.c
@@ -5156,15 +5156,15 @@ static struct sock *sk_lookup(struct net *net, struct 
bpf_sock_tuple *tuple,
        return sk;
 }
 
-/* bpf_sk_lookup performs the core lookup for different types of sockets,
+/* bpf_skc_lookup performs the core lookup for different types of sockets,
  * taking a reference on the socket if it doesn't have the flag SOCK_RCU_FREE.
  * Returns the socket as an 'unsigned long' to simplify the casting in the
  * callers to satisfy BPF_CALL declarations.
  */
 static unsigned long
-__bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
-               struct net *caller_net, u32 ifindex, u8 proto, u64 netns_id,
-               u64 flags)
+__bpf_skc_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
+                struct net *caller_net, u32 ifindex, u8 proto, u64 netns_id,
+                u64 flags)
 {
        struct sock *sk = NULL;
        u8 family = AF_UNSPEC;
@@ -5192,15 +5192,28 @@ __bpf_sk_lookup(struct sk_buff *skb, struct 
bpf_sock_tuple *tuple, u32 len,
                put_net(net);
        }
 
-       if (sk)
-               sk = sk_to_full_sk(sk);
 out:
        return (unsigned long) sk;
 }
 
 static unsigned long
-bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
-             u8 proto, u64 netns_id, u64 flags)
+__bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
+               struct net *caller_net, u32 ifindex, u8 proto, u64 netns_id,
+               u64 flags)
+{
+       struct sock *sk;
+
+       sk = (struct sock *)__bpf_skc_lookup(skb, tuple, len, caller_net,
+                                           ifindex, proto, netns_id, flags);
+       if (sk)
+               sk = sk_to_full_sk(sk);
+
+       return (unsigned long)sk;
+}
+
+static unsigned long
+bpf_skc_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
+              u8 proto, u64 netns_id, u64 flags)
 {
        struct net *caller_net;
        int ifindex;
@@ -5213,10 +5226,42 @@ bpf_sk_lookup(struct sk_buff *skb, struct 
bpf_sock_tuple *tuple, u32 len,
                ifindex = 0;
        }
 
-       return __bpf_sk_lookup(skb, tuple, len, caller_net, ifindex,
-                             proto, netns_id, flags);
+       return __bpf_skc_lookup(skb, tuple, len, caller_net, ifindex,
+                               proto, netns_id, flags);
 }
 
+static unsigned long
+bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
+             u8 proto, u64 netns_id, u64 flags)
+{
+       struct sock *sk;
+
+       sk = (struct sock *)bpf_skc_lookup(skb, tuple, len, proto, netns_id,
+                                         flags);
+       if (sk)
+               sk = sk_to_full_sk(sk);
+
+       return (unsigned long)sk;
+}
+
+BPF_CALL_5(bpf_skc_lookup_tcp, struct sk_buff *, skb,
+          struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags)
+{
+       return bpf_skc_lookup(skb, tuple, len, IPPROTO_TCP, netns_id, flags);
+}
+
+static const struct bpf_func_proto bpf_skc_lookup_tcp_proto = {
+       .func           = bpf_skc_lookup_tcp,
+       .gpl_only       = false,
+       .pkt_access     = true,
+       .ret_type       = RET_PTR_TO_SOCK_COMMON_OR_NULL,
+       .arg1_type      = ARG_PTR_TO_CTX,
+       .arg2_type      = ARG_PTR_TO_MEM,
+       .arg3_type      = ARG_CONST_SIZE,
+       .arg4_type      = ARG_ANYTHING,
+       .arg5_type      = ARG_ANYTHING,
+};
+
 BPF_CALL_5(bpf_sk_lookup_tcp, struct sk_buff *, skb,
           struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags)
 {
@@ -5289,6 +5334,28 @@ static const struct bpf_func_proto 
bpf_xdp_sk_lookup_udp_proto = {
        .arg5_type      = ARG_ANYTHING,
 };
 
+BPF_CALL_5(bpf_xdp_skc_lookup_tcp, struct xdp_buff *, ctx,
+          struct bpf_sock_tuple *, tuple, u32, len, u32, netns_id, u64, flags)
+{
+       struct net *caller_net = dev_net(ctx->rxq->dev);
+       int ifindex = ctx->rxq->dev->ifindex;
+
+       return __bpf_skc_lookup(NULL, tuple, len, caller_net, ifindex,
+                               IPPROTO_TCP, netns_id, flags);
+}
+
+static const struct bpf_func_proto bpf_xdp_skc_lookup_tcp_proto = {
+       .func           = bpf_xdp_skc_lookup_tcp,
+       .gpl_only       = false,
+       .pkt_access     = true,
+       .ret_type       = RET_PTR_TO_SOCK_COMMON_OR_NULL,
+       .arg1_type      = ARG_PTR_TO_CTX,
+       .arg2_type      = ARG_PTR_TO_MEM,
+       .arg3_type      = ARG_CONST_SIZE,
+       .arg4_type      = ARG_ANYTHING,
+       .arg5_type      = ARG_ANYTHING,
+};
+
 BPF_CALL_5(bpf_xdp_sk_lookup_tcp, struct xdp_buff *, ctx,
           struct bpf_sock_tuple *, tuple, u32, len, u32, netns_id, u64, flags)
 {
@@ -5311,6 +5378,24 @@ static const struct bpf_func_proto 
bpf_xdp_sk_lookup_tcp_proto = {
        .arg5_type      = ARG_ANYTHING,
 };
 
+BPF_CALL_5(bpf_sock_addr_skc_lookup_tcp, struct bpf_sock_addr_kern *, ctx,
+          struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags)
+{
+       return __bpf_skc_lookup(NULL, tuple, len, sock_net(ctx->sk), 0,
+                               IPPROTO_TCP, netns_id, flags);
+}
+
+static const struct bpf_func_proto bpf_sock_addr_skc_lookup_tcp_proto = {
+       .func           = bpf_sock_addr_skc_lookup_tcp,
+       .gpl_only       = false,
+       .ret_type       = RET_PTR_TO_SOCK_COMMON_OR_NULL,
+       .arg1_type      = ARG_PTR_TO_CTX,
+       .arg2_type      = ARG_PTR_TO_MEM,
+       .arg3_type      = ARG_CONST_SIZE,
+       .arg4_type      = ARG_ANYTHING,
+       .arg5_type      = ARG_ANYTHING,
+};
+
 BPF_CALL_5(bpf_sock_addr_sk_lookup_tcp, struct bpf_sock_addr_kern *, ctx,
           struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags)
 {
@@ -5586,6 +5671,8 @@ sock_addr_func_proto(enum bpf_func_id func_id, const 
struct bpf_prog *prog)
                return &bpf_sock_addr_sk_lookup_udp_proto;
        case BPF_FUNC_sk_release:
                return &bpf_sk_release_proto;
+       case BPF_FUNC_skc_lookup_tcp:
+               return &bpf_sock_addr_skc_lookup_tcp_proto;
 #endif /* CONFIG_INET */
        default:
                return bpf_base_func_proto(func_id);
@@ -5719,6 +5806,8 @@ tc_cls_act_func_proto(enum bpf_func_id func_id, const 
struct bpf_prog *prog)
                return &bpf_tcp_sock_proto;
        case BPF_FUNC_get_listener_sock:
                return &bpf_get_listener_sock_proto;
+       case BPF_FUNC_skc_lookup_tcp:
+               return &bpf_skc_lookup_tcp_proto;
 #endif
        default:
                return bpf_base_func_proto(func_id);
@@ -5754,6 +5843,8 @@ xdp_func_proto(enum bpf_func_id func_id, const struct 
bpf_prog *prog)
                return &bpf_xdp_sk_lookup_tcp_proto;
        case BPF_FUNC_sk_release:
                return &bpf_sk_release_proto;
+       case BPF_FUNC_skc_lookup_tcp:
+               return &bpf_xdp_skc_lookup_tcp_proto;
 #endif
        default:
                return bpf_base_func_proto(func_id);
@@ -5846,6 +5937,8 @@ sk_skb_func_proto(enum bpf_func_id func_id, const struct 
bpf_prog *prog)
                return &bpf_sk_lookup_udp_proto;
        case BPF_FUNC_sk_release:
                return &bpf_sk_release_proto;
+       case BPF_FUNC_skc_lookup_tcp:
+               return &bpf_skc_lookup_tcp_proto;
 #endif
        default:
                return bpf_base_func_proto(func_id);
-- 
2.19.1

Reply via email to