RDS module sits on top of TCP (rds_tcp) and IB (rds_rdma), so messages arrive in form of skb (over TCP) and scatterlist (over IB/RDMA). However, because socket filter only deal with skb (e.g. struct skb as bpf context) we can only use socket filter for rds_tcp and not for rds_rdma.
Considering one filtering solution for RDS, it seems that the common denominator between sk_buff and scatterlist is scatterlist. Therefore, this patch converts skb to sgvec and invoke sg_filter_run for rds_tcp and simply invoke sg_filter_run for IB/rds_rdma. Signed-off-by: Tushar Dave <tushar.n.d...@oracle.com> Reviewed-by: Sowmini Varadhan <sowmini.varad...@oracle.com> --- net/rds/ib.c | 1 + net/rds/ib.h | 1 + net/rds/ib_recv.c | 12 ++++++ net/rds/rds.h | 1 + net/rds/recv.c | 12 ++++++ net/rds/tcp.c | 1 + net/rds/tcp.h | 2 + net/rds/tcp_recv.c | 108 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 8 files changed, 137 insertions(+), 1 deletion(-) diff --git a/net/rds/ib.c b/net/rds/ib.c index eba75c1..6c40652 100644 --- a/net/rds/ib.c +++ b/net/rds/ib.c @@ -527,6 +527,7 @@ struct rds_transport rds_ib_transport = { .conn_path_shutdown = rds_ib_conn_path_shutdown, .inc_copy_to_user = rds_ib_inc_copy_to_user, .inc_free = rds_ib_inc_free, + .inc_to_sg_get = rds_ib_inc_to_sg_get, .cm_initiate_connect = rds_ib_cm_initiate_connect, .cm_handle_connect = rds_ib_cm_handle_connect, .cm_connect_complete = rds_ib_cm_connect_complete, diff --git a/net/rds/ib.h b/net/rds/ib.h index 73427ff..0a12b41 100644 --- a/net/rds/ib.h +++ b/net/rds/ib.h @@ -404,6 +404,7 @@ int rds_ib_update_ipaddr(struct rds_ib_device *rds_ibdev, void rds_ib_recv_free_caches(struct rds_ib_connection *ic); void rds_ib_recv_refill(struct rds_connection *conn, int prefill, gfp_t gfp); void rds_ib_inc_free(struct rds_incoming *inc); +int rds_ib_inc_to_sg_get(struct rds_incoming *inc, struct scatterlist **sg); int rds_ib_inc_copy_to_user(struct rds_incoming *inc, struct iov_iter *to); void rds_ib_recv_cqe_handler(struct rds_ib_connection *ic, struct ib_wc *wc, struct rds_ib_ack_state *state); diff --git a/net/rds/ib_recv.c b/net/rds/ib_recv.c index 2f16146..0054c7c 100644 --- a/net/rds/ib_recv.c +++ b/net/rds/ib_recv.c @@ -219,6 +219,18 @@ void rds_ib_inc_free(struct rds_incoming *inc) rds_ib_recv_cache_put(&ibinc->ii_cache_entry, &ic->i_cache_incs); } +int rds_ib_inc_to_sg_get(struct rds_incoming *inc, struct scatterlist **sg) +{ + struct rds_ib_incoming *ibinc; + struct rds_page_frag *frag; + + ibinc = container_of(inc, struct rds_ib_incoming, ii_inc); + frag = list_entry(ibinc->ii_frags.next, struct rds_page_frag, f_item); + *sg = &frag->f_sg; + + return 0; +} + static void rds_ib_recv_clear_one(struct rds_ib_connection *ic, struct rds_ib_recv_work *recv) { diff --git a/net/rds/rds.h b/net/rds/rds.h index 6bfaf05..9f3e4df 100644 --- a/net/rds/rds.h +++ b/net/rds/rds.h @@ -542,6 +542,7 @@ struct rds_transport { int (*recv_path)(struct rds_conn_path *cp); int (*inc_copy_to_user)(struct rds_incoming *inc, struct iov_iter *to); void (*inc_free)(struct rds_incoming *inc); + int (*inc_to_sg_get)(struct rds_incoming *inc, struct scatterlist **sg); int (*cm_handle_connect)(struct rdma_cm_id *cm_id, struct rdma_cm_event *event, bool isv6); diff --git a/net/rds/recv.c b/net/rds/recv.c index 1271965..424042e 100644 --- a/net/rds/recv.c +++ b/net/rds/recv.c @@ -290,6 +290,8 @@ void rds_recv_incoming(struct rds_connection *conn, struct in6_addr *saddr, struct sock *sk; unsigned long flags; struct rds_conn_path *cp; + struct sk_filter *filter; + int result = __SOCKSG_PASS; inc->i_conn = conn; inc->i_rx_jiffies = jiffies; @@ -374,6 +376,16 @@ void rds_recv_incoming(struct rds_connection *conn, struct in6_addr *saddr, /* We can be racing with rds_release() which marks the socket dead. */ sk = rds_rs_to_sk(rs); + rcu_read_lock(); + filter = rcu_dereference(sk->sk_filter); + if (filter && conn->c_trans->inc_to_sg_get) { + struct scatterlist *sg = NULL; + + if (conn->c_trans->inc_to_sg_get(inc, &sg) == 0) + result = sg_filter_run(sk, sg); + } + rcu_read_unlock(); + /* serialize with rds_release -> sock_orphan */ write_lock_irqsave(&rs->rs_recv_lock, flags); if (!sock_flag(sk, SOCK_DEAD)) { diff --git a/net/rds/tcp.c b/net/rds/tcp.c index b9bbcf3..b0683e6 100644 --- a/net/rds/tcp.c +++ b/net/rds/tcp.c @@ -464,6 +464,7 @@ struct rds_transport rds_tcp_transport = { .conn_path_shutdown = rds_tcp_conn_path_shutdown, .inc_copy_to_user = rds_tcp_inc_copy_to_user, .inc_free = rds_tcp_inc_free, + .inc_to_sg_get = rds_tcp_inc_to_sg_get, .stats_info_copy = rds_tcp_stats_info_copy, .exit = rds_tcp_exit, .t_owner = THIS_MODULE, diff --git a/net/rds/tcp.h b/net/rds/tcp.h index 3c69361..e4ea16e 100644 --- a/net/rds/tcp.h +++ b/net/rds/tcp.h @@ -7,6 +7,7 @@ struct rds_tcp_incoming { struct rds_incoming ti_inc; struct sk_buff_head ti_skb_list; + struct scatterlist *sg; }; struct rds_tcp_connection { @@ -82,6 +83,7 @@ void rds_tcp_restore_callbacks(struct socket *sock, int rds_tcp_recv_path(struct rds_conn_path *cp); void rds_tcp_inc_free(struct rds_incoming *inc); int rds_tcp_inc_copy_to_user(struct rds_incoming *inc, struct iov_iter *to); +int rds_tcp_inc_to_sg_get(struct rds_incoming *inc, struct scatterlist **sg); /* tcp_send.c */ void rds_tcp_xmit_path_prepare(struct rds_conn_path *cp); diff --git a/net/rds/tcp_recv.c b/net/rds/tcp_recv.c index 42c5ff1..22d84f2 100644 --- a/net/rds/tcp_recv.c +++ b/net/rds/tcp_recv.c @@ -50,14 +50,113 @@ static void rds_tcp_inc_purge(struct rds_incoming *inc) void rds_tcp_inc_free(struct rds_incoming *inc) { struct rds_tcp_incoming *tinc; + int i; + tinc = container_of(inc, struct rds_tcp_incoming, ti_inc); rds_tcp_inc_purge(inc); + + if (tinc->sg) { + for (i = 0; i < sg_nents(tinc->sg); i++) { + struct page *page; + + page = sg_page(&tinc->sg[i]); + put_page(page); + } + kfree(tinc->sg); + } + rdsdebug("freeing tinc %p inc %p\n", tinc, inc); kmem_cache_free(rds_tcp_incoming_slab, tinc); } +#define MAX_SG MAX_SKB_FRAGS +int rds_tcp_inc_to_sg_get(struct rds_incoming *inc, struct scatterlist **sg) +{ + struct rds_tcp_incoming *tinc; + struct sk_buff *skb; + int num_sg = 0; + int i; + + tinc = container_of(inc, struct rds_tcp_incoming, ti_inc); + + /* For now we are assuming that the max sg elements we need is MAX_SG. + * To determine actual number of sg elements we need to traverse the + * skb queue e.g. + * + * skb_queue_walk(&tinc->ti_skb_list, skb) { + * num_sg += skb_shinfo(skb)->nr_frags + 1; + * } + */ + tinc->sg = kzalloc(sizeof(*tinc->sg) * MAX_SG, GFP_KERNEL); + if (!tinc->sg) + return -ENOMEM; + + sg_init_table(tinc->sg, MAX_SG); + skb_queue_walk(&tinc->ti_skb_list, skb) { + num_sg += skb_to_sgvec_nomark(skb, &tinc->sg[num_sg], 0, + skb->len); + } + + /* packet can have zero length */ + if (num_sg <= 0) { + kfree(tinc->sg); + tinc->sg = NULL; + return -ENODATA; + } + + sg_mark_end(&tinc->sg[num_sg - 1]); + *sg = tinc->sg; + + for (i = 0; i < num_sg; i++) + get_page(sg_page(&tinc->sg[i])); + + return 0; +} + +static int rds_tcp_inc_copy_sg_to_user(struct rds_incoming *inc, + struct iov_iter *to) +{ + struct rds_tcp_incoming *tinc; + struct scatterlist *sg; + unsigned long copied = 0; + unsigned long len; + u8 i = 0; + + tinc = container_of(inc, struct rds_tcp_incoming, ti_inc); + len = be32_to_cpu(inc->i_hdr.h_len); + sg = tinc->sg; + + do { + struct page *page; + unsigned long n, copy, to_copy; + + sg = &tinc->sg[i]; + copy = sg->length; + page = sg_page(sg); + to_copy = iov_iter_count(to); + to_copy = min_t(unsigned long, to_copy, copy); + + n = copy_page_to_iter(page, sg->offset, to_copy, to); + if (n != copy) + return -EFAULT; + + rds_stats_add(s_copy_to_user, to_copy); + copied += to_copy; + sg->offset += to_copy; + sg->length -= to_copy; + + if (!sg->length) + i++; + + if (copied == len) + break; + } while (i != sg_nents(tinc->sg)); + return copied; +} /* - * this is pretty lame, but, whatever. + * This is pretty lame, but, whatever. + * Note: bpf filter can change RDS packet and if so then the modified packet is + * contained in the form of scatterlist, not skb. */ int rds_tcp_inc_copy_to_user(struct rds_incoming *inc, struct iov_iter *to) { @@ -70,6 +169,12 @@ int rds_tcp_inc_copy_to_user(struct rds_incoming *inc, struct iov_iter *to) tinc = container_of(inc, struct rds_tcp_incoming, ti_inc); + /* if tinc->sg is not NULL means bpf filter ran on packet and so packet + * now is in the form of scatterlist. + */ + if (tinc->sg) + return rds_tcp_inc_copy_sg_to_user(inc, to); + skb_queue_walk(&tinc->ti_skb_list, skb) { unsigned long to_copy, skb_off; for (skb_off = 0; skb_off < skb->len; skb_off += to_copy) { @@ -176,6 +281,7 @@ static int rds_tcp_data_recv(read_descriptor_t *desc, struct sk_buff *skb, desc->error = -ENOMEM; goto out; } + tinc->sg = NULL; tc->t_tinc = tinc; rdsdebug("alloced tinc %p\n", tinc); rds_inc_path_init(&tinc->ti_inc, cp, -- 1.8.3.1