Rewrite function to have (unneeded) socket descriptors automatically
close()d when leaving the scope. Make sure the "ownership" of fds is
correctly passed via take_fd(); i.e. descriptor returned to caller will
remain valid.

Suggested-by: Jakub Sitnicki <ja...@cloudflare.com>
Signed-off-by: Michal Luczaj <m...@rbox.co>
---
 .../selftests/bpf/prog_tests/sockmap_helpers.h     | 61 +++++++++++++---------
 1 file changed, 36 insertions(+), 25 deletions(-)

diff --git a/tools/testing/selftests/bpf/prog_tests/sockmap_helpers.h 
b/tools/testing/selftests/bpf/prog_tests/sockmap_helpers.h
index ead8ea4fd0da..38e35c72bdaa 100644
--- a/tools/testing/selftests/bpf/prog_tests/sockmap_helpers.h
+++ b/tools/testing/selftests/bpf/prog_tests/sockmap_helpers.h
@@ -17,6 +17,17 @@
 
 #define __always_unused        __attribute__((__unused__))
 
+/* include/linux/cleanup.h */
+#define __get_and_null(p, nullvalue)                                           
\
+       ({                                                                     \
+               __auto_type __ptr = &(p);                                      \
+               __auto_type __val = *__ptr;                                    \
+               *__ptr = nullvalue;                                            \
+               __val;                                                         \
+       })
+
+#define take_fd(fd) __get_and_null(fd, -EBADF)
+
 #define _FAIL(errnum, fmt...)                                                  
\
        ({                                                                     \
                error_at_line(0, (errnum), __func__, __LINE__, fmt);           \
@@ -182,6 +193,14 @@
                __ret;                                                         \
        })
 
+static inline void close_fd(int *fd)
+{
+       if (*fd >= 0)
+               xclose(*fd);
+}
+
+#define __close_fd __attribute__((cleanup(close_fd)))
+
 static inline int poll_connect(int fd, unsigned int timeout_sec)
 {
        struct timeval timeout = { .tv_sec = timeout_sec };
@@ -369,9 +388,10 @@ static inline int socket_loopback(int family, int sotype)
 
 static inline int create_pair(int family, int sotype, int *p0, int *p1)
 {
+       __close_fd int s, c = -1, p = -1;
        struct sockaddr_storage addr;
        socklen_t len = sizeof(addr);
-       int s, c, p, err;
+       int err;
 
        s = socket_loopback(family, sotype);
        if (s < 0)
@@ -379,25 +399,23 @@ static inline int create_pair(int family, int sotype, int 
*p0, int *p1)
 
        err = xgetsockname(s, sockaddr(&addr), &len);
        if (err)
-               goto close_s;
+               return err;
 
        c = xsocket(family, sotype, 0);
-       if (c < 0) {
-               err = c;
-               goto close_s;
-       }
+       if (c < 0)
+               return c;
 
        err = connect(c, sockaddr(&addr), len);
        if (err) {
                if (errno != EINPROGRESS) {
                        FAIL_ERRNO("connect");
-                       goto close_c;
+                       return err;
                }
 
                err = poll_connect(c, IO_TIMEOUT_SEC);
                if (err) {
                        FAIL_ERRNO("poll_connect");
-                       goto close_c;
+                       return err;
                }
        }
 
@@ -405,36 +423,29 @@ static inline int create_pair(int family, int sotype, int 
*p0, int *p1)
        case SOCK_DGRAM:
                err = xgetsockname(c, sockaddr(&addr), &len);
                if (err)
-                       goto close_c;
+                       return err;
 
                err = xconnect(s, sockaddr(&addr), len);
-               if (!err) {
-                       *p0 = s;
-                       *p1 = c;
+               if (err)
                        return err;
-               }
+
+               *p0 = take_fd(s);
                break;
        case SOCK_STREAM:
        case SOCK_SEQPACKET:
                p = xaccept_nonblock(s, NULL, NULL);
-               if (p >= 0) {
-                       *p0 = p;
-                       *p1 = c;
-                       goto close_s;
-               }
+               if (p < 0)
+                       return p;
 
-               err = p;
+               *p0 = take_fd(p);
                break;
        default:
                FAIL("Unsupported socket type %#x", sotype);
-               err = -EOPNOTSUPP;
+               return -EOPNOTSUPP;
        }
 
-close_c:
-       close(c);
-close_s:
-       close(s);
-       return err;
+       *p1 = take_fd(c);
+       return 0;
 }
 
 static inline int create_socket_pairs(int family, int sotype, int *c0, int *c1,

-- 
2.45.2


Reply via email to