Add tests that check if getsockopt(TCP_AO_GET_KEYS) returns the right
keys when using different filters.

Sample output:

> # ok 114 filter keys: by sndid, rcvid, address
> # ok 115 filter keys: by is_current
> # ok 116 filter keys: by is_rnext
> # ok 117 filter keys: by sndid, rcvid
> # ok 118 filter keys: correct nkeys when in.nkeys < matches

Acked-by: Dmitry Safonov <0x7f454...@gmail.com>
Signed-off-by: Leo Stone <leocst...@gmail.com>
---
v3:
  - Ordered locals in reverse xmas tree order
  - Separated socket fd declaration from assignment
  - Broke lines longer than 80 columns
v2: https://lore.kernel.org/netdev/20241016055823.21299-1-leocst...@gmail.com/
  - Changed 2 unnecessary test_error calls to test_fail
  - Added another test to make sure getsockopt returns the right nkeys
  value when the input nkeys is smaller than the number of matching keys
  - Removed the TODO that this patch addresses
v1: https://lore.kernel.org/netdev/20241014213313.15100-1-leocst...@gmail.com/

Thanks to the reviewers for their time and feedback!
---
 .../selftests/net/tcp_ao/setsockopt-closed.c  | 186 +++++++++++++++++-
 1 file changed, 181 insertions(+), 5 deletions(-)

diff --git a/tools/testing/selftests/net/tcp_ao/setsockopt-closed.c 
b/tools/testing/selftests/net/tcp_ao/setsockopt-closed.c
index 084db4ecdff6..0abb9807d742 100644
--- a/tools/testing/selftests/net/tcp_ao/setsockopt-closed.c
+++ b/tools/testing/selftests/net/tcp_ao/setsockopt-closed.c
@@ -6,6 +6,8 @@
 
 static union tcp_addr tcp_md5_client;
 
+#define FILTER_TEST_NKEYS 16
+
 static int test_port = 7788;
 static void make_listen(int sk)
 {
@@ -813,23 +815,197 @@ static void duplicate_tests(void)
        setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EEXIST, "duplicate: SendID 
differs");
 }
 
