This patch adds support of SO_KEEPALIVE flag and TCP related options
to bpf_setsockopt() routine. This is helpful if we want to enable or tune
TCP keepalive for applications which don't do it in the userspace code.
In order to avoid copy-paste, common code from classic setsockopt was moved
to auxiliary functions in the headers.

Signed-off-by: Dmitry Yakunin <z...@yandex-team.ru>
---
 include/net/sock.h |  9 +++++++++
 include/net/tcp.h  | 18 ++++++++++++++++++
 net/core/filter.c  | 39 ++++++++++++++++++++++++++++++++++++++-
 net/core/sock.c    |  9 ---------
 net/ipv4/tcp.c     | 15 ++-------------
 5 files changed, 67 insertions(+), 23 deletions(-)

diff --git a/include/net/sock.h b/include/net/sock.h
index 3e8c6d4..ee35dea 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -879,6 +879,15 @@ static inline void sock_reset_flag(struct sock *sk, enum 
sock_flags flag)
        __clear_bit(flag, &sk->sk_flags);
 }
 
+static inline void sock_valbool_flag(struct sock *sk, enum sock_flags bit,
+                                    int valbool)
+{
+       if (valbool)
+               sock_set_flag(sk, bit);
+       else
+               sock_reset_flag(sk, bit);
+}
+
 static inline bool sock_flag(const struct sock *sk, enum sock_flags flag)
 {
        return test_bit(flag, &sk->sk_flags);
diff --git a/include/net/tcp.h b/include/net/tcp.h
index b681338..ae6a495 100644
--- a/include/net/tcp.h
+++ b/include/net/tcp.h
@@ -1465,6 +1465,24 @@ static inline u32 keepalive_time_elapsed(const struct 
tcp_sock *tp)
                          tcp_jiffies32 - tp->rcv_tstamp);
 }
 
