On Wed, 2021-06-23 at 16:38 +0900, Michael Paquier wrote:
> On Tue, Jun 22, 2021 at 10:37:29PM +0000, Jacob Champion wrote:
> > Currently, the SASL logic is tightly coupled to the SCRAM
> > implementation. This patch untangles the two, by introducing callback
> > structs for both the frontend and backend.
> 
> The approach to define and have a set callbacks feels natural.

Good, thanks!

> +/* Status codes for message exchange */
> +#define SASL_EXCHANGE_CONTINUE     0
> +#define SASL_EXCHANGE_SUCCESS      1
> +#define SASL_EXCHANGE_FAILURE      2
> 
> It may be better to prefix those with PG_ as they begin to be
> published.

Added in v2.

> +/* Backend mechanism API */
> +typedef void  (*pg_be_sasl_mechanism_func)(Port *, StringInfo);
> +typedef void *(*pg_be_sasl_init_func)(Port *, const char *, const
> char *);
> +typedef int   (*pg_be_sasl_exchange_func)(void *, const char *, int,
> char **, int *, char **);
> +
> +typedef struct
> +{
> +   pg_be_sasl_mechanism_func   get_mechanisms;
> +   pg_be_sasl_init_func        init;
> +   pg_be_sasl_exchange_func    exchange;
> +} pg_be_sasl_mech;
> 
> All this is going to require much more documentation to explain what
> is the role of those callbacks and what they are here for.

Yes, definitely. If the current approach seems generally workable, I'll
get started on that.

> Another thing that is not tackled by this patch is the format of the
> messages exchanged which is something only in SCRAM now.  Perhaps it
> would be better to extract the small-ish routines currently in
> fe-auth-scram.c and auth-scram.c that we use to grab values associated
> to an attribute in an exchange message and put them in a central place
> like an auth-sasl.c and fe-auth-sasl.c.  This move could also make
> sense for the exising init and continue routines for SASL in
> fe-auth.c.

We can. I recommend waiting for another GS2 mechanism implementation,
though.

The attribute/value encoding is not part of core SASL (see [1] for that
RFC), and OAUTHBEARER is not technically a GS2 mechanism -- though it
makes use of a vestigal GS2 header block, apparently in the hopes that
one day it might become one. So we could pull out the similarities now,
but I'd hate to extract the wrong abstractions and make someone else
untangle it later.

> +static int
> +CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
> +{
> +   return SASL_exchange(&pg_be_scram_mech, port, shadow_pass, logdetail);
> +}
> It may be cleaner to live without this thin wrapper.  It is a bit
> strange to have a SCRAM API in a file where we want mostly SASL things
> (Okay, uaScram does not count as this is assigned after the HBA
> lookup).  Moving any SASL-related things into a separate file may be a
> cleaner option, especially considering that we have a bit more than
> the exchange itself, like message handling.

Heh, I figured that at ~3500 lines, you all just really wanted the
Check* implementations to live in auth.c. :D

I can definitely move it (into, say, auth-sasl.c?). I'll probably do
that in a second commit, though, since keeping it in place during the
refactor makes the review easier IMO.

> +typedef void *(*pg_sasl_init_func)(PGconn *, const char *, const char
> *);
> +typedef void  (*pg_sasl_exchange_func)(void *, char *, int, char **,
> int *, bool *, bool *);
> +typedef bool  (*pg_sasl_channel_bound_func)(void *);
> +typedef void  (*pg_sasl_free_func)(void *);
> +
> +typedef struct
> +{
> +   pg_sasl_init_func           init;
> +   pg_sasl_exchange_func       exchange;
> +   pg_sasl_channel_bound_func  channel_bound;
> +   pg_sasl_free_func           free;
> +} pg_sasl_mech;
> These would be better into a separate header, with more
> documentation.

Can do. Does libpq-int-sasl.h work as a filename? This should not be
exported to applications.

> It may be more consistent with the backend to name
> that pg_fe_sasl_mech?

Done in v2.

