From: Daniel P. Berrangé <[email protected]>

Future patches will make it possible to load multiple certificate
files. This prepares the sanity checking code to support that by
taking a NUL terminated array of cert filenames.

Signed-off-by: Daniel P. Berrangé <[email protected]>
---
 src/rpc/virnettlscert.c    | 35 ++++++++++++++++++++++-------------
 src/rpc/virnettlscert.h    |  2 +-
 src/rpc/virnettlscontext.c |  6 ++++--
 tools/virt-pki-validate.c  |  3 ++-
 4 files changed, 29 insertions(+), 17 deletions(-)

diff --git a/src/rpc/virnettlscert.c b/src/rpc/virnettlscert.c
index 3efc4f0716..6f20b2601b 100644
--- a/src/rpc/virnettlscert.c
+++ b/src/rpc/virnettlscert.c
@@ -440,40 +440,49 @@ int virNetTLSCertLoadListFromFile(const char *certFile,
 #define MAX_CERTS 16
 int virNetTLSCertSanityCheck(bool isServer,
                              const char *cacertFile,
-                             const char *certFile)
+                             const char *const *certFiles)
 {
-    gnutls_x509_crt_t cert = NULL;
+    gnutls_x509_crt_t *certs = NULL;
     gnutls_x509_crt_t cacerts[MAX_CERTS] = { 0 };
     size_t ncacerts = 0;
     size_t i;
     int ret = -1;
 
-    if ((access(certFile, R_OK) == 0) &&
-        !(cert = virNetTLSCertLoadFromFile(certFile, isServer)))
-        goto cleanup;
+    certs = g_new0(gnutls_x509_crt_t, g_strv_length((gchar **)certFiles));
+    for (i = 0; certFiles[i] != NULL; i++) {
+        if ((access(certFiles[i], R_OK) == 0) &&
+            !(certs[i] = virNetTLSCertLoadFromFile(certFiles[i], isServer)))
+            goto cleanup;
+    }
     if ((access(cacertFile, R_OK) == 0) &&
         virNetTLSCertLoadListFromFile(cacertFile, cacerts,
                                       MAX_CERTS, &ncacerts) < 0)
         goto cleanup;
 
-    if (cert &&
-        virNetTLSCertCheck(cert, certFile, isServer, false) < 0)
-        goto cleanup;
+    for (i = 0; certFiles[i] != NULL; i++) {
+        if (certs[i] &&
+            virNetTLSCertCheck(certs[i], certFiles[i], isServer, false) < 0)
+            goto cleanup;
+    }
 
     for (i = 0; i < ncacerts; i++) {
         if (virNetTLSCertCheck(cacerts[i], cacertFile, isServer, true) < 0)
             goto cleanup;
     }
 
-    if (cert && ncacerts &&
-        virNetTLSCertCheckPair(cert, certFile, cacerts, ncacerts, cacertFile, 
isServer) < 0)
-        goto cleanup;
+    for (i = 0; certFiles[i] != NULL && ncacerts; i++) {
+        if (certs[i] && ncacerts &&
+            virNetTLSCertCheckPair(certs[i], certFiles[i], cacerts, ncacerts, 
cacertFile, isServer) < 0)
+            goto cleanup;
+    }
 
     ret = 0;
 
  cleanup:
-    if (cert)
-        gnutls_x509_crt_deinit(cert);
+    for (i = 0; certFiles[i] != NULL; i++) {
+        if (certs[i])
+            gnutls_x509_crt_deinit(certs[i]);
+    }
     for (i = 0; i < ncacerts; i++)
         gnutls_x509_crt_deinit(cacerts[i]);
     return ret;
diff --git a/src/rpc/virnettlscert.h b/src/rpc/virnettlscert.h
index a2f591d172..086d8dc7d6 100644
--- a/src/rpc/virnettlscert.h
+++ b/src/rpc/virnettlscert.h
@@ -28,7 +28,7 @@
 
 int virNetTLSCertSanityCheck(bool isServer,
                              const char *cacertFile,
-                             const char *certFile);
+                             const char *const *certFiles);
 
 int virNetTLSCertValidateCA(gnutls_x509_crt_t cert,
                             bool isServer);
diff --git a/src/rpc/virnettlscontext.c b/src/rpc/virnettlscontext.c
index f857bb2339..bb9db90dff 100644
--- a/src/rpc/virnettlscontext.c
+++ b/src/rpc/virnettlscontext.c
@@ -199,6 +199,7 @@ static virNetTLSContext *virNetTLSContextNew(const char 
*cacert,
 {
     virNetTLSContext *ctxt;
     int err;
+    const char *certs[] = { cert, NULL };
 
     if (virNetTLSContextInitialize() < 0)
         return NULL;
@@ -224,7 +225,7 @@ static virNetTLSContext *virNetTLSContextNew(const char 
*cacert,
     }
 
     if (sanityCheckCert &&
-        virNetTLSCertSanityCheck(isServer, cacert, cert) < 0)
+        virNetTLSCertSanityCheck(isServer, cacert, certs) < 0)
         goto error;
 
     if (virNetTLSContextLoadCredentials(ctxt, isServer, cacert, cacrl, cert, 
key) < 0)
@@ -367,6 +368,7 @@ int virNetTLSContextReloadForServer(virNetTLSContext *ctxt,
     g_autofree char *cacrl = NULL;
     g_autofree char *cert = NULL;
     g_autofree char *key = NULL;
+    const char *certs[] = { cert, NULL };
 
     x509credBak = g_steal_pointer(&ctxt->x509cred);
 
@@ -382,7 +384,7 @@ int virNetTLSContextReloadForServer(virNetTLSContext *ctxt,
         goto error;
     }
 
-    if (virNetTLSCertSanityCheck(true, cacert, cert))
+    if (virNetTLSCertSanityCheck(true, cacert, certs))
         goto error;
 
     if (virNetTLSContextLoadCredentials(ctxt, true, cacert, cacrl, cert, key))
diff --git a/tools/virt-pki-validate.c b/tools/virt-pki-validate.c
index e693ffaed6..0df289256e 100644
--- a/tools/virt-pki-validate.c
+++ b/tools/virt-pki-validate.c
@@ -163,6 +163,7 @@ virPKIValidateIdentity(bool isServer, bool system, const 
char *path)
 {
     g_autofree char *cacert = NULL, *cacrl = NULL;
     g_autofree char *cert = NULL, *key = NULL;
+    const char *certs[] = { cert, NULL };
     bool ok = true;
     const char *scope = isServer ? "SERVER" : "CLIENT";
 
@@ -274,7 +275,7 @@ virPKIValidateIdentity(bool isServer, bool system, const 
char *path)
 
     if (virNetTLSCertSanityCheck(isServer,
                                  cacert,
-                                 cert) < 0) {
+                                 certs) < 0) {
         virValidateFail(VIR_VALIDATE_FAIL, "%s",
                         virGetLastErrorMessage());
         ok = false;
-- 
2.51.1

Reply via email to