Act ct will send packets to conntrack on a specific zone,
and when commiting a connection, a ct label and ct mark can be set.

Signed-off-by: Paul Blakey <pa...@mellanox.com>
---
 include/net/tc_act/tc_ct.h        |  37 +++
 include/uapi/linux/tc_act/tc_ct.h |  29 +++
 net/sched/Kconfig                 |  11 +
 net/sched/Makefile                |   1 +
 net/sched/act_ct.c                | 465 ++++++++++++++++++++++++++++++++++++++
 5 files changed, 543 insertions(+)
 create mode 100644 include/net/tc_act/tc_ct.h
 create mode 100644 include/uapi/linux/tc_act/tc_ct.h
 create mode 100644 net/sched/act_ct.c

diff --git a/include/net/tc_act/tc_ct.h b/include/net/tc_act/tc_ct.h
new file mode 100644
index 0000000..4a16375
--- /dev/null
+++ b/include/net/tc_act/tc_ct.h
@@ -0,0 +1,37 @@
+/* SPDX-License-Identifier: GPL-2.0 */
+#ifndef __NET_TC_CT_H
+#define __NET_TC_CT_H
+
+#include <net/act_api.h>
+#include <uapi/linux/tc_act/tc_ct.h>
+
+#define TCA_ACT_CT_LABEL_SIZE 4
+struct tcf_ct {
+       struct tc_action common;
+       struct net *net;
+       struct nf_conn *tmpl;
+       u32 labels[TCA_ACT_CT_LABEL_SIZE];
+       u32 labels_mask[TCA_ACT_CT_LABEL_SIZE];
+       u32 mark;
+       u32 mark_mask;
+       u16 zone;
+       bool commit;
+};
+
+#define to_ct(a) ((struct tcf_ct *)a)
+
+static inline bool is_tcf_ct(const struct tc_action *a)
+{
+#ifdef CONFIG_NET_CLS_ACT
+       if (a->ops && a->ops->type == TCA_ACT_CT)
+               return true;
+#endif
+       return false;
+}
+
+static inline struct tcf_ct *tcf_ct(const struct tc_action *a)
+{
+       return to_ct(a);
+}
+
+#endif /* __NET_TC_CT_H */
diff --git a/include/uapi/linux/tc_act/tc_ct.h 
b/include/uapi/linux/tc_act/tc_ct.h
new file mode 100644
index 0000000..6dbd771
--- /dev/null
+++ b/include/uapi/linux/tc_act/tc_ct.h
@@ -0,0 +1,29 @@
+/* SPDX-License-Identifier: GPL-2.0 WITH Linux-syscall-note */
+#ifndef __UAPI_TC_CT_H
+#define __UAPI_TC_CT_H
+
+#include <linux/types.h>
+#include <linux/pkt_cls.h>
+
+#define TCA_ACT_CT 18
+
+struct tc_ct {
+       tc_gen;
+       __u16 zone;
+       __u32 labels[4];
+       __u32 labels_mask[4];
+       __u32 mark;
+       __u32 mark_mask;
+       bool commit;
+};
+
+enum {
+       TCA_CT_UNSPEC,
+       TCA_CT_PARMS,
+       TCA_CT_TM,
+       TCA_CT_PAD,
+       __TCA_CT_MAX
+};
+#define TCA_CT_MAX (__TCA_CT_MAX - 1)
+
+#endif /* __UAPI_TC_CT_H */
diff --git a/net/sched/Kconfig b/net/sched/Kconfig
index 1b9afde..935a327 100644
--- a/net/sched/Kconfig
+++ b/net/sched/Kconfig
@@ -912,6 +912,17 @@ config NET_ACT_TUNNEL_KEY
          To compile this code as a module, choose M here: the
          module will be called act_tunnel_key.
 
+config NET_ACT_CT
+        tristate "connection tracking action"
+        depends on NET_CLS_ACT
+        ---help---
+         Say Y here to allow sending the packets to conntrack module
+
+         If unsure, say N.
+
+         To compile this code as a module, choose M here: the
+         module will be called act_ct.
+
 config NET_IFE_SKBMARK
         tristate "Support to encoding decoding skb mark on IFE action"
         depends on NET_ACT_IFE
diff --git a/net/sched/Makefile b/net/sched/Makefile
index 8a40431..c0a02de 100644
--- a/net/sched/Makefile
+++ b/net/sched/Makefile
@@ -27,6 +27,7 @@ obj-$(CONFIG_NET_IFE_SKBMARK) += act_meta_mark.o
 obj-$(CONFIG_NET_IFE_SKBPRIO)  += act_meta_skbprio.o
 obj-$(CONFIG_NET_IFE_SKBTCINDEX)       += act_meta_skbtcindex.o
 obj-$(CONFIG_NET_ACT_TUNNEL_KEY)+= act_tunnel_key.o