> It looks like there is enough material for a callback able to handle
> channel binding.  In the main patch for OAUTHBEARER, I can see for
> example that the handling of OAUTHBEARER-PLUS copied from its SCRAM
> sibling.  That does not need to be tackled in the same patch.  Just
> noting it on the way.

OAUTHBEARER doesn't support channel binding -- there's no OAUTHBEARER-
PLUS, and there probably won't ever be, given the mechanism's
simplicity -- so I'd recommend that this wait for a second GS2
mechanism implementation, as well.

> > (Note that our protocol implementation provides an "additional data"
> > field for the initial client response, but *not* for the authentication
> > outcome. That seems odd to me, but it is what it is, I suppose.)
> 
> You are referring to the protocol implementation as of
> AuthenticationSASLFinal, right?

Yes, but I misremembered. My statement was wrong -- we do allow for
additional data in the authentication outcome from the server.

For AuthenticationSASLFinal, we don't distinguish between "no
additional data" and "additional data of length zero", which IIRC is a
violation of the SASL protocol. That may cause problems with a
theoretical future mechanism implementation, but I don't think it
affects SCRAM. I believe we *do* distinguish between those cases
correctly for the initial client response packet.

Sorry for the confusion; let me double-check again when I have fresh
eyes at the start of the week, before sending you on a goose chase.

> > Regarding that specific TODO -- I think it'd be good for the framework
> > to fail hard if a mechanism tries to send data during a failure
> > outcome, as it probably means the mechanism isn't implemented to spec.
> 
> Agreed.  That would mean patching libpq to add more safeguards in
> pg_SASL_continue() if I am following correctly.

Right.

Thanks for the review!
--Jacob

[1] https://datatracker.ietf.org/doc/html/rfc5801
From aee9797acd4568fc90da6df04ec417b0da00f3f3 Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchamp...@vmware.com>
Date: Tue, 13 Apr 2021 10:25:48 -0700
Subject: [PATCH v2 1/7] auth: generalize SASL mechanisms

Split the SASL logic out from the SCRAM implementation, so that it can
be reused by other mechanisms.  New implementations will implement both
a pg_fe_sasl_mech and a pg_be_sasl_mech.
---
 src/backend/libpq/auth-scram.c       | 48 ++++++++++++++++------------
 src/backend/libpq/auth.c             | 40 +++++++++++++++--------
 src/include/libpq/sasl.h             | 34 ++++++++++++++++++++
 src/include/libpq/scram.h            | 13 ++------
 src/interfaces/libpq/fe-auth-scram.c | 40 ++++++++++++++++-------
 src/interfaces/libpq/fe-auth.c       | 16 +++++++---
 src/interfaces/libpq/fe-auth.h       | 11 ++-----
 src/interfaces/libpq/fe-connect.c    |  6 +---
 src/interfaces/libpq/libpq-int.h     | 14 ++++++++
 9 files changed, 149 insertions(+), 73 deletions(-)
 create mode 100644 src/include/libpq/sasl.h

diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c
index f9e1026a12..2965ea2ddb 100644
--- a/src/backend/libpq/auth-scram.c
+++ b/src/backend/libpq/auth-scram.c
@@ -101,11 +101,25 @@
 #include "common/sha2.h"
 #include "libpq/auth.h"
 #include "libpq/crypt.h"
+#include "libpq/sasl.h"
 #include "libpq/scram.h"
 #include "miscadmin.h"
 #include "utils/builtins.h"
 #include "utils/timestamp.h"
 
+static void  scram_get_mechanisms(Port *port, StringInfo buf);
+static void *scram_init(Port *port, const char *selected_mech,
+						const char *shadow_pass);
+static int   scram_exchange(void *opaq, const char *input, int inputlen,
+							char **output, int *outputlen, char **logdetail);
+
+/* Mechanism declaration */
+const pg_be_sasl_mech pg_be_scram_mech = {
+	scram_get_mechanisms,
+	scram_init,
+	scram_exchange,
+};
+
 /*
  * Status data for a SCRAM authentication exchange.  This should be kept
  * internal to this file.
@@ -170,16 +184,14 @@ static char *sanitize_str(const char *s);
 static char *scram_mock_salt(const char *username);
 
 /*
- * pg_be_scram_get_mechanisms
- *
  * Get a list of SASL mechanisms that this module supports.
  *
  * For the convenience of building the FE/BE packet that lists the
  * mechanisms, the names are appended to the given StringInfo buffer,
  * separated by '\0' bytes.
  */
