Patrick McHardy wrote:
> Thinking about it .. using RCU seems entirely unneccessary, assuming
> callers don't close the kernel socket and call netlink_has_listeners()
> afterwards, the pointer is always valid.

Latest version without RCU, the bitmask is always valid.
[NETLINK]: Add netlink_has_listeners() to avoid unneccessary event generation

Keep a bitmask of multicast groups with subscribed listeners to allow
checking for listeners before generating multicast messages.

Queries don't perform any locking, which may result in false positives,
it is guaranteed however that any new subscriptions are visible before
bind() or setsockopt() return.

Signed-off-by: Patrick McHardy <[EMAIL PROTECTED]>

---
commit ca0a229142f320780bbca94034287b76acf693c0
tree d0a8fc320d96bb6d29da203859dfa2ad3da53468
parent 3ee68c4af3fd7228c1be63254b9f884614f9ebb2
author Patrick McHardy <[EMAIL PROTECTED]> Tue, 31 Jan 2006 02:23:00 +0100
committer Patrick McHardy <[EMAIL PROTECTED]> Tue, 31 Jan 2006 02:23:00 +0100

 include/linux/netlink.h  |    1 +
 net/netlink/af_netlink.c |   87 ++++++++++++++++++++++++++++++++++++----------
 2 files changed, 69 insertions(+), 19 deletions(-)

diff --git a/include/linux/netlink.h b/include/linux/netlink.h
index 6a2ccf7..03ecb89 100644
--- a/include/linux/netlink.h
+++ b/include/linux/netlink.h
@@ -151,6 +151,7 @@ struct netlink_skb_parms
 
 extern struct sock *netlink_kernel_create(int unit, unsigned int groups, void 
(*input)(struct sock *sk, int len), struct module *module);
 extern void netlink_ack(struct sk_buff *in_skb, struct nlmsghdr *nlh, int err);
+extern int netlink_has_listeners(struct sock *sk, unsigned int group);
 extern int netlink_unicast(struct sock *ssk, struct sk_buff *skb, __u32 pid, 
int nonblock);
 extern int netlink_broadcast(struct sock *ssk, struct sk_buff *skb, __u32 pid,
                             __u32 group, gfp_t allocation);
diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c
index 2101b45..c4c4a42 100644
--- a/net/netlink/af_netlink.c
+++ b/net/netlink/af_netlink.c
@@ -106,6 +106,7 @@ struct nl_pid_hash {
 struct netlink_table {
        struct nl_pid_hash hash;
        struct hlist_head mc_list;
+       unsigned long *listeners;
        unsigned int nl_nonroot;
        unsigned int groups;
        struct module *module;
@@ -296,6 +297,38 @@ static inline int nl_pid_hash_dilute(str
 
 static const struct proto_ops netlink_ops;
 
+static void
+netlink_update_listeners(struct netlink_table *tbl, struct sock *sk)
+{
+       struct netlink_sock *nlk = nlk_sk(sk);
+       struct hlist_node *node;
+       unsigned long mask;
+       unsigned int i;
+
+       for (i = 0; i < NLGRPSZ(tbl->groups)/sizeof(unsigned long); i++) {
+               mask = 0;
+               sk_for_each_bound(sk, node, &tbl->mc_list)
+                       mask |= nlk->groups[i];
+               tbl->listeners[i] = mask;
+       }
+       /* this function is only called with the netlink table "grabbed", which
+        * makes sure updates are visible bind() or setsockopt() return */
+}
+
+static void
+netlink_update_subscriptions(struct sock *sk, unsigned int subscriptions)
+{
+       struct netlink_table *tbl = &nl_table[sk->sk_protocol];
+       struct netlink_sock *nlk = nlk_sk(sk);
+
+       if (nlk->subscriptions && !subscriptions)
+               __sk_del_bind_node(sk);
+       else if (!nlk->subscriptions && subscriptions)
+               sk_add_bind_node(sk, &tbl->mc_list);
+       netlink_update_listeners(tbl, sk);
+       nlk->subscriptions = subscriptions;
+}
+
 static int netlink_insert(struct sock *sk, u32 pid)
 {
        struct nl_pid_hash *hash = &nl_table[sk->sk_protocol].hash;
@@ -456,12 +489,18 @@ static int netlink_release(struct socket
        if (nlk->module)
                module_put(nlk->module);
 
+       netlink_table_grab();
+       if (nlk->subscriptions) {
+               memset(nlk->groups, 0, NLGRPSZ(nlk->ngroups));
+               netlink_update_subscriptions(sk, 0);
+       }
+
        if (nlk->flags & NETLINK_KERNEL_SOCKET) {
-               netlink_table_grab();
+               kfree(nl_table[sk->sk_protocol].listeners);
                nl_table[sk->sk_protocol].module = NULL;
                nl_table[sk->sk_protocol].registered = 0;
-               netlink_table_ungrab();
        }
+       netlink_table_ungrab();
 
        kfree(nlk->groups);
        nlk->groups = NULL;
@@ -514,18 +553,6 @@ static inline int netlink_capable(struct
               capable(CAP_NET_ADMIN);
 } 
 
-static void
-netlink_update_subscriptions(struct sock *sk, unsigned int subscriptions)
-{
-       struct netlink_sock *nlk = nlk_sk(sk);
-
-       if (nlk->subscriptions && !subscriptions)
-               __sk_del_bind_node(sk);
-       else if (!nlk->subscriptions && subscriptions)
-               sk_add_bind_node(sk, &nl_table[sk->sk_protocol].mc_list);
-       nlk->subscriptions = subscriptions;
-}
-
 static int netlink_alloc_groups(struct sock *sk)
 {
        struct netlink_sock *nlk = nlk_sk(sk);
@@ -554,6 +581,7 @@ static int netlink_bind(struct socket *s
        struct sock *sk = sock->sk;
        struct netlink_sock *nlk = nlk_sk(sk);
        struct sockaddr_nl *nladdr = (struct sockaddr_nl *)addr;
+       unsigned int subscriptions;
        int err;
        
        if (nladdr->nl_family != AF_NETLINK)
@@ -585,10 +613,10 @@ static int netlink_bind(struct socket *s
                return 0;
 
        netlink_table_grab();
-       netlink_update_subscriptions(sk, nlk->subscriptions +
-                                        hweight32(nladdr->nl_groups) -
-                                        hweight32(nlk->groups[0]));
-       nlk->groups[0] = (nlk->groups[0] & ~0xffffffffUL) | nladdr->nl_groups; 
+       subscriptions = nlk->subscriptions + hweight32(nladdr->nl_groups)
+                                          - hweight32(nlk->groups[0]);
+       nlk->groups[0] = (nlk->groups[0] & ~0xffffffffUL) | nladdr->nl_groups;
+       netlink_update_subscriptions(sk, subscriptions);
        netlink_table_ungrab();
 
        return 0;
@@ -806,6 +834,17 @@ retry:
        return netlink_sendskb(sk, skb, ssk->sk_protocol);
 }
 
+int netlink_has_listeners(struct sock *sk, unsigned int group)
+{
+       int res = 0;
+
+       BUG_ON(!(nlk_sk(sk)->flags & NETLINK_KERNEL_SOCKET));
+       if (group - 1 < nl_table[sk->sk_protocol].groups)
+               res = test_bit(group - 1, nl_table[sk->sk_protocol].listeners);
+       return res;
+}
+EXPORT_SYMBOL_GPL(netlink_has_listeners);
+
 static __inline__ int netlink_broadcast_deliver(struct sock *sk, struct 
sk_buff *skb)
 {
        struct netlink_sock *nlk = nlk_sk(sk);
@@ -1235,6 +1274,7 @@ netlink_kernel_create(int unit, unsigned
        struct socket *sock;
        struct sock *sk;
        struct netlink_sock *nlk;
+       unsigned long *listeners = NULL;
 
        if (!nl_table)
                return NULL;
@@ -1248,6 +1288,13 @@ netlink_kernel_create(int unit, unsigned
        if (__netlink_create(sock, unit) < 0)
                goto out_sock_release;
 
+       if (groups < 32)
+               groups = 32;
+
+       listeners = kzalloc(NLGRPSZ(groups), GFP_KERNEL);
+       if (!listeners)
+               goto out_sock_release;
+
        sk = sock->sk;
        sk->sk_data_ready = netlink_data_ready;
        if (input)
@@ -1260,7 +1307,8 @@ netlink_kernel_create(int unit, unsigned
        nlk->flags |= NETLINK_KERNEL_SOCKET;
 
        netlink_table_grab();
-       nl_table[unit].groups = groups < 32 ? 32 : groups;
+       nl_table[unit].groups = groups;
+       nl_table[unit].listeners = listeners;
        nl_table[unit].module = module;
        nl_table[unit].registered = 1;
        netlink_table_ungrab();
@@ -1268,6 +1316,7 @@ netlink_kernel_create(int unit, unsigned
        return sk;
 
 out_sock_release:
+       kfree(listeners);
        sock_release(sock);
        return NULL;
 }

Reply via email to