Move code for writing a sock_flow_table to its own function so that it
can be called for other use cases.
---
 net/core/sysctl_net_core.c | 102 +++++++++++++++++++++----------------
 1 file changed, 57 insertions(+), 45 deletions(-)

diff --git a/net/core/sysctl_net_core.c b/net/core/sysctl_net_core.c
index f93f8ace6c56..9c7d46fbb75a 100644
--- a/net/core/sysctl_net_core.c
+++ b/net/core/sysctl_net_core.c
@@ -46,66 +46,78 @@ int sysctl_devconf_inherit_init_net __read_mostly;
 EXPORT_SYMBOL(sysctl_devconf_inherit_init_net);
 
 #ifdef CONFIG_RPS
+static int rps_create_sock_flow_table(size_t size, size_t orig_size,
+                                     struct rps_sock_flow_table *orig_table,
+                                     bool force)
+{
+       struct rps_sock_flow_table *sock_table;
+       int i;
+
+       if (size) {
+               if (size > 1 << 29) {
+                       /* Enforce limit to prevent overflow */
+                       return -EINVAL;
+               }
+               size = roundup_pow_of_two(size);
+               if (size != orig_size || force) {
+                       sock_table = vmalloc(RPS_SOCK_FLOW_TABLE_SIZE(size));
+                       if (!sock_table)
+                               return -ENOMEM;
+
+                       sock_table->mask = size - 1;
+               } else {
+                       sock_table = orig_table;
+               }
+
+               for (i = 0; i < size; i++)
+                       sock_table->ents[i] = RPS_NO_CPU;
+       } else {
+               sock_table = NULL;
+       }
+
+       if (sock_table != orig_table) {
+               rcu_assign_pointer(rps_sock_flow_table, sock_table);
+               if (sock_table) {
+                       static_branch_inc(&rps_needed);
+                       static_branch_inc(&rfs_needed);
+               }
+               if (orig_table) {
+                       static_branch_dec(&rps_needed);
+                       static_branch_dec(&rfs_needed);
+                       synchronize_rcu();
+                       vfree(orig_table);
+               }
+       }
+
+       return 0;
+}
+
+static DEFINE_MUTEX(sock_flow_mutex);
+
 static int rps_sock_flow_sysctl(struct ctl_table *table, int write,
                                void *buffer, size_t *lenp, loff_t *ppos)
 {
        unsigned int orig_size, size;
-       int ret, i;
+       int ret;
        struct ctl_table tmp = {
                .data = &size,
                .maxlen = sizeof(size),
                .mode = table->mode
        };
-       struct rps_sock_flow_table *orig_sock_table, *sock_table;
-       static DEFINE_MUTEX(sock_flow_mutex);
+       struct rps_sock_flow_table *sock_table;
 
        mutex_lock(&sock_flow_mutex);
 
-       orig_sock_table = rcu_dereference_protected(rps_sock_flow_table,
-                                       lockdep_is_held(&sock_flow_mutex));
-       size = orig_size = orig_sock_table ? orig_sock_table->mask + 1 : 0;
+       sock_table = rcu_dereference_protected(rps_sock_flow_table,
+                                              
lockdep_is_held(&sock_flow_mutex));
+       size = sock_table ? sock_table->mask + 1 : 0;
+       orig_size = size;
 
        ret = proc_dointvec(&tmp, write, buffer, lenp, ppos);
 
-       if (write) {
-               if (size) {
-                       if (size > 1<<29) {
-                               /* Enforce limit to prevent overflow */
-                               mutex_unlock(&sock_flow_mutex);
-                               return -EINVAL;
-                       }
-                       size = roundup_pow_of_two(size);
-                       if (size != orig_size) {
-                               sock_table =
-                                   vmalloc(RPS_SOCK_FLOW_TABLE_SIZE(size));
-                               if (!sock_table) {
-                                       mutex_unlock(&sock_flow_mutex);
-                                       return -ENOMEM;
-                               }
-                               rps_cpu_mask = roundup_pow_of_two(nr_cpu_ids) - 
1;
-                               sock_table->mask = size - 1;
-                       } else
-                               sock_table = orig_sock_table;
-
-                       for (i = 0; i < size; i++)
-                               sock_table->ents[i] = RPS_NO_CPU;
-               } else
-                       sock_table = NULL;
-
-               if (sock_table != orig_sock_table) {
-                       rcu_assign_pointer(rps_sock_flow_table, sock_table);
-                       if (sock_table) {
-                               static_branch_inc(&rps_needed);
-                               static_branch_inc(&rfs_needed);
-                       }
-                       if (orig_sock_table) {
-                               static_branch_dec(&rps_needed);
-                               static_branch_dec(&rfs_needed);
-                               synchronize_rcu();
-                               vfree(orig_sock_table);
-                       }
-               }
-       }
+       if (write)
+               ret = rps_create_sock_flow_table(size, orig_size,
+                                                sock_table, false);
 
        mutex_unlock(&sock_flow_mutex);
 
-- 
2.25.1

Reply via email to