In the case where we need a specific number of bytes before a verdict can be assigned, even if the data spans multiple sendmsg or sendfile calls. The BPF program may use msg_apply_bytes().
The extreme case is a user can call sendmsg repeatedly with 1-byte msg segments. Obviously, this is bad for performance but is still valid. If the BPF program needs N bytes to validate a header it can use msg_cork_bytes to specify N bytes and the BPF program will not be called again until N bytes have been accumulated. Signed-off-by: John Fastabend <john.fastab...@gmail.com> --- include/linux/filter.h | 2 include/uapi/linux/bpf.h | 3 kernel/bpf/sockmap.c | 334 ++++++++++++++++++++++++++++++++++++++++------ net/core/filter.c | 16 ++ 4 files changed, 310 insertions(+), 45 deletions(-) diff --git a/include/linux/filter.h b/include/linux/filter.h index 805a566..6058a1b 100644 --- a/include/linux/filter.h +++ b/include/linux/filter.h @@ -511,6 +511,8 @@ struct sk_msg_buff { void *data; void *data_end; int apply_bytes; + int cork_bytes; + int sg_copybreak; int sg_start; int sg_curr; int sg_end; diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h index e50c61f..cfcc002 100644 --- a/include/uapi/linux/bpf.h +++ b/include/uapi/linux/bpf.h @@ -770,7 +770,8 @@ enum bpf_attach_type { FN(override_return), \ FN(sock_ops_cb_flags_set), \ FN(msg_redirect_map), \ - FN(msg_apply_bytes), + FN(msg_apply_bytes), \ + FN(msg_cork_bytes), /* integer value in 'imm' field of BPF_CALL instruction selects which helper * function eBPF program intends to call diff --git a/kernel/bpf/sockmap.c b/kernel/bpf/sockmap.c index 98c6a3b..f637a83 100644 --- a/kernel/bpf/sockmap.c +++ b/kernel/bpf/sockmap.c @@ -78,8 +78,10 @@ struct smap_psock { /* datapath variables for tx_msg ULP */ struct sock *sk_redir; int apply_bytes; + int cork_bytes; int sg_size; int eval; + struct sk_msg_buff *cork; struct strparser strp; struct bpf_prog *bpf_tx_msg; @@ -140,22 +142,30 @@ static int bpf_tcp_init(struct sock *sk) return 0; } +static void smap_release_sock(struct smap_psock *psock, struct sock *sock); +static int free_start_sg(struct sock *sk, struct sk_msg_buff *md); + static void bpf_tcp_release(struct sock *sk) { struct smap_psock *psock; rcu_read_lock(); psock = smap_psock_sk(sk); + if (unlikely(!psock)) + goto out; - if (likely(psock)) { - sk->sk_prot = psock->sk_proto; - psock->sk_proto = NULL; + if (psock->cork) { + free_start_sg(psock->sock, psock->cork); + kfree(psock->cork); + psock->cork = NULL; } + + sk->sk_prot = psock->sk_proto; + psock->sk_proto = NULL; +out: rcu_read_unlock(); } -static void smap_release_sock(struct smap_psock *psock, struct sock *sock); - static void bpf_tcp_close(struct sock *sk, long timeout) { void (*close_fun)(struct sock *sk, long timeout); @@ -211,14 +221,25 @@ static int memcopy_from_iter(struct sock *sk, struct iov_iter *from, int bytes) { struct scatterlist *sg = md->sg_data; - int i = md->sg_curr, rc = 0; + int i = md->sg_curr, rc = -ENOSPC; do { int copy; char *to; - copy = sg[i].length; - to = sg_virt(&sg[i]); + if (md->sg_copybreak >= sg[i].length) { + md->sg_copybreak = 0; + + if (++i == MAX_SKB_FRAGS) + i = 0; + + if (i == md->sg_end) + break; + } + + copy = sg[i].length - md->sg_copybreak; + to = sg_virt(&sg[i]) + md->sg_copybreak; + md->sg_copybreak += copy; if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY) rc = copy_from_iter_nocache(to, copy, from); @@ -234,6 +255,7 @@ static int memcopy_from_iter(struct sock *sk, if (!bytes) break; + md->sg_copybreak = 0; if (++i == MAX_SKB_FRAGS) i = 0; } while (i != md->sg_end); @@ -328,6 +350,33 @@ static void return_mem_sg(struct sock *sk, int bytes, struct sk_msg_buff *md) } while (i != md->sg_end); } +static void free_bytes_sg(struct sock *sk, int bytes, struct sk_msg_buff *md) +{ + struct scatterlist *sg = md->sg_data; + int i = md->sg_start, free; + + while (bytes && sg[i].length) { + free = sg[i].length; + if (bytes < free) { + sg[i].length -= bytes; + sg[i].offset += bytes; + sk_mem_uncharge(sk, bytes); + break; + } + + sk_mem_uncharge(sk, sg[i].length); + put_page(sg_page(&sg[i])); + bytes -= sg[i].length; + sg[i].length = 0; + sg[i].page_link = 0; + sg[i].offset = 0; + i++; + + if (i == MAX_SKB_FRAGS) + i = 0; + } +} + static int free_sg(struct sock *sk, int start, struct sk_msg_buff *md) { struct scatterlist *sg = md->sg_data; @@ -510,6 +559,9 @@ static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); while (msg_data_left(msg)) { + bool cork = false, enospc = false; + struct sk_msg_buff *m; + if (sk->sk_err) { err = sk->sk_err; goto out_err; @@ -519,32 +571,76 @@ static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) if (!sk_stream_memory_free(sk)) goto wait_for_sndbuf; - md.sg_curr = md.sg_end; - err = sk_alloc_sg(sk, copy, sg, - md.sg_start, &md.sg_end, &sg_copy, - md.sg_end); + m = psock->cork_bytes ? psock->cork : &md; + m->sg_curr = m->sg_copybreak ? m->sg_curr : m->sg_end; + err = sk_alloc_sg(sk, copy, m->sg_data, + m->sg_start, &m->sg_end, &sg_copy, + m->sg_end - 1); if (err) { if (err != -ENOSPC) goto wait_for_memory; + enospc = true; copy = sg_copy; } - err = memcopy_from_iter(sk, &md, &msg->msg_iter, copy); + err = memcopy_from_iter(sk, m, &msg->msg_iter, copy); if (err < 0) { - free_curr_sg(sk, &md); + free_curr_sg(sk, m); goto out_err; } psock->sg_size += copy; copied += copy; sg_copy = 0; + + /* When bytes are being corked skip running BPF program and + * applying verdict unless there is no more buffer space. In + * the ENOSPC case simply run BPF prorgram with currently + * accumulated data. We don't have much choice at this point + * we could try extending the page frags or chaining complex + * frags but even in these cases _eventually_ we will hit an + * OOM scenario. More complex recovery schemes may be + * implemented in the future, but BPF programs must handle + * the case where apply_cork requests are not honored. The + * canonical method to verify this is to check data length. + */ + if (psock->cork_bytes) { + if (copy > psock->cork_bytes) + psock->cork_bytes = 0; + else + psock->cork_bytes -= copy; + + if (psock->cork_bytes && !enospc) + goto out_cork; + + /* All cork bytes accounted for re-run filter */ + psock->eval = __SK_NONE; + psock->cork_bytes = 0; + } more_data: /* If msg is larger than MAX_SKB_FRAGS we can send multiple * scatterlists per msg. However BPF decisions apply to the * entire msg. */ if (psock->eval == __SK_NONE) - psock->eval = smap_do_tx_msg(sk, psock, &md); + psock->eval = smap_do_tx_msg(sk, psock, m); + + if (m->cork_bytes && + m->cork_bytes > psock->sg_size && !enospc) { + psock->cork_bytes = m->cork_bytes - psock->sg_size; + if (!psock->cork) { + psock->cork = kcalloc(1, + sizeof(struct sk_msg_buff), + GFP_ATOMIC | __GFP_NOWARN); + + if (!psock->cork) { + err = -ENOMEM; + goto out_err; + } + } + memcpy(psock->cork, m, sizeof(*m)); + goto out_cork; + } send = psock->sg_size; if (psock->apply_bytes && psock->apply_bytes < send) @@ -552,9 +648,9 @@ static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) switch (psock->eval) { case __SK_PASS: - err = bpf_tcp_push(sk, send, &md, flags, true); + err = bpf_tcp_push(sk, send, m, flags, true); if (unlikely(err)) { - copied -= free_start_sg(sk, &md); + copied -= free_start_sg(sk, m); goto out_err; } @@ -576,13 +672,23 @@ static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) psock->apply_bytes -= send; } - return_mem_sg(sk, send, &md); + if (psock->cork) { + cork = true; + psock->cork = NULL; + } + + return_mem_sg(sk, send, m); release_sock(sk); err = bpf_tcp_sendmsg_do_redirect(redir, send, - &md, flags); + m, flags); lock_sock(sk); + if (cork) { + free_start_sg(sk, m); + kfree(m); + m = NULL; + } if (unlikely(err)) { copied -= err; goto out_redir; @@ -592,21 +698,23 @@ static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) break; case __SK_DROP: default: - copied -= free_start_sg(sk, &md); - + free_bytes_sg(sk, send, m); if (psock->apply_bytes) { if (psock->apply_bytes < send) psock->apply_bytes = 0; else psock->apply_bytes -= send; } - psock->sg_size -= copied; + copied -= send; + psock->sg_size -= send; err = -EACCES; break; } bpf_md_init(psock); - if (sg[md.sg_start].page_link && sg[md.sg_start].length) + if (m && + m->sg_data[m->sg_start].page_link && + m->sg_data[m->sg_start].length) goto more_data; continue; wait_for_sndbuf: @@ -623,6 +731,47 @@ static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) release_sock(sk); smap_release_sock(psock, sk); return copied ? copied : err; +out_cork: + release_sock(sk); + smap_release_sock(psock, sk); + return copied; +} + +static int bpf_tcp_sendpage_sg_locked(struct sock *sk, + struct sk_msg_buff *m, + int send, + int flags) +{ + int copied = 0; + + do { + struct scatterlist *sg = &m->sg_data[m->sg_start]; + struct page *p = sg_page(sg); + int off = sg->offset; + int len = sg->length; + int err; + + if (len > send) + len = send; + + err = tcp_sendpage_locked(sk, p, off, len, flags); + if (err < 0) + break; + + sg->length -= len; + sg->offset += len; + copied += len; + send -= len; + if (!sg->length) { + sg->page_link = 0; + put_page(p); + m->sg_start++; + if (m->sg_start == MAX_SKB_FRAGS) + m->sg_start = 0; + } + } while (send && m->sg_start != m->sg_end); + + return copied; } static int bpf_tcp_sendpage_do_redirect(struct sock *sk, @@ -644,7 +793,10 @@ static int bpf_tcp_sendpage_do_redirect(struct sock *sk, rcu_read_unlock(); lock_sock(sk); - rc = tcp_sendpage_locked(sk, page, offset, size, flags); + if (md) + rc = bpf_tcp_sendpage_sg_locked(sk, md, size, flags); + else + rc = tcp_sendpage_locked(sk, page, offset, size, flags); release_sock(sk); smap_release_sock(psock, sk); @@ -657,10 +809,10 @@ static int bpf_tcp_sendpage_do_redirect(struct sock *sk, static int bpf_tcp_sendpage(struct sock *sk, struct page *page, int offset, size_t size, int flags) { - struct sk_msg_buff md = {0}; + struct sk_msg_buff md = {0}, *m = NULL; + bool cork = false, enospc = false; struct smap_psock *psock; - int send, total = 0, rc = __SK_NONE; - int orig_size = size; + int send, total = 0, rc; struct bpf_prog *prog; struct sock *redir; @@ -686,19 +838,90 @@ static int bpf_tcp_sendpage(struct sock *sk, struct page *page, preempt_enable(); lock_sock(sk); + + psock->sg_size += size; +do_cork: + if (psock->cork_bytes) { + struct scatterlist *sg; + + m = psock->cork; + sg = &m->sg_data[m->sg_end]; + sg_set_page(sg, page, send, offset); + get_page(page); + sk_mem_charge(sk, send); + m->sg_end++; + cork = true; + + if (send > psock->cork_bytes) + psock->cork_bytes = 0; + else + psock->cork_bytes -= send; + + if (m->sg_end == MAX_SKB_FRAGS) + m->sg_end = 0; + + if (m->sg_end == m->sg_start) { + enospc = true; + psock->cork_bytes = 0; + } + + if (!psock->cork_bytes) + psock->eval = __SK_NONE; + + if (!enospc && psock->cork_bytes) { + total = send; + goto out_err; + } + } more_sendpage_data: if (psock->eval == __SK_NONE) psock->eval = smap_do_tx_msg(sk, psock, &md); + if (md.cork_bytes && !enospc && md.cork_bytes > psock->sg_size) { + psock->cork_bytes = md.cork_bytes; + if (!psock->cork) { + psock->cork = kzalloc(sizeof(struct sk_msg_buff), + GFP_ATOMIC | __GFP_NOWARN); + + if (!psock->cork) { + psock->sg_size -= size; + total = -ENOMEM; + goto out_err; + } + } + + if (!cork) { + send = psock->sg_size; + goto do_cork; + } + } + + send = psock->sg_size; if (psock->apply_bytes && psock->apply_bytes < send) send = psock->apply_bytes; - switch (rc) { + switch (psock->eval) { case __SK_PASS: - rc = tcp_sendpage_locked(sk, page, offset, send, flags); - if (rc < 0) { - total = total ? : rc; - goto out_err; + /* When data is corked once cork bytes limit is reached + * we may send more data then the current sendfile call + * is expecting. To handle this we have to fixup return + * codes. However, if there is an error there is nothing + * to do but continue. We can not go back in time and + * give errors to data we have already consumed. + */ + if (m) { + rc = bpf_tcp_sendpage_sg_locked(sk, m, send, flags); + if (rc < 0) { + total = total ? : rc; + goto out_err; + } + sk_mem_uncharge(sk, rc); + } else { + rc = tcp_sendpage_locked(sk, page, offset, send, flags); + if (rc < 0) { + total = total ? : rc; + goto out_err; + } } if (psock->apply_bytes) { @@ -711,7 +934,7 @@ static int bpf_tcp_sendpage(struct sock *sk, struct page *page, total += rc; psock->sg_size -= rc; offset += rc; - size -= rc; + send -= rc; break; case __SK_REDIRECT: redir = psock->sk_redir; @@ -728,12 +951,30 @@ static int bpf_tcp_sendpage(struct sock *sk, struct page *page, /* sock lock dropped must not dereference psock below */ rc = bpf_tcp_sendpage_do_redirect(redir, page, offset, send, - flags, &md); + flags, m); lock_sock(sk); - if (rc > 0) { - offset += rc; - psock->sg_size -= rc; - send -= rc; + if (m) { + int free = free_start_sg(sk, m); + + if (rc > 0) { + sk_mem_uncharge(sk, rc); + free = rc + free; + } + psock->sg_size -= free; + psock->cork_bytes = 0; + send = 0; + if (psock->apply_bytes) { + if (psock->apply_bytes > free) + psock->apply_bytes -= free; + else + psock->apply_bytes = 0; + } + } else { + if (rc > 0) { + offset += rc; + psock->sg_size -= rc; + send -= rc; + } } if ((total && rc > 0) || (!total && rc < 0)) @@ -741,7 +982,8 @@ static int bpf_tcp_sendpage(struct sock *sk, struct page *page, break; case __SK_DROP: default: - return_mem_sg(sk, send, &md); + if (m) + free_bytes_sg(sk, send, m); if (psock->apply_bytes) { if (psock->apply_bytes > send) psock->apply_bytes -= send; @@ -749,18 +991,17 @@ static int bpf_tcp_sendpage(struct sock *sk, struct page *page, psock->apply_bytes -= 0; } psock->sg_size -= send; - size -= send; - total += send; - rc = -EACCES; + total = total ? : -EACCES; + goto out_err; } bpf_md_init(psock); - if (size) + if (psock->sg_size) goto more_sendpage_data; out_err: release_sock(sk); smap_release_sock(psock, sk); - return total <= orig_size ? total : orig_size; + return total <= size ? total : size; } static void bpf_tcp_msg_add(struct smap_psock *psock, @@ -1077,6 +1318,11 @@ static void smap_gc_work(struct work_struct *w) if (psock->bpf_tx_msg) bpf_prog_put(psock->bpf_tx_msg); + if (psock->cork) { + free_start_sg(psock->sock, psock->cork); + kfree(psock->cork); + } + list_for_each_entry_safe(e, tmp, &psock->maps, list) { list_del(&e->list); kfree(e); diff --git a/net/core/filter.c b/net/core/filter.c index df2a8f4..2c73af0 100644 --- a/net/core/filter.c +++ b/net/core/filter.c @@ -1942,6 +1942,20 @@ struct sock *do_msg_redirect_map(struct sk_msg_buff *msg) .arg2_type = ARG_ANYTHING, }; +BPF_CALL_2(bpf_msg_cork_bytes, struct sk_msg_buff *, msg, u64, bytes) +{ + msg->cork_bytes = bytes; + return 0; +} + +static const struct bpf_func_proto bpf_msg_cork_bytes_proto = { + .func = bpf_msg_cork_bytes, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, +}; + BPF_CALL_1(bpf_get_cgroup_classid, const struct sk_buff *, skb) { return task_get_classid(skb); @@ -3650,6 +3664,8 @@ static const struct bpf_func_proto *sk_msg_func_proto(enum bpf_func_id func_id) return &bpf_msg_redirect_map_proto; case BPF_FUNC_msg_apply_bytes: return &bpf_msg_apply_bytes_proto; + case BPF_FUNC_msg_cork_bytes: + return &bpf_msg_cork_bytes_proto; default: return bpf_base_func_proto(func_id); }