diff --git a/kernel/audit.c b/kernel/audit.c index 10bc2bad2adf0..a7c6a50477aa6 100644 --- a/kernel/audit.c +++ b/kernel/audit.c @@ -112,18 +112,19 @@ struct audit_net { * @pid: auditd PID * @portid: netlink portid * @net: the associated network namespace - * @lock: spinlock to protect write access + * @rcu: RCU head * * Description: * This struct is RCU protected; you must either hold the RCU lock for reading - * or the included spinlock for writing. + * or the associated spinlock for writing. */ static struct auditd_connection { struct pid *pid; u32 portid; struct net *net; - spinlock_t lock; -} auditd_conn; + struct rcu_head rcu; +} *auditd_conn = NULL; +static DEFINE_SPINLOCK(auditd_conn_lock); /* If audit_rate_limit is non-zero, limit the rate of sending audit records * to that number per second. This prevents DoS attacks, but results in @@ -215,9 +216,11 @@ struct audit_reply { int auditd_test_task(struct task_struct *task) { int rc; + struct auditd_connection *ac; rcu_read_lock(); - rc = (auditd_conn.pid && auditd_conn.pid == task_tgid(task) ? 1 : 0); + ac = rcu_dereference(auditd_conn); + rc = (ac && ac->pid == task_tgid(task) ? 1 : 0); rcu_read_unlock(); return rc; @@ -225,22 +228,21 @@ int auditd_test_task(struct task_struct *task) /** * auditd_pid_vnr - Return the auditd PID relative to the namespace - * @auditd: the auditd connection * * Description: - * Returns the PID in relation to the namespace, 0 on failure. This function - * takes the RCU read lock internally, but if the caller needs to protect the - * auditd_connection pointer it should take the RCU read lock as well. + * Returns the PID in relation to the namespace, 0 on failure. */ -static pid_t auditd_pid_vnr(const struct auditd_connection *auditd) +static pid_t auditd_pid_vnr(void) { pid_t pid; + const struct auditd_connection *ac; rcu_read_lock(); - if (!auditd || !auditd->pid) + ac = rcu_dereference(auditd_conn); + if (!ac || !ac->pid) pid = 0; else - pid = pid_vnr(auditd->pid); + pid = pid_vnr(ac->pid); rcu_read_unlock(); return pid; @@ -433,6 +435,24 @@ static int audit_set_failure(u32 state) return audit_do_config_change("audit_failure", &audit_failure, state); } +/** + * auditd_conn_free - RCU helper to release an auditd connection struct + * @rcu: RCU head + * + * Description: + * Drop any references inside the auditd connection tracking struct and free + * the memory. + */ + static void auditd_conn_free(struct rcu_head *rcu) + { + struct auditd_connection *ac; + + ac = container_of(rcu, struct auditd_connection, rcu); + put_pid(ac->pid); + put_net(ac->net); + kfree(ac); + } + /** * auditd_set - Set/Reset the auditd connection state * @pid: auditd PID @@ -441,27 +461,33 @@ static int audit_set_failure(u32 state) * * Description: * This function will obtain and drop network namespace references as - * necessary. + * necessary. Returns zero on success, negative values on failure. */ -static void auditd_set(struct pid *pid, u32 portid, struct net *net) +static int auditd_set(struct pid *pid, u32 portid, struct net *net) { unsigned long flags; + struct auditd_connection *ac_old, *ac_new; - spin_lock_irqsave(&auditd_conn.lock, flags); - if (auditd_conn.pid) - put_pid(auditd_conn.pid); - if (pid) - auditd_conn.pid = get_pid(pid); - else - auditd_conn.pid = NULL; - auditd_conn.portid = portid; - if (auditd_conn.net) - put_net(auditd_conn.net); - if (net) - auditd_conn.net = get_net(net); - else - auditd_conn.net = NULL; - spin_unlock_irqrestore(&auditd_conn.lock, flags); + if (!pid || !net) + return -EINVAL; + + ac_new = kzalloc(sizeof(*ac_new), GFP_KERNEL); + if (!ac_new) + return -ENOMEM; + ac_new->pid = get_pid(pid); + ac_new->portid = portid; + ac_new->net = get_net(net); + + spin_lock_irqsave(&auditd_conn_lock, flags); + ac_old = rcu_dereference_protected(auditd_conn, + lockdep_is_held(&auditd_conn_lock)); + rcu_assign_pointer(auditd_conn, ac_new); + spin_unlock_irqrestore(&auditd_conn_lock, flags); + + if (ac_old) + call_rcu(&ac_old->rcu, auditd_conn_free); + + return 0; } /** @@ -556,13 +582,19 @@ static void kauditd_retry_skb(struct sk_buff *skb) */ static void auditd_reset(void) { + unsigned long flags; struct sk_buff *skb; + struct auditd_connection *ac_old; /* if it isn't already broken, break the connection */ - rcu_read_lock(); - if (auditd_conn.pid) - auditd_set(0, 0, NULL); - rcu_read_unlock(); + spin_lock_irqsave(&auditd_conn_lock, flags); + ac_old = rcu_dereference_protected(auditd_conn, + lockdep_is_held(&auditd_conn_lock)); + rcu_assign_pointer(auditd_conn, NULL); + spin_unlock_irqrestore(&auditd_conn_lock, flags); + + if (ac_old) + call_rcu(&ac_old->rcu, auditd_conn_free); /* flush all of the main and retry queues to the hold queue */ while ((skb = skb_dequeue(&audit_retry_queue))) @@ -588,6 +620,7 @@ static int auditd_send_unicast_skb(struct sk_buff *skb) u32 portid; struct net *net; struct sock *sk; + struct auditd_connection *ac; /* NOTE: we can't call netlink_unicast while in the RCU section so * take a reference to the network namespace and grab local @@ -597,15 +630,15 @@ static int auditd_send_unicast_skb(struct sk_buff *skb) * section netlink_unicast() should safely return an error */ rcu_read_lock(); - if (!auditd_conn.pid) { + ac = rcu_dereference(auditd_conn); + if (!ac) { rcu_read_unlock(); rc = -ECONNREFUSED; goto err; } - net = auditd_conn.net; - get_net(net); + net = get_net(ac->net); sk = audit_get_sk(net); - portid = auditd_conn.portid; + portid = ac->portid; rcu_read_unlock(); rc = netlink_unicast(sk, skb, portid, 0); @@ -740,6 +773,7 @@ static int kauditd_thread(void *dummy) u32 portid = 0; struct net *net = NULL; struct sock *sk = NULL; + struct auditd_connection *ac; #define UNICAST_RETRIES 5 @@ -747,14 +781,14 @@ static int kauditd_thread(void *dummy) while (!kthread_should_stop()) { /* NOTE: see the lock comments in auditd_send_unicast_skb() */ rcu_read_lock(); - if (!auditd_conn.pid) { + ac = rcu_dereference(auditd_conn); + if (!ac) { rcu_read_unlock(); goto main_queue; } - net = auditd_conn.net; - get_net(net); + net = get_net(ac->net); sk = audit_get_sk(net); - portid = auditd_conn.portid; + portid = ac->portid; rcu_read_unlock(); /* attempt to flush the hold queue */ @@ -1117,7 +1151,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh) s.failure = audit_failure; /* NOTE: use pid_vnr() so the PID is relative to the current * namespace */ - s.pid = auditd_pid_vnr(&auditd_conn); + s.pid = auditd_pid_vnr(); s.rate_limit = audit_rate_limit; s.backlog_limit = audit_backlog_limit; s.lost = atomic_read(&audit_lost); @@ -1160,7 +1194,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh) /* test the auditd connection */ audit_replace(req_pid); - auditd_pid = auditd_pid_vnr(&auditd_conn); + auditd_pid = auditd_pid_vnr(); /* only the current auditd can unregister itself */ if ((!new_pid) && (new_pid != auditd_pid)) { audit_log_config_change("audit_pid", new_pid, @@ -1174,19 +1208,30 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh) return -EEXIST; } - if (audit_enabled != AUDIT_OFF) - audit_log_config_change("audit_pid", new_pid, - auditd_pid, 1); - if (new_pid) { /* register a new auditd connection */ - auditd_set(req_pid, NETLINK_CB(skb).portid, - sock_net(NETLINK_CB(skb).sk)); + err = auditd_set(req_pid, + NETLINK_CB(skb).portid, + sock_net(NETLINK_CB(skb).sk)); + if (audit_enabled != AUDIT_OFF) + audit_log_config_change("audit_pid", + new_pid, + auditd_pid, + err ? 0 : 1); + if (err) + return err; + /* try to process any backlog */ wake_up_interruptible(&kauditd_wait); - } else + } else { + if (audit_enabled != AUDIT_OFF) + audit_log_config_change("audit_pid", + new_pid, + auditd_pid, 1); + /* unregister the auditd connection */ auditd_reset(); + } } if (s.mask & AUDIT_STATUS_RATE_LIMIT) { err = audit_set_rate_limit(s.rate_limit); @@ -1454,10 +1499,11 @@ static void __net_exit audit_net_exit(struct net *net) { struct audit_net *aunet = net_generic(net, audit_net_id); - rcu_read_lock(); - if (net == auditd_conn.net) - auditd_reset(); - rcu_read_unlock(); + /* NOTE: you would think that we would want to check the auditd + * connection and potentially reset it here if it lives in this + * namespace, but since the auditd connection tracking struct holds a + * reference to this namespace (see auditd_set()) we are only ever + * going to get here after that connection has been released */ netlink_kernel_release(aunet->sk); } @@ -1481,9 +1527,6 @@ static int __init audit_init(void) sizeof(struct audit_buffer), 0, SLAB_PANIC, NULL); - memset(&auditd_conn, 0, sizeof(auditd_conn)); - spin_lock_init(&auditd_conn.lock); - skb_queue_head_init(&audit_queue); skb_queue_head_init(&audit_retry_queue); skb_queue_head_init(&audit_hold_queue);