On Wed, Jul 24, 2024 at 01:32 PM +02, Michal Luczaj wrote:
> Rewrite function to have (unneeded) socket descriptors automatically
> close()d when leaving the scope. Make sure the "ownership" of fds is
> correctly passed via take_fd(); i.e. descriptor returned to caller will
> remain valid.
>
> Suggested-by: Jakub Sitnicki <ja...@cloudflare.com>
> Signed-off-by: Michal Luczaj <m...@rbox.co>
> ---
>  .../selftests/bpf/prog_tests/sockmap_helpers.h     | 57 
> ++++++++++++----------
>  1 file changed, 32 insertions(+), 25 deletions(-)
>
> diff --git a/tools/testing/selftests/bpf/prog_tests/sockmap_helpers.h 
> b/tools/testing/selftests/bpf/prog_tests/sockmap_helpers.h
> index ead8ea4fd0da..2e0f9fe459be 100644
> --- a/tools/testing/selftests/bpf/prog_tests/sockmap_helpers.h
> +++ b/tools/testing/selftests/bpf/prog_tests/sockmap_helpers.h
> @@ -182,6 +182,21 @@
>               __ret;                                                         \
>       })
>  
> +#define take_fd(fd)                                                          
>   \
> +     ({                                                                     \
> +             __auto_type __val = (fd);                                      \
> +             fd = -EBADF;                                                   \
> +             __val;                                                         \
> +     })

Probably should operate on a pointer to fd to avoid side effects, like
__get_and_null macro in include/linux/cleanup.h. take_fd is effectively
__get_and_null(fd, -EBADFD).

> +
> +static inline void close_fd(int *fd)
> +{
> +     if (*fd >= 0)
> +             xclose(*fd);
> +}
> +
> +#define __close_fd __attribute__((cleanup(close_fd)))
> +
>  static inline int poll_connect(int fd, unsigned int timeout_sec)
>  {
>       struct timeval timeout = { .tv_sec = timeout_sec };
> @@ -369,9 +384,10 @@ static inline int socket_loopback(int family, int sotype)
>  
>  static inline int create_pair(int family, int sotype, int *p0, int *p1)
>  {
> +     __close_fd int s, c = -1, p = -1;
>       struct sockaddr_storage addr;
>       socklen_t len = sizeof(addr);
> -     int s, c, p, err;
> +     int err;
>  
>       s = socket_loopback(family, sotype);
>       if (s < 0)
> @@ -379,25 +395,23 @@ static inline int create_pair(int family, int sotype, 
> int *p0, int *p1)
>  
>       err = xgetsockname(s, sockaddr(&addr), &len);
>       if (err)
> -             goto close_s;
> +             return err;
>  
>       c = xsocket(family, sotype, 0);
> -     if (c < 0) {
> -             err = c;
> -             goto close_s;
> -     }
> +     if (c < 0)
> +             return c;
>  
>       err = connect(c, sockaddr(&addr), len);
>       if (err) {
>               if (errno != EINPROGRESS) {
>                       FAIL_ERRNO("connect");
> -                     goto close_c;
> +                     return err;
>               }
>  
>               err = poll_connect(c, IO_TIMEOUT_SEC);
>               if (err) {
>                       FAIL_ERRNO("poll_connect");
> -                     goto close_c;
> +                     return err;
>               }
>       }
>  
> @@ -405,36 +419,29 @@ static inline int create_pair(int family, int sotype, 
> int *p0, int *p1)
>       case SOCK_DGRAM:
>               err = xgetsockname(c, sockaddr(&addr), &len);
>               if (err)
> -                     goto close_c;
> +                     return err;
>  
>               err = xconnect(s, sockaddr(&addr), len);
> -             if (!err) {
> -                     *p0 = s;
> -                     *p1 = c;
> +             if (err)
>                       return err;
> -             }
> +
> +             *p0 = take_fd(s);
>               break;
>       case SOCK_STREAM:
>       case SOCK_SEQPACKET:
>               p = xaccept_nonblock(s, NULL, NULL);
> -             if (p >= 0) {
> -                     *p0 = p;
> -                     *p1 = c;
> -                     goto close_s;
> -             }
> +             if (p < 0)
> +                     return p;
>  
> -             err = p;
> +             *p0 = take_fd(p);
>               break;
>       default:
>               FAIL("Unsupported socket type %#x", sotype);
> -             err = -EOPNOTSUPP;
> +             return -EOPNOTSUPP;
>       }
>  
> -close_c:
> -     close(c);
> -close_s:
> -     close(s);
> -     return err;
> +     *p1 = take_fd(c);
> +     return 0;
>  }
>  
>  static inline int create_socket_pairs(int family, int sotype, int *c0, int 
> *c1,

This turned out nice and readable, IMHO.

Reply via email to