An rds_connection can get added during netns deletion between lines 528 and 529 of
506 static void rds_tcp_kill_sock(struct net *net) : /* code to pull out all the rds_connections that should be destroyed */ : 528 spin_unlock_irq(&rds_tcp_conn_lock); 529 list_for_each_entry_safe(tc, _tc, &tmp_list, t_tcp_node) 530 rds_conn_destroy(tc->t_cpath->cp_conn); Such an rds_connection would miss out the rds_conn_destroy() loop (that cancels all pending work) and (if it was scheduled after netns deletion) could trigger the use-after-free. A similar race-window exists for the module unload path in rds_tcp_exit -> rds_tcp_destroy_conns To avoid the addition of new rds_connections during kill_sock or netns_delete, this patch introduces a per-netns flag, RTN_DELETE_PENDING, that will cause RDS connection creation to fail. RCU is used to make sure that we wait for the critical section of __rds_conn_create threads (that may have started before the setting of RTN_DELETE_PENDING) to complete before starting the connection destruction. Reported-by: syzbot+bbd8e9a06452cc480...@syzkaller.appspotmail.com Signed-off-by: Sowmini Varadhan <sowmini.varad...@oracle.com> --- net/rds/connection.c | 3 ++ net/rds/tcp.c | 82 ++++++++++++++++++++++++++++++++----------------- net/rds/tcp.h | 1 + 3 files changed, 57 insertions(+), 29 deletions(-) diff --git a/net/rds/connection.c b/net/rds/connection.c index b10c0ef..2ae539d 100644 --- a/net/rds/connection.c +++ b/net/rds/connection.c @@ -220,8 +220,10 @@ static void __rds_conn_path_init(struct rds_connection *conn, is_outgoing); conn->c_path[i].cp_index = i; } + rcu_read_lock(); ret = trans->conn_alloc(conn, gfp); if (ret) { + rcu_read_unlock(); kfree(conn->c_path); kmem_cache_free(rds_conn_slab, conn); conn = ERR_PTR(ret); @@ -283,6 +285,7 @@ static void __rds_conn_path_init(struct rds_connection *conn, } } spin_unlock_irqrestore(&rds_conn_lock, flags); + rcu_read_unlock(); out: return conn; diff --git a/net/rds/tcp.c b/net/rds/tcp.c index 9920d2f..2bdd3cc 100644 --- a/net/rds/tcp.c +++ b/net/rds/tcp.c @@ -274,14 +274,13 @@ static int rds_tcp_laddr_check(struct net *net, __be32 addr) static void rds_tcp_conn_free(void *arg) { struct rds_tcp_connection *tc = arg; - unsigned long flags; rdsdebug("freeing tc %p\n", tc); - spin_lock_irqsave(&rds_tcp_conn_lock, flags); + spin_lock_bh(&rds_tcp_conn_lock); if (!tc->t_tcp_node_detached) list_del(&tc->t_tcp_node); - spin_unlock_irqrestore(&rds_tcp_conn_lock, flags); + spin_unlock_bh(&rds_tcp_conn_lock); kmem_cache_free(rds_tcp_conn_slab, tc); } @@ -296,7 +295,7 @@ static int rds_tcp_conn_alloc(struct rds_connection *conn, gfp_t gfp) tc = kmem_cache_alloc(rds_tcp_conn_slab, gfp); if (!tc) { ret = -ENOMEM; - break; + goto fail; } mutex_init(&tc->t_conn_path_lock); tc->t_sock = NULL; @@ -306,14 +305,25 @@ static int rds_tcp_conn_alloc(struct rds_connection *conn, gfp_t gfp) conn->c_path[i].cp_transport_data = tc; tc->t_cpath = &conn->c_path[i]; + tc->t_tcp_node_detached = true; - spin_lock_irq(&rds_tcp_conn_lock); - tc->t_tcp_node_detached = false; - list_add_tail(&tc->t_tcp_node, &rds_tcp_conn_list); - spin_unlock_irq(&rds_tcp_conn_lock); rdsdebug("rds_conn_path [%d] tc %p\n", i, conn->c_path[i].cp_transport_data); } + spin_lock_bh(&rds_tcp_conn_lock); + if (rds_tcp_netns_delete_pending(rds_conn_net(conn))) { + rdsdebug("RTN_DELETE_PENDING\n"); + ret = -ENETDOWN; + spin_unlock_bh(&rds_tcp_conn_lock); + goto fail; + } + for (i = 0; i < RDS_MPATH_WORKERS; i++) { + tc = conn->c_path[i].cp_transport_data; + tc->t_tcp_node_detached = false; + list_add_tail(&tc->t_tcp_node, &rds_tcp_conn_list); + } + spin_unlock_bh(&rds_tcp_conn_lock); +fail: if (ret) { for (j = 0; j < i; j++) rds_tcp_conn_free(conn->c_path[j].cp_transport_data); @@ -332,23 +342,6 @@ static bool list_has_conn(struct list_head *list, struct rds_connection *conn) return false; } -static void rds_tcp_destroy_conns(void) -{ - struct rds_tcp_connection *tc, *_tc; - LIST_HEAD(tmp_list); - - /* avoid calling conn_destroy with irqs off */ - spin_lock_irq(&rds_tcp_conn_lock); - list_for_each_entry_safe(tc, _tc, &rds_tcp_conn_list, t_tcp_node) { - if (!list_has_conn(&tmp_list, tc->t_cpath->cp_conn)) - list_move_tail(&tc->t_tcp_node, &tmp_list); - } - spin_unlock_irq(&rds_tcp_conn_lock); - - list_for_each_entry_safe(tc, _tc, &tmp_list, t_tcp_node) - rds_conn_destroy(tc->t_cpath->cp_conn); -} - static void rds_tcp_exit(void); struct rds_transport rds_tcp_transport = { @@ -382,8 +375,30 @@ struct rds_tcp_net { struct ctl_table *ctl_table; int sndbuf_size; int rcvbuf_size; + unsigned long rtn_flags; +#define RTN_DELETE_PENDING 0 }; +static void rds_tcp_destroy_conns(void) +{ + struct rds_tcp_connection *tc, *_tc; + struct rds_tcp_net *rtn = net_generic(&init_net, rds_tcp_netid); + LIST_HEAD(tmp_list); + + /* avoid calling conn_destroy with irqs off */ + set_bit(RTN_DELETE_PENDING, &rtn->rtn_flags); + synchronize_rcu(); + spin_lock_bh(&rds_tcp_conn_lock); + list_for_each_entry_safe(tc, _tc, &rds_tcp_conn_list, t_tcp_node) { + if (!list_has_conn(&tmp_list, tc->t_cpath->cp_conn)) + list_move_tail(&tc->t_tcp_node, &tmp_list); + } + spin_unlock_bh(&rds_tcp_conn_lock); + + list_for_each_entry_safe(tc, _tc, &tmp_list, t_tcp_node) + rds_conn_destroy(tc->t_cpath->cp_conn); +} + /* All module specific customizations to the RDS-TCP socket should be done in * rds_tcp_tune() and applied after socket creation. */ @@ -504,6 +519,13 @@ static void __net_exit rds_tcp_exit_net(struct net *net) .size = sizeof(struct rds_tcp_net), }; +bool rds_tcp_netns_delete_pending(struct net *net) +{ + struct rds_tcp_net *rtn = net_generic(net, rds_tcp_netid); + + return test_bit(RTN_DELETE_PENDING, &rtn->rtn_flags); +} + static void rds_tcp_kill_sock(struct net *net) { struct rds_tcp_connection *tc, *_tc; @@ -513,7 +535,9 @@ static void rds_tcp_kill_sock(struct net *net) rtn->rds_tcp_listen_sock = NULL; rds_tcp_listen_stop(lsock, &rtn->rds_tcp_accept_w); - spin_lock_irq(&rds_tcp_conn_lock); + set_bit(RTN_DELETE_PENDING, &rtn->rtn_flags); + synchronize_rcu(); + spin_lock_bh(&rds_tcp_conn_lock); list_for_each_entry_safe(tc, _tc, &rds_tcp_conn_list, t_tcp_node) { struct net *c_net = read_pnet(&tc->t_cpath->cp_conn->c_net); @@ -526,7 +550,7 @@ static void rds_tcp_kill_sock(struct net *net) tc->t_tcp_node_detached = true; } } - spin_unlock_irq(&rds_tcp_conn_lock); + spin_unlock_bh(&rds_tcp_conn_lock); list_for_each_entry_safe(tc, _tc, &tmp_list, t_tcp_node) rds_conn_destroy(tc->t_cpath->cp_conn); } @@ -574,7 +598,7 @@ static void rds_tcp_sysctl_reset(struct net *net) { struct rds_tcp_connection *tc, *_tc; - spin_lock_irq(&rds_tcp_conn_lock); + spin_lock_bh(&rds_tcp_conn_lock); list_for_each_entry_safe(tc, _tc, &rds_tcp_conn_list, t_tcp_node) { struct net *c_net = read_pnet(&tc->t_cpath->cp_conn->c_net); @@ -584,7 +608,7 @@ static void rds_tcp_sysctl_reset(struct net *net) /* reconnect with new parameters */ rds_conn_path_drop(tc->t_cpath, false); } - spin_unlock_irq(&rds_tcp_conn_lock); + spin_unlock_bh(&rds_tcp_conn_lock); } static int rds_tcp_skbuf_handler(struct ctl_table *ctl, int write, diff --git a/net/rds/tcp.h b/net/rds/tcp.h index c6fa080..b07dbd7 100644 --- a/net/rds/tcp.h +++ b/net/rds/tcp.h @@ -60,6 +60,7 @@ void rds_tcp_restore_callbacks(struct socket *sock, u64 rds_tcp_map_seq(struct rds_tcp_connection *tc, u32 seq); extern struct rds_transport rds_tcp_transport; void rds_tcp_accept_work(struct sock *sk); +bool rds_tcp_netns_delete_pending(struct net *net); /* tcp_connect.c */ int rds_tcp_conn_path_connect(struct rds_conn_path *cp); -- 1.7.1