On 11/16/17 16:50, Michael Paquier wrote:
> On Thu, Nov 16, 2017 at 10:47 PM, Peter Eisentraut
> <peter.eisentr...@2ndquadrant.com> wrote:
>> Pushed 0001, will continue with 0002.
> 
> Thanks!

I have combed through the 0002 patch in detail now.  I have attached my
result.

I cleaned up a bunch of variable names to be more precise.  I also
removed some unnecessary code.  For example the whole deal about
channel_binding_advertised was not necessary, because it's implied by
the chosen SASL mechanism.  We also don't need to have separate USE_SSL
sections in most cases, because state->ssl_in_use already takes care of
that.

I made some significant changes to the logic.

The selection of the channel binding flag (n/y/p) in the client seemed
wrong.  Your code treated 'y' as an error, but I think that is a
legitimate case, for example a PG11 client connecting to a PG10 server
over SSL.  The client supports channel binding in that case and
(correctly) thinks the server does not, so we use the 'y' flag and
proceed normally without channel binding.

The selection of the mechanism in the client was similarly incorrect, I
think.  The code would not tolerate a situation were an SSL connection
is in use but the server does not advertise the -PLUS mechanism, which
again could be from a PG10 server.

The creation of the channel binding data didn't match the spec, because
the gs2-header (p=type,,) was not included in the data put through
base64.  This was done incorrectly on both server and client, so the
protocol still worked.  (However, in the non-channel-binding case we
hardcode "biws", which is exactly the base64-encoding of the gs2-header.
 So that was inconsistent.)

I think we also need to backpatch a bug fix into PG10 so that the server
can accept base64("y,,") as channel binding data.  Otherwise, the above
scenario of a PG11 client connecting to a PG10 server over SSL will
currently fail because the server will not accept the channel binding data.

Please check my patch and think through these changes.  I'm happy to
commit the patch as is if there are no additional insights.

-- 
Peter Eisentraut              http://www.2ndQuadrant.com/
PostgreSQL Development, 24x7 Support, Remote DBA, Training & Services
From ae28f935473afe9f754a7a0ec2a6eca0162ab445 Mon Sep 17 00:00:00 2001
From: Peter Eisentraut <pete...@gmx.net>
Date: Fri, 17 Nov 2017 14:15:50 -0500
Subject: [PATCH] Support channel binding 'tls-unique' in SCRAM

This is the basic feature set using OpenSSL to support the feature.  In
order to allow the frontend and the backend to fetch the sent and
expected TLS finish messages, a PG-like API is added to be able to make
the interface pluggable for other SSL implementations.

This commit also adds a lot of basic infrastructure to facilitate the
addition of future channel binding types as well as libpq parameters to
control the SASL mechanism names and channel binding names. Those will
be added by upcoming commits.

A set of basic tests to stress default channel binding handling is added
as well, though those are part of the SSL suite.
---
 doc/src/sgml/protocol.sgml               |  31 ++++--
 src/backend/libpq/auth-scram.c           | 180 +++++++++++++++++++++++++------
 src/backend/libpq/auth.c                 |  54 ++++++++--
 src/backend/libpq/be-secure-openssl.c    |  24 +++++
 src/include/libpq/libpq-be.h             |   1 +
 src/include/libpq/scram.h                |  10 +-
 src/interfaces/libpq/fe-auth-scram.c     | 170 +++++++++++++++++++++++++----
 src/interfaces/libpq/fe-auth.c           |  90 +++++++++++-----
 src/interfaces/libpq/fe-auth.h           |   7 +-
 src/interfaces/libpq/fe-secure-openssl.c |  27 +++++
 src/interfaces/libpq/libpq-int.h         |   5 +-
 src/test/ssl/ServerSetup.pm              |  27 +++--
 src/test/ssl/t/001_ssltests.pl           |   2 +-
 src/test/ssl/t/002_scram.pl              |  38 +++++++
 14 files changed, 554 insertions(+), 112 deletions(-)
 create mode 100644 src/test/ssl/t/002_scram.pl

diff --git a/doc/src/sgml/protocol.sgml b/doc/src/sgml/protocol.sgml
index 6d4dcf83ac..4d3b6446c4 100644
--- a/doc/src/sgml/protocol.sgml
+++ b/doc/src/sgml/protocol.sgml
@@ -1461,10 +1461,11 @@ <title>SASL Authentication</title>
 
 <para>
 <firstterm>SASL</firstterm> is a framework for authentication in 
connection-oriented
-protocols. At the moment, <productname>PostgreSQL</productname> implements 
only one SASL
-authentication mechanism, SCRAM-SHA-256, but more might be added in the
-future. The below steps illustrate how SASL authentication is performed in
-general, while the next subsection gives more details on SCRAM-SHA-256.
+protocols. At the moment, <productname>PostgreSQL</productname> implements two 
SASL
+authentication mechanisms, SCRAM-SHA-256 and SCRAM-SHA-256-PLUS. More
+might be added in the future. The below steps illustrate how SASL
+authentication is performed in general, while the next subsection gives
+more details on SCRAM-SHA-256 and SCRAM-SHA-256-PLUS.
 </para>
 
 <procedure>
@@ -1518,9 +1519,10 @@ <title>SASL Authentication Message Flow</title>
   <title>SCRAM-SHA-256 authentication</title>
 
   <para>
-    <firstterm>SCRAM-SHA-256</firstterm> (called just 
<firstterm>SCRAM</firstterm> from now on) is
-    the only implemented SASL mechanism, at the moment. It is described in 
detail
-    in RFC 7677 and RFC 5802.
+   The implemented SASL mechanisms at the moment
+   are <literal>SCRAM-SHA-256</literal> and its variant with channel
+   binding <literal>SCRAM-SHA-256-PLUS</literal>. They are described in
+   detail in RFC 7677 and RFC 5802.
   </para>
 
   <para>
@@ -1547,7 +1549,10 @@ <title>SCRAM-SHA-256 authentication</title>
   </para>
 
   <para>
