On Wed, Dec 20, 2017 at 09:35:55AM +0900, Michael Paquier wrote:
> However, it is possible to simply optimize the frontend code as in
> pg_SASL_init() we already know the channel binding type selected when
> calling pgtls_get_finished() and pgtls_get_peer_certificate_hash(). So
> while I agree with your point, my opinion is to keep the code as
> simple as possible, and then just optimize the frontend code. What do
> you think?

I have looked at how things could be done in symmetry for both the frontend
and backend code, and I have produced the attached patch 0002, which
can be applied on top of 0001 implementing tls-server-end-point. This
simplifies the interfaces to initialize the SCRAM status data by saving
into scram_state and fe_scram_state respectively Port* and PGconn* which
holds most of the data needed for the exchange. With this patch, cbind_data
is generated only if a specific channel binding type is used with the
appropriate data. So if no channel binding is used there is no additional
SSL call done to get the TLS finished data or the server certificate hash.

0001 has no real changes compared to the last versions.

Peter, thoughts?
-- 
Michael
From 69dcf31f5ce5938f9f56a94bf55c8439ea53ed27 Mon Sep 17 00:00:00 2001
From: Michael Paquier <mich...@paquier.xyz>
Date: Fri, 22 Dec 2017 10:49:10 +0900
Subject: [PATCH 1/2] Implement channel binding tls-server-end-point for SCRAM

As referenced in RFC 5929, this channel binding is not the default value
and uses a hash of the certificate as binding data. On the frontend, this
can be resumed in getting the data from SSL_get_peer_certificate() and
on the backend SSL_get_certificate().

The hashing algorithm needs also to switch to SHA-256 if the signature
algorithm is MD5 or SHA-1, so let's be careful about that.
---
 doc/src/sgml/protocol.sgml               |  5 +-
 src/backend/libpq/auth-scram.c           | 26 ++++++++---
 src/backend/libpq/auth.c                 |  8 +++-
 src/backend/libpq/be-secure-openssl.c    | 61 +++++++++++++++++++++++++
 src/include/common/scram-common.h        |  1 +
 src/include/libpq/libpq-be.h             |  1 +
 src/include/libpq/scram.h                |  3 +-
 src/interfaces/libpq/fe-auth-scram.c     | 18 ++++++--
 src/interfaces/libpq/fe-auth.c           | 12 ++++-
 src/interfaces/libpq/fe-auth.h           |  4 +-
 src/interfaces/libpq/fe-secure-openssl.c | 78 ++++++++++++++++++++++++++++++++
 src/interfaces/libpq/libpq-int.h         |  1 +
 src/test/ssl/t/002_scram.pl              |  5 +-
 13 files changed, 207 insertions(+), 16 deletions(-)

diff --git a/doc/src/sgml/protocol.sgml b/doc/src/sgml/protocol.sgml
index 8174e3defa..365f72b51d 100644
--- a/doc/src/sgml/protocol.sgml
+++ b/doc/src/sgml/protocol.sgml
@@ -1576,8 +1576,9 @@ the password is in.
   <para>
 <firstterm>Channel binding</firstterm> is supported in PostgreSQL builds with
 SSL support. The SASL mechanism name for SCRAM with channel binding
-is <literal>SCRAM-SHA-256-PLUS</literal>.  The only channel binding type
-supported at the moment is <literal>tls-unique</literal>, defined in RFC 5929.
+is <literal>SCRAM-SHA-256-PLUS</literal>.  Two channel binding types are
+supported at the moment: <literal>tls-unique</literal>, which is the default,
+and <literal>tls-server-end-point</literal>, both defined in RFC 5929.
   </para>
 
 <procedure>
diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c
index d52a763457..849587d141 100644
--- a/src/backend/libpq/auth-scram.c
+++ b/src/backend/libpq/auth-scram.c
@@ -114,6 +114,8 @@ typedef struct
        bool            ssl_in_use;
        const char *tls_finished_message;
        size_t          tls_finished_len;
+       const char *certificate_hash;
+       size_t          certificate_hash_len;
        char       *channel_binding_type;
 
        int                     iterations;
