Update BPF_CGROUP_RUN_PROG_INET_EGRESS() callers to support returning
congestion notifications.

If BPF_CGROUP_RUN_PROG_INET_EGRESS() returns a value other than
NET_XMIT_SUCCESS or NET_XMIT_CN, the skb is dropped and the value
is returned to the caller.

Else, if the return of the output function is not NET_XMIT_SUCCESS,
return it, otherwise return the return value of the call to
BPF_CGROUP_RUN_PROG_INET_EGRESS().

Otherwise, return the return value of the output function.

Signed-off-by: Lawrence Brakmo <bra...@fb.com>
---
 net/ipv4/ip_output.c  | 39 ++++++++++++++++++++++-----------------
 net/ipv6/ip6_output.c | 22 +++++++++++++---------
 2 files changed, 35 insertions(+), 26 deletions(-)

diff --git a/net/ipv4/ip_output.c b/net/ipv4/ip_output.c
index c80188875f39..efa0b9a195b4 100644
--- a/net/ipv4/ip_output.c
+++ b/net/ipv4/ip_output.c
@@ -292,43 +292,48 @@ static int ip_finish_output_gso(struct net *net, struct 
sock *sk,
 static int ip_finish_output(struct net *net, struct sock *sk, struct sk_buff 
*skb)
 {
        unsigned int mtu;
+       int ret_bpf;
        int ret;
 
-       ret = BPF_CGROUP_RUN_PROG_INET_EGRESS(sk, skb);
-       if (ret) {
+       ret_bpf = BPF_CGROUP_RUN_PROG_INET_EGRESS(sk, skb);
+       if (ret_bpf != NET_XMIT_SUCCESS && ret_bpf != NET_XMIT_CN) {
                kfree_skb(skb);
-               return ret;
+               return ret_bpf;
        }
 
 #if defined(CONFIG_NETFILTER) && defined(CONFIG_XFRM)
        /* Policy lookup after SNAT yielded a new policy */
        if (skb_dst(skb)->xfrm) {
                IPCB(skb)->flags |= IPSKB_REROUTED;
-               return dst_output(net, sk, skb);
-       }
+               ret = dst_output(net, sk, skb);
+       } else
 #endif
-       mtu = ip_skb_dst_mtu(sk, skb);
-       if (skb_is_gso(skb))
-               return ip_finish_output_gso(net, sk, skb, mtu);
-
-       if (skb->len > mtu || (IPCB(skb)->flags & IPSKB_FRAG_PMTU))
-               return ip_fragment(net, sk, skb, mtu, ip_finish_output2);
-
-       return ip_finish_output2(net, sk, skb);
+       {
+               mtu = ip_skb_dst_mtu(sk, skb);
+               if (skb_is_gso(skb))
+                       ret = ip_finish_output_gso(net, sk, skb, mtu);
+               else if (skb->len > mtu || (IPCB(skb)->flags & IPSKB_FRAG_PMTU))
+                       ret = ip_fragment(net, sk, skb, mtu, ip_finish_output2);
+               else
+                       ret = ip_finish_output2(net, sk, skb);
+       }
+       return ret ? : ret_bpf;
 }
 
 static int ip_mc_finish_output(struct net *net, struct sock *sk,
                               struct sk_buff *skb)
 {
+       int ret_bpf;
        int ret;
 
-       ret = BPF_CGROUP_RUN_PROG_INET_EGRESS(sk, skb);
-       if (ret) {
+       ret_bpf = BPF_CGROUP_RUN_PROG_INET_EGRESS(sk, skb);
+       if (ret_bpf != NET_XMIT_SUCCESS && ret_bpf != NET_XMIT_CN) {
                kfree_skb(skb);
-               return ret;
+               return ret_bpf;
        }
 
-       return dev_loopback_xmit(net, sk, skb);
+       ret = dev_loopback_xmit(net, sk, skb);
+       return ret ? : ret_bpf;
 }
 
 int ip_mc_output(struct net *net, struct sock *sk, struct sk_buff *skb)
diff --git a/net/ipv6/ip6_output.c b/net/ipv6/ip6_output.c
index edbd12067170..53a838d82a21 100644
--- a/net/ipv6/ip6_output.c
+++ b/net/ipv6/ip6_output.c
@@ -130,28 +130,32 @@ static int ip6_finish_output2(struct net *net, struct 
sock *sk, struct sk_buff *
 
 static int ip6_finish_output(struct net *net, struct sock *sk, struct sk_buff 
*skb)
 {
+       int ret_bpf;
        int ret;
 
-       ret = BPF_CGROUP_RUN_PROG_INET_EGRESS(sk, skb);
-       if (ret) {
+       ret_bpf = BPF_CGROUP_RUN_PROG_INET_EGRESS(sk, skb);
+       if (ret_bpf != NET_XMIT_SUCCESS && ret_bpf != NET_XMIT_CN) {
                kfree_skb(skb);
-               return ret;
+               return ret_bpf;
        }
 
 #if defined(CONFIG_NETFILTER) && defined(CONFIG_XFRM)
        /* Policy lookup after SNAT yielded a new policy */
        if (skb_dst(skb)->xfrm) {
                IPCB(skb)->flags |= IPSKB_REROUTED;
-               return dst_output(net, sk, skb);
-       }
+               ret = dst_output(net, sk, skb);
+       } else
 #endif
 
        if ((skb->len > ip6_skb_dst_mtu(skb) && !skb_is_gso(skb)) ||
            dst_allfrag(skb_dst(skb)) ||
-           (IP6CB(skb)->frag_max_size && skb->len > IP6CB(skb)->frag_max_size))
-               return ip6_fragment(net, sk, skb, ip6_finish_output2);
-       else
-               return ip6_finish_output2(net, sk, skb);
+           (IP6CB(skb)->frag_max_size && skb->len >
+            IP6CB(skb)->frag_max_size)) {
+               ret = ip6_fragment(net, sk, skb, ip6_finish_output2);
+       } else {
+               ret = ip6_finish_output2(net, sk, skb);
+       }
+       return ret ? : ret_bpf;
 }
 
 int ip6_output(struct net *net, struct sock *sk, struct sk_buff *skb)
-- 
2.17.1

Reply via email to