+/* val must be validated at the top level function */
+static inline void keepalive_time_set(struct tcp_sock *tp, int val)
+{
+       struct sock *sk = (struct sock *)tp;
+
+       tp->keepalive_time = val * HZ;
+       if (sock_flag(sk, SOCK_KEEPOPEN) &&
+           !((1 << sk->sk_state) & (TCPF_CLOSE | TCPF_LISTEN))) {
+               u32 elapsed = keepalive_time_elapsed(tp);
+
+               if (tp->keepalive_time > elapsed)
+                       elapsed = tp->keepalive_time - elapsed;
+               else
+                       elapsed = 0;
+               inet_csk_reset_keepalive_timer(sk, elapsed);
+       }
+}
+
 static inline int tcp_fin_time(const struct sock *sk)
 {
        int fin_timeout = tcp_sk(sk)->linger2 ? : 
sock_net(sk)->ipv4.sysctl_tcp_fin_timeout;
diff --git a/net/core/filter.c b/net/core/filter.c
index a6fc234..1035e43 100644
--- a/net/core/filter.c
+++ b/net/core/filter.c
@@ -4248,8 +4248,8 @@ static const struct bpf_func_proto 
bpf_get_socket_uid_proto = {
 static int _bpf_setsockopt(struct sock *sk, int level, int optname,
                           char *optval, int optlen, u32 flags)
 {
+       int val, valbool;
        int ret = 0;
-       int val;
 
        if (!sk_fullsock(sk))
                return -EINVAL;
@@ -4260,6 +4260,7 @@ static int _bpf_setsockopt(struct sock *sk, int level, 
int optname,
                if (optlen != sizeof(int))
                        return -EINVAL;
                val = *((int *)optval);
+               valbool = val ? 1 : 0;
 
                /* Only some socketops are supported */
                switch (optname) {
@@ -4298,6 +4299,11 @@ static int _bpf_setsockopt(struct sock *sk, int level, 
int optname,
                                sk_dst_reset(sk);
                        }
                        break;
+               case SO_KEEPALIVE:
+                       if (sk->sk_prot->keepalive)
+                               sk->sk_prot->keepalive(sk, valbool);
+                       sock_valbool_flag(sk, SOCK_KEEPOPEN, valbool);
+                       break;
                default:
                        ret = -EINVAL;
                }
@@ -4358,6 +4364,7 @@ static int _bpf_setsockopt(struct sock *sk, int level, 
int optname,
                        ret = tcp_set_congestion_control(sk, name, false,
                                                         reinit, true);
                } else {
+                       struct inet_connection_sock *icsk = inet_csk(sk);
                        struct tcp_sock *tp = tcp_sk(sk);
 
                        if (optlen != sizeof(int))
@@ -4386,6 +4393,36 @@ static int _bpf_setsockopt(struct sock *sk, int level, 
int optname,
                                else
                                        tp->save_syn = val;
                                break;
+                       case TCP_KEEPIDLE:
+                               if (val < 1 || val > MAX_TCP_KEEPIDLE)
+                                       ret = -EINVAL;
+                               else
+                                       keepalive_time_set(tp, val);
+                               break;
+                       case TCP_KEEPINTVL:
+                               if (val < 1 || val > MAX_TCP_KEEPINTVL)
+                                       ret = -EINVAL;
+                               else
+                                       tp->keepalive_intvl = val * HZ;
+                               break;
+                       case TCP_KEEPCNT:
+                               if (val < 1 || val > MAX_TCP_KEEPCNT)
+                                       ret = -EINVAL;
+                               else
+                                       tp->keepalive_probes = val;
+                               break;
+                       case TCP_SYNCNT:
+                               if (val < 1 || val > MAX_TCP_SYNCNT)
+                                       ret = -EINVAL;
+                               else
+                                       icsk->icsk_syn_retries = val;
+                               break;
+                       case TCP_USER_TIMEOUT:
+                               if (val < 0)
+                                       ret = -EINVAL;
+                               else
+                                       icsk->icsk_user_timeout = val;
+                               break;
                        default:
                                ret = -EINVAL;
                        }
diff --git a/net/core/sock.c b/net/core/sock.c
index fd85e65..9836b01 100644
--- a/net/core/sock.c
+++ b/net/core/sock.c
@@ -684,15 +684,6 @@ static int sock_getbindtodevice(struct sock *sk, char 
__user *optval,
        return ret;
 }
 
-static inline void sock_valbool_flag(struct sock *sk, enum sock_flags bit,
-                                    int valbool)
-{
-       if (valbool)
-               sock_set_flag(sk, bit);
-       else
-               sock_reset_flag(sk, bit);
-}
-
 bool sk_mc_loop(struct sock *sk)
 {
        if (dev_recursion_level())
diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c
index 9700649..7b239e8 100644
--- a/net/ipv4/tcp.c
+++ b/net/ipv4/tcp.c
@@ -3003,19 +3003,8 @@ static int do_tcp_setsockopt(struct sock *sk, int level,
        case TCP_KEEPIDLE:
                if (val < 1 || val > MAX_TCP_KEEPIDLE)
                        err = -EINVAL;
-               else {
-                       tp->keepalive_time = val * HZ;
-                       if (sock_flag(sk, SOCK_KEEPOPEN) &&
-                           !((1 << sk->sk_state) &
-                             (TCPF_CLOSE | TCPF_LISTEN))) {
-                               u32 elapsed = keepalive_time_elapsed(tp);
-                               if (tp->keepalive_time > elapsed)
-                                       elapsed = tp->keepalive_time - elapsed;
-                               else
-                                       elapsed = 0;
-                               inet_csk_reset_keepalive_timer(sk, elapsed);
-                       }
-               }
+               else
+                       keepalive_time_set(tp, val);
                break;
        case TCP_KEEPINTVL:
                if (val < 1 || val > MAX_TCP_KEEPINTVL)
-- 
2.7.4

Reply via email to