From: Zhang Cen <[email protected]>

SK_MSG uses msg->sg.copy as per-scatterlist-entry provenance. Entries
with this bit set are copied before data/data_end are exposed to SK_MSG
BPF programs for direct packet access.

bpf_msg_pull_data(), bpf_msg_push_data(), and bpf_msg_pop_data()
rewrite the sk_msg scatterlist ring by collapsing, splitting, and
shifting entries. These operations move msg->sg.data[] entries, but the
parallel copy bitmap can be left behind on the old slot. A copied entry
can then return to msg->sg.start with its copy bit clear and be exposed
as directly writable packet data.

This corruption path requires an attached SK_MSG BPF program that calls
the mutating helpers; ordinary sockmap/TLS traffic that never runs
push/pop/pull helper sequences is not affected.

Keep msg->sg.copy synchronized with scatterlist entry moves, preserve
the copy bit when an entry is split, clear it when a helper replaces an
entry with a private page, and clear slots vacated by pull-data
compaction.

Fixes: 015632bb30da ("bpf: sk_msg program helper bpf_sk_msg_pull_data")
Fixes: 6fff607e2f14 ("bpf: sk_msg program helper bpf_msg_push_data")
Fixes: 7246d8ed4dcc ("bpf: helper to pop data from messages")
Cc: [email protected]
Co-developed-by: Han Guidong <[email protected]>
Reviewed-by: John Fastabend <[email protected]>
Reviewed-by: Emil Tsalapatis <[email protected]>
Signed-off-by: Han Guidong <[email protected]>
Signed-off-by: Zhang Cen <[email protected]>
Signed-off-by: Jiayuan Chen <[email protected]>
---
 net/core/filter.c | 88 ++++++++++++++++++++++++++++++++++++++++++++---
 1 file changed, 83 insertions(+), 5 deletions(-)

diff --git a/net/core/filter.c b/net/core/filter.c
index 6e345ca65ca14..643411e292ce5 100644
--- a/net/core/filter.c
+++ b/net/core/filter.c
@@ -2654,6 +2654,38 @@ static void sk_msg_reset_curr(struct sk_msg *msg)
        }
 }
 