@@ -176,7 +178,9 @@ pg_be_scram_init(const char *username,
                                 const char *shadow_pass,
                                 bool ssl_in_use,
                                 const char *tls_finished_message,
-                                size_t tls_finished_len)
+                                size_t tls_finished_len,
+                                const char *certificate_hash,
+                                size_t certificate_hash_len)
 {
        scram_state *state;
        bool            got_verifier;
@@ -187,6 +191,8 @@ pg_be_scram_init(const char *username,
        state->ssl_in_use = ssl_in_use;
        state->tls_finished_message = tls_finished_message;
        state->tls_finished_len = tls_finished_len;
+       state->certificate_hash = certificate_hash;
+       state->certificate_hash_len = certificate_hash_len;
        state->channel_binding_type = NULL;
 
        /*
@@ -857,13 +863,15 @@ read_client_first_message(scram_state *state, char *input)
                                }
 
                                /*
-                                * Read value provided by client; only 
tls-unique is supported
-                                * for now.  (It is not safe to print the name 
of an
-                                * unsupported binding type in the error 
message.  Pranksters
-                                * could print arbitrary strings into the log 
that way.)
+                                * Read value provided by client; only 
tls-unique and
+                                * tls-server-end-point are supported for now.  
(It is
+                                * not safe to print the name of an unsupported 
binding
+                                * type in the error message.  Pranksters could 
print
+                                * arbitrary strings into the log that way.)
                                 */
                                channel_binding_type = read_attr_value(&input, 
'p');
-                               if (strcmp(channel_binding_type, 
SCRAM_CHANNEL_BINDING_TLS_UNIQUE) != 0)
+                               if (strcmp(channel_binding_type, 
SCRAM_CHANNEL_BINDING_TLS_UNIQUE) != 0 &&
+                                       strcmp(channel_binding_type, 
SCRAM_CHANNEL_BINDING_TLS_ENDPOINT) != 0)
                                        ereport(ERROR,
                                                        
(errcode(ERRCODE_PROTOCOL_VIOLATION),
                                                         (errmsg("unsupported 
SCRAM channel-binding type"))));
@@ -1123,6 +1131,12 @@ read_client_final_message(scram_state *state, char 
*input)
                        cbind_data = state->tls_finished_message;
                        cbind_data_len = state->tls_finished_len;
                }
+               else if (strcmp(state->channel_binding_type,
+                                               
SCRAM_CHANNEL_BINDING_TLS_ENDPOINT) == 0)
+               {
+                       cbind_data = state->certificate_hash;
+                       cbind_data_len = state->certificate_hash_len;
+               }
                else
                {
                        /* should not happen */
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index b7f9bb1669..700a3bffa4 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -875,6 +875,8 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char 
**logdetail)
        bool            initial;
        char       *tls_finished = NULL;
        size_t          tls_finished_len = 0;
+       char       *certificate_hash = NULL;
+       size_t          certificate_hash_len = 0;
 
        /*
         * SASL auth is not supported for protocol versions before 3, because it
@@ -923,6 +925,8 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char 
**logdetail)
        if (port->ssl_in_use)
        {
                tls_finished = be_tls_get_peer_finished(port, 
&tls_finished_len);
+               certificate_hash = be_tls_get_certificate_hash(port,
+                                                                               
                           &certificate_hash_len);
        }
 #endif
 
@@ -941,7 +945,9 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char 
**logdetail)
                                                                  shadow_pass,
                                                                  
port->ssl_in_use,
                                                                  tls_finished,
-                                                                 
tls_finished_len);
+                                                                 
tls_finished_len,
+                                                                 
certificate_hash,
+                                                                 
certificate_hash_len);
 
        /*
         * Loop through SASL message exchange.  This exchange can consist of
diff --git a/src/backend/libpq/be-secure-openssl.c 
b/src/backend/libpq/be-secure-openssl.c
index 1e3e19f5e0..e3e8a535c8 100644
--- a/src/backend/libpq/be-secure-openssl.c
+++ b/src/backend/libpq/be-secure-openssl.c
@@ -1239,6 +1239,67 @@ be_tls_get_peer_finished(Port *port, size_t *len)
        return result;
 }
 
+/*
+ * Get the server certificate hash for authentication purposes. Per
+ * RFC 5929 and tls-server-end-point, the TLS server's certificate bytes
+ * need to be hashed with SHA-256 if its signature algorithm is MD5 or
+ * SHA-1 as per RFC 5929 (https://tools.ietf.org/html/rfc5929#section-4.1).
+ * If something else is used, the same hash as the signature algorithm is
+ * used. The result is a palloc'd hash of the server certificate with its
+ * size, and NULL if there is no certificate available.
+ */
+char *
+be_tls_get_certificate_hash(Port *port, size_t *len)
+{
+       char    *cert_hash = NULL;
+       X509    *server_cert;
+
+       *len = 0;
+       server_cert = SSL_get_certificate(port->ssl);
+
+       if (server_cert != NULL)
+       {
+               const EVP_MD   *algo_type = NULL;
+               char                    hash[EVP_MAX_MD_SIZE];  /* size for 
SHA-512 */
+               unsigned int    hash_size;
+               int                             algo_nid;
+
+               /*
+                * Get the signature algorithm of the certificate to determine 
the
+                * hash algorithm to use for the result.
+                */
+               if (!OBJ_find_sigid_algs(X509_get_signature_nid(server_cert),
+                                                                &algo_nid, 
NULL))
+                       elog(ERROR, "could not find signature algorithm");
+
+               switch (algo_nid)
+               {
+                       case NID_md5:
+                       case NID_sha1:
+                               algo_type = EVP_sha256();
+                               break;
+
+                       default:
+                               algo_type = EVP_get_digestbynid(algo_nid);
+                               if (algo_type == NULL)
+                                       elog(ERROR, "could not find digest for 
NID %s",
+                                                OBJ_nid2sn(algo_nid));
+                               break;
+               }
+
+               /* generate and save the certificate hash */
+               if (!X509_digest(server_cert, algo_type, (unsigned char *) hash,
+                                                &hash_size))
+                       elog(ERROR, "could not generate server certificate 
hash");
+
+               cert_hash = (char *) palloc(hash_size);
+               memcpy(cert_hash, hash, hash_size);
+               *len = hash_size;
+       }
+
+       return cert_hash;
+}
+
 /*
  * Convert an X509 subject name to a cstring.
  *
diff --git a/src/include/common/scram-common.h 
b/src/include/common/scram-common.h
index 857a60e71f..5aec5cadb8 100644
--- a/src/include/common/scram-common.h
+++ b/src/include/common/scram-common.h
@@ -21,6 +21,7 @@
 
 /* Channel binding types */
 #define SCRAM_CHANNEL_BINDING_TLS_UNIQUE    "tls-unique"
+#define SCRAM_CHANNEL_BINDING_TLS_ENDPOINT     "tls-server-end-point"
 
 /* Length of SCRAM keys (client and server) */
 #define SCRAM_KEY_LEN                          PG_SHA256_DIGEST_LENGTH
diff --git a/src/include/libpq/libpq-be.h b/src/include/libpq/libpq-be.h
index 856e0439d5..cf9d8b7870 100644
--- a/src/include/libpq/libpq-be.h
+++ b/src/include/libpq/libpq-be.h
@@ -210,6 +210,7 @@ extern void be_tls_get_version(Port *port, char *ptr, 
size_t len);
 extern void be_tls_get_cipher(Port *port, char *ptr, size_t len);
 extern void be_tls_get_peerdn_name(Port *port, char *ptr, size_t len);
 extern char *be_tls_get_peer_finished(Port *port, size_t *len);
+extern char *be_tls_get_certificate_hash(Port *port, size_t *len);
 #endif
 
 extern ProtocolVersion FrontendProtocol;
diff --git a/src/include/libpq/scram.h b/src/include/libpq/scram.h
index 2c245813d6..7c8f009a3b 100644
--- a/src/include/libpq/scram.h
+++ b/src/include/libpq/scram.h
@@ -21,7 +21,8 @@
 /* Routines dedicated to authentication */
 extern void *pg_be_scram_init(const char *username, const char *shadow_pass,
                                 bool ssl_in_use, const char 
*tls_finished_message,
-                                size_t tls_finished_len);
+                                size_t tls_finished_len, const char 
*certificate_hash,
+                                size_t certificate_hash_len);
 extern int pg_be_scram_exchange(void *opaq, char *input, int inputlen,
                                         char **output, int *outputlen, char 
**logdetail);
 
diff --git a/src/interfaces/libpq/fe-auth-scram.c 
b/src/interfaces/libpq/fe-auth-scram.c
index b8f7a6b5be..a56fccf12e 100644
--- a/src/interfaces/libpq/fe-auth-scram.c
+++ b/src/interfaces/libpq/fe-auth-scram.c
@@ -47,6 +47,8 @@ typedef struct
        bool            ssl_in_use;
        char       *tls_finished_message;
        size_t          tls_finished_len;
+       char       *certificate_hash;
+       size_t          certificate_hash_len;
        char       *sasl_mechanism;
        const char *channel_binding_type;
 
@@ -95,7 +97,9 @@ pg_fe_scram_init(const char *username,
                                 const char *sasl_mechanism,
                                 const char *channel_binding_type,
                                 char *tls_finished_message,
-                                size_t tls_finished_len)
+                                size_t tls_finished_len,
+                                char *certificate_hash,
+                                size_t certificate_hash_len)
 {
        fe_scram_state *state;
        char       *prep_password;
@@ -112,6 +116,8 @@ pg_fe_scram_init(const char *username,
        state->ssl_in_use = ssl_in_use;
        state->tls_finished_message = tls_finished_message;
        state->tls_finished_len = tls_finished_len;
+       state->certificate_hash = certificate_hash;
+       state->certificate_hash_len = certificate_hash_len;
        state->sasl_mechanism = strdup(sasl_mechanism);
        state->channel_binding_type = channel_binding_type;
 
@@ -156,8 +162,8 @@ pg_fe_scram_free(void *opaq)
                free(state->password);
        if (state->tls_finished_message)
                free(state->tls_finished_message);
-       if (state->sasl_mechanism)
-               free(state->sasl_mechanism);
+       if (state->certificate_hash)
+               free(state->certificate_hash);
 
        /* client messages */
        if (state->client_nonce)
@@ -461,6 +467,12 @@ build_client_final_message(fe_scram_state *state, 
PQExpBuffer errormessage)
                        cbind_data = state->tls_finished_message;
                        cbind_data_len = state->tls_finished_len;
                }
+               else if (strcmp(state->channel_binding_type,
+                                               
SCRAM_CHANNEL_BINDING_TLS_ENDPOINT) == 0)
+               {
+                       cbind_data = state->certificate_hash;
+                       cbind_data_len = state->certificate_hash_len;
+               }
                else
                {
                        /* should not happen */
diff --git a/src/interfaces/libpq/fe-auth.c b/src/interfaces/libpq/fe-auth.c
index 3340a9ad93..bb9b0573d1 100644
--- a/src/interfaces/libpq/fe-auth.c
+++ b/src/interfaces/libpq/fe-auth.c
@@ -493,6 +493,8 @@ pg_SASL_init(PGconn *conn, int payloadlen)
        PQExpBufferData mechanism_buf;
        char       *tls_finished = NULL;
        size_t          tls_finished_len = 0;
+       char       *certificate_hash = NULL;
+       size_t          certificate_hash_len = 0;
        char       *password;
 
        initPQExpBuffer(&mechanism_buf);
@@ -580,6 +582,12 @@ pg_SASL_init(PGconn *conn, int payloadlen)
                tls_finished = pgtls_get_finished(conn, &tls_finished_len);
                if (tls_finished == NULL)
                        goto oom_error;
+
+               certificate_hash =
+                       pgtls_get_peer_certificate_hash(conn,
+                                                                               
        &certificate_hash_len);
+               if (certificate_hash == NULL)
+                       goto error;             /* error message is set */
        }
 #endif
 
@@ -595,7 +603,9 @@ pg_SASL_init(PGconn *conn, int payloadlen)
                                                                                
selected_mechanism,
                                                                                
conn->scram_channel_binding,
                                                                                
tls_finished,
-                                                                               
tls_finished_len);
+                                                                               
tls_finished_len,
+                                                                               
certificate_hash,
+                                                                               
certificate_hash_len);
        if (!conn->sasl_state)
                goto oom_error;
 
diff --git a/src/interfaces/libpq/fe-auth.h b/src/interfaces/libpq/fe-auth.h
index db319ac071..68de8b6e32 100644
--- a/src/interfaces/libpq/fe-auth.h
+++ b/src/interfaces/libpq/fe-auth.h
@@ -29,7 +29,9 @@ extern void *pg_fe_scram_init(const char *username,
                                 const char *sasl_mechanism,
                                 const char *channel_binding_type,
                                 char *tls_finished_message,
-                                size_t tls_finished_len);
+                                size_t tls_finished_len,
+                                char *certificate_hash,
+                                size_t certificate_hash_len);
 extern void pg_fe_scram_free(void *opaq);
 extern void pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
                                         char **output, int *outputlen,
diff --git a/src/interfaces/libpq/fe-secure-openssl.c 
b/src/interfaces/libpq/fe-secure-openssl.c
index 61d161b367..99077c3d9a 100644
--- a/src/interfaces/libpq/fe-secure-openssl.c
+++ b/src/interfaces/libpq/fe-secure-openssl.c
@@ -419,6 +419,84 @@ pgtls_get_finished(PGconn *conn, size_t *len)
        return result;
 }
 
+/*
+ *     Get the hash of the server certificate
+ *
+ * This information is useful for end-point channel binding, where the
+ * client certificate hash is used as a link, per RFC 5929. If the
+ * signature hash algorithm is MD5 or SHA-1, fall back to SHA-256,
+ * as per RFC 5929 (https://tools.ietf.org/html/rfc5929#section-4.1).
+ * NULL is sent back to the caller in the event of an error, with an
+ * error message for the caller to consume.
+ */
+char *
+pgtls_get_peer_certificate_hash(PGconn *conn, size_t *len)
+{
+       char       *cert_hash = NULL;
+
+       *len = 0;
+
+       if (conn->peer)
+       {
+               X509               *peer_cert = conn->peer;
+               const EVP_MD   *algo_type = NULL;
+               char                    hash[EVP_MAX_MD_SIZE];  /* size for 
SHA-512 */
+               unsigned int    hash_size;
+               int                             algo_nid;
+
+               /*
+                * Get the signature algorithm of the certificate to determine 
the
+                * hash algorithm to use for the result.
+                */
+               if (!OBJ_find_sigid_algs(X509_get_signature_nid(peer_cert),
+                                                                &algo_nid, 
NULL))
+               {
+                       printfPQExpBuffer(&conn->errorMessage,
+                                                         libpq_gettext("could 
not find signature algorithm\n"));
+                       return NULL;
+               }
+
+               switch (algo_nid)
+               {
+                       case NID_md5:
+                       case NID_sha1:
+                               algo_type = EVP_sha256();
+                               break;
+
+                       default:
+                               algo_type = EVP_get_digestbynid(algo_nid);
+                               if (algo_type == NULL)
+                               {
+                                       printfPQExpBuffer(&conn->errorMessage,
+                                                                         
libpq_gettext("could not find digest for NID %s\n"),
+                                                                         
OBJ_nid2sn(algo_nid));
+                                       return NULL;
+                               }
+                               break;
+               }
+
+               if (!X509_digest(peer_cert, algo_type, (unsigned char *) hash,
+                                                &hash_size))
+               {
+                       printfPQExpBuffer(&conn->errorMessage,
+                                                         libpq_gettext("could 
not generate peer certificate hash\n"));
+                       return NULL;
+               }
+
+               /* save result */
+               cert_hash = (char *) malloc(hash_size);
+               if (cert_hash == NULL)
+               {
+                       printfPQExpBuffer(&conn->errorMessage,
+                                                         libpq_gettext("out of 
memory\n"));
+                       return NULL;
+               }
+               memcpy(cert_hash, hash, hash_size);
+               *len = hash_size;
+       }
+
+       return cert_hash;
+}
 
 /* ------------------------------------------------------------ */
 /*                                             OpenSSL specific code           
                        */
diff --git a/src/interfaces/libpq/libpq-int.h b/src/interfaces/libpq/libpq-int.h
index f6c1023f37..756c4d61e1 100644
--- a/src/interfaces/libpq/libpq-int.h
+++ b/src/interfaces/libpq/libpq-int.h
@@ -672,6 +672,7 @@ extern ssize_t pgtls_read(PGconn *conn, void *ptr, size_t 
len);
 extern bool pgtls_read_pending(PGconn *conn);
 extern ssize_t pgtls_write(PGconn *conn, const void *ptr, size_t len);
 extern char *pgtls_get_finished(PGconn *conn, size_t *len);
+extern char *pgtls_get_peer_certificate_hash(PGconn *conn, size_t *len);
 
 /*
  * this is so that we can check if a connection is non-blocking internally
diff --git a/src/test/ssl/t/002_scram.pl b/src/test/ssl/t/002_scram.pl
index 324b4888d4..3f425e00f0 100644
--- a/src/test/ssl/t/002_scram.pl
+++ b/src/test/ssl/t/002_scram.pl
@@ -4,7 +4,7 @@ use strict;
 use warnings;
 use PostgresNode;
 use TestLib;
-use Test::More tests => 4;
+use Test::More tests => 5;
 use ServerSetup;
 use File::Copy;
 
@@ -45,6 +45,9 @@ test_connect_ok($common_connstr,
 test_connect_ok($common_connstr,
        "scram_channel_binding=''",
        "SCRAM authentication without channel binding");
+test_connect_ok($common_connstr,
+       "scram_channel_binding=tls-server-end-point",
+       "SCRAM authentication with tls-server-end-point as channel binding");
 test_connect_fails($common_connstr,
        "scram_channel_binding=not-exists",
        "SCRAM authentication with invalid channel binding");
-- 
2.15.1

From d175df57f8df1f0ff497ae61b0c768f3bc9ca0a2 Mon Sep 17 00:00:00 2001
From: Michael Paquier <mich...@paquier.xyz>
Date: Fri, 22 Dec 2017 11:46:16 +0900
Subject: [PATCH 2/2] Refactor channel binding code to fetch cbind_data only
 when necessary

As things stand now, channel binding data is fetched from OpenSSL and saved
into the SASL exchange context for any SSL connection attempted for a SCRAM
authentication, resulting in data fetched but not used if no channel binding
is used or if a different channel binding type is used than what the data
is here for.

Refactor the code in such a way that binding data is only fetched from the
SSL stack only when a specific channel binding is used for both the frontend
and the backend. In order to achieve that, save the libpq connection context
directly in the SCRAM exchange state, and add a dependency to SSL in the
low-level SCRAM routines.

This makes the interface in charge of initializing the SCRAM context cleaner
as all its data comes from either PGconn* (for frontend) or Port* (for the
backend).
---
 src/backend/libpq/auth-scram.c           | 47 +++++++----------
 src/backend/libpq/auth.c                 | 25 +--------
 src/include/libpq/scram.h                |  7 ++-
 src/interfaces/libpq/fe-auth-scram.c     | 91 ++++++++++++++++----------------
 src/interfaces/libpq/fe-auth.c           | 33 +-----------
 src/interfaces/libpq/fe-auth.h           | 10 +---
 src/interfaces/libpq/fe-secure-openssl.c | 14 +++--
 src/interfaces/libpq/libpq-int.h         |  3 +-
 8 files changed, 84 insertions(+), 146 deletions(-)

diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c
index 849587d141..0a50f815ab 100644
--- a/src/backend/libpq/auth-scram.c
+++ b/src/backend/libpq/auth-scram.c
@@ -110,12 +110,8 @@ typedef struct
 
        const char *username;           /* username from startup packet */
 
+       Port       *port;
        char            cbind_flag;
-       bool            ssl_in_use;
-       const char *tls_finished_message;
-       size_t          tls_finished_len;
-       const char *certificate_hash;
-       size_t          certificate_hash_len;
        char       *channel_binding_type;
 
        int                     iterations;
@@ -174,25 +170,15 @@ static char *scram_mock_salt(const char *username);
  * it will fail, as if an incorrect password was given.
  */
 void *
-pg_be_scram_init(const char *username,
-                                const char *shadow_pass,
-                                bool ssl_in_use,
-                                const char *tls_finished_message,
-                                size_t tls_finished_len,
-                                const char *certificate_hash,
-                                size_t certificate_hash_len)
+pg_be_scram_init(Port *port,
+                                const char *shadow_pass)
 {
        scram_state *state;
        bool            got_verifier;
 
        state = (scram_state *) palloc0(sizeof(scram_state));
+       state->port = port;
        state->state = SCRAM_AUTH_INIT;
-       state->username = username;
-       state->ssl_in_use = ssl_in_use;
-       state->tls_finished_message = tls_finished_message;
-       state->tls_finished_len = tls_finished_len;
-       state->certificate_hash = certificate_hash;
-       state->certificate_hash_len = certificate_hash_len;
        state->channel_binding_type = NULL;
 
        /*
@@ -215,7 +201,7 @@ pg_be_scram_init(const char *username,
                                 */
                                ereport(LOG,
                                                (errmsg("invalid SCRAM verifier 
for user \"%s\"",
-                                                               username)));
+                                                               
state->port->user_name)));
                                got_verifier = false;
                        }
                }
@@ -226,7 +212,7 @@ pg_be_scram_init(const char *username,
                         * authentication with an MD5 hash.)
                         */
                        state->logdetail = psprintf(_("User \"%s\" does not 
have a valid SCRAM verifier."),
-                                                                               
state->username);
+                                                                               
state->port->user_name);
                        got_verifier = false;
                }
        }
@@ -248,8 +234,8 @@ pg_be_scram_init(const char *username,
         */
        if (!got_verifier)
        {
-               mock_scram_verifier(username, &state->iterations, &state->salt,
-                                                       state->StoredKey, 
state->ServerKey);
+               mock_scram_verifier(state->port->user_name, &state->iterations,
+                                                       &state->salt, 
state->StoredKey, state->ServerKey);
                state->doomed = true;
        }
 
@@ -821,7 +807,7 @@ read_client_first_message(scram_state *state, char *input)
                         * it supports channel binding, which in this 
implementation is
                         * the case if a connection is using SSL.
                         */
-                       if (state->ssl_in_use)
+                       if (state->port->ssl_in_use)
                                ereport(ERROR,
                                                
(errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
                                                 errmsg("SCRAM channel binding 
negotiation error"),
@@ -845,7 +831,7 @@ read_client_first_message(scram_state *state, char *input)
                        {
                                char       *channel_binding_type;
 
-                               if (!state->ssl_in_use)
+                               if (!state->port->ssl_in_use)
                                {
                                        /*
                                         * Without SSL, we don't support 
channel binding.
@@ -1128,14 +1114,19 @@ read_client_final_message(scram_state *state, char 
*input)
                 */
                if (strcmp(state->channel_binding_type, 
SCRAM_CHANNEL_BINDING_TLS_UNIQUE) == 0)
                {
-                       cbind_data = state->tls_finished_message;
-                       cbind_data_len = state->tls_finished_len;
+                       /* Fetch data from TLS finished message */
+#ifdef USE_SSL
+                       cbind_data = be_tls_get_peer_finished(state->port, 
&cbind_data_len);
+#endif
                }
                else if (strcmp(state->channel_binding_type,
                                                
SCRAM_CHANNEL_BINDING_TLS_ENDPOINT) == 0)
                {
-                       cbind_data = state->certificate_hash;
-                       cbind_data_len = state->certificate_hash_len;
+                       /* Fetch hash data of server's SSL certificate */
+#ifdef USE_SSL
+                       cbind_data = be_tls_get_certificate_hash(state->port,
+                                                                               
                         &cbind_data_len);
+#endif
                }
                else
                {
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index 700a3bffa4..bd91e1cd18 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -873,10 +873,6 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char 
**logdetail)
        int                     inputlen;
        int                     result;
        bool            initial;
-       char       *tls_finished = NULL;
-       size_t          tls_finished_len = 0;
-       char       *certificate_hash = NULL;
-       size_t          certificate_hash_len = 0;
 
        /*
         * SASL auth is not supported for protocol versions before 3, because it
@@ -917,19 +913,6 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char 
**logdetail)
        sendAuthRequest(port, AUTH_REQ_SASL, sasl_mechs, p - sasl_mechs + 1);
        pfree(sasl_mechs);
 
-#ifdef USE_SSL
-
-       /*
-        * Get data for channel binding.
-        */
-       if (port->ssl_in_use)
-       {
-               tls_finished = be_tls_get_peer_finished(port, 
&tls_finished_len);
-               certificate_hash = be_tls_get_certificate_hash(port,
-                                                                               
                           &certificate_hash_len);
-       }
-#endif
-
        /*
         * Initialize the status tracker for message exchanges.
         *
@@ -941,13 +924,7 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char 
**logdetail)
         * This is because we don't want to reveal to an attacker what usernames
         * are valid, nor which users have a valid password.
         */
-       scram_opaq = pg_be_scram_init(port->user_name,
-                                                                 shadow_pass,
-                                                                 
port->ssl_in_use,
-                                                                 tls_finished,
-                                                                 
tls_finished_len,
-                                                                 
certificate_hash,
-                                                                 
certificate_hash_len);
+       scram_opaq = pg_be_scram_init(port, shadow_pass);
 
        /*
         * Loop through SASL message exchange.  This exchange can consist of
diff --git a/src/include/libpq/scram.h b/src/include/libpq/scram.h
index 7c8f009a3b..f404f57253 100644
--- a/src/include/libpq/scram.h
+++ b/src/include/libpq/scram.h
@@ -13,16 +13,15 @@
 #ifndef PG_SCRAM_H
 #define PG_SCRAM_H
 
+#include "libpq/libpq-be.h"
+
 /* Status codes for message exchange */
 #define SASL_EXCHANGE_CONTINUE         0
 #define SASL_EXCHANGE_SUCCESS          1
 #define SASL_EXCHANGE_FAILURE          2
 
 /* Routines dedicated to authentication */
-extern void *pg_be_scram_init(const char *username, const char *shadow_pass,
-                                bool ssl_in_use, const char 
*tls_finished_message,
-                                size_t tls_finished_len, const char 
*certificate_hash,
-                                size_t certificate_hash_len);
+extern void *pg_be_scram_init(Port *port, const char *shadow_pass);
 extern int pg_be_scram_exchange(void *opaq, char *input, int inputlen,
                                         char **output, int *outputlen, char 
**logdetail);
 
diff --git a/src/interfaces/libpq/fe-auth-scram.c 
b/src/interfaces/libpq/fe-auth-scram.c
index a56fccf12e..a44338f0f9 100644
--- a/src/interfaces/libpq/fe-auth-scram.c
+++ b/src/interfaces/libpq/fe-auth-scram.c
@@ -42,15 +42,9 @@ typedef struct
        fe_scram_state_enum state;
 
        /* These are supplied by the user */
-       const char *username;
+       PGconn     *conn;
        char       *password;
-       bool            ssl_in_use;
-       char       *tls_finished_message;
-       size_t          tls_finished_len;
-       char       *certificate_hash;
-       size_t          certificate_hash_len;
        char       *sasl_mechanism;
-       const char *channel_binding_type;
 
        /* We construct these */
        uint8           SaltedPassword[SCRAM_KEY_LEN];
@@ -91,15 +85,9 @@ static bool pg_frontend_random(char *dst, int len);
  * freed by pg_fe_scram_free().
  */
 void *
-pg_fe_scram_init(const char *username,
+pg_fe_scram_init(PGconn *conn,
                                 const char *password,
-                                bool ssl_in_use,
-                                const char *sasl_mechanism,
-                                const char *channel_binding_type,
-                                char *tls_finished_message,
-                                size_t tls_finished_len,
-                                char *certificate_hash,
-                                size_t certificate_hash_len)
+                                const char *sasl_mechanism)
 {
        fe_scram_state *state;
        char       *prep_password;
@@ -111,15 +99,9 @@ pg_fe_scram_init(const char *username,
        if (!state)
                return NULL;
        memset(state, 0, sizeof(fe_scram_state));
+       state->conn = conn;
        state->state = FE_SCRAM_INIT;
-       state->username = username;
-       state->ssl_in_use = ssl_in_use;
-       state->tls_finished_message = tls_finished_message;
-       state->tls_finished_len = tls_finished_len;
-       state->certificate_hash = certificate_hash;
-       state->certificate_hash_len = certificate_hash_len;
        state->sasl_mechanism = strdup(sasl_mechanism);
-       state->channel_binding_type = channel_binding_type;
 
        if (!state->sasl_mechanism)
        {
@@ -160,10 +142,6 @@ pg_fe_scram_free(void *opaq)
 
        if (state->password)
                free(state->password);
-       if (state->tls_finished_message)
-               free(state->tls_finished_message);
-       if (state->certificate_hash)
-               free(state->certificate_hash);
 
        /* client messages */
        if (state->client_nonce)
@@ -376,11 +354,11 @@ build_client_first_message(fe_scram_state *state, 
PQExpBuffer errormessage)
         */
        if (strcmp(state->sasl_mechanism, SCRAM_SHA256_PLUS_NAME) == 0)
        {
-               Assert(state->ssl_in_use);
-               appendPQExpBuffer(&buf, "p=%s", state->channel_binding_type);
+               Assert(state->conn->ssl_in_use);
+               appendPQExpBuffer(&buf, "p=%s", 
state->conn->scram_channel_binding);
        }
-       else if (state->channel_binding_type == NULL ||
-                        strlen(state->channel_binding_type) == 0)
+       else if (state->conn->scram_channel_binding == NULL ||
+                        strlen(state->conn->scram_channel_binding) == 0)
        {
                /*
                 * Client has chosen to not show to server that it supports 
channel
@@ -388,7 +366,7 @@ build_client_first_message(fe_scram_state *state, 
PQExpBuffer errormessage)
                 */
                appendPQExpBuffer(&buf, "n");
        }
-       else if (state->ssl_in_use)
+       else if (state->conn->ssl_in_use)
        {
                /*
                 * Client supports channel binding, but thinks the server does 
not.
@@ -456,22 +434,36 @@ build_client_final_message(fe_scram_state *state, 
PQExpBuffer errormessage)
         */
        if (strcmp(state->sasl_mechanism, SCRAM_SHA256_PLUS_NAME) == 0)
        {
-               char       *cbind_data;
-               size_t          cbind_data_len;
+               char       *cbind_data = NULL;
+               size_t          cbind_data_len = 0;
                size_t          cbind_header_len;
                char       *cbind_input;
                size_t          cbind_input_len;
 
-               if (strcmp(state->channel_binding_type, 
SCRAM_CHANNEL_BINDING_TLS_UNIQUE) == 0)
+               if (strcmp(state->conn->scram_channel_binding, 
SCRAM_CHANNEL_BINDING_TLS_UNIQUE) == 0)
                {
-                       cbind_data = state->tls_finished_message;
-                       cbind_data_len = state->tls_finished_len;
+                       /* Fetch data from TLS finished message */
+#ifdef USE_SSL
+                       cbind_data = pgtls_get_finished(state->conn, 
&cbind_data_len);
+                       if (cbind_data == NULL)
+                               goto oom_error;
+#endif
                }
-               else if (strcmp(state->channel_binding_type,
+               else if (strcmp(state->conn->scram_channel_binding,
                                                
SCRAM_CHANNEL_BINDING_TLS_ENDPOINT) == 0)
                {
-                       cbind_data = state->certificate_hash;
-                       cbind_data_len = state->certificate_hash_len;
+                       /* Fetch hash data of server's SSL certificate */
+#ifdef USE_SSL
+                       cbind_data =
+                               pgtls_get_peer_certificate_hash(state->conn,
+                                                                               
                &cbind_data_len,
+                                                                               
                errormessage);
+                       if (cbind_data == NULL)
+                       {
+                               /* error message is already set on error */
+                               return NULL;
+                       }
+#endif
                }
                else
                {
@@ -485,37 +477,46 @@ build_client_final_message(fe_scram_state *state, 
PQExpBuffer errormessage)
                /* should not happen */
                if (cbind_data == NULL || cbind_data_len == 0)
                {
+                       if (cbind_data != NULL)
+                               free(cbind_data);
                        termPQExpBuffer(&buf);
                        printfPQExpBuffer(errormessage,
                                                          libpq_gettext("empty 
channel binding data for channel binding type \"%s\"\n"),
-                                                         
state->channel_binding_type);
+                                                         
state->conn->scram_channel_binding);
                        return NULL;
                }
 
                appendPQExpBuffer(&buf, "c=");
 
-               cbind_header_len = 4 + strlen(state->channel_binding_type); /* 
p=type,, */
+               /* p=type,, */
+               cbind_header_len = 4 + 
strlen(state->conn->scram_channel_binding);
                cbind_input_len = cbind_header_len + cbind_data_len;
                cbind_input = malloc(cbind_input_len);
                if (!cbind_input)
+               {
+                       free(cbind_data);
                        goto oom_error;
-               snprintf(cbind_input, cbind_input_len, "p=%s,,", 
state->channel_binding_type);
+               }
+               snprintf(cbind_input, cbind_input_len, "p=%s,,",
+                                state->conn->scram_channel_binding);
                memcpy(cbind_input + cbind_header_len, cbind_data, 
cbind_data_len);
 
                if (!enlargePQExpBuffer(&buf, pg_b64_enc_len(cbind_input_len)))
                {
+                       free(cbind_data);
                        free(cbind_input);
                        goto oom_error;
                }
                buf.len += pg_b64_encode(cbind_input, cbind_input_len, buf.data 
+ buf.len);
                buf.data[buf.len] = '\0';
 
+               free(cbind_data);
                free(cbind_input);
        }
-       else if (state->channel_binding_type == NULL ||
-                        strlen(state->channel_binding_type) == 0)
+       else if (state->conn->scram_channel_binding == NULL ||
+                        strlen(state->conn->scram_channel_binding) == 0)
                appendPQExpBuffer(&buf, "c=biws");      /* base64 of "n,," */
-       else if (state->ssl_in_use)
+       else if (state->conn->ssl_in_use)
                appendPQExpBuffer(&buf, "c=eSws");      /* base64 of "y,," */
        else
                appendPQExpBuffer(&buf, "c=biws");      /* base64 of "n,," */
diff --git a/src/interfaces/libpq/fe-auth.c b/src/interfaces/libpq/fe-auth.c
index bb9b0573d1..b1ea0e7cef 100644
--- a/src/interfaces/libpq/fe-auth.c
+++ b/src/interfaces/libpq/fe-auth.c
@@ -491,10 +491,6 @@ pg_SASL_init(PGconn *conn, int payloadlen)
        bool            success;
        const char *selected_mechanism;
        PQExpBufferData mechanism_buf;
-       char       *tls_finished = NULL;
-       size_t          tls_finished_len = 0;
-       char       *certificate_hash = NULL;
-       size_t          certificate_hash_len = 0;
        char       *password;
 
        initPQExpBuffer(&mechanism_buf);
@@ -572,40 +568,15 @@ pg_SASL_init(PGconn *conn, int payloadlen)
                goto error;
        }
 
-#ifdef USE_SSL
-
-       /*
-        * Get data for channel binding.
-        */
-       if (strcmp(selected_mechanism, SCRAM_SHA256_PLUS_NAME) == 0)
-       {
-               tls_finished = pgtls_get_finished(conn, &tls_finished_len);
-               if (tls_finished == NULL)
-                       goto oom_error;
-
-               certificate_hash =
-                       pgtls_get_peer_certificate_hash(conn,
-                                                                               
        &certificate_hash_len);
-               if (certificate_hash == NULL)
-                       goto error;             /* error message is set */
-       }
-#endif
-
        /*
         * Initialize the SASL state information with all the information 
gathered
         * during the initial exchange.
         *
         * Note: Only tls-unique is supported for the moment.
         */
-       conn->sasl_state = pg_fe_scram_init(conn->pguser,
+       conn->sasl_state = pg_fe_scram_init(conn,
                                                                                
password,
-                                                                               
conn->ssl_in_use,
-                                                                               
selected_mechanism,
-                                                                               
conn->scram_channel_binding,
-                                                                               
tls_finished,
-                                                                               
tls_finished_len,
-                                                                               
certificate_hash,
-                                                                               
certificate_hash_len);
+                                                                               
selected_mechanism);
        if (!conn->sasl_state)
                goto oom_error;
 
diff --git a/src/interfaces/libpq/fe-auth.h b/src/interfaces/libpq/fe-auth.h
index 68de8b6e32..4658a12837 100644
--- a/src/interfaces/libpq/fe-auth.h
+++ b/src/interfaces/libpq/fe-auth.h
@@ -23,15 +23,9 @@ extern int   pg_fe_sendauth(AuthRequest areq, int 
payloadlen, PGconn *conn);
 extern char *pg_fe_getauthname(PQExpBuffer errorMessage);
 
 /* Prototypes for functions in fe-auth-scram.c */
-extern void *pg_fe_scram_init(const char *username,
+extern void *pg_fe_scram_init(PGconn *conn,
                                 const char *password,
-                                bool ssl_in_use,
-                                const char *sasl_mechanism,
-                                const char *channel_binding_type,
-                                char *tls_finished_message,
-                                size_t tls_finished_len,
-                                char *certificate_hash,
-                                size_t certificate_hash_len);
+                                const char *sasl_mechanism);
 extern void pg_fe_scram_free(void *opaq);
 extern void pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
                                         char **output, int *outputlen,
diff --git a/src/interfaces/libpq/fe-secure-openssl.c 
b/src/interfaces/libpq/fe-secure-openssl.c
index 99077c3d9a..1acaffc488 100644
--- a/src/interfaces/libpq/fe-secure-openssl.c
+++ b/src/interfaces/libpq/fe-secure-openssl.c
@@ -428,9 +428,13 @@ pgtls_get_finished(PGconn *conn, size_t *len)
  * as per RFC 5929 (https://tools.ietf.org/html/rfc5929#section-4.1).
  * NULL is sent back to the caller in the event of an error, with an
  * error message for the caller to consume.
+ * If an error happens while processing, fill in errorMessage but do
+ * not append it to the connection's error message buffer as this
+ * gets passed down later on.
  */
 char *
-pgtls_get_peer_certificate_hash(PGconn *conn, size_t *len)
+pgtls_get_peer_certificate_hash(PGconn *conn, size_t *len,
+                                                               PQExpBuffer 
errorMessage)
 {
        char       *cert_hash = NULL;
 
@@ -451,7 +455,7 @@ pgtls_get_peer_certificate_hash(PGconn *conn, size_t *len)
                if (!OBJ_find_sigid_algs(X509_get_signature_nid(peer_cert),
                                                                 &algo_nid, 
NULL))
                {
-                       printfPQExpBuffer(&conn->errorMessage,
+                       printfPQExpBuffer(errorMessage,
                                                          libpq_gettext("could 
not find signature algorithm\n"));
                        return NULL;
                }
@@ -467,7 +471,7 @@ pgtls_get_peer_certificate_hash(PGconn *conn, size_t *len)
                                algo_type = EVP_get_digestbynid(algo_nid);
                                if (algo_type == NULL)
                                {
-                                       printfPQExpBuffer(&conn->errorMessage,
+                                       printfPQExpBuffer(errorMessage,
                                                                          
libpq_gettext("could not find digest for NID %s\n"),
                                                                          
OBJ_nid2sn(algo_nid));
                                        return NULL;
@@ -478,7 +482,7 @@ pgtls_get_peer_certificate_hash(PGconn *conn, size_t *len)
                if (!X509_digest(peer_cert, algo_type, (unsigned char *) hash,
                                                 &hash_size))
                {
-                       printfPQExpBuffer(&conn->errorMessage,
+                       printfPQExpBuffer(errorMessage,
                                                          libpq_gettext("could 
not generate peer certificate hash\n"));
                        return NULL;
                }
@@ -487,7 +491,7 @@ pgtls_get_peer_certificate_hash(PGconn *conn, size_t *len)
                cert_hash = (char *) malloc(hash_size);
                if (cert_hash == NULL)
                {
-                       printfPQExpBuffer(&conn->errorMessage,
+                       printfPQExpBuffer(errorMessage,
                                                          libpq_gettext("out of 
memory\n"));
                        return NULL;
                }
diff --git a/src/interfaces/libpq/libpq-int.h b/src/interfaces/libpq/libpq-int.h
index 756c4d61e1..a946aa4048 100644
--- a/src/interfaces/libpq/libpq-int.h
+++ b/src/interfaces/libpq/libpq-int.h
@@ -672,7 +672,8 @@ extern ssize_t pgtls_read(PGconn *conn, void *ptr, size_t 
len);
 extern bool pgtls_read_pending(PGconn *conn);
 extern ssize_t pgtls_write(PGconn *conn, const void *ptr, size_t len);
 extern char *pgtls_get_finished(PGconn *conn, size_t *len);
-extern char *pgtls_get_peer_certificate_hash(PGconn *conn, size_t *len);
+extern char *pgtls_get_peer_certificate_hash(PGconn *conn, size_t *len,
+                                                                               
         PQExpBuffer errorMessage);
 
 /*
  * this is so that we can check if a connection is non-blocking internally
-- 
2.15.1

Attachment: signature.asc
Description: PGP signature

Reply via email to