On Thu, Apr 27, 2017 at 1:31 PM, Cong Wang <xiyou.wangc...@gmail.com> wrote:
> On Wed, Apr 26, 2017 at 2:20 PM, Paul Moore <p...@paul-moore.com> wrote:
>> Thanks for the report, this is the only one like it that I've seen.
>> I'm looking at the code in Linus' tree and I'm not seeing anything
>> obvious ... looking at the trace above it appears that the problem is
>> when get_net() goes to bump the refcount and the passed net pointer is
>> NULL; unless I'm missing something, the only way this would happen in
>> kauditd_thread() is if the auditd_conn.pid value is non-zero but the
>> auditd_conn.net pointer is NULL.
>>
>> That shouldn't happen.
>>
>
> Looking at the code that reads/writes the global auditd_conn,
> I don't see how it even works with RCU+spinlock, RCU plays
> with pointers and you have to make a copy as its name implies.
> But it looks like you simply use RCU+spinlock as a traditional
> rwlock, it doesn't work.

The attached patch seems working for me, I tried to boot my
VM for 4 times, so far no crash or warning.

Please let me know if it looks reasonable to you.
diff --git a/kernel/audit.c b/kernel/audit.c
index a871bf8..182d31b 100644
--- a/kernel/audit.c
+++ b/kernel/audit.c
@@ -110,7 +110,6 @@ struct audit_net {
  * @pid: auditd PID
  * @portid: netlink portid
  * @net: the associated network namespace
- * @lock: spinlock to protect write access
  *
  * Description:
  * This struct is RCU protected; you must either hold the RCU lock for reading
@@ -120,8 +119,9 @@ static struct auditd_connection {
        int pid;
        u32 portid;
        struct net *net;
-       spinlock_t lock;
-} auditd_conn;
+       struct rcu_head rcu;
+} *auditd_conn;
+static DEFINE_SPINLOCK(auditd_conn_lock);
 
 /* If audit_rate_limit is non-zero, limit the rate of sending audit records
  * to that number per second.  This prevents DoS attacks, but results in
@@ -223,9 +223,11 @@ struct audit_reply {
 int auditd_test_task(const struct task_struct *task)
 {
        int rc;
+       pid_t pid;
 
        rcu_read_lock();
-       rc = (auditd_conn.pid && task->tgid == auditd_conn.pid ? 1 : 0);
+       pid = rcu_dereference(auditd_conn)->pid;
+       rc = (pid && task->tgid == pid ? 1 : 0);
        rcu_read_unlock();
 
        return rc;
@@ -426,30 +428,37 @@ static int audit_set_failure(u32 state)
        return audit_do_config_change("audit_failure", &audit_failure, state);
 }
 
+static void auditd_conn_free(struct rcu_head *rcu)
+{
+       struct auditd_connection *cn = container_of(rcu, struct 
auditd_connection, rcu);
+
+       if (cn->net)
+               put_net(cn->net);
+       kfree(cn);
+}
+
 /**
- * auditd_set - Set/Reset the auditd connection state
- * @pid: auditd PID
- * @portid: auditd netlink portid
- * @net: auditd network namespace pointer
+ * auditd_set - Set the auditd connection state
+ * @new_coon: the new auditd connection
  *
  * Description:
  * This function will obtain and drop network namespace references as
  * necessary.
  */
