From: Daniel P. Berrangé <[email protected]>

In the transition to Post-Quantum Cryptography, it will often be
desirable to load multiple sets of certificates, some with RSA/ECC
and some with MLDSA. This extends the TLS context code to support
the loading of many certs, passed as a NULL terminated array.

Signed-off-by: Daniel P. Berrangé <[email protected]>
---
 src/libvirt_probes.d         |  3 ++-
 src/remote/remote_daemon.c   |  6 +++--
 src/rpc/virnettlscontext.c   | 51 ++++++++++++++++++++----------------
 src/rpc/virnettlscontext.h   | 26 +++++++++---------
 tests/virnettlscontexttest.c | 10 ++++---
 tests/virnettlssessiontest.c |  9 ++++---
 6 files changed, 58 insertions(+), 47 deletions(-)

diff --git a/src/libvirt_probes.d b/src/libvirt_probes.d
index 6fac10a2bf..d9e75d9797 100644
--- a/src/libvirt_probes.d
+++ b/src/libvirt_probes.d
@@ -54,7 +54,8 @@ provider libvirt {
        # file: src/rpc/virnettlscontext.c
        # prefix: rpc
        probe rpc_tls_context_new(void *ctxt, const char *cacert, const char 
*cacrl,
-                                 const char *cert, const char *key, int 
sanityCheckCert, int requireValidCert, int isServer);
+                                 const char **cert, const char **keys,
+                                 int sanityCheckCert, int requireValidCert, 
int isServer);
        probe rpc_tls_context_dispose(void *ctxt);
        probe rpc_tls_context_session_allow(void *ctxt, void *sess, const char 
*dname);
        probe rpc_tls_context_session_deny(void *ctxt, void *sess, const char 
*dname);
diff --git a/src/remote/remote_daemon.c b/src/remote/remote_daemon.c
index 2973813548..e7c8f587c4 100644
--- a/src/remote/remote_daemon.c
+++ b/src/remote/remote_daemon.c
@@ -327,6 +327,9 @@ daemonSetupNetworking(virNetServer *srv,
         if (config->ca_file ||
             config->cert_file ||
             config->key_file) {
+            const char *certs[] = { config->cert_file, NULL };
+            const char *keys[] = { config->key_file, NULL };
+
             if (!config->ca_file) {
                 virReportError(VIR_ERR_CONFIG_UNSUPPORTED, "%s",
                                _("No CA certificate path set to match server 
key/cert"));
@@ -346,8 +349,7 @@ daemonSetupNetworking(virNetServer *srv,
                       config->ca_file, config->cert_file, config->key_file);
             if (!(ctxt = virNetTLSContextNewServer(config->ca_file,
                                                    config->crl_file,
-                                                   config->cert_file,
-                                                   config->key_file,
+                                                   certs, keys,
                                                    (const char 
*const*)config->tls_allowed_dn_list,
                                                    config->tls_priority,
                                                    
config->tls_no_sanity_certificate ? false : true,
diff --git a/src/rpc/virnettlscontext.c b/src/rpc/virnettlscontext.c
index bb9db90dff..5e9c262b48 100644
--- a/src/rpc/virnettlscontext.c
+++ b/src/rpc/virnettlscontext.c
@@ -115,10 +115,11 @@ static int 
virNetTLSContextLoadCredentials(virNetTLSContext *ctxt,
                                            bool isServer,
                                            const char *cacert,
                                            const char *cacrl,
-                                           const char *cert,
-                                           const char *key)
+                                           const char *const *certs,
+                                           const char *const *keys)
 {
     int err;
+    size_t i;
 
     if (cacert && cacert[0] != '\0') {
         if (virNetTLSContextCheckCertFile("CA certificate", cacert, false) < 0)
@@ -157,29 +158,29 @@ static int 
virNetTLSContextLoadCredentials(virNetTLSContext *ctxt,
         }
     }
 
-    if (cert && cert[0] != '\0' && key && key[0] != '\0') {
+    for (i = 0; certs[i] != NULL && keys[i] != NULL; i++) {
         int rv;
-        if ((rv = virNetTLSContextCheckCertFile("certificate", cert, 
!isServer)) < 0)
+        if ((rv = virNetTLSContextCheckCertFile("certificate", certs[i], 
!isServer)) < 0)
             return -1;
         if (rv == 0 &&
-            (rv = virNetTLSContextCheckCertFile("private key", key, 
!isServer)) < 0)
+            (rv = virNetTLSContextCheckCertFile("private key", keys[i], 
!isServer)) < 0)
             return -1;
 
         if (rv == 0) {
-            VIR_DEBUG("loading cert and key from %s and %s", cert, key);
+            VIR_DEBUG("loading cert and key from %s and %s", certs[i], 
keys[i]);
             err =
                 gnutls_certificate_set_x509_key_file(ctxt->x509cred,
-                                                     cert, key,
+                                                     certs[i], keys[i],
                                                      GNUTLS_X509_FMT_PEM);
             if (err < 0) {
                 virReportError(VIR_ERR_SYSTEM_ERROR,
                                _("Unable to set x509 key and certificate: 
%1$s, %2$s: %3$s"),
-                               key, cert, gnutls_strerror(err));
+                               keys[i], certs[i], gnutls_strerror(err));
                 return -1;
             }
         } else {
             VIR_DEBUG("Skipping non-existent cert %s key %s on client",
-                      cert, key);
+                      certs[i], keys[i]);
         }
     }
 
@@ -189,8 +190,8 @@ static int virNetTLSContextLoadCredentials(virNetTLSContext 
*ctxt,
 
 static virNetTLSContext *virNetTLSContextNew(const char *cacert,
                                              const char *cacrl,
-                                             const char *cert,
-                                             const char *key,
+                                             const char *const *certs,
+                                             const char *const *keys,
                                              const char *const *x509dnACL,
                                              const char *priority,
                                              bool sanityCheckCert,
@@ -199,7 +200,6 @@ static virNetTLSContext *virNetTLSContextNew(const char 
*cacert,
 {
     virNetTLSContext *ctxt;
     int err;
-    const char *certs[] = { cert, NULL };
 
     if (virNetTLSContextInitialize() < 0)
         return NULL;
@@ -228,7 +228,8 @@ static virNetTLSContext *virNetTLSContextNew(const char 
*cacert,
         virNetTLSCertSanityCheck(isServer, cacert, certs) < 0)
         goto error;
 
-    if (virNetTLSContextLoadCredentials(ctxt, isServer, cacert, cacrl, cert, 
key) < 0)
+    if (virNetTLSContextLoadCredentials(ctxt, isServer, cacert, cacrl,
+                                        certs, keys) < 0)
         goto error;
 
     ctxt->requireValidCert = requireValidCert;
@@ -236,8 +237,8 @@ static virNetTLSContext *virNetTLSContextNew(const char 
*cacert,
     ctxt->isServer = isServer;
 
     PROBE(RPC_TLS_CONTEXT_NEW,
-          "ctxt=%p cacert=%s cacrl=%s cert=%s key=%s sanityCheckCert=%d 
requireValidCert=%d isServer=%d",
-          ctxt, cacert, NULLSTR(cacrl), cert, key, sanityCheckCert, 
requireValidCert, isServer);
+          "ctxt=%p cacert=%s cacrl=%s cert=%p key=%p sanityCheckCert=%d 
requireValidCert=%d isServer=%d",
+          ctxt, cacert, NULLSTR(cacrl), certs, keys, sanityCheckCert, 
requireValidCert, isServer);
 
     return ctxt;
 
@@ -313,12 +314,14 @@ static virNetTLSContext *virNetTLSContextNewPath(const 
char *pkipath,
     g_autofree char *cacrl = NULL;
     g_autofree char *key = NULL;
     g_autofree char *cert = NULL;
+    const char *certs[] = { cert, NULL };
+    const char *keys[] = { key, NULL };
 
     if (virNetTLSContextLocateCredentials(pkipath, tryUserPkiPath, isServer,
                                           &cacert, &cacrl, &cert, &key) < 0)
         return NULL;
 
-    return virNetTLSContextNew(cacert, cacrl, cert, key,
+    return virNetTLSContextNew(cacert, cacrl, certs, keys,
                                x509dnACL, priority, sanityCheckCert,
                                requireValidCert, isServer);
 }
@@ -347,14 +350,14 @@ virNetTLSContext *virNetTLSContextNewClientPath(const 
char *pkipath,
 
 virNetTLSContext *virNetTLSContextNewServer(const char *cacert,
                                             const char *cacrl,
-                                            const char *cert,
-                                            const char *key,
+                                            const char *const *certs,
+                                            const char *const *keys,
                                             const char *const *x509dnACL,
                                             const char *priority,
                                             bool sanityCheckCert,
                                             bool requireValidCert)
 {
-    return virNetTLSContextNew(cacert, cacrl, cert, key, x509dnACL, priority,
+    return virNetTLSContextNew(cacert, cacrl, certs, keys, x509dnACL, priority,
                                sanityCheckCert, requireValidCert, true);
 }
 
@@ -369,6 +372,7 @@ int virNetTLSContextReloadForServer(virNetTLSContext *ctxt,
     g_autofree char *cert = NULL;
     g_autofree char *key = NULL;
     const char *certs[] = { cert, NULL };
+    const char *keys[] = { key, NULL };
 
     x509credBak = g_steal_pointer(&ctxt->x509cred);
 
@@ -387,7 +391,8 @@ int virNetTLSContextReloadForServer(virNetTLSContext *ctxt,
     if (virNetTLSCertSanityCheck(true, cacert, certs))
         goto error;
 
-    if (virNetTLSContextLoadCredentials(ctxt, true, cacert, cacrl, cert, key))
+    if (virNetTLSContextLoadCredentials(ctxt, true, cacert, cacrl,
+                                        certs, keys))
         goto error;
 
     gnutls_certificate_free_credentials(x509credBak);
@@ -404,13 +409,13 @@ int virNetTLSContextReloadForServer(virNetTLSContext 
*ctxt,
 
 virNetTLSContext *virNetTLSContextNewClient(const char *cacert,
                                             const char *cacrl,
-                                            const char *cert,
-                                            const char *key,
+                                            const char *const *certs,
+                                            const char *const *keys,
                                             const char *priority,
                                             bool sanityCheckCert,
                                             bool requireValidCert)
 {
-    return virNetTLSContextNew(cacert, cacrl, cert, key, NULL, priority,
+    return virNetTLSContextNew(cacert, cacrl, certs, keys, NULL, priority,
                                sanityCheckCert, requireValidCert, false);
 }
 
diff --git a/src/rpc/virnettlscontext.h b/src/rpc/virnettlscontext.h
index 11c954ce4b..1e67171e3e 100644
--- a/src/rpc/virnettlscontext.h
+++ b/src/rpc/virnettlscontext.h
@@ -44,21 +44,21 @@ virNetTLSContext *virNetTLSContextNewClientPath(const char 
*pkipath,
                                                   bool requireValidCert);
 
 virNetTLSContext *virNetTLSContextNewServer(const char *cacert,
-                                              const char *cacrl,
-                                              const char *cert,
-                                              const char *key,
-                                              const char *const *x509dnACL,
-                                              const char *priority,
-                                              bool sanityCheckCert,
-                                              bool requireValidCert);
+                                            const char *cacrl,
+                                            const char *const *certs,
+                                            const char *const *keys,
+                                            const char *const *x509dnACL,
+                                            const char *priority,
+                                            bool sanityCheckCert,
+                                            bool requireValidCert);
 
 virNetTLSContext *virNetTLSContextNewClient(const char *cacert,
-                                              const char *cacrl,
-                                              const char *cert,
-                                              const char *key,
-                                              const char *priority,
-                                              bool sanityCheckCert,
-                                              bool requireValidCert);
+                                            const char *cacrl,
+                                            const char *const *certs,
+                                            const char *const *keys,
+                                            const char *priority,
+                                            bool sanityCheckCert,
+                                            bool requireValidCert);
 
 int virNetTLSContextReloadForServer(virNetTLSContext *ctxt,
                                     bool tryUserPkiPath);
diff --git a/tests/virnettlscontexttest.c b/tests/virnettlscontexttest.c
index 48bdefdd76..47675bffd0 100644
--- a/tests/virnettlscontexttest.c
+++ b/tests/virnettlscontexttest.c
@@ -56,12 +56,14 @@ static int testTLSContextInit(const void *opaque)
     struct testTLSContextData *data = (struct testTLSContextData *)opaque;
     virNetTLSContext *ctxt = NULL;
     int ret = -1;
+    const char *certs[] = { data->crt, NULL };
+    const char *keys[] = { KEYFILE, NULL };
 
     if (data->isServer) {
         ctxt = virNetTLSContextNewServer(data->cacrt,
                                          NULL,
-                                         data->crt,
-                                         KEYFILE,
+                                         certs,
+                                         keys,
                                          NULL,
                                          "NORMAL",
                                          true,
@@ -69,8 +71,8 @@ static int testTLSContextInit(const void *opaque)
     } else {
         ctxt = virNetTLSContextNewClient(data->cacrt,
                                          NULL,
-                                         data->crt,
-                                         KEYFILE,
+                                         certs,
+                                         keys,
                                          "NORMAL",
                                          true,
                                          true);
diff --git a/tests/virnettlssessiontest.c b/tests/virnettlssessiontest.c
index 459e17c52c..e8d64c7da0 100644
--- a/tests/virnettlssessiontest.c
+++ b/tests/virnettlssessiontest.c
@@ -81,6 +81,9 @@ static int testTLSSessionInit(const void *opaque)
     int channel[2];
     bool clientShake = false;
     bool serverShake = false;
+    const char *keys[] = { KEYFILE, NULL };
+    const char *clientcerts[] = { data->clientcrt, NULL };
+    const char *servercerts[] = { data->servercrt, NULL };
 
 
     /* We'll use this for our fake client-server connection */
@@ -102,8 +105,7 @@ static int testTLSSessionInit(const void *opaque)
      */
     serverCtxt = virNetTLSContextNewServer(data->servercacrt,
                                            NULL,
-                                           data->servercrt,
-                                           KEYFILE,
+                                           servercerts, keys,
                                            data->wildcards,
                                            "NORMAL",
                                            false,
@@ -111,8 +113,7 @@ static int testTLSSessionInit(const void *opaque)
 
     clientCtxt = virNetTLSContextNewClient(data->clientcacrt,
                                            NULL,
-                                           data->clientcrt,
-                                           KEYFILE,
+                                           clientcerts, keys,
                                            "NORMAL",
                                            false,
                                            true);
-- 
2.51.1

Reply via email to