On Fri, Jan 12, 2018 at 10:11:11AM -0800, John Fastabend wrote:
> This implements a BPF ULP layer to allow policy enforcement and
> monitoring at the socket layer. In order to support this a new
> program type BPF_PROG_TYPE_SK_MSG is used to run the policy at
> the sendmsg/sendpage hook. To attach the policy to sockets a
> sockmap is used with a new program attach type BPF_SK_MSG_VERDICT.
> 
> Similar to previous sockmap usages when a sock is added to a
> sockmap, via a map update, if the map contains a BPF_SK_MSG_VERDICT
> program type attached then the BPF ULP layer is created on the
> socket and the attached BPF_PROG_TYPE_SK_MSG program is run for
> every msg in sendmsg case and page/offset in sendpage case.
> 
> BPF_PROG_TYPE_SK_MSG Semantics/API:
> 
> BPF_PROG_TYPE_SK_MSG supports only two return codes SK_PASS and
> SK_DROP. Returning SK_DROP free's the copied data in the sendmsg
> case and in the sendpage case leaves the data untouched. Both cases
> return -EACESS to the user. Returning SK_PASS will allow the msg to
> be sent.
> 
> In the sendmsg case data is copied into kernel space buffers before
> running the BPF program. In the sendpage case data is never copied.
> The implication being users may change data after BPF programs run in
> the sendpage case. (A flag will be added to always copy shortly
> if the copy must always be performed).
> 
> The verdict from the BPF_PROG_TYPE_SK_MSG applies to the entire msg
> in the sendmsg() case and the entire page/offset in the sendpage case.
> This avoid ambiguity on how to handle mixed return codes in the
> sendmsg case. The readable/writeable data provided to the program
> in the sendmsg case may not be the entire message, in fact for
> large sends this is likely the case. The data range that can be
> read is part of the sk_msg_md structure. This is because similar
> to the tc bpf_cls case the data is stored in a scatter gather list.
> Future work will address this short-coming to allow users to pull
> in more data if needed (similar to TC BPF).
> 
> The helper msg_redirect_map() can be used to select the socket to
> send the data on. This is used similar to existing redirect use
> cases. This allows policy to redirect msgs.
> 
> Pseudo code simple example:
> 
> The basic logic to attach a program to a socket is as follows,
> 
>   // load the programs
>   bpf_prog_load(SOCKMAP_TCP_MSG_PROG, BPF_PROG_TYPE_SK_MSG,
>               &obj, &msg_prog);
> 
>   // lookup the sockmap
>   bpf_map_msg = bpf_object__find_map_by_name(obj, "my_sock_map");
> 
>   // get fd for sockmap
>   map_fd_msg = bpf_map__fd(bpf_map_msg);
> 
>   // attach program to sockmap
>   bpf_prog_attach(msg_prog, map_fd_msg, BPF_SK_MSG_VERDICT, 0);
> 
> Adding sockets to the map is done in the normal way,
> 
>   // Add a socket 'fd' to sockmap at location 'i'
>   bpf_map_update_elem(map_fd_msg, &i, fd, BPF_ANY);
> 
> After the above any socket attached to "my_sock_map", in this case
> 'fd', will run the BPF msg verdict program (msg_prog) on every
> sendmsg and sendpage system call.
> 
> For a complete example see BPF selftests bpf/sockmap_tcp_msg_*.c and
> test_maps.c
> 
> Implementation notes:
> 
> It seemed the simplest, to me at least, to use a refcnt to ensure
> psock is not lost across the sendmsg copy into the sg, the bpf program
> running on the data in sg_data, and the final pass to the TCP stack.
> Some performance testing may show a better method to do this and avoid
> the refcnt cost, but for now use the simpler method.
> 
> Another item that will come after basic support is in place is
> supporting MSG_MORE flag. At the moment we call sendpages even if
> the MSG_MORE flag is set. An enhancement would be to collect the
> pages into a larger scatterlist and pass down the stack. Notice that
> bpf_tcp_sendmsg() could support this with some additional state saved
> across sendmsg calls. I built the code to support this without having
> to do refactoring work. Other flags TBD include ZEROCOPY flag.
> 
> Yet another detail that needs some thought is the size of scatterlist.
> Currently, we use MAX_SKB_FRAGS simply because this was being used
> already in the TLS case. Future work to improve the kernel sk APIs to
> tune this depending on workload may be useful. This is a trade-off
> between memory usage and B/s performance.
> 
> Signed-off-by: John Fastabend <john.fastab...@gmail.com>