-static void auditd_set(int pid, u32 portid, struct net *net)
+static void auditd_set(struct auditd_connection *new_conn)
 {
+       struct auditd_connection *old_conn;
        unsigned long flags;
 
-       spin_lock_irqsave(&auditd_conn.lock, flags);
-       auditd_conn.pid = pid;
-       auditd_conn.portid = portid;
-       if (auditd_conn.net)
-               put_net(auditd_conn.net);
-       if (net)
-               auditd_conn.net = get_net(net);
-       else
-               auditd_conn.net = NULL;
-       spin_unlock_irqrestore(&auditd_conn.lock, flags);
+       if (new_conn->net)
+               get_net(new_conn->net);
+       spin_lock_irqsave(&auditd_conn_lock, flags);
+       old_conn = rcu_dereference_protected(auditd_conn,
+                                       lockdep_is_held(&auditd_conn_lock));
+       rcu_assign_pointer(auditd_conn, new_conn);
+       spin_unlock_irqrestore(&auditd_conn_lock, flags);
+
+       call_rcu(&old_conn->rcu, auditd_conn_free);
 }
 
 /**
@@ -537,19 +546,24 @@ static void kauditd_retry_skb(struct sk_buff *skb)
 
 /**
  * auditd_reset - Disconnect the auditd connection
+ * @flags: GFP flags
  *
  * Description:
  * Break the auditd/kauditd connection and move all the queued records into the
  * hold queue in case auditd reconnects.
  */
-static void auditd_reset(void)
+static void auditd_reset(gfp_t flags)
 {
        struct sk_buff *skb;
+       struct auditd_connection *null_conn;
 
+       null_conn = kzalloc(sizeof(*null_conn), flags);
+       if (!null_conn)
+               return;
        /* if it isn't already broken, break the connection */
        rcu_read_lock();
-       if (auditd_conn.pid)
-               auditd_set(0, 0, NULL);
+       if (rcu_dereference(auditd_conn)->pid)
+               auditd_set(null_conn);
        rcu_read_unlock();
 
        /* flush all of the main and retry queues to the hold queue */
@@ -585,15 +599,15 @@ static int auditd_send_unicast_skb(struct sk_buff *skb)
         *       section netlink_unicast() should safely return an error */
 
        rcu_read_lock();
-       if (!auditd_conn.pid) {
+       if (!rcu_dereference(auditd_conn)->pid) {
                rcu_read_unlock();
                rc = -ECONNREFUSED;
                goto err;
        }
-       net = auditd_conn.net;
+       net = rcu_dereference(auditd_conn)->net;
        get_net(net);
        sk = audit_get_sk(net);
-       portid = auditd_conn.portid;
+       portid = rcu_dereference(auditd_conn)->portid;
        rcu_read_unlock();
 
        rc = netlink_unicast(sk, skb, portid, 0);
@@ -605,7 +619,7 @@ static int auditd_send_unicast_skb(struct sk_buff *skb)
 
 err:
        if (rc == -ECONNREFUSED)
-               auditd_reset();
+               auditd_reset(GFP_KERNEL);
        return rc;
 }
 
@@ -735,14 +749,14 @@ static int kauditd_thread(void *dummy)
        while (!kthread_should_stop()) {
                /* NOTE: see the lock comments in auditd_send_unicast_skb() */
                rcu_read_lock();
-               if (!auditd_conn.pid) {
+               if (!rcu_dereference(auditd_conn)->pid) {
                        rcu_read_unlock();
                        goto main_queue;
                }
-               net = auditd_conn.net;
+               net = rcu_dereference(auditd_conn)->net;
                get_net(net);
                sk = audit_get_sk(net);
-               portid = auditd_conn.portid;
+               portid = rcu_dereference(auditd_conn)->portid;
                rcu_read_unlock();
 
                /* attempt to flush the hold queue */
@@ -751,7 +765,7 @@ static int kauditd_thread(void *dummy)
                                        NULL, kauditd_rehold_skb);
                if (rc < 0) {
                        sk = NULL;
-                       auditd_reset();
+                       auditd_reset(GFP_KERNEL);
                        goto main_queue;
                }
 
@@ -761,7 +775,7 @@ static int kauditd_thread(void *dummy)
                                        NULL, kauditd_hold_skb);
                if (rc < 0) {
                        sk = NULL;
-                       auditd_reset();
+                       auditd_reset(GFP_KERNEL);
                        goto main_queue;
                }
 
@@ -774,7 +788,7 @@ static int kauditd_thread(void *dummy)
                                        kauditd_send_multicast_skb,
                                        kauditd_retry_skb);
                if (sk == NULL || rc < 0)
