I'm not sure this is a good idea to begin with, refcount is right next to state spinlock which is taken for both tx and rx ops, plus this complicates debugging quite a bit.
Also expect UAF or refcount bugs, especially xfrm_states that won't get free'd. In case someone wants to experiment with this, I've placed it here: https://git.breakpoint.cc/cgit/fw/net-next.git/log/?h=xfrm_pcpu_gc_04 Summary of old/new locking model: Old locking rules: From control plane: - new states get refcount of 1 - calls xfrm_state_put() when done - calls xfrm_state_hold() to make sure state won't go away From data plane: - calls xfrm_state_hold_rcu, which will tell when refcount is already 0 (i.e, state is being destroyed) - calls xfrm_state_put when done From gc worker: - steals current gc list head - call synchronize_rcu - can free all entries in the list after this Last xfrm_state_put will observe refcount transition to 0, this will place xfrm_state on gc list, where its picked up (and free'd) by xfrm state gc worker. New locking rules: From control plane: - (NEW) must *explicitly* call xfrm_state_delete() when state should be removed - otherwise same as above From data plane: - calls xfrm_state_hold() to get ref, xfrm_state_put when done From gc worker: - steals current gc list head - call synchronize_rcu - entries on list MAY STILL BE IN USE, so * must check refcount for each (expensive, as it summing pcpu refcounts for each state) * needs to place entries that it can't free yet back on gc list and resched itself XXX: - survives minimal testing (netns + veth + esp tunnel mode) - refcount for newly allocated states should probably be changed to 0, and only do 0 -> 1 increment right before state is published (added to state hashes). - probably misses several error handling spots that now would have to call xfrm_state_delete() rather than xfrm_state_put(). Signed-off-by: Florian Westphal <f...@strlen.de> --- include/net/xfrm.h | 20 ++-------- net/key/af_key.c | 4 +- net/xfrm/xfrm_state.c | 90 ++++++++++++++++++++++++------------------- net/xfrm/xfrm_user.c | 2 +- 4 files changed, 57 insertions(+), 59 deletions(-) diff --git a/include/net/xfrm.h b/include/net/xfrm.h index 728d4e7b82c0..f9be94a3a029 100644 --- a/include/net/xfrm.h +++ b/include/net/xfrm.h @@ -153,7 +153,7 @@ struct xfrm_state { struct hlist_node bysrc; struct hlist_node byspi; - refcount_t refcnt; + int __percpu *pcpu_refcnt; spinlock_t lock; struct xfrm_id id; @@ -790,26 +790,14 @@ static inline void xfrm_pols_put(struct xfrm_policy **pols, int npols) void __xfrm_state_destroy(struct xfrm_state *, bool); -static inline void __xfrm_state_put(struct xfrm_state *x) +static inline void xfrm_state_hold(struct xfrm_state *x) { - refcount_dec(&x->refcnt); + this_cpu_inc(*x->pcpu_refcnt); } static inline void xfrm_state_put(struct xfrm_state *x) { - if (refcount_dec_and_test(&x->refcnt)) - __xfrm_state_destroy(x, false); -} - -static inline void xfrm_state_put_sync(struct xfrm_state *x) -{ - if (refcount_dec_and_test(&x->refcnt)) - __xfrm_state_destroy(x, true); -} - -static inline void xfrm_state_hold(struct xfrm_state *x) -{ - refcount_inc(&x->refcnt); + this_cpu_dec(*x->pcpu_refcnt); } static inline bool addr_match(const void *token1, const void *token2, diff --git a/net/key/af_key.c b/net/key/af_key.c index 5651c29cb5bd..21d4776b5431 100644 --- a/net/key/af_key.c +++ b/net/key/af_key.c @@ -1303,7 +1303,7 @@ static struct xfrm_state * pfkey_msg2xfrm_state(struct net *net, out: x->km.state = XFRM_STATE_DEAD; - xfrm_state_put(x); + xfrm_state_delete(x); return ERR_PTR(err); } @@ -1525,7 +1525,7 @@ static int pfkey_add(struct sock *sk, struct sk_buff *skb, const struct sadb_msg if (err < 0) { x->km.state = XFRM_STATE_DEAD; - __xfrm_state_put(x); + xfrm_state_delete(x); goto out; } diff --git a/net/xfrm/xfrm_state.c b/net/xfrm/xfrm_state.c index 8951c09ae9e9..4614df21ff5c 100644 --- a/net/xfrm/xfrm_state.c +++ b/net/xfrm/xfrm_state.c @@ -31,6 +31,8 @@ #define xfrm_state_deref_prot(table, net) \ rcu_dereference_protected((table), lockdep_is_held(&(net)->xfrm.xfrm_state_lock)) +#define XFRM_GC_WAIT_JIFFIES 2 + static void xfrm_state_gc_task(struct work_struct *work); /* Each xfrm_state may be linked to two tables: @@ -44,14 +46,9 @@ static unsigned int xfrm_state_hashmax __read_mostly = 1 * 1024 * 1024; static __read_mostly seqcount_t xfrm_state_hash_generation = SEQCNT_ZERO(xfrm_state_hash_generation); static struct kmem_cache *xfrm_state_cache __ro_after_init; -static DECLARE_WORK(xfrm_state_gc_work, xfrm_state_gc_task); +static DECLARE_DELAYED_WORK(xfrm_state_gc_work, xfrm_state_gc_task); static HLIST_HEAD(xfrm_state_gc_list); -static inline bool xfrm_state_hold_rcu(struct xfrm_state __rcu *x) -{ - return refcount_inc_not_zero(&x->refcnt); -} - static inline unsigned int xfrm_dst_hash(struct net *net, const xfrm_address_t *daddr, const xfrm_address_t *saddr, @@ -472,6 +469,7 @@ static const struct xfrm_mode *xfrm_get_mode(unsigned int encap, int family) void xfrm_state_free(struct xfrm_state *x) { + free_percpu(x->pcpu_refcnt); kmem_cache_free(xfrm_state_cache, x); } EXPORT_SYMBOL(xfrm_state_free); @@ -499,6 +497,15 @@ static void ___xfrm_state_destroy(struct xfrm_state *x) xfrm_state_free(x); } +static int xfrm_state_refcnt_read(const struct xfrm_state *x) +{ + int i, refcnt = 0; + + for_each_possible_cpu(i) + refcnt += *per_cpu_ptr(x->pcpu_refcnt, i); + return refcnt; +} + static void xfrm_state_gc_task(struct work_struct *work) { struct xfrm_state *x; @@ -511,8 +518,22 @@ static void xfrm_state_gc_task(struct work_struct *work) synchronize_rcu(); - hlist_for_each_entry_safe(x, tmp, &gc_list, gclist) + hlist_for_each_entry_safe(x, tmp, &gc_list, gclist) { + if (xfrm_state_refcnt_read(x)) + continue; + + hlist_del(&x->gclist); ___xfrm_state_destroy(x); + } + + if (!hlist_empty(&gc_list)) { + spin_lock_bh(&xfrm_state_gc_lock); + hlist_for_each_entry_safe(x, tmp, &gc_list, gclist) + hlist_add_head(&x->gclist, &xfrm_state_gc_list); + spin_unlock_bh(&xfrm_state_gc_lock); + + schedule_delayed_work(&xfrm_state_gc_work, XFRM_GC_WAIT_JIFFIES); + } } static enum hrtimer_restart xfrm_timer_handler(struct hrtimer *me) @@ -611,8 +632,14 @@ struct xfrm_state *xfrm_state_alloc(struct net *net) x = kmem_cache_alloc(xfrm_state_cache, GFP_ATOMIC | __GFP_ZERO); if (x) { + x->pcpu_refcnt = alloc_percpu(int); + if (!x->pcpu_refcnt) { + kmem_cache_free(xfrm_state_cache, x); + return NULL; + } + write_pnet(&x->xs_net, net); - refcount_set(&x->refcnt, 1); + xfrm_state_hold(x); atomic_set(&x->tunnel_users, 0); INIT_LIST_HEAD(&x->km.all); INIT_HLIST_NODE(&x->bydst); @@ -634,22 +661,6 @@ struct xfrm_state *xfrm_state_alloc(struct net *net) } EXPORT_SYMBOL(xfrm_state_alloc); -void __xfrm_state_destroy(struct xfrm_state *x, bool sync) -{ - WARN_ON(x->km.state != XFRM_STATE_DEAD); - - if (sync) { - synchronize_rcu(); - ___xfrm_state_destroy(x); - } else { - spin_lock_bh(&xfrm_state_gc_lock); - hlist_add_head(&x->gclist, &xfrm_state_gc_list); - spin_unlock_bh(&xfrm_state_gc_lock); - schedule_work(&xfrm_state_gc_work); - } -} -EXPORT_SYMBOL(__xfrm_state_destroy); - int __xfrm_state_delete(struct xfrm_state *x) { struct net *net = xs_net(x); @@ -673,6 +684,12 @@ int __xfrm_state_delete(struct xfrm_state *x) * is what we are dropping here. */ xfrm_state_put(x); + + spin_lock_bh(&xfrm_state_gc_lock); + hlist_add_head(&x->gclist, &xfrm_state_gc_list); + spin_unlock_bh(&xfrm_state_gc_lock); + + schedule_delayed_work(&xfrm_state_gc_work, 0); err = 0; } @@ -771,10 +788,7 @@ int xfrm_state_flush(struct net *net, u8 proto, bool task_valid, bool sync) err = xfrm_state_delete(x); xfrm_audit_state_delete(x, err ? 0 : 1, task_valid); - if (sync) - xfrm_state_put_sync(x); - else - xfrm_state_put(x); + xfrm_state_put(x); if (!err) cnt++; @@ -937,8 +951,8 @@ static struct xfrm_state *__xfrm_state_lookup(struct net *net, u32 mark, if ((mark & x->mark.m) != x->mark.v) continue; - if (!xfrm_state_hold_rcu(x)) - continue; + + xfrm_state_hold(x); return x; } @@ -962,8 +976,7 @@ static struct xfrm_state *__xfrm_state_lookup_byaddr(struct net *net, u32 mark, if ((mark & x->mark.m) != x->mark.v) continue; - if (!xfrm_state_hold_rcu(x)) - continue; + xfrm_state_hold(x); return x; } @@ -1151,10 +1164,7 @@ xfrm_state_find(const xfrm_address_t *daddr, const xfrm_address_t *saddr, } out: if (x) { - if (!xfrm_state_hold_rcu(x)) { - *err = -EAGAIN; - x = NULL; - } + xfrm_state_hold(x); } else { *err = acquire_in_progress ? -EAGAIN : error; } @@ -1681,7 +1691,7 @@ int xfrm_state_update(struct xfrm_state *x) err = 0; x->km.state = XFRM_STATE_DEAD; - __xfrm_state_put(x); + xfrm_state_put(x); } fail: @@ -2372,7 +2382,7 @@ struct xfrm_state_afinfo *xfrm_state_get_afinfo(unsigned int family) void xfrm_flush_gc(void) { - flush_work(&xfrm_state_gc_work); + flush_delayed_work(&xfrm_state_gc_work); } EXPORT_SYMBOL(xfrm_flush_gc); @@ -2385,7 +2395,7 @@ void xfrm_state_delete_tunnel(struct xfrm_state *x) if (atomic_read(&t->tunnel_users) == 2) xfrm_state_delete(t); atomic_dec(&t->tunnel_users); - xfrm_state_put_sync(t); + xfrm_state_put(t); x->tunnel = NULL; } } @@ -2531,7 +2541,7 @@ void xfrm_state_fini(struct net *net) unsigned int sz; flush_work(&net->xfrm.state_hash_work); - flush_work(&xfrm_state_gc_work); + flush_delayed_work(&xfrm_state_gc_work); xfrm_state_flush(net, IPSEC_PROTO_ANY, false, true); WARN_ON(!list_empty(&net->xfrm.state_all)); diff --git a/net/xfrm/xfrm_user.c b/net/xfrm/xfrm_user.c index a131f9ff979e..4587f1342af0 100644 --- a/net/xfrm/xfrm_user.c +++ b/net/xfrm/xfrm_user.c @@ -675,7 +675,7 @@ static int xfrm_add_sa(struct sk_buff *skb, struct nlmsghdr *nlh, if (err < 0) { x->km.state = XFRM_STATE_DEAD; xfrm_dev_state_delete(x); - __xfrm_state_put(x); + xfrm_state_delete(x); goto out; } -- 2.21.0