-void
-pg_be_scram_get_mechanisms(Port *port, StringInfo buf)
+static void
+scram_get_mechanisms(Port *port, StringInfo buf)
 {
 	/*
 	 * Advertise the mechanisms in decreasing order of importance.  So the
@@ -199,8 +211,6 @@ pg_be_scram_get_mechanisms(Port *port, StringInfo buf)
 }
 
 /*
- * pg_be_scram_init
- *
  * Initialize a new SCRAM authentication exchange status tracker.  This
  * needs to be called before doing any exchange.  It will be filled later
  * after the beginning of the exchange with authentication information.
@@ -215,10 +225,8 @@ pg_be_scram_get_mechanisms(Port *port, StringInfo buf)
  * an authentication exchange, but it will fail, as if an incorrect password
  * was given.
  */
-void *
-pg_be_scram_init(Port *port,
-				 const char *selected_mech,
-				 const char *shadow_pass)
+static void *
+scram_init(Port *port, const char *selected_mech, const char *shadow_pass)
 {
 	scram_state *state;
 	bool		got_secret;
@@ -325,9 +333,9 @@ pg_be_scram_init(Port *port,
  * string at *logdetail that will be sent to the postmaster log (but not
  * the client).
  */
-int
-pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
-					 char **output, int *outputlen, char **logdetail)
+static int
+scram_exchange(void *opaq, const char *input, int inputlen,
+			   char **output, int *outputlen, char **logdetail)
 {
 	scram_state *state = (scram_state *) opaq;
 	int			result;
@@ -346,7 +354,7 @@ pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
 
 		*output = pstrdup("");
 		*outputlen = 0;
-		return SASL_EXCHANGE_CONTINUE;
+		return PG_SASL_EXCHANGE_CONTINUE;
 	}
 
 	/*
@@ -379,7 +387,7 @@ pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
 			*output = build_server_first_message(state);
 
 			state->state = SCRAM_AUTH_SALT_SENT;
-			result = SASL_EXCHANGE_CONTINUE;
+			result = PG_SASL_EXCHANGE_CONTINUE;
 			break;
 
 		case SCRAM_AUTH_SALT_SENT:
@@ -408,7 +416,7 @@ pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
 			 * erroring out in an application-specific way.  We choose to do
 			 * the latter, so that the error message for invalid password is
 			 * the same for all authentication methods.  The caller will call
-			 * ereport(), when we return SASL_EXCHANGE_FAILURE with no output.
+			 * ereport(), when we return PG_SASL_EXCHANGE_FAILURE with no output.
 			 *
 			 * NB: the order of these checks is intentional.  We calculate the
 			 * client proof even in a mock authentication, even though it's
@@ -417,7 +425,7 @@ pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
 			 */
 			if (!verify_client_proof(state) || state->doomed)
 			{
-				result = SASL_EXCHANGE_FAILURE;
+				result = PG_SASL_EXCHANGE_FAILURE;
 				break;
 			}
 
@@ -425,16 +433,16 @@ pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
 			*output = build_server_final_message(state);
 
 			/* Success! */
-			result = SASL_EXCHANGE_SUCCESS;
+			result = PG_SASL_EXCHANGE_SUCCESS;
 			state->state = SCRAM_AUTH_FINISHED;
 			break;
 
 		default:
 			elog(ERROR, "invalid SCRAM exchange state");
-			result = SASL_EXCHANGE_FAILURE;
+			result = PG_SASL_EXCHANGE_FAILURE;
 	}
 
-	if (result == SASL_EXCHANGE_FAILURE && state->logdetail && logdetail)
+	if (result == PG_SASL_EXCHANGE_FAILURE && state->logdetail && logdetail)
 		*logdetail = state->logdetail;
 
 	if (*output)
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index 967b5ef73c..82f043a343 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -26,11 +26,11 @@
 #include "commands/user.h"
 #include "common/ip.h"
 #include "common/md5.h"
-#include "common/scram-common.h"
 #include "libpq/auth.h"
 #include "libpq/crypt.h"
 #include "libpq/libpq.h"
 #include "libpq/pqformat.h"
+#include "libpq/sasl.h"
 #include "libpq/scram.h"
 #include "miscadmin.h"
 #include "port/pg_bswap.h"
@@ -51,6 +51,13 @@ static void auth_failed(Port *port, int status, char *logdetail);
 static char *recv_password_packet(Port *port);
 static void set_authn_id(Port *port, const char *id);
 
+/*----------------------------------------------------------------
+ * SASL common authentication
+ *----------------------------------------------------------------
+ */
+static int	SASL_exchange(const pg_be_sasl_mech *mech, Port *port,
+						  char *shadow_pass, char **logdetail);
+
 
 /*----------------------------------------------------------------
  * Password-based authentication methods (password, md5, and scram-sha-256)
@@ -912,12 +919,13 @@ CheckMD5Auth(Port *port, char *shadow_pass, char **logdetail)
 }
 
 static int
-CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
+SASL_exchange(const pg_be_sasl_mech *mech, Port *port, char *shadow_pass,
+			  char **logdetail)
 {
 	StringInfoData sasl_mechs;
 	int			mtype;
 	StringInfoData buf;
-	void	   *scram_opaq = NULL;
+	void	   *opaq = NULL;
 	char	   *output = NULL;
 	int			outputlen = 0;
 	const char *input;
@@ -931,7 +939,7 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
 	 */
 	initStringInfo(&sasl_mechs);
 
-	pg_be_scram_get_mechanisms(port, &sasl_mechs);
+	mech->get_mechanisms(port, &sasl_mechs);
 	/* Put another '\0' to mark that list is finished. */
 	appendStringInfoChar(&sasl_mechs, '\0');
 
@@ -998,7 +1006,7 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
 			 * This is because we don't want to reveal to an attacker what
 			 * usernames are valid, nor which users have a valid password.
 			 */
-			scram_opaq = pg_be_scram_init(port, selected_mech, shadow_pass);
+			opaq = mech->init(port, selected_mech, shadow_pass);
 
 			inputlen = pq_getmsgint(&buf, 4);
 			if (inputlen == -1)
@@ -1022,12 +1030,11 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
 		Assert(input == NULL || input[inputlen] == '\0');
 
 		/*
-		 * we pass 'logdetail' as NULL when doing a mock authentication,
-		 * because we should already have a better error message in that case
+		 * Hand the incoming message to the mechanism implementation.
 		 */
-		result = pg_be_scram_exchange(scram_opaq, input, inputlen,
-									  &output, &outputlen,
-									  logdetail);
+		result = mech->exchange(opaq, input, inputlen,
+								&output, &outputlen,
+								logdetail);
 
 		/* input buffer no longer used */
 		pfree(buf.data);
@@ -1039,17 +1046,18 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
 			 */
 			elog(DEBUG4, "sending SASL challenge of length %u", outputlen);
 
-			if (result == SASL_EXCHANGE_SUCCESS)
+			/* TODO: PG_SASL_EXCHANGE_FAILURE with output is forbidden in SASL */
+			if (result == PG_SASL_EXCHANGE_SUCCESS)
 				sendAuthRequest(port, AUTH_REQ_SASL_FIN, output, outputlen);
 			else
 				sendAuthRequest(port, AUTH_REQ_SASL_CONT, output, outputlen);
 
 			pfree(output);
 		}