-                       auditd_reset();
+                       auditd_reset(GFP_KERNEL);
                sk = NULL;
 
                /* drop our netns reference, no auditd sends past this line */
@@ -1103,7 +1117,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct 
nlmsghdr *nlh)
                s.enabled               = audit_enabled;
                s.failure               = audit_failure;
                rcu_read_lock();
-               s.pid                   = auditd_conn.pid;
+               s.pid                   = rcu_dereference(auditd_conn)->pid;
                rcu_read_unlock();
                s.rate_limit            = audit_rate_limit;
                s.backlog_limit         = audit_backlog_limit;
@@ -1139,17 +1153,22 @@ static int audit_receive_msg(struct sk_buff *skb, 
struct nlmsghdr *nlh)
                        int new_pid = s.pid;
                        pid_t auditd_pid;
                        pid_t requesting_pid = task_tgid_vnr(current);
+                       struct auditd_connection *new;
 
+                       new = kmalloc(sizeof(*new), GFP_KERNEL);
+                       if (!new)
+                               return -ENOMEM;
                        /* test the auditd connection */
                        audit_replace(requesting_pid);
 
                        rcu_read_lock();
-                       auditd_pid = auditd_conn.pid;
+                       auditd_pid = rcu_dereference(auditd_conn)->pid;
                        /* only the current auditd can unregister itself */
                        if ((!new_pid) && (requesting_pid != auditd_pid)) {
                                rcu_read_unlock();
                                audit_log_config_change("audit_pid", new_pid,
                                                        auditd_pid, 0);
+                               kfree(new);
                                return -EACCES;
                        }
                        /* replacing a healthy auditd is not allowed */
@@ -1157,6 +1176,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct 
nlmsghdr *nlh)
                                rcu_read_unlock();
                                audit_log_config_change("audit_pid", new_pid,
                                                        auditd_pid, 0);
+                               kfree(new);
                                return -EEXIST;
                        }
                        rcu_read_unlock();
@@ -1166,15 +1186,16 @@ static int audit_receive_msg(struct sk_buff *skb, 
struct nlmsghdr *nlh)
                                                        auditd_pid, 1);
 
                        if (new_pid) {
+                               new->pid = new_pid;
+                               new->portid = NETLINK_CB(skb).portid;
+                               new->net = sock_net(NETLINK_CB(skb).sk);
                                /* register a new auditd connection */
-                               auditd_set(new_pid,
-                                          NETLINK_CB(skb).portid,
-                                          sock_net(NETLINK_CB(skb).sk));
+                               auditd_set(new);
                                /* try to process any backlog */
                                wake_up_interruptible(&kauditd_wait);
                        } else
                                /* unregister the auditd connection */
-                               auditd_reset();
+                               auditd_reset(GFP_KERNEL);
                }
                if (s.mask & AUDIT_STATUS_RATE_LIMIT) {
                        err = audit_set_rate_limit(s.rate_limit);
@@ -1448,8 +1469,8 @@ static void __net_exit audit_net_exit(struct net *net)
        struct audit_net *aunet = net_generic(net, audit_net_id);
 
        rcu_read_lock();
-       if (net == auditd_conn.net)
-               auditd_reset();
+       if (net == rcu_dereference(auditd_conn)->net)
+               auditd_reset(GFP_ATOMIC);
        rcu_read_unlock();
 
        netlink_kernel_release(aunet->sk);
@@ -1470,8 +1491,9 @@ static int __init audit_init(void)
        if (audit_initialized == AUDIT_DISABLED)
                return 0;
 
-       memset(&auditd_conn, 0, sizeof(auditd_conn));
-       spin_lock_init(&auditd_conn.lock);
+       auditd_conn = kzalloc(sizeof(*auditd_conn), GFP_KERNEL);
+       if (!auditd_conn)
+               return -ENOMEM;
 
        skb_queue_head_init(&audit_queue);
        skb_queue_head_init(&audit_retry_queue);

Reply via email to