overall design looks clean. imo huge improvement from first version.

Few nits:

> ---
>  include/linux/bpf.h       |    1 
>  include/linux/bpf_types.h |    1 
>  include/linux/filter.h    |   10 +
>  include/net/tcp.h         |    2 
>  include/uapi/linux/bpf.h  |   28 +++
>  kernel/bpf/sockmap.c      |  485 
> ++++++++++++++++++++++++++++++++++++++++++++-
>  kernel/bpf/syscall.c      |   14 +
>  kernel/bpf/verifier.c     |    5 
>  net/core/filter.c         |  106 ++++++++++
>  9 files changed, 638 insertions(+), 14 deletions(-)
> 
> diff --git a/include/linux/bpf.h b/include/linux/bpf.h
> index 9e03046..14cdb4d 100644
> --- a/include/linux/bpf.h
> +++ b/include/linux/bpf.h
> @@ -21,6 +21,7 @@
>  struct perf_event;
>  struct bpf_prog;
>  struct bpf_map;
> +struct sock;
>  
>  /* map is generic key/value storage optionally accesible by eBPF programs */
>  struct bpf_map_ops {
> diff --git a/include/linux/bpf_types.h b/include/linux/bpf_types.h
> index 19b8349..5e2e8a4 100644
> --- a/include/linux/bpf_types.h
> +++ b/include/linux/bpf_types.h
> @@ -13,6 +13,7 @@
>  BPF_PROG_TYPE(BPF_PROG_TYPE_LWT_XMIT, lwt_xmit)
>  BPF_PROG_TYPE(BPF_PROG_TYPE_SOCK_OPS, sock_ops)
>  BPF_PROG_TYPE(BPF_PROG_TYPE_SK_SKB, sk_skb)
> +BPF_PROG_TYPE(BPF_PROG_TYPE_SK_MSG, sk_msg)
>  #endif
>  #ifdef CONFIG_BPF_EVENTS
>  BPF_PROG_TYPE(BPF_PROG_TYPE_KPROBE, kprobe)
> diff --git a/include/linux/filter.h b/include/linux/filter.h
> index 425056c..f1e9833 100644
> --- a/include/linux/filter.h
> +++ b/include/linux/filter.h
> @@ -507,6 +507,15 @@ struct xdp_buff {
>       struct xdp_rxq_info *rxq;
>  };
>  
> +struct sk_msg_buff {
> +     void *data;
> +     void *data_end;
> +     struct scatterlist sg_data[MAX_SKB_FRAGS];
> +     __u32 key;
> +     __u32 flags;
> +     struct bpf_map *map;
> +};
> +
>  /* Compute the linear packet data range [data, data_end) which
>   * will be accessed by various program types (cls_bpf, act_bpf,
>   * lwt, ...). Subsystems allowing direct data access must (!)
> @@ -769,6 +778,7 @@ int xdp_do_redirect(struct net_device *dev,
>  void bpf_warn_invalid_xdp_action(u32 act);
>  
>  struct sock *do_sk_redirect_map(struct sk_buff *skb);
> +struct sock *do_msg_redirect_map(struct sk_msg_buff *md);
>  
>  #ifdef CONFIG_BPF_JIT
>  extern int bpf_jit_enable;
> diff --git a/include/net/tcp.h b/include/net/tcp.h
> index a99ceb8..7f56c3c 100644
> --- a/include/net/tcp.h
> +++ b/include/net/tcp.h
> @@ -1984,6 +1984,7 @@ static inline void tcp_listendrop(const struct sock *sk)
>  
>  enum {
>       TCP_ULP_TLS,
> +     TCP_ULP_BPF,
>  };
>  
>  struct tcp_ulp_ops {
> @@ -2001,6 +2002,7 @@ struct tcp_ulp_ops {
>  int tcp_register_ulp(struct tcp_ulp_ops *type);
>  void tcp_unregister_ulp(struct tcp_ulp_ops *type);
>  int tcp_set_ulp(struct sock *sk, const char *name);
> +int tcp_set_ulp_id(struct sock *sk, const int ulp);
>  void tcp_get_available_ulp(char *buf, size_t len);
>  void tcp_cleanup_ulp(struct sock *sk);
>  
> diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h
> index 405317f..bf649ae 100644
> --- a/include/uapi/linux/bpf.h
> +++ b/include/uapi/linux/bpf.h
> @@ -133,6 +133,7 @@ enum bpf_prog_type {
>       BPF_PROG_TYPE_SOCK_OPS,
>       BPF_PROG_TYPE_SK_SKB,
>       BPF_PROG_TYPE_CGROUP_DEVICE,
> +     BPF_PROG_TYPE_SK_MSG,
>  };
>  
>  enum bpf_attach_type {
> @@ -143,6 +144,7 @@ enum bpf_attach_type {
>       BPF_SK_SKB_STREAM_PARSER,
>       BPF_SK_SKB_STREAM_VERDICT,
>       BPF_CGROUP_DEVICE,
> +     BPF_SK_MSG_VERDICT,
>       __MAX_BPF_ATTACH_TYPE
>  };
>  
> @@ -687,6 +689,15 @@ enum bpf_attach_type {
>   * int bpf_override_return(pt_regs, rc)
>   *   @pt_regs: pointer to struct pt_regs
>   *   @rc: the return value to set
> + *
> + * int bpf_msg_redirect_map(map, key, flags)
> + *     Redirect msg to a sock in map using key as a lookup key for the
> + *     sock in map.
> + *     @map: pointer to sockmap
> + *     @key: key to lookup sock in map
> + *     @flags: reserved for future use
> + *     Return: SK_PASS
> + *
>   */
>  #define __BPF_FUNC_MAPPER(FN)                \
>       FN(unspec),                     \
> @@ -747,7 +758,8 @@ enum bpf_attach_type {
>       FN(perf_event_read_value),      \
>       FN(perf_prog_read_value),       \
>       FN(getsockopt),                 \
> -     FN(override_return),
> +     FN(override_return),            \
> +     FN(msg_redirect_map),
>  
>  /* integer value in 'imm' field of BPF_CALL instruction selects which helper
>   * function eBPF program intends to call
> @@ -909,6 +921,20 @@ enum sk_action {
>       SK_PASS,
>  };
>  
> +/* User return codes for SK_MSG prog type. */
> +enum sk_msg_action {
> +     SK_MSG_DROP = 0,
> +     SK_MSG_PASS,
> +};
> +
> +/* user accessible metadata for SK_MSG packet hook, new fields must
> + * be added to the end of this structure
> + */
> +struct sk_msg_md {
> +     __u32 data;
> +     __u32 data_end;
> +};
> +
>  #define BPF_TAG_SIZE 8
>  
>  struct bpf_prog_info {
> diff --git a/kernel/bpf/sockmap.c b/kernel/bpf/sockmap.c
> index 972608f..5793f3a 100644
> --- a/kernel/bpf/sockmap.c
> +++ b/kernel/bpf/sockmap.c
> @@ -38,6 +38,7 @@
>  #include <linux/skbuff.h>
>  #include <linux/workqueue.h>
>  #include <linux/list.h>
> +#include <linux/mm.h>
>  #include <net/strparser.h>
>  #include <net/tcp.h>
>  
> @@ -47,6 +48,7 @@
>  struct bpf_stab {
>       struct bpf_map map;
>       struct sock **sock_map;
> +     struct bpf_prog *bpf_tx_msg;
>       struct bpf_prog *bpf_parse;
>       struct bpf_prog *bpf_verdict;
>  };
> @@ -74,6 +76,7 @@ struct smap_psock {
>       struct sk_buff *save_skb;
>  
>       struct strparser strp;
> +     struct bpf_prog *bpf_tx_msg;
>       struct bpf_prog *bpf_parse;
>       struct bpf_prog *bpf_verdict;
>       struct list_head maps;
> @@ -90,6 +93,8 @@ struct smap_psock {
>       void (*save_state_change)(struct sock *sk);
>  };
>  
> +static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
> +
>  static inline struct smap_psock *smap_psock_sk(const struct sock *sk)
>  {
>       return rcu_dereference_sk_user_data(sk);
> @@ -99,8 +104,439 @@ enum __sk_action {
>       __SK_DROP = 0,
>       __SK_PASS,
>       __SK_REDIRECT,
> +     __SK_NONE,
>  };
>  
> +static int memcopy_from_iter(struct sock *sk, struct scatterlist *sg,
> +                          int sg_num, struct iov_iter *from, int bytes)
> +{
> +     int i, rc = 0;
> +
> +     for (i = 0; i < sg_num; ++i) {
> +             int copy = sg[i].length;
> +             char *to = sg_virt(&sg[i]);
> +
> +             if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
> +                     rc = copy_from_iter_nocache(to, copy, from);
> +             else
> +                     rc = copy_from_iter(to, copy, from);
> +
> +             if (rc != copy) {
> +                     rc = -EFAULT;
> +                     goto out;
> +             }
> +
> +             bytes -= copy;
> +             if (!bytes)
> +                     break;
> +     }
> +out:
> +     return rc;
> +}
> +
> +static int bpf_tcp_push(struct sock *sk, struct scatterlist *sg,
> +                     int *sg_end, int flags, bool charge)
> +{
> +     int sendpage_flags = flags | MSG_SENDPAGE_NOTLAST;
> +     int offset, ret = 0;
> +     struct page *p;
> +     size_t size;
> +
> +     size = sg->length;
> +     offset = sg->offset;
> +
> +     while (1) {
> +             if (sg_is_last(sg))
> +                     sendpage_flags = flags;
> +
> +             tcp_rate_check_app_limited(sk);
> +             p = sg_page(sg);
> +retry:
> +             ret = do_tcp_sendpages(sk, p, offset, size, sendpage_flags);
> +             if (ret != size) {
> +                     if (ret > 0) {
> +                             offset += ret;
> +                             size -= ret;
> +                             goto retry;
> +                     }
> +
> +                     if (charge)
> +                             sk_mem_uncharge(sk,
> +                                             sg->length - size - sg->offset);

should the bool argument be called 'uncharge' instead ?

> +
> +                     sg->offset = offset;
> +                     sg->length = size;
> +                     return ret;
> +             }
> +
> +             put_page(p);
> +             if (charge)
> +                     sk_mem_uncharge(sk, sg->length);
> +             *sg_end += 1;
> +             sg = sg_next(sg);
> +             if (!sg)
> +                     break;
> +
> +             offset = sg->offset;
> +             size = sg->length;
> +     }
> +
> +     return 0;
> +}
> +
> +static inline void bpf_compute_data_pointers_sg(struct sk_msg_buff *md)
> +{
> +     md->data = sg_virt(md->sg_data);
> +     md->data_end = md->data + md->sg_data->length;
> +}
> +
> +static void return_mem_sg(struct sock *sk, struct scatterlist *sg, int end)
> +{
> +     int i;
> +
> +     for (i = 0; i < end; ++i)
> +             sk_mem_uncharge(sk, sg[i].length);
> +}
> +
> +static int free_sg(struct sock *sk, struct scatterlist *sg, int start, int 
> len)
> +{
> +     int i, free = 0;
> +
> +     for (i = start; i < len; ++i) {
> +             free += sg[i].length;
> +             sk_mem_uncharge(sk, sg[i].length);
> +             put_page(sg_page(&sg[i]));
> +     }
> +
> +     return free;
> +}
> +
> +static unsigned int smap_do_tx_msg(struct sock *sk,
> +                                struct smap_psock *psock,
> +                                struct sk_msg_buff *md)
> +{
> +     struct bpf_prog *prog;
> +     unsigned int rc, _rc;
> +
> +     preempt_disable();
> +     rcu_read_lock();
> +
> +     /* If the policy was removed mid-send then default to 'accept' */
> +     prog = READ_ONCE(psock->bpf_tx_msg);
> +     if (unlikely(!prog)) {
> +             _rc = SK_PASS;
> +             goto verdict;
> +     }
> +
> +     bpf_compute_data_pointers_sg(md);
> +     _rc = (*prog->bpf_func)(md, prog->insnsi);
> +
> +verdict:
> +     rcu_read_unlock();
> +     preempt_enable();
> +
> +     /* Moving return codes from UAPI namespace into internal namespace */
> +     rc = ((_rc == SK_PASS) ?
> +           (md->map ? __SK_REDIRECT : __SK_PASS) :
> +           __SK_DROP);
> +
> +     return rc;
> +}
> +
> +static int bpf_tcp_sendmsg_do_redirect(struct scatterlist *sg, int sg_num,
> +                                    struct sk_msg_buff *md, int flags)
> +{
> +     int i, sg_curr = 0, err, free;
> +     struct smap_psock *psock;
> +     struct sock *sk;
> +
> +     rcu_read_lock();
> +     sk = do_msg_redirect_map(md);
> +     if (unlikely(!sk))
> +             goto out_rcu;
> +
> +     psock = smap_psock_sk(sk);
> +     if (unlikely(!psock))
> +             goto out_rcu;
> +
> +     if (!refcount_inc_not_zero(&psock->refcnt))
> +             goto out_rcu;
> +
> +     rcu_read_unlock();
> +     lock_sock(sk);
> +     err = bpf_tcp_push(sk, sg, &sg_curr, flags, false);
> +     if (unlikely(err))
> +             goto out;
> +     release_sock(sk);
> +     smap_release_sock(psock, sk);
> +     return 0;
> +out_rcu:
> +     rcu_read_unlock();
> +out:
> +     for (i = sg_curr; i < sg_num; ++i) {
> +             free += sg[i].length;
> +             put_page(sg_page(&sg[i]));
> +     }
> +     return free;

erro path keeps rcu_lock and sk locked?

> +}
> +
> +static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
> +{
> +     int err = 0, eval = __SK_NONE, sg_size = 0, sg_num = 0;
> +     int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS;
> +     struct sk_msg_buff md = {0};
> +     struct smap_psock *psock;
> +     size_t copy, copied = 0;
> +     struct scatterlist *sg;
> +     long timeo;
> +
> +     sg = md.sg_data;
> +     sg_init_table(sg, MAX_SKB_FRAGS);
> +
> +     /* Its possible a sock event or user removed the psock _but_ the ops
> +      * have not been reprogrammed yet so we get here. In this case fallback
> +      * to tcp_sendmsg. Note this only works because we _only_ ever allow
> +      * a single ULP there is no hierarchy here.
> +      */
> +     rcu_read_lock();
> +     psock = smap_psock_sk(sk);
> +     if (unlikely(!psock)) {
> +             rcu_read_unlock();
> +             return tcp_sendmsg(sk, msg, size);
> +     }
> +
> +     /* Increment the psock refcnt to ensure its not released while sending a
> +      * message. Required because sk lookup and bpf programs are used in
> +      * separate rcu critical sections. Its OK if we lose the map entry
> +      * but we can't lose the sock reference, possible when the refcnt hits
> +      * zero and garbage collection calls sock_put().
> +      */
> +     if (!refcount_inc_not_zero(&psock->refcnt)) {
> +             rcu_read_unlock();
> +             return tcp_sendmsg(sk, msg, size);
> +     }
> +
> +     rcu_read_unlock();
> +
> +     lock_sock(sk);
> +     timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
> +
> +     while (msg_data_left(msg)) {
> +             int sg_curr;
> +
> +             if (sk->sk_err) {
> +                     err = sk->sk_err;
> +                     goto out_err;
> +             }
> +
> +             copy = msg_data_left(msg);
> +             if (!sk_stream_memory_free(sk))
> +                     goto wait_for_sndbuf;
> +
> +             /* sg_size indicates bytes already allocated and sg_num
> +              * is last sg element used. This is used when alloc_sg
> +              * partially allocates a scatterlist and then is sent
> +              * to wait for memory. In normal case (no memory pressure)
> +              * both sg_nun and sg_size are zero.
> +              */
> +             copy = copy - sg_size;
> +             err = sk_alloc_sg(sk, copy, sg, &sg_num, &sg_size, 0);
> +             if (err) {
> +                     if (err != -ENOSPC)
> +                             goto wait_for_memory;
> +                     copy = sg_size;
> +             }
> +
> +             err = memcopy_from_iter(sk, sg, sg_num, &msg->msg_iter, copy);
> +             if (err < 0) {
> +                     free_sg(sk, sg, 0, sg_num);
> +                     goto out_err;
> +             }
> +
> +             copied += copy;
> +
> +             /* If msg is larger than MAX_SKB_FRAGS we can send multiple
> +              * scatterlists per msg. However BPF decisions apply to the
> +              * entire msg.
> +              */
> +             if (eval == __SK_NONE)
> +                     eval = smap_do_tx_msg(sk, psock, &md);

it seems sk_alloc_sg() will put 64k bytes into sg_data,
but this program will see only first SG ?
and it's typically going to be one page only ?
then what's the value of waiting for MAX_SKB_FRAGS ?

> +
> +             switch (eval) {
> +             case __SK_PASS:
> +                     sg_mark_end(sg + sg_num - 1);
> +                     err = bpf_tcp_push(sk, sg, &sg_curr, flags, true);
> +                     if (unlikely(err)) {
> +                             copied -= free_sg(sk, sg, sg_curr, sg_num);
> +                             goto out_err;
> +                     }
> +                     break;
> +             case __SK_REDIRECT:
> +                     sg_mark_end(sg + sg_num - 1);
> +                     goto do_redir;
> +             case __SK_DROP:
> +             default:
> +                     copied -= free_sg(sk, sg, 0, sg_num);
> +                     goto out_err;
> +             }
> +
> +             sg_num = 0;
> +             sg_size = 0;
> +             continue;
> +wait_for_sndbuf:
> +             set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
> +wait_for_memory:
> +             err = sk_stream_wait_memory(sk, &timeo);
> +             if (err)
> +                     goto out_err;
> +     }
> +out_err:
> +     if (err < 0)
> +             err = sk_stream_error(sk, msg->msg_flags, err);
> +     release_sock(sk);
> +     smap_release_sock(psock, sk);
> +     return copied ? copied : err;
> +
> +do_redir:
> +     /* To avoid deadlock with multiple socks all doing redirects to
> +      * each other we must first drop the current sock lock and release
> +      * the psock. Then get the redirect socket (assuming it still
> +      * exists), take it's lock, and finally do the send here. If the
> +      * redirect fails there is nothing to do, we don't want to blame
> +      * the sender for remote socket failures. Instead we simply
> +      * continue making forward progress.
> +      */
> +     return_mem_sg(sk, sg, sg_num);
> +     release_sock(sk);
> +     smap_release_sock(psock, sk);
> +     copied -= bpf_tcp_sendmsg_do_redirect(sg, sg_num, &md, flags);
> +     return copied;
> +}
> +
> +static int bpf_tcp_sendpage_do_redirect(struct page *page, int offset,
> +                                     size_t size, int flags,
> +                                     struct sk_msg_buff *md)
> +{
> +     struct smap_psock *psock;
> +     struct sock *sk;
> +     int rc;
> +
> +     rcu_read_lock();
> +     sk = do_msg_redirect_map(md);
> +     if (unlikely(!sk))
> +             goto out_rcu;
> +
> +     psock = smap_psock_sk(sk);
> +     if (unlikely(!psock))
> +             goto out_rcu;
> +
> +     if (!refcount_inc_not_zero(&psock->refcnt))
> +             goto out_rcu;
> +
> +     rcu_read_unlock();
> +
> +     lock_sock(sk);
> +     rc = tcp_sendpage_locked(sk, page, offset, size, flags);
> +     release_sock(sk);
> +
> +     smap_release_sock(psock, sk);
> +     return rc;
> +out_rcu:
> +     rcu_read_unlock();
> +     return -EINVAL;
> +}
> +
> +static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
> +                         int offset, size_t size, int flags)
> +{
> +     struct smap_psock *psock;
> +     int rc, _rc = __SK_PASS;
> +     struct bpf_prog *prog;
> +     struct sk_msg_buff md;
> +
> +     preempt_disable();
> +     rcu_read_lock();
> +     psock = smap_psock_sk(sk);
> +     if (unlikely(!psock))
> +             goto verdict;
> +
> +     /* If the policy was removed mid-send then default to 'accept' */
> +     prog = READ_ONCE(psock->bpf_tx_msg);
> +     if (unlikely(!prog))
> +             goto verdict;
> +
> +     /* Calculate pkt data pointers and run BPF program */
> +     md.data = page_address(page) + offset;
> +     md.data_end = md.data + size;
> +     _rc = (*prog->bpf_func)(&md, prog->insnsi);
> +
> +verdict:
> +     rcu_read_unlock();
> +     preempt_enable();
> +
> +     /* Moving return codes from UAPI namespace into internal namespace */
> +     rc = ((_rc == SK_PASS) ? __SK_PASS : __SK_DROP);
> +
> +     switch (rc) {
> +     case __SK_PASS:
> +             lock_sock(sk);
> +             rc = tcp_sendpage_locked(sk, page, offset, size, flags);
> +             release_sock(sk);
> +             break;
> +     case __SK_REDIRECT:
> +             smap_release_sock(psock, sk);
> +             rc = bpf_tcp_sendpage_do_redirect(page, offset, size, flags,
> +                                               &md);

looks like this path wasn't tested,
since above rc = ...; line cannot return REDIRECT...
probably should be common helper for both tcp_bpf_*() funcs
to call into bpf and convert rc.

> +             break;
> +     case __SK_DROP:
> +     default:
> +             rc = -EACCES;
> +     }
> +
> +     return rc;
> +}
> +
> +static int bpf_tcp_msg_add(struct smap_psock *psock,
> +                        struct sock *sk,
> +                        struct bpf_prog *tx_msg)
> +{
> +     struct bpf_prog *orig_tx_msg;
> +
> +     orig_tx_msg = xchg(&psock->bpf_tx_msg, tx_msg);
> +     if (orig_tx_msg)
> +             bpf_prog_put(orig_tx_msg);

the function is replacing the program. why is it called bpf_tcp_msg_add ?

> +
> +     return tcp_set_ulp_id(sk, TCP_ULP_BPF);
> +}
> +
> +struct proto tcp_bpf_proto;
> +static int bpf_tcp_init(struct sock *sk)
> +{
> +     sk->sk_prot = &tcp_bpf_proto;
> +     return 0;
> +}
> +
> +static void bpf_tcp_release(struct sock *sk)
> +{
> +     sk->sk_prot = &tcp_prot;
> +}
> +
> +static struct tcp_ulp_ops bpf_tcp_ulp_ops __read_mostly = {
> +     .name                   = "bpf_tcp",
> +     .uid                    = TCP_ULP_BPF,
> +     .owner                  = NULL,
> +     .init                   = bpf_tcp_init,
> +     .release                = bpf_tcp_release,
> +};
> +
> +static int bpf_tcp_ulp_register(void)
> +{
> +     tcp_bpf_proto = tcp_prot;
> +     tcp_bpf_proto.sendmsg = bpf_tcp_sendmsg;
> +     tcp_bpf_proto.sendpage = bpf_tcp_sendpage;
> +     return tcp_register_ulp(&bpf_tcp_ulp_ops);

I don't see corresponding tcp_unregister_ulp().

> +}
> +
>  static int smap_verdict_func(struct smap_psock *psock, struct sk_buff *skb)
>  {
>       struct bpf_prog *prog = READ_ONCE(psock->bpf_verdict);
> @@ -165,8 +601,6 @@ static void smap_report_sk_error(struct smap_psock 
> *psock, int err)
>       sk->sk_error_report(sk);
>  }
>  
> -static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
> -
>  /* Called with lock_sock(sk) held */
>  static void smap_state_change(struct sock *sk)
>  {
> @@ -317,6 +751,7 @@ static void smap_write_space(struct sock *sk)
>  
>  static void smap_stop_sock(struct smap_psock *psock, struct sock *sk)
>  {
> +     tcp_cleanup_ulp(sk);
>       if (!psock->strp_enabled)
>               return;
>       sk->sk_data_ready = psock->save_data_ready;
> @@ -384,7 +819,6 @@ static int smap_parse_func_strparser(struct strparser 
> *strp,
>       return rc;
>  }
>  
> -
>  static int smap_read_sock_done(struct strparser *strp, int err)
>  {
>       return err;
> @@ -456,6 +890,8 @@ static void smap_gc_work(struct work_struct *w)
>               bpf_prog_put(psock->bpf_parse);
>       if (psock->bpf_verdict)
>               bpf_prog_put(psock->bpf_verdict);
> +     if (psock->bpf_tx_msg)
> +             bpf_prog_put(psock->bpf_tx_msg);
>  
>       list_for_each_entry_safe(e, tmp, &psock->maps, list) {
>               list_del(&e->list);
> @@ -491,8 +927,7 @@ static struct smap_psock *smap_init_psock(struct sock 
> *sock,
>  
>  static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
>  {
> -     struct bpf_stab *stab;
> -     int err = -EINVAL;
> +     struct bpf_stab *stab; int err = -EINVAL;
>       u64 cost;
>  
>       if (!capable(CAP_NET_ADMIN))
> @@ -506,6 +941,10 @@ static struct bpf_map *sock_map_alloc(union bpf_attr 
> *attr)
>       if (attr->value_size > KMALLOC_MAX_SIZE)
>               return ERR_PTR(-E2BIG);
>  
> +     err = bpf_tcp_ulp_register();
> +     if (err && err != -EEXIST)
> +             return ERR_PTR(err);
> +
>       stab = kzalloc(sizeof(*stab), GFP_USER);
>       if (!stab)
>               return ERR_PTR(-ENOMEM);
> @@ -590,6 +1029,8 @@ static void sock_map_free(struct bpf_map *map)
>               bpf_prog_put(stab->bpf_verdict);
>       if (stab->bpf_parse)
>               bpf_prog_put(stab->bpf_parse);
> +     if (stab->bpf_tx_msg)
> +             bpf_prog_put(stab->bpf_tx_msg);
>  
>       sock_map_remove_complete(stab);
>  }
> @@ -684,7 +1125,7 @@ static int sock_map_ctx_update_elem(struct 
> bpf_sock_ops_kern *skops,
>  {
>       struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
>       struct smap_psock_map_entry *e = NULL;
> -     struct bpf_prog *verdict, *parse;
> +     struct bpf_prog *verdict, *parse, *tx_msg;
>       struct sock *osock, *sock;
>       struct smap_psock *psock;
>       u32 i = *(u32 *)key;
> @@ -710,6 +1151,7 @@ static int sock_map_ctx_update_elem(struct 
> bpf_sock_ops_kern *skops,
>        */
>       verdict = READ_ONCE(stab->bpf_verdict);
>       parse = READ_ONCE(stab->bpf_parse);
> +     tx_msg = READ_ONCE(stab->bpf_tx_msg);
>  
>       if (parse && verdict) {
>               /* bpf prog refcnt may be zero if a concurrent attach operation
> @@ -728,6 +1170,17 @@ static int sock_map_ctx_update_elem(struct 
> bpf_sock_ops_kern *skops,
>               }
>       }
>  
> +     if (tx_msg) {
> +             tx_msg = bpf_prog_inc_not_zero(stab->bpf_tx_msg);

prog_inc_not_zero() looks scary here.
Why 'not_zero' is necessary ?

> +             if (IS_ERR(tx_msg)) {
> +                     if (verdict)
> +                             bpf_prog_put(verdict);
> +                     if (parse)
> +                             bpf_prog_put(parse);
> +                     return PTR_ERR(tx_msg);
> +             }
> +     }
> +

Reply via email to