-	} while (result == SASL_EXCHANGE_CONTINUE);
+	} while (result == PG_SASL_EXCHANGE_CONTINUE);
 
 	/* Oops, Something bad happened */
-	if (result != SASL_EXCHANGE_SUCCESS)
+	if (result != PG_SASL_EXCHANGE_SUCCESS)
 	{
 		return STATUS_ERROR;
 	}
@@ -1057,6 +1065,12 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
 	return STATUS_OK;
 }
 
+static int
+CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
+{
+	return SASL_exchange(&pg_be_scram_mech, port, shadow_pass, logdetail);
+}
+
 
 /*----------------------------------------------------------------
  * GSSAPI authentication system
diff --git a/src/include/libpq/sasl.h b/src/include/libpq/sasl.h
new file mode 100644
index 0000000000..1afabf843d
--- /dev/null
+++ b/src/include/libpq/sasl.h
@@ -0,0 +1,34 @@
+/*-------------------------------------------------------------------------
+ *
+ * sasl.h
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * src/include/libpq/sasl.h
+ *
+ *-------------------------------------------------------------------------
+ */
+#ifndef PG_SASL_H
+#define PG_SASL_H
+
+#include "libpq/libpq-be.h"
+
+/* Status codes for message exchange */
+#define PG_SASL_EXCHANGE_CONTINUE		0
+#define PG_SASL_EXCHANGE_SUCCESS		1
+#define PG_SASL_EXCHANGE_FAILURE		2
+
+/* Backend mechanism API */
+typedef void  (*pg_be_sasl_mechanism_func)(Port *, StringInfo);
+typedef void *(*pg_be_sasl_init_func)(Port *, const char *, const char *);
+typedef int   (*pg_be_sasl_exchange_func)(void *, const char *, int, char **, int *, char **);
+
+typedef struct
+{
+	pg_be_sasl_mechanism_func	get_mechanisms;
+	pg_be_sasl_init_func		init;
+	pg_be_sasl_exchange_func	exchange;
+} pg_be_sasl_mech;
+
+#endif /* PG_SASL_H */
diff --git a/src/include/libpq/scram.h b/src/include/libpq/scram.h
index 2c879150da..9e4540bde3 100644
--- a/src/include/libpq/scram.h
+++ b/src/include/libpq/scram.h
@@ -15,17 +15,10 @@
 
 #include "lib/stringinfo.h"
 #include "libpq/libpq-be.h"