-<firstterm>Channel binding</firstterm> has not been implemented yet.
+<firstterm>Channel binding</firstterm> is supported in PostgreSQL builds with
+SSL support. The SASL mechanism name for SCRAM with channel binding
+is <literal>SCRAM-SHA-256-PLUS</literal>.  The only channel binding type
+supported at the moment is <literal>tls-unique</literal>, defined in RFC 5929.
   </para>
 
 <procedure>
@@ -1556,13 +1561,19 @@ <title>Example</title>
 <para>
   The server sends an AuthenticationSASL message. It includes a list of
   SASL authentication mechanisms that the server can accept.
+  This will be <literal>SCRAM-SHA-256-PLUS</literal>
+  and <literal>SCRAM-SHA-256</literal> if the server is built with SSL
+  support, or else just the latter.
 </para>
 </step>
 <step id="scram-client-first">
 <para>
   The client responds by sending a SASLInitialResponse message, which
-  indicates the chosen mechanism, <literal>SCRAM-SHA-256</literal>. In the 
Initial
-  Client response field, the message contains the SCRAM
+  indicates the chosen mechanism, <literal>SCRAM-SHA-256</literal> or
+  <literal>SCRAM-SHA-256-PLUS</literal>. (A client is free to choose either
+  mechanism, but for better security it should choose the channel-binding
+  variant if it can support it.) In the Initial Client response field,
+  the message contains the SCRAM
   <structname>client-first-message</structname>.
 </para>
 </step>
diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c
index ec4bb9a88e..62c8d67e56 100644
--- a/src/backend/libpq/auth-scram.c
+++ b/src/backend/libpq/auth-scram.c
@@ -17,8 +17,6 @@
  *      by the SASLprep profile, we skip the SASLprep pre-processing and use
  *      the raw bytes in calculating the hash.
  *
- * - Channel binding is not supported yet.
- *
  *
  * The password stored in pg_authid consists of the iteration count, salt,
  * StoredKey and ServerKey.
@@ -112,6 +110,11 @@ typedef struct
 
        const char *username;           /* username from startup packet */
 
+       bool            ssl_in_use;
+       const char *tls_finished_message;
+       size_t          tls_finished_len;
+       char       *channel_binding_type;
+
        int                     iterations;
        char       *salt;                       /* base64-encoded */
        uint8           StoredKey[SCRAM_KEY_LEN];
@@ -168,7 +171,11 @@ static char *scram_mock_salt(const char *username);
  * it will fail, as if an incorrect password was given.
  */
 void *
