Implement depth tracking for KStackWatch to support stack-level filtering.
Each task's recursive entry depth is stored in a global hash table keyed by
pid:

 - get_recursive_depth()/set_recursive_depth() manage per-task depth
 - reset_recursive_depth() clears all tracked entries
 - entry/exit handlers increment or decrement depth and skip  if the
   current depth does not match the configured depth.

This works even across task scheduling or in interrupt context, since depth
is tracked per-task, ensuring KStackWatch can selectively monitor a
specific recursion level without redundant triggers.

Signed-off-by: Jinchao Wang <wangjinchao...@gmail.com>
---
 mm/kstackwatch/stack.c | 105 ++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 103 insertions(+), 2 deletions(-)

diff --git a/mm/kstackwatch/stack.c b/mm/kstackwatch/stack.c
index 00cb38085a9f..8758b8b94f7f 100644
--- a/mm/kstackwatch/stack.c
+++ b/mm/kstackwatch/stack.c
@@ -1,6 +1,8 @@
 // SPDX-License-Identifier: GPL-2.0
 
 #include <linux/fprobe.h>
+#include <linux/hashtable.h>
+#include <linux/hash.h>
 #include <linux/interrupt.h>
 #include <linux/kprobes.h>
 #include <linux/percpu.h>
@@ -12,6 +14,81 @@
 
 struct ksw_config *probe_config;
 
+#define DEPTH_HASH_BITS 8
+#define DEPTH_HASH_SIZE BIT(DEPTH_HASH_BITS)
+
+struct depth_entry {
+       pid_t pid;
+       int depth; /* starts from 0 */
+       struct hlist_node node;
+};
+
+static DEFINE_HASHTABLE(depth_hash, DEPTH_HASH_BITS);
+static DEFINE_SPINLOCK(depth_hash_lock);
+
+static int get_recursive_depth(void)
+{
+       struct depth_entry *entry;
+       pid_t pid = current->pid;
+       int depth = 0;
+
+       spin_lock(&depth_hash_lock);
+       hash_for_each_possible(depth_hash, entry, node,
+                              hash_32(pid, DEPTH_HASH_BITS)) {
+               if (entry->pid == pid) {
+                       depth = entry->depth;
+                       break;
+               }
+       }
+       spin_unlock(&depth_hash_lock);
+       return depth;
+}
+
+static void set_recursive_depth(int depth)
+{
+       struct depth_entry *entry;
+       pid_t pid = current->pid;
+       bool found = false;
+
+       spin_lock(&depth_hash_lock);
+       hash_for_each_possible(depth_hash, entry, node,
+                              hash_32(pid, DEPTH_HASH_BITS)) {
+               if (entry->pid == pid) {
+                       entry->depth = depth;
+                       found = true;
+                       break;
+               }
+       }
+
+       if (!found && depth > 0) {
+               entry = kmalloc(sizeof(*entry), GFP_ATOMIC);
+               if (entry) {
+                       entry->pid = pid;
+                       entry->depth = depth;
+                       hash_add(depth_hash, &entry->node,
+                                hash_32(pid, DEPTH_HASH_BITS));
+               }
+       } else if (found && depth == 0) {
+               hash_del(&entry->node);
+               kfree(entry);
+       }
+       spin_unlock(&depth_hash_lock);
+}
+
+static void reset_recursive_depth(void)
+{
+       struct depth_entry *entry;
+       struct hlist_node *tmp;
+       int bkt;
+
+       spin_lock(&depth_hash_lock);
+       hash_for_each_safe(depth_hash, bkt, tmp, entry, node) {
+               hash_del(&entry->node);
+               kfree(entry);
+       }
+       spin_unlock(&depth_hash_lock);
+}
+
 /* Find canary address in current stack frame */
 static unsigned long ksw_stack_find_canary(struct pt_regs *regs)
 {
@@ -119,10 +196,21 @@ static struct fprobe exit_probe_fprobe;
 static void ksw_stack_entry_handler(struct kprobe *p, struct pt_regs *regs,
                                    unsigned long flags)
 {
+       int cur_depth;
        int ret;
        u64 watch_addr;
        u64 watch_len;
 
+       cur_depth = get_recursive_depth();
+       set_recursive_depth(cur_depth + 1);
+
+       /* depth start from 0 */
+       if (cur_depth != probe_config->depth) {
+               pr_info("KSW: config_depth:%u cur_depth:%d entry skipping\n",
+                       probe_config->depth, cur_depth);
+               return;
+       }
+
        ret = ksw_stack_prepare_watch(regs, probe_config, &watch_addr,
                                      &watch_len);
        if (ret) {
@@ -132,8 +220,8 @@ static void ksw_stack_entry_handler(struct kprobe *p, 
struct pt_regs *regs,
 
        ret = ksw_watch_on(watch_addr, watch_len);
        if (ret) {
-               pr_err("KSW: failed to watch on addr:0x%llx len:%llx %d\n",
-                      watch_addr, watch_len, ret);
+               pr_err("KSW: failed to watch on depth:%d addr:0x%llx len:%llx 
%d\n",
+                      cur_depth, watch_addr, watch_len, ret);
                return;
        }
 }
@@ -142,6 +230,17 @@ static void ksw_stack_exit_handler(struct fprobe *fp, 
unsigned long ip,
                                   unsigned long ret_ip,
                                   struct ftrace_regs *regs, void *data)
 {
+       int cur_depth;
+
+       cur_depth = get_recursive_depth() - 1;
+       set_recursive_depth(cur_depth);
+
+       if (cur_depth != probe_config->depth) {
+               pr_info("KSW: config_depth:%u cur_depth:%d exit skipping\n",
+                       probe_config->depth, cur_depth);
+               return;
+       }
+
        ksw_watch_off();
 }
 
@@ -150,6 +249,8 @@ int ksw_stack_init(struct ksw_config *config)
        int ret;
        char *symbuf = NULL;
 
+       reset_recursive_depth();
+
        /* Setup entry probe */
        memset(&entry_probe, 0, sizeof(entry_probe));
        entry_probe.symbol_name = config->function;
-- 
2.43.0


Reply via email to