New version of the netlink_has_listeners() patch.

Changes:

- Fix missing listeners bitmap update when there was no delta in the
  number of subscribed groups
- Use RCU to protect nltable listeners bitmap

[NETLINK]: Add netlink_has_listeners() for checking for multicast listeners

netlink_has_listeners() should be used to avoid unneccessary event message
generation if there are no listeners.

Signed-off-by: Patrick McHardy <[EMAIL PROTECTED]>

---
commit e32451ce5600d4e0d7f74486b840279986f05cfb
tree 19a1553d346d8877ffd95065f07a6b632338d6cc
parent 3ee68c4af3fd7228c1be63254b9f884614f9ebb2
author Patrick McHardy <[EMAIL PROTECTED]> Tue, 31 Jan 2006 01:20:38 +0100
committer Patrick McHardy <[EMAIL PROTECTED]> Tue, 31 Jan 2006 01:20:38 +0100

 include/linux/netlink.h  |    1 
 net/netlink/af_netlink.c |  100 ++++++++++++++++++++++++++++++++++++----------
 2 files changed, 80 insertions(+), 21 deletions(-)

diff --git a/include/linux/netlink.h b/include/linux/netlink.h
index 6a2ccf7..03ecb89 100644
--- a/include/linux/netlink.h
+++ b/include/linux/netlink.h
@@ -151,6 +151,7 @@ struct netlink_skb_parms
 
 extern struct sock *netlink_kernel_create(int unit, unsigned int groups, void 
(*input)(struct sock *sk, int len), struct module *module);
 extern void netlink_ack(struct sk_buff *in_skb, struct nlmsghdr *nlh, int err);
+extern int netlink_has_listeners(struct sock *sk, unsigned int group);
 extern int netlink_unicast(struct sock *ssk, struct sk_buff *skb, __u32 pid, 
int nonblock);
 extern int netlink_broadcast(struct sock *ssk, struct sk_buff *skb, __u32 pid,
                             __u32 group, gfp_t allocation);
diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c
index 2101b45..bf963f6 100644
--- a/net/netlink/af_netlink.c
+++ b/net/netlink/af_netlink.c
@@ -106,6 +106,7 @@ struct nl_pid_hash {
 struct netlink_table {
        struct nl_pid_hash hash;
        struct hlist_head mc_list;
+       unsigned long *listeners;
        unsigned int nl_nonroot;
        unsigned int groups;
        struct module *module;
@@ -296,6 +297,38 @@ static inline int nl_pid_hash_dilute(str
 
 static const struct proto_ops netlink_ops;
 
+static void
+netlink_update_listeners(struct netlink_table *tbl, struct sock *sk)
+{
+       struct netlink_sock *nlk = nlk_sk(sk);
+       struct hlist_node *node;
+       unsigned long mask;
+       unsigned int i;
+
+       for (i = 0; i < NLGRPSZ(tbl->groups)/sizeof(unsigned long); i++) {
+               mask = 0;
+               sk_for_each_bound(sk, node, &tbl->mc_list)
+                       mask |= nlk->groups[i];
+               tbl->listeners[i] = mask;
+       }
+       /* make sure updates are visible before bind() or setsockopt() return */
+       smp_wmb();
+}
+
+static void
+netlink_update_subscriptions(struct sock *sk, unsigned int subscriptions)
+{
+       struct netlink_table *tbl = &nl_table[sk->sk_protocol];
+       struct netlink_sock *nlk = nlk_sk(sk);
+
+       if (nlk->subscriptions && !subscriptions)
+               __sk_del_bind_node(sk);
+       else if (!nlk->subscriptions && subscriptions)
+               sk_add_bind_node(sk, &tbl->mc_list);
+       netlink_update_listeners(tbl, sk);
+       nlk->subscriptions = subscriptions;
+}
+
 static int netlink_insert(struct sock *sk, u32 pid)
 {
        struct nl_pid_hash *hash = &nl_table[sk->sk_protocol].hash;
@@ -453,15 +486,27 @@ static int netlink_release(struct socket
                notifier_call_chain(&netlink_chain, NETLINK_URELEASE, &n);
        }       
 
-       if (nlk->module)
-               module_put(nlk->module);
+       netlink_table_grab();
+       if (nlk->subscriptions) {
+               memset(nlk->groups, 0, NLGRPSZ(nlk->ngroups));
+               netlink_update_subscriptions(sk, 0);
+       }
 
        if (nlk->flags & NETLINK_KERNEL_SOCKET) {
-               netlink_table_grab();
+               unsigned long *listeners;
+
+               listeners = nl_table[sk->sk_protocol].listeners;
+               nl_table[sk->sk_protocol].listeners = NULL;
+               synchronize_rcu();
+               kfree(nl_table[sk->sk_protocol].listeners);
+
                nl_table[sk->sk_protocol].module = NULL;
                nl_table[sk->sk_protocol].registered = 0;
-               netlink_table_ungrab();
        }
+       netlink_table_ungrab();
+
+       if (nlk->module)
+               module_put(nlk->module);
 
        kfree(nlk->groups);
        nlk->groups = NULL;
@@ -514,18 +559,6 @@ static inline int netlink_capable(struct
               capable(CAP_NET_ADMIN);
 } 
 
-static void
-netlink_update_subscriptions(struct sock *sk, unsigned int subscriptions)
-{
-       struct netlink_sock *nlk = nlk_sk(sk);
-
-       if (nlk->subscriptions && !subscriptions)
-               __sk_del_bind_node(sk);
-       else if (!nlk->subscriptions && subscriptions)
-               sk_add_bind_node(sk, &nl_table[sk->sk_protocol].mc_list);
-       nlk->subscriptions = subscriptions;
-}
-
 static int netlink_alloc_groups(struct sock *sk)
 {
        struct netlink_sock *nlk = nlk_sk(sk);
@@ -554,6 +587,7 @@ static int netlink_bind(struct socket *s
        struct sock *sk = sock->sk;
        struct netlink_sock *nlk = nlk_sk(sk);
        struct sockaddr_nl *nladdr = (struct sockaddr_nl *)addr;
+       unsigned int subscriptions;
        int err;
        
        if (nladdr->nl_family != AF_NETLINK)
@@ -585,10 +619,10 @@ static int netlink_bind(struct socket *s
                return 0;
 
        netlink_table_grab();
-       netlink_update_subscriptions(sk, nlk->subscriptions +
-                                        hweight32(nladdr->nl_groups) -
-                                        hweight32(nlk->groups[0]));
-       nlk->groups[0] = (nlk->groups[0] & ~0xffffffffUL) | nladdr->nl_groups; 
+       subscriptions = nlk->subscriptions + hweight32(nladdr->nl_groups)
+                                          - hweight32(nlk->groups[0]);
+       nlk->groups[0] = (nlk->groups[0] & ~0xffffffffUL) | nladdr->nl_groups;
+       netlink_update_subscriptions(sk, subscriptions);
        netlink_table_ungrab();
 
        return 0;
@@ -806,6 +840,20 @@ retry:
        return netlink_sendskb(sk, skb, ssk->sk_protocol);
 }
 
+int netlink_has_listeners(struct sock *sk, unsigned int group)
+{
+       unsigned long *listeners;
+       int res = 0;
+
+       rcu_read_lock();
+       listeners = rcu_dereference(nl_table[sk->sk_protocol].listeners);
+       if (group - 1 < nl_table[sk->sk_protocol].groups)
+               res = test_bit(group - 1, listeners);
+       rcu_read_unlock();
+       return res;
+}
+EXPORT_SYMBOL_GPL(netlink_has_listeners);
+
 static __inline__ int netlink_broadcast_deliver(struct sock *sk, struct 
sk_buff *skb)
 {
        struct netlink_sock *nlk = nlk_sk(sk);
@@ -1235,6 +1283,7 @@ netlink_kernel_create(int unit, unsigned
        struct socket *sock;
        struct sock *sk;
        struct netlink_sock *nlk;
+       unsigned long *listeners = NULL;
 
        if (!nl_table)
                return NULL;
@@ -1248,6 +1297,13 @@ netlink_kernel_create(int unit, unsigned
        if (__netlink_create(sock, unit) < 0)
                goto out_sock_release;
 
+       if (groups < 32)
+               groups = 32;
+
+       listeners = kzalloc(NLGRPSZ(groups), GFP_KERNEL);
+       if (!listeners)
+               goto out_sock_release;
+
        sk = sock->sk;
        sk->sk_data_ready = netlink_data_ready;
        if (input)
@@ -1260,7 +1316,8 @@ netlink_kernel_create(int unit, unsigned
        nlk->flags |= NETLINK_KERNEL_SOCKET;
 
        netlink_table_grab();
-       nl_table[unit].groups = groups < 32 ? 32 : groups;
+       nl_table[unit].groups = groups;
+       nl_table[unit].listeners = listeners;
        nl_table[unit].module = module;
        nl_table[unit].registered = 1;
        netlink_table_ungrab();
@@ -1268,6 +1325,7 @@ netlink_kernel_create(int unit, unsigned
        return sk;
 
 out_sock_release:
+       kfree(listeners);
        sock_release(sock);
        return NULL;
 }

Reply via email to