ip tos segment can be changed by setsockopt(IP_TOS), or by iptables;
this patch creates a new method to change socket tos segment of
processes based on cgroup

The usage:

    1. mount ip_tos cgroup, and setting tos value
    mount -t cgroup -o ip_tos ip_tos /cgroups/tos
    echo tos_value >/cgroups/tos/ip_tos.tos
    2. then move processes to cgroup, or create processes in cgroup

Signed-off-by: jimyan <jim...@baidu.com>
Signed-off-by: Li RongQing <lirongq...@baidu.com>
---
 include/linux/cgroup_subsys.h |   4 ++
 include/net/tos_cgroup.h      |  35 ++++++++++++
 net/ipv4/Kconfig              |  10 ++++
 net/ipv4/Makefile             |   1 +
 net/ipv4/af_inet.c            |   2 +
 net/ipv4/tos_cgroup.c         | 128 ++++++++++++++++++++++++++++++++++++++++++
 net/ipv6/af_inet6.c           |   2 +
 7 files changed, 182 insertions(+)
 create mode 100644 include/net/tos_cgroup.h
 create mode 100644 net/ipv4/tos_cgroup.c

diff --git a/include/linux/cgroup_subsys.h b/include/linux/cgroup_subsys.h
index acb77dcff3b4..1b86eda1c23e 100644
--- a/include/linux/cgroup_subsys.h
+++ b/include/linux/cgroup_subsys.h
@@ -61,6 +61,10 @@ SUBSYS(pids)
 SUBSYS(rdma)
 #endif
 
+#if IS_ENABLED(CONFIG_IP_TOS_CGROUP)
+SUBSYS(ip_tos)
+#endif
+
 /*
  * The following subsystems are not supported on the default hierarchy.
  */
diff --git a/include/net/tos_cgroup.h b/include/net/tos_cgroup.h
new file mode 100644
index 000000000000..0868e921faf3
--- /dev/null
+++ b/include/net/tos_cgroup.h
@@ -0,0 +1,35 @@
+/* SPDX-License-Identifier: GPL-2.0 */
+
+#ifndef _IP_TOS_CGROUP_H
+#define _IP_TOS_CGROUP_H
+
+#include <linux/cgroup.h>
+#include <linux/hardirq.h>
+
+struct tos_cgroup_state {
+       struct cgroup_subsys_state css;
+       u32 tos;
+};
+
+#if IS_ENABLED(CONFIG_IP_TOS_CGROUP)
+static inline u32 task_ip_tos(struct task_struct *p)
+{
+       u32 tos;
+
+       if (in_interrupt())
+               return 0;
+
+       rcu_read_lock();
+       tos = container_of(task_css(p, ip_tos_cgrp_id),
+                       struct tos_cgroup_state, css)->tos;
+       rcu_read_unlock();
+
+       return tos;
+}
+#else /* !CONFIG_IP_TOS_CGROUP */
+static inline u32 task_ip_tos(struct task_struct *p)
+{
+       return 0;
+}
+#endif /* CONFIG_IP_TOS_CGROUP */
+#endif  /* _IP_TOS_CGROUP_H */
diff --git a/net/ipv4/Kconfig b/net/ipv4/Kconfig
index 80dad301361d..57070bbb0394 100644
--- a/net/ipv4/Kconfig
+++ b/net/ipv4/Kconfig
@@ -753,3 +753,13 @@ config TCP_MD5SIG
          on the Internet.
 
          If unsure, say N.
+
+config IP_TOS_CGROUP
+       bool "ip tos cgroup"
+       depends on CGROUPS
+       default n
+       ---help---
+         Say Y here if you want to set ip packet tos based on the
+         control cgroup of their process.
+
+         This can set ip packet tos
diff --git a/net/ipv4/Makefile b/net/ipv4/Makefile
index a07b7dd06def..12c708142d1f 100644
--- a/net/ipv4/Makefile
+++ b/net/ipv4/Makefile
@@ -61,6 +61,7 @@ obj-$(CONFIG_TCP_CONG_LP) += tcp_lp.o
 obj-$(CONFIG_TCP_CONG_YEAH) += tcp_yeah.o
 obj-$(CONFIG_TCP_CONG_ILLINOIS) += tcp_illinois.o
 obj-$(CONFIG_NETLABEL) += cipso_ipv4.o
+obj-$(CONFIG_IP_TOS_CGROUP) += tos_cgroup.o
 
 obj-$(CONFIG_XFRM) += xfrm4_policy.o xfrm4_state.o xfrm4_input.o \
                      xfrm4_output.o xfrm4_protocol.o
diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c
index eaed0367e669..e2dd902b06dd 100644
--- a/net/ipv4/af_inet.c
+++ b/net/ipv4/af_inet.c
@@ -120,6 +120,7 @@
 #include <linux/mroute.h>
 #endif
 #include <net/l3mdev.h>
+#include <net/tos_cgroup.h>
 
 #include <trace/events/sock.h>
 