+#include "libpq/sasl.h"
 
-/* Status codes for message exchange */
-#define SASL_EXCHANGE_CONTINUE		0
-#define SASL_EXCHANGE_SUCCESS		1
-#define SASL_EXCHANGE_FAILURE		2
-
-/* Routines dedicated to authentication */
-extern void pg_be_scram_get_mechanisms(Port *port, StringInfo buf);
-extern void *pg_be_scram_init(Port *port, const char *selected_mech, const char *shadow_pass);
-extern int	pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
-								 char **output, int *outputlen, char **logdetail);
+/* Implementation */
+extern const pg_be_sasl_mech pg_be_scram_mech;
 
 /* Routines to handle and check SCRAM-SHA-256 secret */
 extern char *pg_be_scram_build_secret(const char *password);
diff --git a/src/interfaces/libpq/fe-auth-scram.c b/src/interfaces/libpq/fe-auth-scram.c
index 5881386e37..515ef66f37 100644
--- a/src/interfaces/libpq/fe-auth-scram.c
+++ b/src/interfaces/libpq/fe-auth-scram.c
@@ -21,6 +21,22 @@
 #include "fe-auth.h"
 
 
+/* The exported SCRAM callback mechanism. */
+static void *scram_init(PGconn *conn, const char *password,
+						const char *sasl_mechanism);
+static void scram_exchange(void *opaq, char *input, int inputlen,
+						   char **output, int *outputlen,
+						   bool *done, bool *success);
+static bool scram_channel_bound(void *opaq);
+static void scram_free(void *opaq);
+
+const pg_fe_sasl_mech pg_scram_mech = {
+	scram_init,
+	scram_exchange,
+	scram_channel_bound,
+	scram_free,
+};
+
 /*
  * Status of exchange messages used for SCRAM authentication via the
  * SASL protocol.
@@ -72,10 +88,10 @@ static bool calculate_client_proof(fe_scram_state *state,
 /*
  * Initialize SCRAM exchange status.
  */