+static bool sk_msg_elem_is_copy(const struct sk_msg *msg, u32 i)
+{
+       return test_bit(i, msg->sg.copy);
+}
+
+static void sk_msg_clear_elem_copy(struct sk_msg *msg, u32 i)
+{
+       __clear_bit(i, msg->sg.copy);
+}
+
+static void sk_msg_set_elem_copy(struct sk_msg *msg, u32 i)
+{
+       __set_bit(i, msg->sg.copy);
+}
+
+static void sk_msg_clear_copy_range(struct sk_msg *msg, u32 start, u32 end)
+{
+       while (start != end) {
+               sk_msg_clear_elem_copy(msg, start);
+               sk_msg_iter_var_next(start);
+       }
+}
+
+static void sk_msg_sg_move(struct sk_msg *msg, u32 dst, u32 src)
+{
+       msg->sg.data[dst] = msg->sg.data[src];
+       if (sk_msg_elem_is_copy(msg, src))
+               sk_msg_set_elem_copy(msg, dst);
+       else
+               sk_msg_clear_elem_copy(msg, dst);
+}
+
 static const struct bpf_func_proto bpf_msg_cork_bytes_proto = {
        .func           = bpf_msg_cork_bytes,
        .gpl_only       = false,
@@ -2692,7 +2724,7 @@ BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, 
start,
         * account for the headroom.
         */
        bytes_sg_total = start - offset + bytes;
-       if (!test_bit(i, msg->sg.copy) && bytes_sg_total <= len)
+       if (!sk_msg_elem_is_copy(msg, i) && bytes_sg_total <= len)
                goto out;
 
        /* At this point we need to linearize multiple scatterlist
@@ -2738,6 +2770,7 @@ BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, 
start,
        } while (i != last_sge);
 
        sg_set_page(&msg->sg.data[first_sge], page, copy, 0);
+       sk_msg_clear_elem_copy(msg, first_sge);
 
        /* To repair sg ring we need to shift entries. If we only
         * had a single entry though we can just replace it and
@@ -2747,8 +2780,14 @@ BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, 
start,
        shift = last_sge > first_sge ?
                last_sge - first_sge - 1 :
                NR_MSG_FRAG_IDS - first_sge + last_sge - 1;
-       if (!shift)
+       if (!shift) {
+               sk_msg_clear_elem_copy(msg, msg->sg.end);
                goto out;
+       }
+
+       i = first_sge;
+       sk_msg_iter_var_next(i);
+       sk_msg_clear_copy_range(msg, i, last_sge);
 
        i = first_sge;
        sk_msg_iter_var_next(i);
@@ -2762,16 +2801,18 @@ BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, 
u32, start,
                if (move_from == msg->sg.end)
                        break;
 
-               msg->sg.data[i] = msg->sg.data[move_from];
+               sk_msg_sg_move(msg, i, move_from);
                msg->sg.data[move_from].length = 0;
                msg->sg.data[move_from].page_link = 0;
                msg->sg.data[move_from].offset = 0;
+               sk_msg_clear_elem_copy(msg, move_from);
                sk_msg_iter_var_next(i);
        } while (1);
 
        msg->sg.end = msg->sg.end - shift > msg->sg.end ?
                      msg->sg.end - shift + NR_MSG_FRAG_IDS :
                      msg->sg.end - shift;
+       sk_msg_clear_elem_copy(msg, msg->sg.end);
 out:
        sk_msg_reset_curr(msg);
        msg->data = sg_virt(&msg->sg.data[first_sge]) + start - offset;
@@ -2792,8 +2833,10 @@ static const struct bpf_func_proto 
bpf_msg_pull_data_proto = {
 BPF_CALL_4(bpf_msg_push_data, struct sk_msg *, msg, u32, start,
           u32, len, u64, flags)
 {
+       bool sge_copy = false, nsge_copy = false, nnsge_copy = false;
        struct scatterlist sge, nsge, nnsge, rsge = {0}, *psge;
        u32 new, i = 0, l = 0, space, copy = 0, offset = 0;
+       bool rsge_copy = false;
        u8 *raw, *to, *from;
        struct page *page;
 
@@ -2869,6 +2912,7 @@ BPF_CALL_4(bpf_msg_push_data, struct sk_msg *, msg, u32, 
start,
                        sk_msg_iter_var_prev(i);
                psge = sk_msg_elem(msg, i);
                rsge = sk_msg_elem_cpy(msg, i);
+               rsge_copy = sk_msg_elem_is_copy(msg, i);
 
                psge->length = start - offset;
                rsge.length -= psge->length;
@@ -2894,23 +2938,34 @@ BPF_CALL_4(bpf_msg_push_data, struct sk_msg *, msg, 
u32, start,
        /* Shift one or two slots as needed */
        sge = sk_msg_elem_cpy(msg, new);
        sg_unmark_end(&sge);
+       sge_copy = sk_msg_elem_is_copy(msg, new);
 
        nsge = sk_msg_elem_cpy(msg, i);
+       nsge_copy = sk_msg_elem_is_copy(msg, i);
        if (rsge.length) {
                sk_msg_iter_var_next(i);
                nnsge = sk_msg_elem_cpy(msg, i);
+               nnsge_copy = sk_msg_elem_is_copy(msg, i);
                sk_msg_iter_next(msg, end);
        }
 
        while (i != msg->sg.end) {
                msg->sg.data[i] = sge;
+               if (sge_copy)
+                       sk_msg_set_elem_copy(msg, i);
+               else
+                       sk_msg_clear_elem_copy(msg, i);
                sge = nsge;
+               sge_copy = nsge_copy;
                sk_msg_iter_var_next(i);
                if (rsge.length) {
                        nsge = nnsge;
+                       nsge_copy = nnsge_copy;
                        nnsge = sk_msg_elem_cpy(msg, i);
+                       nnsge_copy = sk_msg_elem_is_copy(msg, i);
                } else {
                        nsge = sk_msg_elem_cpy(msg, i);
+                       nsge_copy = sk_msg_elem_is_copy(msg, i);
                }
        }
 
@@ -2918,13 +2973,18 @@ BPF_CALL_4(bpf_msg_push_data, struct sk_msg *, msg, 
u32, start,
        /* Place newly allocated data buffer */
        sk_mem_charge(msg->sk, len);
        msg->sg.size += len;
-       __clear_bit(new, msg->sg.copy);
+       sk_msg_clear_elem_copy(msg, new);
        sg_set_page(&msg->sg.data[new], page, len + copy, 0);
        if (rsge.length) {
                get_page(sg_page(&rsge));
                sk_msg_iter_var_next(new);
                msg->sg.data[new] = rsge;
+               if (rsge_copy)
+                       sk_msg_set_elem_copy(msg, new);
+               else
+                       sk_msg_clear_elem_copy(msg, new);
        }
+       sk_msg_clear_elem_copy(msg, msg->sg.end);
 
        sk_msg_reset_curr(msg);
        sk_msg_compute_data_pointers(msg);
@@ -2950,27 +3010,38 @@ static void sk_msg_shift_left(struct sk_msg *msg, int i)
        do {
                prev = i;
                sk_msg_iter_var_next(i);
-               msg->sg.data[prev] = msg->sg.data[i];
+               sk_msg_sg_move(msg, prev, i);
        } while (i != msg->sg.end);
 
        sk_msg_iter_prev(msg, end);
+       sk_msg_clear_elem_copy(msg, msg->sg.end);
 }
 
 static void sk_msg_shift_right(struct sk_msg *msg, int i)
 {
        struct scatterlist tmp, sge;
+       bool tmp_copy, sge_copy;
 
        sk_msg_iter_next(msg, end);
        sge = sk_msg_elem_cpy(msg, i);
+       sge_copy = sk_msg_elem_is_copy(msg, i);
        sk_msg_iter_var_next(i);
        tmp = sk_msg_elem_cpy(msg, i);
+       tmp_copy = sk_msg_elem_is_copy(msg, i);
 
        while (i != msg->sg.end) {
                msg->sg.data[i] = sge;
+               if (sge_copy)
+                       sk_msg_set_elem_copy(msg, i);
+               else
+                       sk_msg_clear_elem_copy(msg, i);
                sk_msg_iter_var_next(i);
                sge = tmp;
+               sge_copy = tmp_copy;
                tmp = sk_msg_elem_cpy(msg, i);
+               tmp_copy = sk_msg_elem_is_copy(msg, i);
        }
+       sk_msg_clear_elem_copy(msg, msg->sg.end);
 }
 
 BPF_CALL_4(bpf_msg_pop_data, struct sk_msg *, msg, u32, start,
@@ -3027,8 +3098,10 @@ BPF_CALL_4(bpf_msg_pop_data, struct sk_msg *, msg, u32, 
start,
         */
        if (start != offset) {
                struct scatterlist *nsge, *sge = sk_msg_elem(msg, i);
+               bool sge_copy = sk_msg_elem_is_copy(msg, i);
                int a = start - offset;
                int b = sge->length - pop - a;
+               u32 sge_idx = i;
 
                sk_msg_iter_var_next(i);
 
@@ -3041,6 +3114,10 @@ BPF_CALL_4(bpf_msg_pop_data, struct sk_msg *, msg, u32, 
start,
                                sg_set_page(nsge,
                                            sg_page(sge),
                                            b, sge->offset + pop + a);
+                               if (sge_copy)
+                                       sk_msg_set_elem_copy(msg, i);
+                               else
+                                       sk_msg_clear_elem_copy(msg, i);
                        } else {
                                struct page *page, *orig;
                                u8 *to, *from;
@@ -3057,6 +3134,7 @@ BPF_CALL_4(bpf_msg_pop_data, struct sk_msg *, msg, u32, 
start,
                                memcpy(to, from, a);
                                memcpy(to + a, from + a + pop, b);
                                sg_set_page(sge, page, a + b, 0);
+                               sk_msg_clear_elem_copy(msg, sge_idx);
                                put_page(orig);
                        }
                        pop = 0;
-- 
2.43.0


Reply via email to