@@ -356,6 +357,7 @@ static int inet_create(struct net *net, struct socket 
*sock, int protocol,
        inet->mc_index  = 0;
        inet->mc_list   = NULL;
        inet->rcv_tos   = 0;
+       inet->tos       = task_ip_tos(current);
 
        sk_refcnt_debug_inc(sk);
 
diff --git a/net/ipv4/tos_cgroup.c b/net/ipv4/tos_cgroup.c
new file mode 100644
index 000000000000..dbb828f5b464
--- /dev/null
+++ b/net/ipv4/tos_cgroup.c
@@ -0,0 +1,128 @@
+// SPDX-License-Identifier: GPL-2.0
+#include <linux/module.h>
+#include <linux/types.h>
+#include <linux/errno.h>
+#include <linux/slab.h>
+#include <linux/skbuff.h>
+#include <linux/cgroup.h>
+#include <net/sock.h>
+#include <net/inet_sock.h>
+#include <net/tos_cgroup.h>
+#include <linux/fdtable.h>
+#include <net/route.h>
+#include <net/inet_ecn.h>
+#include <linux/sched/task.h>
+
+static inline
+struct tos_cgroup_state *css_tos_cgroup(struct cgroup_subsys_state *css)
+{
+       return css ? container_of(css, struct tos_cgroup_state, css) : NULL;
+}
+
+static inline struct tos_cgroup_state *task_tos_cgroup(struct task_struct 
*task)
+{
+       return css_tos_cgroup(task_css(task, ip_tos_cgrp_id));
+}
+
+static struct cgroup_subsys_state
+*cgrp_css_alloc(struct cgroup_subsys_state *parent_css)
+{
+       struct tos_cgroup_state *cs;
+
+       cs = kzalloc(sizeof(*cs), GFP_KERNEL);
+       if (!cs)
+               return ERR_PTR(-ENOMEM);
+
+       return &cs->css;
+}
+
+static void cgrp_css_free(struct cgroup_subsys_state *css)
+{
+       kfree(css_tos_cgroup(css));
+}
+
+static int update_tos(const void *v, struct file *file, unsigned int n)
+{
+       int err;
+       struct socket *sock = sock_from_file(file, &err);
+       unsigned char val = (unsigned char)*(u64 *)v;
+
+       if (sock && (sock->sk->sk_family == PF_INET ||
+                               sock->sk->sk_family == PF_INET6)) {
+               struct inet_sock *inet = inet_sk(sock->sk);
+
+               lock_sock(sock->sk);
+               if (sock->sk->sk_type == SOCK_STREAM) {
+                       val &= ~INET_ECN_MASK;
+                       val |= inet->tos & INET_ECN_MASK;
+               }
+               if (inet->tos != val) {
+                       inet->tos = val;
+                       sock->sk->sk_priority = rt_tos2priority(val);
+                       sk_dst_reset(sock->sk);
+               }
+               release_sock(sock->sk);
+       }
+       return 0;
+}
+
+static void cgrp_attach(struct cgroup_taskset *tset)
+{
+       struct task_struct *p;
+       struct cgroup_subsys_state *css;
+       u64 v;
+
+       cgroup_taskset_for_each(p, css, tset) {
+               task_lock(p);
+               v = task_tos_cgroup(p)->tos;
+               iterate_fd(p->files, 0, update_tos, (void *)&v);
+               task_unlock(p);
+       }
+}
+
+static u64 read_tos(struct cgroup_subsys_state *css, struct cftype *cft)
+{
+       return css_tos_cgroup(css)->tos;
+}
+
+static int
+write_tos(struct cgroup_subsys_state *css, struct cftype *cft, u64 value)
+{
+       struct css_task_iter it;
+       struct task_struct *task = NULL;
+
+       if (value < 0 || value > 255) {
+               pr_info("Invalid TOS value\n");
+               return 0;
+       }
+
+       css_tos_cgroup(css)->tos = (u32)value;
+
+       css_task_iter_start(css, 0, &it);
+       while ((task = css_task_iter_next(&it))) {
+               task_lock(task);
+               iterate_fd(task->files, 0, update_tos, (void *)&value);
+               task_unlock(task);
+       }
+       css_task_iter_end(&it);
+
+       return 0;
+}
+
+static struct cftype ss_files[] = {
+       {
+               .name = "tos",
+               .read_u64 = read_tos,
+               .write_u64 = write_tos,
+       },
+       { }     /* terminate */
+};
+
+struct cgroup_subsys ip_tos_cgrp_subsys = {
+       .css_alloc      = cgrp_css_alloc,
+       .css_free       = cgrp_css_free,
+       .attach         = cgrp_attach,
+       .legacy_cftypes = ss_files,
+};
+
+MODULE_LICENSE("GPL v2");
diff --git a/net/ipv6/af_inet6.c b/net/ipv6/af_inet6.c
index 8da0b513f188..d33e240613e0 100644
--- a/net/ipv6/af_inet6.c
+++ b/net/ipv6/af_inet6.c
@@ -44,6 +44,7 @@
 #include <linux/icmpv6.h>
 #include <linux/netfilter_ipv6.h>
 
+#include <net/tos_cgroup.h>
 #include <net/ip.h>
 #include <net/ipv6.h>
 #include <net/udp.h>
@@ -223,6 +224,7 @@ static int inet6_create(struct net *net, struct socket 
*sock, int protocol,
        inet->mc_index  = 0;
        inet->mc_list   = NULL;
        inet->rcv_tos   = 0;
+       inet->tos       = task_ip_tos(current);
 
        if (net->ipv4.sysctl_ip_no_pmtu_disc)
                inet->pmtudisc = IP_PMTUDISC_DONT;
-- 
2.11.0

Reply via email to