-void *
-pg_fe_scram_init(PGconn *conn,
-				 const char *password,
-				 const char *sasl_mechanism)
+static void *
+scram_init(PGconn *conn,
+		   const char *password,
+		   const char *sasl_mechanism)
 {
 	fe_scram_state *state;
 	char	   *prep_password;
@@ -128,8 +144,8 @@ pg_fe_scram_init(PGconn *conn,
  * Note that the caller must also ensure that the exchange was actually
  * successful.
  */
-bool
-pg_fe_scram_channel_bound(void *opaq)
+static bool
+scram_channel_bound(void *opaq)
 {
 	fe_scram_state *state = (fe_scram_state *) opaq;
 
@@ -152,8 +168,8 @@ pg_fe_scram_channel_bound(void *opaq)
 /*
  * Free SCRAM exchange status
  */
-void
-pg_fe_scram_free(void *opaq)
+static void
+scram_free(void *opaq)
 {
 	fe_scram_state *state = (fe_scram_state *) opaq;
 
@@ -188,10 +204,10 @@ pg_fe_scram_free(void *opaq)
 /*
  * Exchange a SCRAM message with backend.
  */
-void
-pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
-					 char **output, int *outputlen,
-					 bool *done, bool *success)
+static void
+scram_exchange(void *opaq, char *input, int inputlen,
+			   char **output, int *outputlen,
+			   bool *done, bool *success)
 {
 	fe_scram_state *state = (fe_scram_state *) opaq;
 	PGconn	   *conn = state->conn;
diff --git a/src/interfaces/libpq/fe-auth.c b/src/interfaces/libpq/fe-auth.c
index e8062647e6..d5cbac108e 100644
--- a/src/interfaces/libpq/fe-auth.c
+++ b/src/interfaces/libpq/fe-auth.c
@@ -482,7 +482,10 @@ pg_SASL_init(PGconn *conn, int payloadlen)
 				 * channel_binding is not disabled.
 				 */
 				if (conn->channel_binding[0] != 'd')	/* disable */
+				{
 					selected_mechanism = SCRAM_SHA_256_PLUS_NAME;
+					conn->sasl = &pg_scram_mech;
+				}
 #else
 				/*
 				 * The client does not support channel binding.  If it is
@@ -516,7 +519,10 @@ pg_SASL_init(PGconn *conn, int payloadlen)
 		}
 		else if (strcmp(mechanism_buf.data, SCRAM_SHA_256_NAME) == 0 &&
 				 !selected_mechanism)
+		{
 			selected_mechanism = SCRAM_SHA_256_NAME;
+			conn->sasl = &pg_scram_mech;
+		}
 	}
 
 	if (!selected_mechanism)
@@ -555,20 +561,22 @@ pg_SASL_init(PGconn *conn, int payloadlen)
 		goto error;
 	}
 
+	Assert(conn->sasl);
+
 	/*
 	 * Initialize the SASL state information with all the information gathered
 	 * during the initial exchange.
 	 *
 	 * Note: Only tls-unique is supported for the moment.
 	 */
-	conn->sasl_state = pg_fe_scram_init(conn,
+	conn->sasl_state = conn->sasl->init(conn,
 										password,
 										selected_mechanism);
 	if (!conn->sasl_state)
 		goto oom_error;
 
 	/* Get the mechanism-specific Initial Client Response, if any */
-	pg_fe_scram_exchange(conn->sasl_state,
+	conn->sasl->exchange(conn->sasl_state,
 						 NULL, -1,
 						 &initialresponse, &initialresponselen,
 						 &done, &success);
@@ -649,7 +657,7 @@ pg_SASL_continue(PGconn *conn, int payloadlen, bool final)
 	/* For safety and convenience, ensure the buffer is NULL-terminated. */
 	challenge[payloadlen] = '\0';
 
-	pg_fe_scram_exchange(conn->sasl_state,
+	conn->sasl->exchange(conn->sasl_state,
 						 challenge, payloadlen,
 						 &output, &outputlen,
 						 &done, &success);
@@ -830,7 +838,7 @@ check_expected_areq(AuthRequest areq, PGconn *conn)
 			case AUTH_REQ_SASL_FIN:
 				break;
 			case AUTH_REQ_OK:
-				if (!pg_fe_scram_channel_bound(conn->sasl_state))
+				if (!conn->sasl || !conn->sasl->channel_bound(conn->sasl_state))
 				{
 					appendPQExpBufferStr(&conn->errorMessage,
 										 libpq_gettext("channel binding required, but server authenticated client without channel binding\n"));
diff --git a/src/interfaces/libpq/fe-auth.h b/src/interfaces/libpq/fe-auth.h
index 7877dcbd09..63927480ee 100644
--- a/src/interfaces/libpq/fe-auth.h
+++ b/src/interfaces/libpq/fe-auth.h
@@ -22,15 +22,8 @@
 extern int	pg_fe_sendauth(AuthRequest areq, int payloadlen, PGconn *conn);
 extern char *pg_fe_getauthname(PQExpBuffer errorMessage);
 
-/* Prototypes for functions in fe-auth-scram.c */
-extern void *pg_fe_scram_init(PGconn *conn,
-							  const char *password,
-							  const char *sasl_mechanism);
-extern bool pg_fe_scram_channel_bound(void *opaq);
-extern void pg_fe_scram_free(void *opaq);
-extern void pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
-								 char **output, int *outputlen,
-								 bool *done, bool *success);
+/* Mechanisms in fe-auth-scram.c */
+extern const pg_fe_sasl_mech pg_scram_mech;
 extern char *pg_fe_scram_build_secret(const char *password);
 
 #endif							/* FE_AUTH_H */
diff --git a/src/interfaces/libpq/fe-connect.c b/src/interfaces/libpq/fe-connect.c
index 80703698b8..10d007582c 100644
--- a/src/interfaces/libpq/fe-connect.c
+++ b/src/interfaces/libpq/fe-connect.c
@@ -517,11 +517,7 @@ pqDropConnection(PGconn *conn, bool flushInput)
 #endif
 	if (conn->sasl_state)
 	{
-		/*
-		 * XXX: if support for more authentication mechanisms is added, this
-		 * needs to call the right 'free' function.
-		 */
-		pg_fe_scram_free(conn->sasl_state);
+		conn->sasl->free(conn->sasl_state);
 		conn->sasl_state = NULL;
 	}
 }
diff --git a/src/interfaces/libpq/libpq-int.h b/src/interfaces/libpq/libpq-int.h
index e81dc37906..99eaff50a0 100644
--- a/src/interfaces/libpq/libpq-int.h
+++ b/src/interfaces/libpq/libpq-int.h
@@ -339,6 +339,19 @@ typedef struct pg_conn_host
 								 * found in password file. */
 } pg_conn_host;
 
+typedef void *(*pg_fe_sasl_init_func)(PGconn *, const char *, const char *);
+typedef void  (*pg_fe_sasl_exchange_func)(void *, char *, int, char **, int *, bool *, bool *);
+typedef bool  (*pg_fe_sasl_channel_bound_func)(void *);
+typedef void  (*pg_fe_sasl_free_func)(void *);
+
+typedef struct
+{
+	pg_fe_sasl_init_func			init;
+	pg_fe_sasl_exchange_func		exchange;
+	pg_fe_sasl_channel_bound_func	channel_bound;
+	pg_fe_sasl_free_func			free;
+} pg_fe_sasl_mech;
+
 /*
  * PGconn stores all the state data associated with a single connection
  * to a backend.
@@ -500,6 +513,7 @@ struct pg_conn
 	PGresult   *next_result;	/* next result (used in single-row mode) */
 
 	/* Assorted state for SASL, SSL, GSS, etc */
+	const pg_fe_sasl_mech *sasl;
 	void	   *sasl_state;
 
 	/* SSL structures */
-- 
2.25.1

Reply via email to