+obj-$(CONFIG_NET_ACT_CT)       += act_ct.o
 obj-$(CONFIG_NET_SCH_FIFO)     += sch_fifo.o
 obj-$(CONFIG_NET_SCH_CBQ)      += sch_cbq.o
 obj-$(CONFIG_NET_SCH_HTB)      += sch_htb.o
diff --git a/net/sched/act_ct.c b/net/sched/act_ct.c
new file mode 100644
index 0000000..61155cc
--- /dev/null
+++ b/net/sched/act_ct.c
@@ -0,0 +1,465 @@
+/*
+ * net/sched/act_ct.c  connection tracking action
+ *
+ * Copyright (c) 2018
+ *
+ * Authors:    Yossi Kuperman <yoss...@mellanox.com>
+ *             Paul Blakey <pa...@mellanox.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 2 of the License, or
+ * (at your option) any later version.
+*/
+
+#include <linux/module.h>
+#include <linux/init.h>
+#include <linux/kernel.h>
+#include <linux/skbuff.h>
+#include <linux/rtnetlink.h>
+#include <linux/pkt_cls.h>
+#include <linux/ip.h>
+#include <linux/ipv6.h>
+#include <net/netlink.h>
+#include <net/pkt_sched.h>
+#include <net/act_api.h>
+#include <uapi/linux/tc_act/tc_ct.h>
+#include <net/tc_act/tc_ct.h>
+
+#include <net/netfilter/nf_conntrack.h>
+#include <net/netfilter/nf_conntrack_core.h>
+#include <net/netfilter/nf_conntrack_zones.h>
+#include <net/netfilter/nf_conntrack_helper.h>
+#include <net/netfilter/nf_conntrack_labels.h>
+
+#include <net/pkt_cls.h>
+
+static unsigned int ct_net_id;
+static struct tc_action_ops act_ct_ops;
+
+/* Determine whether skb->_nfct is equal to the result of conntrack lookup. */
+static bool skb_nfct_cached(struct net *net, struct sk_buff *skb, u16 zone_id)
+{
+       enum ip_conntrack_info ctinfo;
+       struct nf_conn *ct;
+
+       ct = nf_ct_get(skb, &ctinfo);
+       if (!ct)
+               return false;
+       if (!net_eq(net, read_pnet(&ct->ct_net)))
+               return false;
+       if (nf_ct_zone(ct)->id != zone_id)
+               return false;
+       return true;
+}
+
+/* Trim the skb to the length specified by the IP/IPv6 header,
+ * removing any trailing lower-layer padding. This prepares the skb
+ * for higher-layer processing that assumes skb->len excludes padding
+ * (such as nf_ip_checksum). The caller needs to pull the skb to the
+ * network header, and ensure ip_hdr/ipv6_hdr points to valid data.
+ */
+static int tcf_skb_network_trim(struct sk_buff *skb)
+{
+       unsigned int len;
+       int err;
+
+       switch (skb->protocol) {
+       case htons(ETH_P_IP):
+               len = ntohs(ip_hdr(skb)->tot_len);
+               break;
+       case htons(ETH_P_IPV6):
+               len = sizeof(struct ipv6hdr)
+                       + ntohs(ipv6_hdr(skb)->payload_len);
+               break;
+       default:
+               len = skb->len;
+       }
+
+       err = pskb_trim_rcsum(skb, len);
+
+       return err;
+}
+
+static u_int8_t tcf_skb_family(struct sk_buff *skb)
+{
+       u_int8_t family = PF_UNSPEC;
+
+       switch (skb->protocol) {
+       case htons(ETH_P_IP):
+               family = PF_INET;
+               break;
+       case htons(ETH_P_IPV6):
+               family = PF_INET6;
+               break;
+       default:
+        break;
+       }
+
+       return family;
+}
+
+static bool labels_nonzero(const u32 *labels_mask)
+{
+       return !!memchr_inv(labels_mask, 0, 4);
+}
+
+static int tcf_ct_act(struct sk_buff *skb, const struct tc_action *a,
+                     struct tcf_result *res)
+{
+       struct net *net = dev_net(skb->dev);
+       struct tcf_ct *ca = to_ct(a);
+       enum ip_conntrack_info ctinfo;
+       struct nf_conn *tmpl = NULL;
+       struct nf_hook_state state;
+       struct nf_conn *ct;
+       int nh_ofs, err;
+       u_int8_t family;
+       bool cached;
+
+       /* The conntrack module expects to be working at L3. */
+       nh_ofs = skb_network_offset(skb);
+       skb_pull_rcsum(skb, nh_ofs);
+
+       err = tcf_skb_network_trim(skb);
+       if (err)
+               goto drop;
+
+       family = tcf_skb_family(skb);
+       if (family == PF_UNSPEC)
+               goto drop;
+
+       state.hook = NF_INET_PRE_ROUTING,
+       state.net = net,
+       state.pf = family;
+
+       spin_lock(&ca->tcf_lock);
+       tcf_lastuse_update(&ca->tcf_tm);
+       bstats_update(&ca->tcf_bstats, skb);
+       tmpl = ca->tmpl;
+
+       /* If we are recirculating packets to match on ct fields and
+        * committing with a separate ct action, then we don't need to
+        * actually run the packet through conntrack twice unless it's for a
+        * different zone. */
+       cached = skb_nfct_cached(net, skb, ca->zone);
+       if (!cached) {
+               /* Associate skb with specified zone. */
+               if (tmpl) {
+                       if (skb_nfct(skb))
+                               nf_conntrack_put(skb_nfct(skb));
+                       nf_conntrack_get(&tmpl->ct_general);
+                       nf_ct_set(skb, tmpl, IP_CT_NEW);
+               }
+
+               err = nf_conntrack_in(skb, &state);
+               if (err != NF_ACCEPT)
+                       goto out;
+       }
+
+       ct = nf_ct_get(skb, &ctinfo);
+       if (!ct)
+               goto out;
+       nf_ct_deliver_cached_events(ct);
+
+       if (ca->commit) {
+               u32 *labels = ca->labels;
+               u32 *labels_m = ca->labels_mask;
+
+#if IS_ENABLED(CONFIG_NF_CONNTRACK_MARK)
+               if (ca->mark_mask) {
+                       u32 ct_mark = ca->mark;
+                       u32 mask = ca->mark_mask;
+                       u32 new_mark;
+
+                       new_mark = ct_mark | (ct->mark & ~(mask));
+                       if (ct->mark != new_mark) {
+                               ct->mark = new_mark;
+                               if (nf_ct_is_confirmed(ct))
+                                       nf_conntrack_event_cache(IPCT_MARK, ct);
+                       }
+               }
+#endif
+               if (!nf_ct_is_confirmed(ct)) {
+                       bool have_mask = labels_nonzero(labels_m);
+                       struct nf_conn_labels *cl, *master_cl;
+
+                       /* Inherit master's labels to the related connection? */
+                       master_cl = ct->master ? nf_ct_labels_find(ct->master) 
: NULL;
+
+                       if (!master_cl && !have_mask)
+                               goto skip; /* Nothing to do. */
+
+                       /* Get labels or add them */
+                       cl = nf_ct_labels_find(ct);
+                       if (!cl) {
+                               nf_ct_labels_ext_add(ct);
+                               cl = nf_ct_labels_find(ct);
+                       }
+                       if (!cl)
+                               goto out;
+
+                       /* Inherit the master's labels, if any.  Must use 
memcpy for backport
+                        * as struct assignment only copies the length field in 
older
+                        * kernels.
+                       */
+                       if (master_cl)
+                               memcpy(cl->bits, master_cl->bits, 
NF_CT_LABELS_MAX_SIZE);
+
+                       if (have_mask) {
+                               u32 *dst = (u32 *)cl->bits;
+                               int i;
+
+                               for (i = 0; i < 4; i++)
+                                       dst[i] = (dst[i] & ~labels_m[i]) | 
(labels[i] & labels_m[i]);
+                       }
+
+                       /* Labels are included in the IPCTNL_MSG_CT_NEW event 
only if the
+                        * IPCT_LABEL bit is set in the event cache.
+                        */
+                       nf_conntrack_event_cache(IPCT_LABEL, ct);
+               } else if (labels_nonzero(labels_m)) {
+                       struct nf_conn_labels *cl;
+
+                       cl = nf_ct_labels_find(ct);
+                       if (!cl) {
+                               nf_ct_labels_ext_add(ct);
+                               cl = nf_ct_labels_find(ct);
+                       }
+
+                       if (!cl)
+                               goto out;
+
+                       nf_connlabels_replace(ct, ca->labels, ca->labels_mask, 
4);
+               }
+skip:
+               /* This will take care of sending queued events even if the 
connection
+                * is already confirmed. */
+               nf_conntrack_confirm(skb);
+       }
+
+out:
+       skb_push(skb, nh_ofs);
+       skb_postpush_rcsum(skb, skb->data, nh_ofs);
+
+       spin_unlock(&ca->tcf_lock);
+       return ca->tcf_action;
+
+drop:
+       spin_lock(&ca->tcf_lock);
+       ca->tcf_qstats.drops++;
+       spin_unlock(&ca->tcf_lock);
+       return TC_ACT_SHOT;
+}
+
+static const struct nla_policy ct_policy[TCA_CT_MAX + 1] = {
+       [TCA_CT_PARMS] = { .len = sizeof(struct tc_ct) },
+};
+
+static int tcf_ct_init(struct net *net, struct nlattr *nla,
+                      struct nlattr *est, struct tc_action **a,
+                      int ovr, int bind, bool rtnl_held,
+                      struct netlink_ext_ack *extack)
+{
+       struct tc_action_net *tn = net_generic(net, ct_net_id);
+       struct nlattr *tb[TCA_CT_MAX + 1];
+       struct nf_conntrack_zone zone;
+       struct nf_conn *tmpl = NULL;
+       bool exists = false;
+       struct tc_ct *parm;
+       struct tcf_ct *ci;
+       int ret, err;
+
+       if (!nla) {
+               NL_SET_ERR_MSG_MOD(extack, "Ct requires attributes to be 
passed");
+               return -EINVAL;
+       }
+
+       ret = nla_parse_nested(tb, TCA_CT_MAX, nla, ct_policy, extack);
+       if (ret < 0)
+               return ret;
+
+       if (!tb[TCA_CT_PARMS]) {
+               NL_SET_ERR_MSG_MOD(extack, "Missing required ct parameters");
+               return -EINVAL;
+       }
+
+       parm = nla_data(tb[TCA_CT_PARMS]);
+
+       err = tcf_idr_check_alloc(tn, &parm->index, a, bind);
+       if (err < 0)
+               return err;
+       exists = err;
+       if (exists && bind)
+               return 0;
+
+       if (!exists) {
+               ret = tcf_idr_create(tn, parm->index, est, a, &act_ct_ops, 
bind, false);
+               if (ret) {
+                       tcf_idr_cleanup(tn, parm->index);
+                       return ret;
+               }
+
+               ci = to_ct(*a);
+               ci->tcf_action = parm->action;
+               ci->net = net;
+               ci->commit = parm->commit;
+               ci->zone = parm->zone;
+#if !IS_ENABLED(CONFIG_NF_CONNTRACK_MARK)
+               if (parm->mark_mask) {
+                       NL_SET_ERR_MSG_MOD(extack, "Mark not supported by 
kernel config");
+                       return -EOPNOTSUPP;
+               }
+#endif
+#if !IS_ENABLED(CONFIG_NF_CONNTRACK_LABELS)
+               if (labels_nonzero(parm->labels_mask)) {
+                       NL_SET_ERR_MSG_MOD(extack, "Labels not supported by 
kernel config");
+                       return -EOPNOTSUPP;
+               }
+#endif
+               if (parm->zone != NF_CT_DEFAULT_ZONE_ID) {
+                       nf_ct_zone_init(&zone, parm->zone,
+                                       NF_CT_DEFAULT_ZONE_DIR, 0);
+
+                       tmpl = nf_ct_tmpl_alloc(net, &zone, GFP_ATOMIC);
+                       if (!tmpl) {
+                               NL_SET_ERR_MSG_MOD(extack, "Failed to allocate 
conntrack template");
+                               tcf_idr_cleanup(tn, parm->index);
+                               return -ENOMEM;
+                       }
+                       __set_bit(IPS_CONFIRMED_BIT, &tmpl->status);
+                       nf_conntrack_get(&tmpl->ct_general);
+               }
+
+               ci->tmpl = tmpl;
+               ci->mark = parm->mark;
+               ci->mark_mask = parm->mark_mask;
+               memcpy(ci->labels, parm->labels, sizeof(parm->labels));
+               memcpy(ci->labels_mask, parm->labels_mask, 
sizeof(parm->labels_mask));
+
+               tcf_idr_insert(tn, *a);
+               ret = ACT_P_CREATED;
+       } else {
+               /* TODO: handle replace */
+               NL_SET_ERR_MSG_MOD(extack, "Ct can only be created");
+               tcf_idr_cleanup(tn, parm->index);
+               return -EOPNOTSUPP;
+       }
+
+       return ret;
+}
+
+static void tcf_ct_release(struct tc_action *a)
+{
+       struct tcf_ct *ci = to_ct(a);
+
+       if (ci->tmpl)
+               nf_conntrack_put(&ci->tmpl->ct_general);
+}
+
+static inline int tcf_ct_dump(struct sk_buff *skb, struct tc_action *a,
+                                   int bind, int ref)
+{
+       unsigned char *b = skb_tail_pointer(skb);
+       struct tcf_ct *ci = to_ct(a);
+
+       struct tc_ct opt = {
+               .index   = ci->tcf_index,
+               .refcnt  = refcount_read(&ci->tcf_refcnt) - ref,
+               .bindcnt = atomic_read(&ci->tcf_bindcnt) - bind,
+       };
+       struct tcf_t t;
+
+       spin_lock_bh(&ci->tcf_lock);
+       opt.action  = ci->tcf_action,
+       opt.zone   = ci->zone,
+       opt.commit = ci->commit,
+       opt.mark = ci->mark,
+       opt.mark_mask = ci->mark_mask,
+       memcpy(opt.labels, ci->labels, sizeof(opt.labels));
+       memcpy(opt.labels_mask, ci->labels_mask, sizeof(opt.labels_mask));
+
+       if (nla_put(skb, TCA_CT_PARMS, sizeof(opt), &opt))
+               goto nla_put_failure;
+
+       tcf_tm_dump(&t, &ci->tcf_tm);
+       if (nla_put_64bit(skb, TCA_CT_TM, sizeof(t), &t, TCA_CT_PAD))
+               goto nla_put_failure;
+       spin_unlock_bh(&ci->tcf_lock);
+
+       return skb->len;
+nla_put_failure:
+       spin_unlock_bh(&ci->tcf_lock);
+       nlmsg_trim(skb, b);
+       return -1;
+}
+
+static int tcf_ct_walker(struct net *net, struct sk_buff *skb,
+                        struct netlink_callback *cb, int type,
+                        const struct tc_action_ops *ops,
+                        struct netlink_ext_ack *extack)
+{
+       struct tc_action_net *tn = net_generic(net, ct_net_id);
+
+       return tcf_generic_walker(tn, skb, cb, type, ops, extack);
+}
+
+static int tcf_ct_search(struct net *net, struct tc_action **a, u32 index)
+{
+       struct tc_action_net *tn = net_generic(net, ct_net_id);
+
+       return tcf_idr_search(tn, a, index);
+}
+
+static struct tc_action_ops act_ct_ops = {
+       .kind           =       "ct",
+       .type           =       TCA_ACT_CT,
+       .owner          =       THIS_MODULE,
+       .act            =       tcf_ct_act,
+       .dump           =       tcf_ct_dump,
+       .init           =       tcf_ct_init,
+       .cleanup        =       tcf_ct_release,
+       .walk           =       tcf_ct_walker,
+       .lookup         =       tcf_ct_search,
+       .size           =       sizeof(struct tcf_ct),
+};
+
+static __net_init int ct_init_net(struct net *net)
+{
+       struct tc_action_net *tn = net_generic(net, ct_net_id);
+
+       return tc_action_net_init(tn, &act_ct_ops);
+}
+
+static void __net_exit ct_exit_net(struct list_head *net_list)
+{
+       tc_action_net_exit(net_list, ct_net_id);
+}
+
+static struct pernet_operations ct_net_ops = {
+       .init = ct_init_net,
+       .exit_batch = ct_exit_net,
+       .id   = &ct_net_id,
+       .size = sizeof(struct tc_action_net),
+};
+
+static int __init ct_init_module(void)
+{
+       char *mark = IS_ENABLED(CONFIG_NF_CONNTRACK_MARK) ? "on" : "off";
+       char *label = IS_ENABLED(CONFIG_NF_CONNTRACK_LABELS) ? "on" : "off";
+
+       pr_info("ct action on, mark: %s, label: %s\n", mark, label);
+       return tcf_register_action(&act_ct_ops, &ct_net_ops);
+}
+
+static void __exit ct_cleanup_module(void)
+{
+       tcf_unregister_action(&act_ct_ops, &ct_net_ops);
+}
+
+module_init(ct_init_module);
+module_exit(ct_cleanup_module);
+MODULE_AUTHOR("Yossi Kuperman <yoss...@mellanox.com>");
+MODULE_DESCRIPTION("Connection tracking action");
+MODULE_LICENSE("GPL");
+
-- 
1.8.3.1

Reply via email to