+static void fetch_all_keys(int sk, struct tcp_ao_getsockopt *keys)
+{
+       socklen_t optlen = sizeof(struct tcp_ao_getsockopt);
+
+       memset(keys, 0, sizeof(struct tcp_ao_getsockopt) * FILTER_TEST_NKEYS);
+       keys[0].get_all = 1;
+       keys[0].nkeys = FILTER_TEST_NKEYS;
+       if (getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, &keys[0], &optlen))
+               test_error("getsockopt");
+}
+
+static int prepare_test_keys(struct tcp_ao_getsockopt *keys)
+{
+       const char *test_password = "Test password number ";
+       struct tcp_ao_add test_ao[FILTER_TEST_NKEYS];
+       char test_password_scratch[64] = {};
+       u8 rcvid = 100, sndid = 100;
+       int sk;
+
+       sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
+       if (sk < 0)
+               test_error("socket()");
+
+       for (int i = 0; i < FILTER_TEST_NKEYS; i++) {
+               snprintf(test_password_scratch, 64, "%s %d", test_password, i);
+               test_prepare_key(&test_ao[i], DEFAULT_TEST_ALGO, this_ip_dest,
+                         false, false, DEFAULT_TEST_PREFIX, 0, sndid++,
+                         rcvid++, 0, 0, strlen(test_password_scratch),
+                         test_password_scratch);
+       }
+       test_ao[0].set_current = 1;
+       test_ao[1].set_rnext = 1;
+       /* One key with a different addr and overlapping sndid, rcvid */
+       tcp_addr_to_sockaddr_in(&test_ao[2].addr, &this_ip_addr, 0);
+       test_ao[2].sndid = 100;
+       test_ao[2].rcvid = 100;
+
+       /* Add keys in a random order */
+       for (int i = 0; i < FILTER_TEST_NKEYS; i++) {
+               int randidx = rand() % (FILTER_TEST_NKEYS - i);
+
+               if (setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY,
+                              &test_ao[randidx], sizeof(struct tcp_ao_add)))
+                       test_error("setsockopt()");
+               memcpy(&test_ao[randidx], &test_ao[FILTER_TEST_NKEYS - 1 - i],
+                      sizeof(struct tcp_ao_add));
+       }
+
+       fetch_all_keys(sk, keys);
+
+       return sk;
+}
+
+/* Assumes passwords are unique */
+static int compare_mkts(struct tcp_ao_getsockopt *expected, int nexpected,
+                       struct tcp_ao_getsockopt *actual, int nactual)
+{
+       int matches = 0;
+
+       for (int i = 0; i < nexpected; i++) {
+               for (int j = 0; j < nactual; j++) {
+                       if (memcmp(expected[i].key, actual[j].key,
+                                  TCP_AO_MAXKEYLEN) == 0)
+                               matches++;
+               }
+       }
+       return nexpected - matches;
+}
+
+static void filter_keys_checked(int sk, struct tcp_ao_getsockopt *filter,
+                               struct tcp_ao_getsockopt *expected,
+                               unsigned int nexpected, const char *tst)
+{
+       struct tcp_ao_getsockopt filtered_keys[FILTER_TEST_NKEYS] = {};
+       struct tcp_ao_getsockopt all_keys[FILTER_TEST_NKEYS] = {};
+       socklen_t len = sizeof(struct tcp_ao_getsockopt);
+
+       fetch_all_keys(sk, all_keys);
+       memcpy(&filtered_keys[0], filter, sizeof(struct tcp_ao_getsockopt));
+       filtered_keys[0].nkeys = FILTER_TEST_NKEYS;
+       if (getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, filtered_keys, &len))
+               test_error("getsockopt");
+       if (filtered_keys[0].nkeys != nexpected) {
+               test_fail("wrong nr of keys, expected %u got %u", nexpected,
+                         filtered_keys[0].nkeys);
+               goto out_close;
+       }
+       if (compare_mkts(expected, nexpected, filtered_keys,
+                        filtered_keys[0].nkeys)) {
+               test_fail("got wrong keys back");
+               goto out_close;
+       }
+       test_ok("filter keys: %s", tst);
+
+out_close:
+       close(sk);
+       memset(filter, 0, sizeof(struct tcp_ao_getsockopt));
+}
+
+static void filter_tests(void)
+{
+       struct tcp_ao_getsockopt original_keys[FILTER_TEST_NKEYS];
+       struct tcp_ao_getsockopt expected_keys[FILTER_TEST_NKEYS];
+       struct tcp_ao_getsockopt filter = {};
+       int sk, f, nmatches;
+       socklen_t len;
+
+       f = 2;
+       sk = prepare_test_keys(original_keys);
+       filter.rcvid = original_keys[f].rcvid;
+       filter.sndid = original_keys[f].sndid;
+       memcpy(&filter.addr, &original_keys[f].addr,
+              sizeof(original_keys[f].addr));
+       filter.prefix = original_keys[f].prefix;
+       filter_keys_checked(sk, &filter, &original_keys[f], 1,
+                           "by sndid, rcvid, address");
+
+       f = -1;
+       sk = prepare_test_keys(original_keys);
+       for (int i = 0; i < original_keys[0].nkeys; i++) {
+               if (original_keys[i].is_current) {
+                       f = i;
+                       break;
+               }
+       }
+       if (f < 0)
+               test_error("No current key after adding one");
+       filter.is_current = 1;
+       filter_keys_checked(sk, &filter, &original_keys[f], 1, "by is_current");
+
+       f = -1;
+       sk = prepare_test_keys(original_keys);
+       for (int i = 0; i < original_keys[0].nkeys; i++) {
+               if (original_keys[i].is_rnext) {
+                       f = i;
+                       break;
+               }
+       }
+       if (f < 0)
+               test_error("No rnext key after adding one");
+       filter.is_rnext = 1;
+       filter_keys_checked(sk, &filter, &original_keys[f], 1, "by is_rnext");
+
+       f = -1;
+       nmatches = 0;
+       sk = prepare_test_keys(original_keys);
+       for (int i = 0; i < original_keys[0].nkeys; i++) {
+               if (original_keys[i].sndid == 100) {
+                       f = i;
+                       memcpy(&expected_keys[nmatches], &original_keys[i],
+                              sizeof(struct tcp_ao_getsockopt));
+                       nmatches++;
+               }
+       }
+       if (f < 0)
+               test_error("No key for sndid 100");
+       if (nmatches != 2)
+               test_error("Should have 2 keys with sndid 100");
+       filter.rcvid = original_keys[f].rcvid;
+       filter.sndid = original_keys[f].sndid;
+       filter.addr.ss_family = test_family;
+       filter_keys_checked(sk, &filter, expected_keys, nmatches,
+                           "by sndid, rcvid");
+
+       sk = prepare_test_keys(original_keys);
+       filter.get_all = 1;
+       filter.nkeys = FILTER_TEST_NKEYS / 2;
+       len = sizeof(struct tcp_ao_getsockopt);
+       if (getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, &filter, &len))
+               test_error("getsockopt");
+       if (filter.nkeys == FILTER_TEST_NKEYS)
+               test_ok("filter keys: correct nkeys when in.nkeys < matches");
+       else
+               test_fail("filter keys: wrong nkeys, expected %u got %u",
+                         FILTER_TEST_NKEYS, filter.nkeys);
+}
+
 static void *client_fn(void *arg)
 {
        if (inet_pton(TEST_FAMILY, __TEST_CLIENT_IP(2), &tcp_md5_client) != 1)
                test_error("Can't convert ip address");
        extend_tests();
        einval_tests();
+       filter_tests();
        duplicate_tests();
-       /*
-        * TODO: check getsockopt(TCP_AO_GET_KEYS) with different filters
-        * returning proper nr & keys;
-        */
 
        return NULL;
 }
 
 int main(int argc, char *argv[])
 {
-       test_init(121, client_fn, NULL);
+       test_init(126, client_fn, NULL);
        return 0;
 }
-- 
2.43.0


Reply via email to