I'm not sure this is a good idea to begin with, refcount
is right next to state spinlock which is taken for both tx and rx ops,
plus this complicates debugging quite a bit.

Also expect UAF or refcount bugs, especially xfrm_states that won't get free'd.

In case someone wants to experiment with this, I've placed it here:
https://git.breakpoint.cc/cgit/fw/net-next.git/log/?h=xfrm_pcpu_gc_04

Summary of old/new locking model:

Old locking rules:
 From control plane:
   - new states get refcount of 1
   - calls xfrm_state_put() when done
   - calls xfrm_state_hold() to make sure
     state won't go away

 From data plane:
  - calls xfrm_state_hold_rcu, which will tell when
    refcount is already 0 (i.e, state is being destroyed)
  - calls xfrm_state_put when done

 From gc worker:
  - steals current gc list head
  - call synchronize_rcu
  - can free all entries in the list after this

Last xfrm_state_put will observe refcount transition to 0,
this will place xfrm_state on gc list, where its picked
up (and free'd) by xfrm state gc worker.

New locking rules:
 From control plane:
   - (NEW) must *explicitly* call xfrm_state_delete()
     when state should be removed
   - otherwise same as above
 From data plane:
   - calls xfrm_state_hold() to get ref,
     xfrm_state_put when done
 From gc worker:
  - steals current gc list head
  - call synchronize_rcu
  - entries on list MAY STILL BE IN USE, so
    * must check refcount for each (expensive, as it
      summing pcpu refcounts for each state)
    * needs to place entries that it can't free yet
      back on gc list and resched itself

XXX:
- survives minimal testing (netns + veth + esp tunnel mode)
- refcount for newly allocated states should probably be
  changed to 0, and only do 0 -> 1 increment right before
  state is published (added to state hashes).
- probably misses several error handling spots that now would
  have to call xfrm_state_delete() rather than xfrm_state_put().

Signed-off-by: Florian Westphal <f...@strlen.de>
---
 include/net/xfrm.h    | 20 ++--------
 net/key/af_key.c      |  4 +-
 net/xfrm/xfrm_state.c | 90 ++++++++++++++++++++++++-------------------
 net/xfrm/xfrm_user.c  |  2 +-
 4 files changed, 57 insertions(+), 59 deletions(-)

diff --git a/include/net/xfrm.h b/include/net/xfrm.h
index 728d4e7b82c0..f9be94a3a029 100644
--- a/include/net/xfrm.h
+++ b/include/net/xfrm.h
@@ -153,7 +153,7 @@ struct xfrm_state {
        struct hlist_node       bysrc;
        struct hlist_node       byspi;
 
-       refcount_t              refcnt;
+       int __percpu            *pcpu_refcnt;
        spinlock_t              lock;
 
        struct xfrm_id          id;
@@ -790,26 +790,14 @@ static inline void xfrm_pols_put(struct xfrm_policy 
**pols, int npols)
 
 void __xfrm_state_destroy(struct xfrm_state *, bool);
 
-static inline void __xfrm_state_put(struct xfrm_state *x)
+static inline void xfrm_state_hold(struct xfrm_state *x)
 {
-       refcount_dec(&x->refcnt);
+       this_cpu_inc(*x->pcpu_refcnt);
 }
 
 static inline void xfrm_state_put(struct xfrm_state *x)
 {
-       if (refcount_dec_and_test(&x->refcnt))
-               __xfrm_state_destroy(x, false);
-}
-
-static inline void xfrm_state_put_sync(struct xfrm_state *x)
-{
-       if (refcount_dec_and_test(&x->refcnt))
-               __xfrm_state_destroy(x, true);
-}
-
-static inline void xfrm_state_hold(struct xfrm_state *x)
-{
-       refcount_inc(&x->refcnt);
+       this_cpu_dec(*x->pcpu_refcnt);
 }
 
 static inline bool addr_match(const void *token1, const void *token2,
diff --git a/net/key/af_key.c b/net/key/af_key.c
index 5651c29cb5bd..21d4776b5431 100644
--- a/net/key/af_key.c
+++ b/net/key/af_key.c
@@ -1303,7 +1303,7 @@ static struct xfrm_state * pfkey_msg2xfrm_state(struct 
net *net,
 
 out:
        x->km.state = XFRM_STATE_DEAD;
-       xfrm_state_put(x);
+       xfrm_state_delete(x);
        return ERR_PTR(err);
 }
 
@@ -1525,7 +1525,7 @@ static int pfkey_add(struct sock *sk, struct sk_buff 
*skb, const struct sadb_msg
 
        if (err < 0) {
                x->km.state = XFRM_STATE_DEAD;
-               __xfrm_state_put(x);
+               xfrm_state_delete(x);
                goto out;
        }
 
diff --git a/net/xfrm/xfrm_state.c b/net/xfrm/xfrm_state.c
index 8951c09ae9e9..4614df21ff5c 100644
--- a/net/xfrm/xfrm_state.c
+++ b/net/xfrm/xfrm_state.c
@@ -31,6 +31,8 @@
 #define xfrm_state_deref_prot(table, net) \
        rcu_dereference_protected((table), 
lockdep_is_held(&(net)->xfrm.xfrm_state_lock))
 
+#define XFRM_GC_WAIT_JIFFIES   2
+
 static void xfrm_state_gc_task(struct work_struct *work);
 
 /* Each xfrm_state may be linked to two tables:
@@ -44,14 +46,9 @@ static unsigned int xfrm_state_hashmax __read_mostly = 1 * 
1024 * 1024;
 static __read_mostly seqcount_t xfrm_state_hash_generation = 
SEQCNT_ZERO(xfrm_state_hash_generation);
 static struct kmem_cache *xfrm_state_cache __ro_after_init;
 
-static DECLARE_WORK(xfrm_state_gc_work, xfrm_state_gc_task);
+static DECLARE_DELAYED_WORK(xfrm_state_gc_work, xfrm_state_gc_task);
 static HLIST_HEAD(xfrm_state_gc_list);
 
-static inline bool xfrm_state_hold_rcu(struct xfrm_state __rcu *x)
-{
-       return refcount_inc_not_zero(&x->refcnt);
-}
-
 static inline unsigned int xfrm_dst_hash(struct net *net,
                                         const xfrm_address_t *daddr,
                                         const xfrm_address_t *saddr,
@@ -472,6 +469,7 @@ static const struct xfrm_mode *xfrm_get_mode(unsigned int 
encap, int family)
 
 void xfrm_state_free(struct xfrm_state *x)
 {
+       free_percpu(x->pcpu_refcnt);
        kmem_cache_free(xfrm_state_cache, x);
 }
 EXPORT_SYMBOL(xfrm_state_free);
@@ -499,6 +497,15 @@ static void ___xfrm_state_destroy(struct xfrm_state *x)
        xfrm_state_free(x);
 }
 
+static int xfrm_state_refcnt_read(const struct xfrm_state *x)
+{
+       int i, refcnt = 0;
+
+       for_each_possible_cpu(i)
+               refcnt += *per_cpu_ptr(x->pcpu_refcnt, i);
+       return refcnt;
+}
+
 static void xfrm_state_gc_task(struct work_struct *work)
 {
        struct xfrm_state *x;
@@ -511,8 +518,22 @@ static void xfrm_state_gc_task(struct work_struct *work)
 
        synchronize_rcu();
 
-       hlist_for_each_entry_safe(x, tmp, &gc_list, gclist)
+       hlist_for_each_entry_safe(x, tmp, &gc_list, gclist) {
+               if (xfrm_state_refcnt_read(x))
+                       continue;
+
+               hlist_del(&x->gclist);
                ___xfrm_state_destroy(x);
+       }
+
+       if (!hlist_empty(&gc_list)) {
+               spin_lock_bh(&xfrm_state_gc_lock);
+               hlist_for_each_entry_safe(x, tmp, &gc_list, gclist)
+                       hlist_add_head(&x->gclist, &xfrm_state_gc_list);
+               spin_unlock_bh(&xfrm_state_gc_lock);
+
+               schedule_delayed_work(&xfrm_state_gc_work, 
XFRM_GC_WAIT_JIFFIES);
+       }
 }
 
 static enum hrtimer_restart xfrm_timer_handler(struct hrtimer *me)
@@ -611,8 +632,14 @@ struct xfrm_state *xfrm_state_alloc(struct net *net)
        x = kmem_cache_alloc(xfrm_state_cache, GFP_ATOMIC | __GFP_ZERO);
 
        if (x) {
+               x->pcpu_refcnt = alloc_percpu(int);
+               if (!x->pcpu_refcnt) {
+                       kmem_cache_free(xfrm_state_cache, x);
+                       return NULL;
+               }
+
                write_pnet(&x->xs_net, net);
-               refcount_set(&x->refcnt, 1);
+               xfrm_state_hold(x);
                atomic_set(&x->tunnel_users, 0);
                INIT_LIST_HEAD(&x->km.all);
                INIT_HLIST_NODE(&x->bydst);
@@ -634,22 +661,6 @@ struct xfrm_state *xfrm_state_alloc(struct net *net)
 }
 EXPORT_SYMBOL(xfrm_state_alloc);
 
-void __xfrm_state_destroy(struct xfrm_state *x, bool sync)
-{
-       WARN_ON(x->km.state != XFRM_STATE_DEAD);
-
-       if (sync) {
-               synchronize_rcu();
-               ___xfrm_state_destroy(x);
-       } else {
-               spin_lock_bh(&xfrm_state_gc_lock);
-               hlist_add_head(&x->gclist, &xfrm_state_gc_list);
-               spin_unlock_bh(&xfrm_state_gc_lock);
-               schedule_work(&xfrm_state_gc_work);
-       }
-}
-EXPORT_SYMBOL(__xfrm_state_destroy);
-
 int __xfrm_state_delete(struct xfrm_state *x)
 {
        struct net *net = xs_net(x);
@@ -673,6 +684,12 @@ int __xfrm_state_delete(struct xfrm_state *x)
                 * is what we are dropping here.
                 */
                xfrm_state_put(x);
+
+               spin_lock_bh(&xfrm_state_gc_lock);
+               hlist_add_head(&x->gclist, &xfrm_state_gc_list);
+               spin_unlock_bh(&xfrm_state_gc_lock);
+
+               schedule_delayed_work(&xfrm_state_gc_work, 0);
                err = 0;
        }
 
@@ -771,10 +788,7 @@ int xfrm_state_flush(struct net *net, u8 proto, bool 
task_valid, bool sync)
                                err = xfrm_state_delete(x);
                                xfrm_audit_state_delete(x, err ? 0 : 1,
                                                        task_valid);
-                               if (sync)
-                                       xfrm_state_put_sync(x);
-                               else
-                                       xfrm_state_put(x);
+                               xfrm_state_put(x);
                                if (!err)
                                        cnt++;
 
@@ -937,8 +951,8 @@ static struct xfrm_state *__xfrm_state_lookup(struct net 
*net, u32 mark,
 
                if ((mark & x->mark.m) != x->mark.v)
                        continue;
-               if (!xfrm_state_hold_rcu(x))
-                       continue;
+
+               xfrm_state_hold(x);
                return x;
        }
 
@@ -962,8 +976,7 @@ static struct xfrm_state *__xfrm_state_lookup_byaddr(struct 
net *net, u32 mark,
 
                if ((mark & x->mark.m) != x->mark.v)
                        continue;
-               if (!xfrm_state_hold_rcu(x))
-                       continue;
+               xfrm_state_hold(x);
                return x;
        }
 
@@ -1151,10 +1164,7 @@ xfrm_state_find(const xfrm_address_t *daddr, const 
xfrm_address_t *saddr,
        }
 out:
        if (x) {
-               if (!xfrm_state_hold_rcu(x)) {
-                       *err = -EAGAIN;
-                       x = NULL;
-               }
+               xfrm_state_hold(x);
        } else {
                *err = acquire_in_progress ? -EAGAIN : error;
        }
@@ -1681,7 +1691,7 @@ int xfrm_state_update(struct xfrm_state *x)
 
                err = 0;
                x->km.state = XFRM_STATE_DEAD;
-               __xfrm_state_put(x);
+               xfrm_state_put(x);
        }
 
 fail:
@@ -2372,7 +2382,7 @@ struct xfrm_state_afinfo *xfrm_state_get_afinfo(unsigned 
int family)
 
 void xfrm_flush_gc(void)
 {
-       flush_work(&xfrm_state_gc_work);
+       flush_delayed_work(&xfrm_state_gc_work);
 }
 EXPORT_SYMBOL(xfrm_flush_gc);
 
@@ -2385,7 +2395,7 @@ void xfrm_state_delete_tunnel(struct xfrm_state *x)
                if (atomic_read(&t->tunnel_users) == 2)
                        xfrm_state_delete(t);
                atomic_dec(&t->tunnel_users);
-               xfrm_state_put_sync(t);
+               xfrm_state_put(t);
                x->tunnel = NULL;
        }
 }
@@ -2531,7 +2541,7 @@ void xfrm_state_fini(struct net *net)
        unsigned int sz;
 
        flush_work(&net->xfrm.state_hash_work);
-       flush_work(&xfrm_state_gc_work);
+       flush_delayed_work(&xfrm_state_gc_work);
        xfrm_state_flush(net, IPSEC_PROTO_ANY, false, true);
 
        WARN_ON(!list_empty(&net->xfrm.state_all));
diff --git a/net/xfrm/xfrm_user.c b/net/xfrm/xfrm_user.c
index a131f9ff979e..4587f1342af0 100644
--- a/net/xfrm/xfrm_user.c
+++ b/net/xfrm/xfrm_user.c
@@ -675,7 +675,7 @@ static int xfrm_add_sa(struct sk_buff *skb, struct nlmsghdr 
*nlh,
        if (err < 0) {
                x->km.state = XFRM_STATE_DEAD;
                xfrm_dev_state_delete(x);
-               __xfrm_state_put(x);
+               xfrm_state_delete(x);
                goto out;
        }
 
-- 
2.21.0

Reply via email to