-pg_be_scram_init(const char *username, const char *shadow_pass)
+pg_be_scram_init(const char *username,
+                                const char *shadow_pass,
+                                bool ssl_in_use,
+                                const char *tls_finished_message,
+                                size_t tls_finished_len)
 {
        scram_state *state;
        bool            got_verifier;
@@ -176,6 +183,10 @@ pg_be_scram_init(const char *username, const char 
*shadow_pass)
        state = (scram_state *) palloc0(sizeof(scram_state));
        state->state = SCRAM_AUTH_INIT;
        state->username = username;
+       state->ssl_in_use = ssl_in_use;
+       state->tls_finished_message = tls_finished_message;
+       state->tls_finished_len = tls_finished_len;
+       state->channel_binding_type = NULL;
 
        /*
         * Parse the stored password verifier.
@@ -773,31 +784,88 @@ read_client_first_message(scram_state *state, char *input)
         *------
         */
 
-       /* read gs2-cbind-flag */
+       /*
+        * Read gs2-cbind-flag.  (For details see also RFC 5802 Section 6 
"Channel
+        * Binding".)
+        */
        switch (*input)
        {
                case 'n':
-                       /* Client does not support channel binding */
+                       /*
+                        * The client does not support channel binding or has 
simply
+                        * decided to not use it.  In that case just let it go.
+                        */
+                       input++;
+                       if (*input != ',')
+                               ereport(ERROR,
+                                               
(errcode(ERRCODE_PROTOCOL_VIOLATION),
+                                                errmsg("malformed SCRAM 
message"),
+                                                errdetail("Comma expected, but 
found character \"%s\".",
+                                                                  
sanitize_char(*input))));
                        input++;
                        break;
                case 'y':
-                       /* Client supports channel binding, but we're not doing 
it today */
+                       /*
+                        * The client supports channel binding and thinks that 
the server
+                        * does not.  In this case, the server must fail 
authentication if
+                        * it supports channel binding, which in this 
implementation is
+                        * the case if a connection is using SSL.
+                        */
+                       if (state->ssl_in_use)
+                               ereport(ERROR,
+                                               
(errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
+                                                errmsg("SCRAM channel binding 
negotiation error"),
+                                                errdetail("The client supports 
SCRAM channel binding but thinks the server does not.  "
+                                                                  "However, 
this server does support channel binding.")));
+                       input++;
+                       if (*input != ',')
+                               ereport(ERROR,
+                                               
(errcode(ERRCODE_PROTOCOL_VIOLATION),
+                                                errmsg("malformed SCRAM 
message"),
+                                                errdetail("Comma expected, but 
found character \"%s\".",
+                                                                  
sanitize_char(*input))));
                        input++;
                        break;
                case 'p':
-
                        /*
-                        * Client requires channel binding.  We don't support 
it.
-                        *
-                        * RFC 5802 specifies a particular error code,
-                        * e=server-does-support-channel-binding, for this.  
But it can
-                        * only be sent in the server-final message, and we 
don't want to
-                        * go through the motions of the authentication, 
knowing it will
-                        * fail, just to send that error message.
+                        * The client requires channel biding.  Channel binding 
type
+                        * follows, e.g., "p=tls-unique".
                         */
-                       ereport(ERROR,
-                                       (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
-                                        errmsg("client requires SCRAM channel 
binding, but it is not supported")));
+                       {
+                               char *channel_binding_type;
+
+                               if (!state->ssl_in_use)
+                                       /*
+                                        * Without SSL, we don't support 
channel binding.
+                                        *
+                                        * RFC 5802 specifies a particular 
error code,
+                                        * 
e=server-does-support-channel-binding, for this.  But
+                                        * it can only be sent in the 
server-final message, and we
+                                        * don't want to go through the motions 
of the
+                                        * authentication, knowing it will 
fail, just to send that
+                                        * error message.
+                                        */
+                                       ereport(ERROR,
+                                                       
(errcode(ERRCODE_PROTOCOL_VIOLATION),
+                                                        errmsg("client 
requires SCRAM channel binding, but it is not supported")));
+
+                               /*
+                                * Read value provided by client; only 
tls-unique is supported
+                                * for now.  XXX Not sure whether it would be 
safe to print
+                                * the name of an unsupported binding type in 
the error
+                                * message.  Pranksters could print arbitrary 
strings into the
+                                * log that way.
+                                */
+                               channel_binding_type = read_attr_value(&input, 
'p');
+                               if (strcmp(channel_binding_type, 
SCRAM_CHANNEL_BINDING_TLS_UNIQUE) != 0)
+                                       ereport(ERROR,
+                                                       
(errcode(ERRCODE_PROTOCOL_VIOLATION),
+                                                        (errmsg("unsupported 
SCRAM channel-binding type"))));
+
+                               /* Save the name for handling of subsequent 
messages */
+                               state->channel_binding_type = 
pstrdup(channel_binding_type);
+                       }
+                       break;
                default:
                        ereport(ERROR,
                                        (errcode(ERRCODE_PROTOCOL_VIOLATION),
@@ -805,13 +873,6 @@ read_client_first_message(scram_state *state, char *input)
                                         errdetail("Unexpected channel-binding 
flag \"%s\".",
                                                           
sanitize_char(*input))));
        }
-       if (*input != ',')
-               ereport(ERROR,
-                               (errcode(ERRCODE_PROTOCOL_VIOLATION),
-                                errmsg("malformed SCRAM message"),
-                                errdetail("Comma expected, but found character 
\"%s\".",
-                                                  sanitize_char(*input))));
-       input++;
 
        /*
         * Forbid optional authzid (authorization identity).  We don't support 
it.
@@ -1032,14 +1093,73 @@ read_client_final_message(scram_state *state, char 
*input)
         */
 
        /*
-        * Read channel-binding.  We don't support channel binding, so it's
-        * expected to always be "biws", which is "n,,", base64-encoded.
+        * Read channel binding.  This repeats the channel-binding flags and is
+        * then followed by the actual binding data depending on the type.
         */
        channel_binding = read_attr_value(&p, 'c');
-       if (strcmp(channel_binding, "biws") != 0)
-               ereport(ERROR,
-                               (errcode(ERRCODE_PROTOCOL_VIOLATION),
-                                (errmsg("unexpected SCRAM channel-binding 
attribute in client-final-message"))));
+       if (state->channel_binding_type)
+       {
+               const char *cbind_data = NULL;
+               size_t          cbind_data_len = 0;
+               size_t          cbind_header_len;
+               char       *cbind_input;
+               size_t          cbind_input_len;
+               char       *b64_message;
+               int                     b64_message_len;
+
+               /*
+                * Fetch data appropriate for channel binding type
+                */
+               if (strcmp(state->channel_binding_type, 
SCRAM_CHANNEL_BINDING_TLS_UNIQUE) == 0)
+               {
+                       cbind_data = state->tls_finished_message;
+                       cbind_data_len = state->tls_finished_len;
+               }
+               else
+               {
+                       /* should not happen */
+                       elog(ERROR, "invalid channel binding type");
+               }
+
+               /* should not happen */
+               if (cbind_data == NULL || cbind_data_len == 0)
+                       elog(ERROR, "empty channel binding data for channel 
binding type \"%s\"",
+                                state->channel_binding_type);
+
+               cbind_header_len = 4 + strlen(state->channel_binding_type); /* 
p=type,, */
+               cbind_input_len = cbind_header_len + cbind_data_len;
+               cbind_input = palloc(cbind_input_len);
+               snprintf(cbind_input, cbind_input_len, "p=%s", 
state->channel_binding_type);
+               memcpy(cbind_input + cbind_header_len, cbind_data, 
cbind_data_len);
+
+               b64_message = palloc(pg_b64_enc_len(cbind_input_len) + 1);
+               b64_message_len = pg_b64_encode(cbind_input, cbind_input_len,
+                                                                               
b64_message);
+               b64_message[b64_message_len] = '\0';
+
+               /*
+                * Compare the value sent by the client with the value expected 
by
+                * the server.
+                */
+               if (strcmp(channel_binding, b64_message) != 0)
+                       ereport(ERROR,
+                                       
(errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
+                                        (errmsg("SCRAM channel binding check 
failed"))));
+       }
+       else
+       {
+               /*
+                * If we are not using channel binding, the binding data is 
expected
+                * to always be "biws", which is "n,," base64-encoded, or 
"eSws",
+                * which is "y,,".
+                */
+               if (strcmp(channel_binding, "biws") != 0 &&
+                       strcmp(channel_binding, "eSws") != 0)
+                       ereport(ERROR,
+                                       (errcode(ERRCODE_PROTOCOL_VIOLATION),
+                                        (errmsg("unexpected SCRAM 
channel-binding attribute in client-final-message"))));
+       }
+
        state->client_final_nonce = read_attr_value(&p, 'r');
 
        /* ignore optional extensions */
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index 6c915a7289..2dd3328d71 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -860,6 +860,8 @@ CheckMD5Auth(Port *port, char *shadow_pass, char 
**logdetail)
 static int
 CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
 {
+       char       *sasl_mechs;
+       char       *p;
        int                     mtype;
        StringInfoData buf;
        void       *scram_opaq;
@@ -869,6 +871,8 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char 
**logdetail)
        int                     inputlen;
        int                     result;
        bool            initial;
+       char       *tls_finished = NULL;
+       size_t          tls_finished_len = 0;
 
        /*
         * SASL auth is not supported for protocol versions before 3, because it
@@ -885,12 +889,39 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char 
**logdetail)
 
        /*
         * Send the SASL authentication request to user.  It includes the list 
of
-        * authentication mechanisms (which is trivial, because we only support
-        * SCRAM-SHA-256 at the moment).  The extra "\0" is for an empty string 
to
-        * terminate the list.
+        * authentication mechanisms that are supported.  The order of 
mechanisms
+        * is advertised in decreasing order of importance.  So the
+        * channel-binding variants go first, if they are supported.  Channel
+        * binding is only supported in SSL builds.
         */
-       sendAuthRequest(port, AUTH_REQ_SASL, SCRAM_SHA256_NAME "\0",
-                                       strlen(SCRAM_SHA256_NAME) + 2);
+       sasl_mechs = palloc(strlen(SCRAM_SHA256_PLUS_NAME) +
+                                               strlen(SCRAM_SHA256_NAME) + 3);
+       p = sasl_mechs;
+
+       if (port->ssl_in_use)
+       {
+               strcpy(p, SCRAM_SHA256_PLUS_NAME);
+               p += strlen(SCRAM_SHA256_PLUS_NAME) + 1;
+       }
+
+       strcpy(p, SCRAM_SHA256_NAME);
+       p += strlen(SCRAM_SHA256_NAME) + 1;
+
+       /* Put another '\0' to mark that list is finished. */
+       p[0] = '\0';
+
+       sendAuthRequest(port, AUTH_REQ_SASL, sasl_mechs, p - sasl_mechs + 1);
+       pfree(sasl_mechs);
+
+#ifdef USE_SSL
+       /*
+        * Get data for channel binding.
+        */
+       if (port->ssl_in_use)
+       {
+               tls_finished = be_tls_get_peer_finished(port, 
&tls_finished_len);
+       }
+#endif
 
        /*
         * Initialize the status tracker for message exchanges.
@@ -903,7 +934,11 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char 
**logdetail)
         * This is because we don't want to reveal to an attacker what usernames
         * are valid, nor which users have a valid password.
         */
-       scram_opaq = pg_be_scram_init(port->user_name, shadow_pass);
+       scram_opaq = pg_be_scram_init(port->user_name,
+                                                                 shadow_pass,
+                                                                 
port->ssl_in_use,
+                                                                 tls_finished,
+                                                                 
tls_finished_len);
 
        /*
         * Loop through SASL message exchange.  This exchange can consist of
@@ -951,12 +986,9 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char 
**logdetail)
                {
                        const char *selected_mech;
 
-                       /*
-                        * We only support SCRAM-SHA-256 at the moment, so 
anything else
-                        * is an error.
-                        */
                        selected_mech = pq_getmsgrawstring(&buf);
-                       if (strcmp(selected_mech, SCRAM_SHA256_NAME) != 0)
+                       if (strcmp(selected_mech, SCRAM_SHA256_NAME) != 0 &&
+                               strcmp(selected_mech, SCRAM_SHA256_PLUS_NAME) 
!= 0)
                        {
                                ereport(ERROR,
                                                
(errcode(ERRCODE_PROTOCOL_VIOLATION),
diff --git a/src/backend/libpq/be-secure-openssl.c 
b/src/backend/libpq/be-secure-openssl.c
index fe15227a77..1e3e19f5e0 100644
--- a/src/backend/libpq/be-secure-openssl.c
+++ b/src/backend/libpq/be-secure-openssl.c
@@ -1215,6 +1215,30 @@ be_tls_get_peerdn_name(Port *port, char *ptr, size_t len)
                ptr[0] = '\0';
 }
 
+/*
+ * Routine to get the expected TLS Finished message information from the
+ * client, useful for authorization when doing channel binding.
+ *
+ * Result is a palloc'd copy of the TLS Finished message with its size.
+ */
+char *
+be_tls_get_peer_finished(Port *port, size_t *len)
+{
+       char            dummy[1];
+       char       *result;
+
+       /*
+        * OpenSSL does not offer an API to directly get the length of the
+        * expected TLS Finished message, so just do a dummy call to grab this
+        * information to allow caller to do an allocation with a correct size.
+        */
+       *len = SSL_get_peer_finished(port->ssl, dummy, sizeof(dummy));
+       result = palloc(*len);
+       (void) SSL_get_peer_finished(port->ssl, result, *len);
+
+       return result;
+}
+
 /*
  * Convert an X509 subject name to a cstring.
  *
diff --git a/src/include/libpq/libpq-be.h b/src/include/libpq/libpq-be.h
index 7bde744d51..856e0439d5 100644
--- a/src/include/libpq/libpq-be.h
+++ b/src/include/libpq/libpq-be.h
@@ -209,6 +209,7 @@ extern bool be_tls_get_compression(Port *port);
 extern void be_tls_get_version(Port *port, char *ptr, size_t len);
 extern void be_tls_get_cipher(Port *port, char *ptr, size_t len);
 extern void be_tls_get_peerdn_name(Port *port, char *ptr, size_t len);
+extern char *be_tls_get_peer_finished(Port *port, size_t *len);
 #endif
 
 extern ProtocolVersion FrontendProtocol;
diff --git a/src/include/libpq/scram.h b/src/include/libpq/scram.h
index 0166e1945d..99560d3d2f 100644
--- a/src/include/libpq/scram.h
+++ b/src/include/libpq/scram.h
@@ -13,8 +13,12 @@
 #ifndef PG_SCRAM_H
 #define PG_SCRAM_H
 
-/* Name of SCRAM-SHA-256 per IANA */
+/* Name of SCRAM mechanisms per IANA */
 #define SCRAM_SHA256_NAME "SCRAM-SHA-256"
+#define SCRAM_SHA256_PLUS_NAME "SCRAM-SHA-256-PLUS"    /* with channel binding 
*/
+
+/* Channel binding types */
+#define SCRAM_CHANNEL_BINDING_TLS_UNIQUE       "tls-unique"
 
 /* Status codes for message exchange */
 #define SASL_EXCHANGE_CONTINUE         0
@@ -22,7 +26,9 @@
 #define SASL_EXCHANGE_FAILURE          2
 
 /* Routines dedicated to authentication */
-extern void *pg_be_scram_init(const char *username, const char *shadow_pass);
+extern void *pg_be_scram_init(const char *username, const char *shadow_pass,
+                                        bool ssl_in_use, const char 
*tls_finished_message,
+                                        size_t tls_finished_len);
 extern int pg_be_scram_exchange(void *opaq, char *input, int inputlen,
                                         char **output, int *outputlen, char 
**logdetail);
 
diff --git a/src/interfaces/libpq/fe-auth-scram.c 
b/src/interfaces/libpq/fe-auth-scram.c
index edfd42df85..e633e56434 100644
--- a/src/interfaces/libpq/fe-auth-scram.c
+++ b/src/interfaces/libpq/fe-auth-scram.c
@@ -17,6 +17,7 @@
 #include "common/base64.h"
 #include "common/saslprep.h"
 #include "common/scram-common.h"
+#include "libpq/scram.h"
 #include "fe-auth.h"
 
 /* These are needed for getpid(), in the fallback implementation */
@@ -44,6 +45,11 @@ typedef struct
        /* These are supplied by the user */
        const char *username;
        char       *password;
+       bool            ssl_in_use;
+       char       *tls_finished_message;
+       size_t          tls_finished_len;
+       char       *sasl_mechanism;
+       const char *channel_binding_type;
 
        /* We construct these */
        uint8           SaltedPassword[SCRAM_KEY_LEN];
@@ -79,25 +85,50 @@ static bool pg_frontend_random(char *dst, int len);
 
 /*
  * Initialize SCRAM exchange status.
+ *
+ * The non-const char* arguments should be passed in malloc'ed.  They will be
+ * freed by pg_fe_scram_free().
  */
 void *
-pg_fe_scram_init(const char *username, const char *password)
+pg_fe_scram_init(const char *username,
+                                const char *password,
+                                bool ssl_in_use,
+                                const char *sasl_mechanism,
+                                char *tls_finished_message,
+                                size_t tls_finished_len)
 {
        fe_scram_state *state;
        char       *prep_password;
        pg_saslprep_rc rc;
 
+       Assert(sasl_mechanism != NULL);
+
        state = (fe_scram_state *) malloc(sizeof(fe_scram_state));
        if (!state)
                return NULL;
        memset(state, 0, sizeof(fe_scram_state));
        state->state = FE_SCRAM_INIT;
        state->username = username;
+       state->ssl_in_use = ssl_in_use;
+       state->tls_finished_message = tls_finished_message;
+       state->tls_finished_len = tls_finished_len;
+       state->sasl_mechanism = strdup(sasl_mechanism);
+       if (!state->sasl_mechanism)
+       {
+               free(state);
+               return NULL;
+       }
+
+       /*
+        * Store channel binding type.  Only one type is currently supported.
+        */
+       state->channel_binding_type = SCRAM_CHANNEL_BINDING_TLS_UNIQUE;
 
        /* Normalize the password with SASLprep, if possible */
        rc = pg_saslprep(password, &prep_password);
        if (rc == SASLPREP_OOM)
        {
+               free(state->sasl_mechanism);
                free(state);
                return NULL;
        }
@@ -106,6 +137,7 @@ pg_fe_scram_init(const char *username, const char *password)
                prep_password = strdup(password);
                if (!prep_password)
                {
+                       free(state->sasl_mechanism);
                        free(state);
                        return NULL;
                }
@@ -125,6 +157,10 @@ pg_fe_scram_free(void *opaq)
 
        if (state->password)
                free(state->password);
+       if (state->tls_finished_message)
+               free(state->tls_finished_message);
+       if (state->sasl_mechanism)
+               free(state->sasl_mechanism);
 
        /* client messages */
        if (state->client_nonce)
@@ -297,9 +333,10 @@ static char *
 build_client_first_message(fe_scram_state *state, PQExpBuffer errormessage)
 {
        char            raw_nonce[SCRAM_RAW_NONCE_LEN + 1];
-       char       *buf;
-       char            buflen;
+       char       *result;
+       int                     channel_info_len;
        int                     encoded_len;
+       PQExpBufferData buf;
 
        /*
         * Generate a "raw" nonce.  This is converted to ASCII-printable form by
@@ -328,26 +365,61 @@ build_client_first_message(fe_scram_state *state, 
PQExpBuffer errormessage)
         * prepared with SASLprep, the message parsing would fail if it includes
         * '=' or ',' characters.
         */
-       buflen = 8 + strlen(state->client_nonce) + 1;
-       buf = malloc(buflen);
-       if (buf == NULL)
+
+       initPQExpBuffer(&buf);
+
+       /*
+        * First build the gs2-header with channel binding information.
+        */
+       if (strcmp(state->sasl_mechanism, SCRAM_SHA256_PLUS_NAME) == 0)
        {
-               printfPQExpBuffer(errormessage,
-                                                 libpq_gettext("out of 
memory\n"));
-               return NULL;
+               Assert(state->ssl_in_use);
+               appendPQExpBuffer(&buf, "p=%s", state->channel_binding_type);
        }
-       snprintf(buf, buflen, "n,,n=,r=%s", state->client_nonce);
-
-       state->client_first_message_bare = strdup(buf + 3);
-       if (!state->client_first_message_bare)
+       else if (state->ssl_in_use)
        {
-               free(buf);
-               printfPQExpBuffer(errormessage,
-                                                 libpq_gettext("out of 
memory\n"));
-               return NULL;
+               /*
+                * Client supports channel binding, but thinks the server does 
not.
+                */
+               appendPQExpBuffer(&buf, "y");
        }
+       else
+       {
+               /*
+                * Client does not support channel binding.
+                */
+               appendPQExpBuffer(&buf, "n");
+       }
+
+       if (PQExpBufferDataBroken(buf))
+               goto oom_error;
+
+       channel_info_len = buf.len;
+
+       appendPQExpBuffer(&buf, ",,n=,r=%s", state->client_nonce);
+       if (PQExpBufferDataBroken(buf))
+               goto oom_error;
+
+       /*
+        * The first message content needs to be saved without channel binding
+        * information.
+        */
+       state->client_first_message_bare = strdup(buf.data + channel_info_len + 
2);
+       if (!state->client_first_message_bare)
+               goto oom_error;
+
+       result = strdup(buf.data);
+       if (result == NULL)
+               goto oom_error;
+
+       termPQExpBuffer(&buf);
+       return result;
 
-       return buf;
+oom_error:
+       termPQExpBuffer(&buf);
+       printfPQExpBuffer(errormessage,
+                                         libpq_gettext("out of memory\n"));
+       return NULL;
 }
 
 /*
@@ -366,7 +438,67 @@ build_client_final_message(fe_scram_state *state, 
PQExpBuffer errormessage)
         * Construct client-final-message-without-proof.  We need to remember it
         * for verifying the server proof in the final step of authentication.
         */
-       appendPQExpBuffer(&buf, "c=biws,r=%s", state->nonce);
+       if (strcmp(state->sasl_mechanism, SCRAM_SHA256_PLUS_NAME) == 0)
+       {
+               char       *cbind_data;
+               size_t          cbind_data_len;
+               size_t          cbind_header_len;
+               char       *cbind_input;
+               size_t          cbind_input_len;
+
+               if (strcmp(state->channel_binding_type, 
SCRAM_CHANNEL_BINDING_TLS_UNIQUE) == 0)
+               {
+                       cbind_data = state->tls_finished_message;
+                       cbind_data_len = state->tls_finished_len;
+               }
+               else
+               {
+                       /* should not happen */
+                       termPQExpBuffer(&buf);
+                       printfPQExpBuffer(errormessage,
+                                                         
libpq_gettext("invalid channel binding type\n"));
+                       return NULL;
+               }
+
+               /* should not happen */
+               if (cbind_data == NULL || cbind_data_len == 0)
+               {
+                       termPQExpBuffer(&buf);
+                       printfPQExpBuffer(errormessage,
+                                                         libpq_gettext("empty 
channel binding data for channel binding type \"%s\"\n"),
+                                                         
state->channel_binding_type);
+                       return NULL;
+               }
+
+               appendPQExpBuffer(&buf, "c=");
+
+               cbind_header_len = 4 + strlen(state->channel_binding_type); /* 
p=type,, */
+               cbind_input_len = cbind_header_len + cbind_data_len;
+               cbind_input = malloc(cbind_input_len);
+               if (!cbind_input)
+                       goto oom_error;
+               snprintf(cbind_input, cbind_input_len, "p=%s", 
state->channel_binding_type);
+               memcpy(cbind_input + cbind_header_len, cbind_data, 
cbind_data_len);
+
+               if (!enlargePQExpBuffer(&buf, pg_b64_enc_len(cbind_input_len)))
+               {
+                       free(cbind_input);
+                       goto oom_error;
+               }
+               buf.len += pg_b64_encode(cbind_input, cbind_input_len, buf.data 
+ buf.len);
+               buf.data[buf.len] = '\0';
+
+               free(cbind_input);
+       }
+       else if (state->ssl_in_use)
+               appendPQExpBuffer(&buf, "c=eSws"); /* base64 of "y,," */
+       else
+               appendPQExpBuffer(&buf, "c=biws"); /* base64 of "n,," */
+
+       if (PQExpBufferDataBroken(buf))
+               goto oom_error;
+
+       appendPQExpBuffer(&buf, ",r=%s", state->nonce);
        if (PQExpBufferDataBroken(buf))
                goto oom_error;
 
diff --git a/src/interfaces/libpq/fe-auth.c b/src/interfaces/libpq/fe-auth.c
index 382558f3f8..9d394919ef 100644
--- a/src/interfaces/libpq/fe-auth.c
+++ b/src/interfaces/libpq/fe-auth.c
@@ -491,6 +491,9 @@ pg_SASL_init(PGconn *conn, int payloadlen)
        bool            success;
        const char *selected_mechanism;
        PQExpBufferData mechanism_buf;
+       char       *tls_finished = NULL;
+       size_t          tls_finished_len = 0;
+       char       *password;
 
        initPQExpBuffer(&mechanism_buf);
 
@@ -504,7 +507,8 @@ pg_SASL_init(PGconn *conn, int payloadlen)
        /*
         * Parse the list of SASL authentication mechanisms in the
         * AuthenticationSASL message, and select the best mechanism that we
-        * support.  (Only SCRAM-SHA-256 is supported at the moment.)
+        * support.  SCRAM-SHA-256-PLUS and SCRAM-SHA-256 are the only ones
+        * supported at the moment, listed by order of decreasing importance.
         */
        selected_mechanism = NULL;
        for (;;)
@@ -523,35 +527,17 @@ pg_SASL_init(PGconn *conn, int payloadlen)
                        break;
 
                /*
-                * If we have already selected a mechanism, just skip through 
the rest
-                * of the list.
+                * Select the mechanism to use.  Pick SCRAM-SHA-256-PLUS over 
anything
+                * else.  Pick SCRAM-SHA-256 if nothing else has already been 
picked.
+                * If we add more mechanisms, a more refined priority mechanism 
might
+                * become necessary.
                 */
-               if (selected_mechanism)
-                       continue;
-
-               /*
-                * Do we support this mechanism?
-                */
-               if (strcmp(mechanism_buf.data, SCRAM_SHA256_NAME) == 0)
-               {
-                       char       *password;
-
-                       conn->password_needed = true;
-                       password = conn->connhost[conn->whichhost].password;
-                       if (password == NULL)
-                               password = conn->pgpass;
-                       if (password == NULL || password[0] == '\0')
-                       {
-                               printfPQExpBuffer(&conn->errorMessage,
-                                                                 
PQnoPasswordSupplied);
-                               goto error;
-                       }
-
-                       conn->sasl_state = pg_fe_scram_init(conn->pguser, 
password);
-                       if (!conn->sasl_state)
-                               goto oom_error;
+               if (conn->ssl_in_use &&
+                       strcmp(mechanism_buf.data, SCRAM_SHA256_PLUS_NAME) == 0)
+                               selected_mechanism = SCRAM_SHA256_PLUS_NAME;
+               else if (strcmp(mechanism_buf.data, SCRAM_SHA256_NAME) == 0 &&
+                                !selected_mechanism)
                        selected_mechanism = SCRAM_SHA256_NAME;
-               }
        }
 
        if (!selected_mechanism)
@@ -561,6 +547,54 @@ pg_SASL_init(PGconn *conn, int payloadlen)
                goto error;
        }
 
+       /*
+        * Now that the SASL mechanism has been chosen for the exchange,
+        * initialize its state information.
+        */
+
+       /*
+        * First, select the password to use for the exchange, complaining if
+        * there isn't one.  Currently, all supported SASL mechanisms require a
+        * password, so we can just go ahead here without further distinction.
+        */
+       conn->password_needed = true;
+       password = conn->connhost[conn->whichhost].password;
+       if (password == NULL)
+               password = conn->pgpass;
+       if (password == NULL || password[0] == '\0')
+       {
+               printfPQExpBuffer(&conn->errorMessage,
+                                                 PQnoPasswordSupplied);
+               goto error;
+       }
+
+#ifdef USE_SSL
+       /*
+        * Get data for channel binding.
+        */
+       if (strcmp(selected_mechanism, SCRAM_SHA256_PLUS_NAME) == 0)
+       {
+               tls_finished = pgtls_get_finished(conn, &tls_finished_len);
+               if (tls_finished == NULL)
+                       goto oom_error;
+       }
+#endif
+
+       /*
+        * Initialize the SASL state information with all the information
+        * gathered during the initial exchange.
+        *
+        * Note: Only tls-unique is supported for the moment.
+        */
+       conn->sasl_state = pg_fe_scram_init(conn->pguser,
+                                                                               
password,
+                                                                               
conn->ssl_in_use,
+                                                                               
selected_mechanism,
+                                                                               
tls_finished,
+                                                                               
tls_finished_len);
+       if (!conn->sasl_state)
+               goto oom_error;
+
        /* Get the mechanism-specific Initial Client Response, if any */
        pg_fe_scram_exchange(conn->sasl_state,
                                                 NULL, -1,
diff --git a/src/interfaces/libpq/fe-auth.h b/src/interfaces/libpq/fe-auth.h
index 5dc6bb5341..1525a52742 100644
--- a/src/interfaces/libpq/fe-auth.h
+++ b/src/interfaces/libpq/fe-auth.h
@@ -23,7 +23,12 @@ extern int   pg_fe_sendauth(AuthRequest areq, int 
payloadlen, PGconn *conn);
 extern char *pg_fe_getauthname(PQExpBuffer errorMessage);
 
 /* Prototypes for functions in fe-auth-scram.c */
-extern void *pg_fe_scram_init(const char *username, const char *password);
+extern void *pg_fe_scram_init(const char *username,
+                                                         const char *password,
+                                                         bool ssl_in_use,
+                                                         const char 
*sasl_mechanism,
+                                                         char 
*tls_finished_message,
+                                                         size_t 
tls_finished_len);
 extern void pg_fe_scram_free(void *opaq);
 extern void pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
                                         char **output, int *outputlen,
diff --git a/src/interfaces/libpq/fe-secure-openssl.c 
b/src/interfaces/libpq/fe-secure-openssl.c
index 2f29820e82..61d161b367 100644
--- a/src/interfaces/libpq/fe-secure-openssl.c
+++ b/src/interfaces/libpq/fe-secure-openssl.c
@@ -393,6 +393,33 @@ pgtls_write(PGconn *conn, const void *ptr, size_t len)
        return n;
 }
 
+/*
+ *     Get the TLS finish message sent during last handshake
+ *
+ * This information is useful for callers doing channel binding during
+ * authentication.
+ */
+char *
+pgtls_get_finished(PGconn *conn, size_t *len)
+{
+       char            dummy[1];
+       char       *result;
+
+       /*
+        * OpenSSL does not offer an API to get directly the length of the TLS
+        * Finished message sent, so first do a dummy call to grab this
+        * information and then do an allocation with the correct size.
+        */
+       *len = SSL_get_finished(conn->ssl, dummy, sizeof(dummy));
+       result = malloc(*len);
+       if (result == NULL)
+               return NULL;
+       (void) SSL_get_finished(conn->ssl, result, *len);
+
+       return result;
+}
+
+
 /* ------------------------------------------------------------ */
 /*                                             OpenSSL specific code           
                        */
 /* ------------------------------------------------------------ */
diff --git a/src/interfaces/libpq/libpq-int.h b/src/interfaces/libpq/libpq-int.h
index 42913604e3..8412ee8160 100644
--- a/src/interfaces/libpq/libpq-int.h
+++ b/src/interfaces/libpq/libpq-int.h
@@ -453,11 +453,13 @@ struct pg_conn
        /* Assorted state for SASL, SSL, GSS, etc */
        void       *sasl_state;
 
+       /* SSL structures */
+       bool            ssl_in_use;
+
 #ifdef USE_SSL
        bool            allow_ssl_try;  /* Allowed to try SSL negotiation */
        bool            wait_ssl_try;   /* Delay SSL negotiation until after
                                                                 * attempting 
normal connection */
-       bool            ssl_in_use;
 #ifdef USE_OPENSSL
        SSL                *ssl;                        /* SSL status, if have 
SSL connection */
        X509       *peer;                       /* X509 cert of server */
@@ -668,6 +670,7 @@ extern void pgtls_close(PGconn *conn);
 extern ssize_t pgtls_read(PGconn *conn, void *ptr, size_t len);
 extern bool pgtls_read_pending(PGconn *conn);
 extern ssize_t pgtls_write(PGconn *conn, const void *ptr, size_t len);
+extern char *pgtls_get_finished(PGconn *conn, size_t *len);
 
 /*
  * this is so that we can check if a connection is non-blocking internally
diff --git a/src/test/ssl/ServerSetup.pm b/src/test/ssl/ServerSetup.pm
index ad2e036602..02f8028b2b 100644
--- a/src/test/ssl/ServerSetup.pm
+++ b/src/test/ssl/ServerSetup.pm
@@ -57,19 +57,21 @@ sub test_connect_ok
 {
        my $common_connstr = $_[0];
        my $connstr = $_[1];
+       my $test_name = $_[2];
 
        my $result =
          run_test_psql("$common_connstr $connstr", "(should succeed)");
-       ok($result, $connstr);
+       ok($result, $test_name || $connstr);
 }
 
 sub test_connect_fails
 {
        my $common_connstr = $_[0];
        my $connstr = $_[1];
+       my $test_name = $_[2];
 
        my $result = run_test_psql("$common_connstr $connstr", "(should fail)");
-       ok(!$result, "$connstr (should fail)");
+       ok(!$result, $test_name || "$connstr (should fail)");
 }
 
 # Copy a set of files, taking into account wildcards
@@ -89,8 +91,7 @@ sub copy_files
 
 sub configure_test_server_for_ssl
 {
-       my $node       = $_[0];
-       my $serverhost = $_[1];
+       my ($node, $serverhost, $authmethod, $password, $password_enc) = @_;
 
        my $pgdata = $node->data_dir;
 
@@ -100,6 +101,15 @@ sub configure_test_server_for_ssl
        $node->psql('postgres', "CREATE DATABASE trustdb");
        $node->psql('postgres', "CREATE DATABASE certdb");
 
+       # Update password of each user as needed.
+       if (defined($password))
+       {
+               $node->psql('postgres',
+"SET password_encryption='$password_enc'; ALTER USER ssltestuser PASSWORD 
'$password';");
+               $node->psql('postgres',
+"SET password_encryption='$password_enc'; ALTER USER anotheruser PASSWORD 
'$password';");
+       }
+
        # enable logging etc.
        open my $conf, '>>', "$pgdata/postgresql.conf";
        print $conf "fsync=off\n";
@@ -129,7 +139,7 @@ sub configure_test_server_for_ssl
        $node->restart;
 
        # Change pg_hba after restart because hostssl requires ssl=on
-       configure_hba_for_ssl($node, $serverhost);
+       configure_hba_for_ssl($node, $serverhost, $authmethod);
 }
 
 # Change the configuration to use given server cert file, and reload
@@ -157,8 +167,7 @@ sub switch_server_cert
 
 sub configure_hba_for_ssl
 {
-       my $node       = $_[0];
-       my $serverhost = $_[1];
+       my ($node, $serverhost, $authmethod) = @_;
        my $pgdata     = $node->data_dir;
 
   # Only accept SSL connections from localhost. Our tests don't depend on this
@@ -169,9 +178,9 @@ sub configure_hba_for_ssl
        print $hba
 "# TYPE  DATABASE        USER            ADDRESS                 METHOD\n";
        print $hba
-"hostssl trustdb         ssltestuser     $serverhost/32            trust\n";
+"hostssl trustdb         ssltestuser     $serverhost/32            
$authmethod\n";
        print $hba
-"hostssl trustdb         ssltestuser     ::1/128                 trust\n";
+"hostssl trustdb         ssltestuser     ::1/128                 
$authmethod\n";
        print $hba
 "hostssl certdb          ssltestuser     $serverhost/32            cert\n";
        print $hba
diff --git a/src/test/ssl/t/001_ssltests.pl b/src/test/ssl/t/001_ssltests.pl
index 890e3051a2..a0a06825c6 100644
--- a/src/test/ssl/t/001_ssltests.pl
+++ b/src/test/ssl/t/001_ssltests.pl
@@ -32,7 +32,7 @@
 $ENV{PGHOST} = $node->host;
 $ENV{PGPORT} = $node->port;
 $node->start;
-configure_test_server_for_ssl($node, $SERVERHOSTADDR);
+configure_test_server_for_ssl($node, $SERVERHOSTADDR, 'trust');
 switch_server_cert($node, 'server-cn-only');
 
 ### Part 1. Run client-side tests.
diff --git a/src/test/ssl/t/002_scram.pl b/src/test/ssl/t/002_scram.pl
new file mode 100644
index 0000000000..25f75bd52a
--- /dev/null
+++ b/src/test/ssl/t/002_scram.pl
@@ -0,0 +1,38 @@
+# Test SCRAM authentication and TLS channel binding types
+
+use strict;
+use warnings;
+use PostgresNode;
+use TestLib;
+use Test::More tests => 1;
+use ServerSetup;
+use File::Copy;
+
+# This is the hostname used to connect to the server.
+my $SERVERHOSTADDR = '127.0.0.1';
+
+# Allocation of base connection string shared among multiple tests.
+my $common_connstr;
+
+# Set up the server.
+
+note "setting up data directory";
+my $node = get_new_node('master');
+$node->init;
+
+# PGHOST is enforced here to set up the node, subsequent connections
+# will use a dedicated connection string.
+$ENV{PGHOST} = $node->host;
+$ENV{PGPORT} = $node->port;
+$node->start;
+
+# Configure server for SSL connections, with password handling.
+configure_test_server_for_ssl($node, $SERVERHOSTADDR, "scram-sha-256",
+                                                         "pass", 
"scram-sha-256");
+switch_server_cert($node, 'server-cn-only');
+$ENV{PGPASSWORD} = "pass";
+$common_connstr =
+"user=ssltestuser dbname=trustdb sslmode=require hostaddr=$SERVERHOSTADDR";
+
+test_connect_ok($common_connstr, '',
+                               "SCRAM authentication with default channel 
binding");
-- 
2.15.0

Reply via email to