On Sat, 2021-06-26 at 09:47 +0900, Michael Paquier wrote:
> On Fri, Jun 25, 2021 at 11:40:33PM +0000, Jacob Champion wrote:
> > I can definitely move it (into, say, auth-sasl.c?). I'll probably do
> > that in a second commit, though, since keeping it in place during the
> > refactor makes the review easier IMO.
> 
> auth-sasl.c is a name consistent with the existing practice.
> 
> > Can do. Does libpq-int-sasl.h work as a filename? This should not be
> > exported to applications.
> 
> I would still with the existing naming used by fe-gssapi-common.h, so
> that would be fe-auth-sasl.c and fe-auth-sasl.h, with the header
> remaining internal.  Not strongly wedded to this name, of course, that
> just seems consistent.

Done in v3, with a second patch for the code motion.

I added a first pass at API documentation as well. This exposed some
additional front-end TODOs that I added inline, but they should
probably be dealt with independently of the refactor:

- Zero-length client responses are legal in the SASL framework;
currently we use zero as a sentinel for "don't send a response".

- I don't think it's legal for a client to refuse a challenge from the
server without aborting the exchange, so we should probably check to
make sure that client responses are non-NULL in the success case.

--Jacob
diff --git a/src/backend/libpq/Makefile b/src/backend/libpq/Makefile
index 8d1d16b0fc..6d385fd6a4 100644
--- a/src/backend/libpq/Makefile
+++ b/src/backend/libpq/Makefile
@@ -15,6 +15,7 @@ include $(top_builddir)/src/Makefile.global
 # be-fsstubs is here for historical reasons, probably belongs elsewhere
 
 OBJS = \
+       auth-sasl.o \
        auth-scram.o \
        auth.o \
        be-fsstubs.o \
diff --git a/src/backend/libpq/auth-sasl.c b/src/backend/libpq/auth-sasl.c
new file mode 100644
index 0000000000..b7cdb2ecf6
--- /dev/null
+++ b/src/backend/libpq/auth-sasl.c
@@ -0,0 +1,187 @@
+/*-------------------------------------------------------------------------
+ *
+ * auth-sasl.c
+ *       Routines to handle network authentication via SASL
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ *
+ * IDENTIFICATION
+ *       src/backend/libpq/auth-sasl.c
+ *
+ *-------------------------------------------------------------------------
+ */
+
+#include "postgres.h"
+
+#include "libpq/auth.h"
+#include "libpq/libpq.h"
+#include "libpq/pqformat.h"
+#include "libpq/sasl.h"
+
+/*
+ * Perform a SASL exchange with a libpq client, using a specific mechanism
+ * implementation.
+ *
+ * shadow_pass is an optional pointer to the shadow entry for the client's
+ * presented user name. For mechanisms that use shadowed passwords, a NULL
+ * pointer here means that an entry could not be found for the user (or the 
user
+ * does not exist), and the mechanism should fail the authentication exchange.
+ *
+ * Mechanisms must take care not to reveal to the client that a user entry does
+ * not exist; ideally, the external failure mode is identical to that of an
+ * incorrect password. Mechanisms may instead use the logdetail output 
parameter
+ * to internally differentiate between failure cases and assist debugging by 
the
+ * server admin.
+ *
+ * A mechanism is not required to utilize a shadow entry, or even a password
+ * system at all; for these cases, shadow_pass may be ignored and the caller
+ * should just pass NULL.
+ */
+int
+CheckSASLAuth(const pg_be_sasl_mech *mech, Port *port, char *shadow_pass,
+                         char **logdetail)
+{
+       StringInfoData sasl_mechs;
+       int                     mtype;
+       StringInfoData buf;
+       void       *opaq = NULL;
+       char       *output = NULL;
+       int                     outputlen = 0;
+       const char *input;
+       int                     inputlen;
+       int                     result;
+       bool            initial;
+
+       /*
+        * Send the SASL authentication request to user.  It includes the list 
of
+        * authentication mechanisms that are supported.
+        */
+       initStringInfo(&sasl_mechs);
+
+       mech->get_mechanisms(port, &sasl_mechs);
+       /* Put another '\0' to mark that list is finished. */
+       appendStringInfoChar(&sasl_mechs, '\0');
+
+       sendAuthRequest(port, AUTH_REQ_SASL, sasl_mechs.data, sasl_mechs.len);
+       pfree(sasl_mechs.data);
+
+       /*
+        * Loop through SASL message exchange.  This exchange can consist of
+        * multiple messages sent in both directions.  First message is always
+        * from the client.  All messages from client to server are password
+        * packets (type 'p').
+        */
+       initial = true;
+       do
+       {
+               pq_startmsgread();
+               mtype = pq_getbyte();
+               if (mtype != 'p')
+               {
+                       /* Only log error if client didn't disconnect. */
+                       if (mtype != EOF)
+                       {
+                               ereport(ERROR,
+                                               
(errcode(ERRCODE_PROTOCOL_VIOLATION),
+                                                errmsg("expected SASL 
response, got message type %d",
+                                                               mtype)));
+                       }
+                       else
+                               return STATUS_EOF;
+               }
+
+               /* Get the actual SASL message */
+               initStringInfo(&buf);
+               if (pq_getmessage(&buf, PG_MAX_SASL_MESSAGE_LENGTH))
+               {
+                       /* EOF - pq_getmessage already logged error */
+                       pfree(buf.data);
+                       return STATUS_ERROR;
+               }
+
+               elog(DEBUG4, "processing received SASL response of length %d", 
buf.len);
+
+               /*
+                * The first SASLInitialResponse message is different from the 
others.
+                * It indicates which SASL mechanism the client selected, and 
contains
+                * an optional Initial Client Response payload.  The subsequent
+                * SASLResponse messages contain just the SASL payload.
+                */
+               if (initial)
+               {
+                       const char *selected_mech;
+
+                       selected_mech = pq_getmsgrawstring(&buf);
+
+                       /*
+                        * Initialize the status tracker for message exchanges.
+                        *
+                        * If the user doesn't exist, or doesn't have a valid 
password, or
+                        * it's expired, we still go through the motions of SASL
+                        * authentication, but tell the authentication method 
that the
+                        * authentication is "doomed". That is, it's going to 
fail, no
+                        * matter what.
+                        *
+                        * This is because we don't want to reveal to an 
attacker what
+                        * usernames are valid, nor which users have a valid 
password.
+                        */
+                       opaq = mech->init(port, selected_mech, shadow_pass);
+
+                       inputlen = pq_getmsgint(&buf, 4);
+                       if (inputlen == -1)
+                               input = NULL;
+                       else
+                               input = pq_getmsgbytes(&buf, inputlen);
+
+                       initial = false;
+               }
+               else
+               {
+                       inputlen = buf.len;
+                       input = pq_getmsgbytes(&buf, buf.len);
+               }
+               pq_getmsgend(&buf);
+
+               /*
+                * The StringInfo guarantees that there's a \0 byte after the
+                * response.
+                */
+               Assert(input == NULL || input[inputlen] == '\0');
+
+               /*
+                * Hand the incoming message to the mechanism implementation.
+                */
+               result = mech->exchange(opaq, input, inputlen,
+                                                               &output, 
&outputlen,
+                                                               logdetail);
+
+               /* input buffer no longer used */
+               pfree(buf.data);
+
+               if (output)
+               {
+                       /*
+                        * Negotiation generated data to be sent to the client.
+                        */
+                       elog(DEBUG4, "sending SASL challenge of length %u", 
outputlen);
+
+                       /* TODO: PG_SASL_EXCHANGE_FAILURE with output is 
forbidden in SASL */
+                       if (result == PG_SASL_EXCHANGE_SUCCESS)
+                               sendAuthRequest(port, AUTH_REQ_SASL_FIN, 
output, outputlen);
+                       else
+                               sendAuthRequest(port, AUTH_REQ_SASL_CONT, 
output, outputlen);
+
+                       pfree(output);
+               }
+       } while (result == PG_SASL_EXCHANGE_CONTINUE);
+
+       /* Oops, Something bad happened */
+       if (result != PG_SASL_EXCHANGE_SUCCESS)
+       {
+               return STATUS_ERROR;
+       }
+
+       return STATUS_OK;
+}
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index 82f043a343..ac6fe4a747 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -45,19 +45,10 @@
  * Global authentication functions
  *----------------------------------------------------------------
  */
-static void sendAuthRequest(Port *port, AuthRequest areq, const char 
*extradata,
-                                                       int extralen);
 static void auth_failed(Port *port, int status, char *logdetail);
 static char *recv_password_packet(Port *port);
 static void set_authn_id(Port *port, const char *id);
 
-/*----------------------------------------------------------------
- * SASL common authentication
- *----------------------------------------------------------------
- */
-static int     SASL_exchange(const pg_be_sasl_mech *mech, Port *port,
-                                                 char *shadow_pass, char 
**logdetail);
-
 
 /*----------------------------------------------------------------
  * Password-based authentication methods (password, md5, and scram-sha-256)
@@ -67,7 +58,6 @@ static int    CheckPasswordAuth(Port *port, char **logdetail);
 static int     CheckPWChallengeAuth(Port *port, char **logdetail);
 
 static int     CheckMD5Auth(Port *port, char *shadow_pass, char **logdetail);
-static int     CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail);
 
 
 /*----------------------------------------------------------------
@@ -231,14 +221,6 @@ static int PerformRadiusTransaction(const char *server, 
const char *secret, cons
  */
 #define PG_MAX_AUTH_TOKEN_LENGTH       65535
 
-/*
- * Maximum accepted size of SASL messages.
- *
- * The messages that the server or libpq generate are much smaller than this,
- * but have some headroom.
- */
-#define PG_MAX_SASL_MESSAGE_LENGTH     1024
-
 /*----------------------------------------------------------------
  * Global authentication functions
  *----------------------------------------------------------------
@@ -675,7 +657,7 @@ ClientAuthentication(Port *port)
 /*
  * Send an authentication request packet to the frontend.
  */
-static void
+void
 sendAuthRequest(Port *port, AuthRequest areq, const char *extradata, int 
extralen)
 {
        StringInfoData buf;
@@ -855,12 +837,13 @@ CheckPWChallengeAuth(Port *port, char **logdetail)
         * SCRAM secret, we must do SCRAM authentication.
         *
         * If MD5 authentication is not allowed, always use SCRAM.  If the user
-        * had an MD5 password, CheckSCRAMAuth() will fail.
+        * had an MD5 password, the SCRAM mechanism will fail.
         */
        if (port->hba->auth_method == uaMD5 && pwtype == PASSWORD_TYPE_MD5)
                auth_result = CheckMD5Auth(port, shadow_pass, logdetail);
        else
-               auth_result = CheckSCRAMAuth(port, shadow_pass, logdetail);
+               auth_result = CheckSASLAuth(&pg_be_scram_mech, port, 
shadow_pass,
+                                                                       
logdetail);
 
        if (shadow_pass)
                pfree(shadow_pass);
@@ -918,159 +901,6 @@ CheckMD5Auth(Port *port, char *shadow_pass, char 
**logdetail)
        return result;
 }
 
-static int
-SASL_exchange(const pg_be_sasl_mech *mech, Port *port, char *shadow_pass,
-                         char **logdetail)
-{
-       StringInfoData sasl_mechs;
-       int                     mtype;
-       StringInfoData buf;
-       void       *opaq = NULL;
-       char       *output = NULL;
-       int                     outputlen = 0;
-       const char *input;
-       int                     inputlen;
-       int                     result;
-       bool            initial;
-
-       /*
-        * Send the SASL authentication request to user.  It includes the list 
of
-        * authentication mechanisms that are supported.
-        */
-       initStringInfo(&sasl_mechs);
-
-       mech->get_mechanisms(port, &sasl_mechs);
-       /* Put another '\0' to mark that list is finished. */
-       appendStringInfoChar(&sasl_mechs, '\0');
-
-       sendAuthRequest(port, AUTH_REQ_SASL, sasl_mechs.data, sasl_mechs.len);
-       pfree(sasl_mechs.data);
-
-       /*
-        * Loop through SASL message exchange.  This exchange can consist of
-        * multiple messages sent in both directions.  First message is always
-        * from the client.  All messages from client to server are password
-        * packets (type 'p').
-        */
-       initial = true;
-       do
-       {
-               pq_startmsgread();
-               mtype = pq_getbyte();
-               if (mtype != 'p')
-               {
-                       /* Only log error if client didn't disconnect. */
-                       if (mtype != EOF)
-                       {
-                               ereport(ERROR,
-                                               
(errcode(ERRCODE_PROTOCOL_VIOLATION),
-                                                errmsg("expected SASL 
response, got message type %d",
-                                                               mtype)));
-                       }
-                       else
-                               return STATUS_EOF;
-               }
-
-               /* Get the actual SASL message */
-               initStringInfo(&buf);
-               if (pq_getmessage(&buf, PG_MAX_SASL_MESSAGE_LENGTH))
-               {
-                       /* EOF - pq_getmessage already logged error */
-                       pfree(buf.data);
-                       return STATUS_ERROR;
-               }
-
-               elog(DEBUG4, "processing received SASL response of length %d", 
buf.len);
-
-               /*
-                * The first SASLInitialResponse message is different from the 
others.
-                * It indicates which SASL mechanism the client selected, and 
contains
-                * an optional Initial Client Response payload.  The subsequent
-                * SASLResponse messages contain just the SASL payload.
-                */
-               if (initial)
-               {
-                       const char *selected_mech;
-
-                       selected_mech = pq_getmsgrawstring(&buf);
-
-                       /*
-                        * Initialize the status tracker for message exchanges.
-                        *
-                        * If the user doesn't exist, or doesn't have a valid 
password, or
-                        * it's expired, we still go through the motions of SASL
-                        * authentication, but tell the authentication method 
that the
-                        * authentication is "doomed". That is, it's going to 
fail, no
-                        * matter what.
-                        *
-                        * This is because we don't want to reveal to an 
attacker what
-                        * usernames are valid, nor which users have a valid 
password.
-                        */
-                       opaq = mech->init(port, selected_mech, shadow_pass);
-
-                       inputlen = pq_getmsgint(&buf, 4);
-                       if (inputlen == -1)
-                               input = NULL;
-                       else
-                               input = pq_getmsgbytes(&buf, inputlen);
-
-                       initial = false;
-               }
-               else
-               {
-                       inputlen = buf.len;
-                       input = pq_getmsgbytes(&buf, buf.len);
-               }
-               pq_getmsgend(&buf);
-
-               /*
-                * The StringInfo guarantees that there's a \0 byte after the
-                * response.
-                */
-               Assert(input == NULL || input[inputlen] == '\0');
-
-               /*
-                * Hand the incoming message to the mechanism implementation.
-                */
-               result = mech->exchange(opaq, input, inputlen,
-                                                               &output, 
&outputlen,
-                                                               logdetail);
-
-               /* input buffer no longer used */
-               pfree(buf.data);
-
-               if (output)
-               {
-                       /*
-                        * Negotiation generated data to be sent to the client.
-                        */
-                       elog(DEBUG4, "sending SASL challenge of length %u", 
outputlen);
-
-                       /* TODO: PG_SASL_EXCHANGE_FAILURE with output is 
forbidden in SASL */
-                       if (result == PG_SASL_EXCHANGE_SUCCESS)
-                               sendAuthRequest(port, AUTH_REQ_SASL_FIN, 
output, outputlen);
-                       else
-                               sendAuthRequest(port, AUTH_REQ_SASL_CONT, 
output, outputlen);
-
-                       pfree(output);
-               }
-       } while (result == PG_SASL_EXCHANGE_CONTINUE);
-
-       /* Oops, Something bad happened */
-       if (result != PG_SASL_EXCHANGE_SUCCESS)
-       {
-               return STATUS_ERROR;
-       }
-
-       return STATUS_OK;
-}
-
-static int
-CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
-{
-       return SASL_exchange(&pg_be_scram_mech, port, shadow_pass, logdetail);
-}
-
 
 /*----------------------------------------------------------------
  * GSSAPI authentication system
diff --git a/src/include/libpq/auth.h b/src/include/libpq/auth.h
index 3610fae3ff..3d6734f253 100644
--- a/src/include/libpq/auth.h
+++ b/src/include/libpq/auth.h
@@ -21,6 +21,8 @@ extern bool pg_krb_caseins_users;
 extern char *pg_krb_realm;
 
 extern void ClientAuthentication(Port *port);
+extern void sendAuthRequest(Port *port, AuthRequest areq, const char 
*extradata,
+                                                       int extralen);
 
 /* Hook for plugins to get control in ClientAuthentication() */
 typedef void (*ClientAuthentication_hook_type) (Port *, int);
diff --git a/src/include/libpq/sasl.h b/src/include/libpq/sasl.h
index 1afabf843d..dad04d8ecd 100644
--- a/src/include/libpq/sasl.h
+++ b/src/include/libpq/sasl.h
@@ -1,6 +1,10 @@
 /*-------------------------------------------------------------------------
  *
  * sasl.h
+ *     Defines the SASL mechanism interface for the libpq backend. Each SASL
+ *     mechanism defines a frontend and a backend callback structure.
+ *
+ *     See src/interfaces/libpq/fe-auth-sasl.h for the frontend counterpart.
  *
  * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
  * Portions Copyright (c) 1994, Regents of the University of California
@@ -12,6 +16,7 @@
 #ifndef PG_SASL_H
 #define PG_SASL_H
 
+#include "lib/stringinfo.h"
 #include "libpq/libpq-be.h"
 
 /* Status codes for message exchange */
@@ -19,10 +24,107 @@
 #define PG_SASL_EXCHANGE_SUCCESS               1
 #define PG_SASL_EXCHANGE_FAILURE               2
 
-/* Backend mechanism API */
-typedef void  (*pg_be_sasl_mechanism_func)(Port *, StringInfo);
-typedef void *(*pg_be_sasl_init_func)(Port *, const char *, const char *);
-typedef int   (*pg_be_sasl_exchange_func)(void *, const char *, int, char **, 
int *, char **);
+/*
+ * Maximum accepted size of SASL messages.
+ *
+ * The messages that the server or libpq generate are much smaller than this,
+ * but have some headroom.
+ */
+#define PG_MAX_SASL_MESSAGE_LENGTH     1024
+
+/*
+ * Backend mechanism API
+ *
+ * To implement a backend mechanism, declare a pg_be_sasl_mech struct with
+ * appropriate callback implementations. Then pass the mechanism to
+ * CheckSASLAuth() during ClientAuthentication(), once the server has decided
+ * which authentication method to use.
+ */
+
+/*
+ * mech.get_mechanisms()
+ *
+ * Retrieves the list of SASL mechanism names supported by this implementation.
+ * The names are appended into the provided buffer.
+ *
+ * Input parameters:
+ *
+ *   port: the client Port
+ *
+ * Output parameters:
+ *
+ *   buf: a StringInfo buffer that the callback should populate with supported
+ *        mechanism names. Null-terminated names should be printed to the 
buffer
+ *        using appendStringInfo*().
+ */
+typedef void  (*pg_be_sasl_mechanism_func)(Port *port, StringInfo buf);
+
+/*
+ * mech.init()
+ *
+ * Initializes mechanism-specific state for a connection. This callback must
+ * return a pointer to its allocated state, which will be passed as-is as the
+ * first argument to the other callbacks.
+ *
+ * Input paramters:
+ *
+ *   port:        the client Port
+ *
+ *      mech:        the actual mechanism name in use by the client
+ *
+ *      shadow_pass: the shadow entry for the user being authenticated, or 
NULL if
+ *                   one does not exist. Mechanisms that do not use shadow 
entries
+ *                   may ignore this parameter. If a mechanism uses shadow 
entries
+ *                   but shadow_pass is NULL, the implementation must continue 
the
+ *                   exchange as if the user existed and the password did not
+ *                   match, to avoid disclosing valid user names.
+ */
+typedef void *(*pg_be_sasl_init_func)(Port *port, const char *mech,
+                                                                         const 
char *shadow_pass);
+
+/*
+ * mech.exchange()
+ *
+ * Produces a server challenge to be sent to the client. The callback must
+ * return one of the PG_SASL_EXCHANGE_* values, depending on whether the
+ * exchange must continue, has finished successfully, or has failed.
+ *
+ * Input parameters:
+ *
+ *   state:    the opaque mechanism state returned by mech.init()
+ *
+ *   input:    the response data sent by the client, or NULL if the mechanism 
is
+ *             client-first but the client did not send an initial response.
+ *             (This can only happen during the first message from the client.)
+ *             This is guaranteed to be null-terminated for safety, but SASL
+ *             allows embedded nulls in responses, so mechanisms must be 
careful
+ *             to check inputlen.
+ *
+ *   inputlen: the length of the challenge data sent by the server, or -1 if 
the
+ *             client did not send an initial response
+ *
+ * Output parameters, to be set by the callback function:
+ *
+ *   output:    a palloc'd buffer containing either the server's next challenge
+ *              (if PG_SASL_EXCHANGE_CONTINUE is returned) or the server's
+ *              outcome data (if PG_SASL_EXCHANGE_SUCCESS is returned and the
+ *              mechanism requires data to be sent during a successful 
outcome).
+ *              The callback should set this to NULL if the exchange is over 
and
+ *              no output should be sent, which should correspond to either
+ *              PG_SASL_EXCHANGE_FAILURE or a PG_SASL_EXCHANGE_SUCCESS with no
+ *              outcome data.
+ *
+ *   outputlen: the length of the challenge data. Ignored if *output is NULL.
+ *
+ *   logdetail: set to an optional DETAIL message to be printed to the server
+ *              log, to disambiguate failure modes. (The client will only ever
+ *              see the same generic authentication failure message.) Ignored 
if
+ *              the exchange is completed with PG_SASL_EXCHANGE_SUCCESS.
+ */
+typedef int   (*pg_be_sasl_exchange_func)(void *state,
+                                                                               
  const char *input, int inputlen,
+                                                                               
  char **output, int *outputlen,
+                                                                               
  char **logdetail);
 
 typedef struct
 {
@@ -31,4 +133,8 @@ typedef struct
        pg_be_sasl_exchange_func        exchange;
 } pg_be_sasl_mech;
 
+/* Common implementation for auth.c */
+extern int CheckSASLAuth(const pg_be_sasl_mech *mech, Port *port,
+                                                char *shadow_pass, char 
**logdetail);
+
 #endif /* PG_SASL_H */
diff --git a/src/interfaces/libpq/fe-auth-sasl.h 
b/src/interfaces/libpq/fe-auth-sasl.h
new file mode 100644
index 0000000000..1409e51287
--- /dev/null
+++ b/src/interfaces/libpq/fe-auth-sasl.h
@@ -0,0 +1,131 @@
+/*-------------------------------------------------------------------------
+ *
+ * fe-auth-sasl.h
+ *    Defines the SASL mechanism interface for the libpq frontend. Each SASL
+ *    mechanism defines a frontend and a backend callback structure. This is 
not
+ *    part of the public API for applications.
+ *
+ *    See src/include/libpq/sasl.h for the backend counterpart.
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * src/interfaces/libpq/fe-auth-sasl.h
+ *
+ *-------------------------------------------------------------------------
+ */
+
+#ifndef FE_AUTH_SASL_H
+#define FE_AUTH_SASL_H
+
+#include "libpq-fe.h"
+
+/*
+ * Frontend mechanism API
+ *
+ * To implement a frontend mechanism, declare a pg_be_sasl_mech struct with
+ * appropriate callback implementations, then hook it into conn->sasl during
+ * pg_SASL_init()'s mechanism negotiation.
+ */
+
+/*
+ * mech.init()
+ *
+ * Initializes mechanism-specific state for a connection. This callback must
+ * return a pointer to its allocated state, which will be passed as-is as the
+ * first argument to the other callbacks. mech.free() will be called to release
+ * any state resources.
+ *
+ * If state allocation fails, the implementation should return NULL to fail the
+ * authentication exchange.
+ *
+ * Input parameters:
+ *
+ *   conn:     the connection to the server
+ *
+ *   password: the user's supplied password for the current connection
+ *
+ *   mech:     the mechanism name in use, for implementations that may 
advertise
+ *             more than one name (such as *-PLUS variants)
+ */
+typedef void *(*pg_fe_sasl_init_func)(PGconn *conn, const char *password,
+                                                                         const 
char *mech);
+
+/*
+ * mech.exchange()
+ *
+ * Produces a client response to a server challenge. As a special case for
+ * client-first SASL mechanisms, exchange() is called with a NULL server
+ * response once at the start of the authentication exchange to generate an
+ * initial response.
+ *
+ * Input parameters:
+ *
+ *   state:    the opaque mechanism state returned by mech.init()
+ *
+ *   input:    the challenge data sent by the server, or NULL when generating a
+ *             client-first initial response (that is, when the server expects
+ *             the client to send a message to start the exchange). This is
+ *             guaranteed to be null-terminated for safety, but SASL allows
+ *             embedded nulls in challenges, so mechanisms must be careful to
+ *             check inputlen.
+ *
+ *   inputlen: the length of the challenge data sent by the server, or -1
+ *             during client-first initial response generation.
+ *
+ * Output parameters, to be set by the callback function:
+ *
+ *   output:    a malloc'd buffer containing the client's response to the
+ *              server, or NULL if the exchange should be aborted. (*success
+ *              should be set to false in the latter case.)
+ *
+ *   outputlen: the length of the client response buffer, or zero if no data
+ *              should be sent due to an exchange failure
+ *
+ *   done:      set to true if the SASL exchange should not continue, because
+ *              the exchange is either complete or failed
+ *
+ *   success:   set to true if the SASL exchange completed successfully. 
Ignored
+ *              if *done is false.
+ */
+typedef void  (*pg_fe_sasl_exchange_func)(void *state,
+                                                                               
  char *input, int inputlen,
+                                                                               
  char **output, int *outputlen,
+                                                                               
  bool *done, bool *success);
+
+/*
+ * mech.channel_bound()
+ *
+ * Returns true if the connection has an established channel binding. A
+ * mechanism implementation must ensure that a SASL exchange has actually been
+ * completed, in addition to checking that channel binding is in use.
+ *
+ * Mechanisms that do not implement channel binding may simply return false.
+ *
+ * Input parameters:
+ *
+ *   state:    the opaque mechanism state returned by mech.init()
+ */
+typedef bool  (*pg_fe_sasl_channel_bound_func)(void *);
+
+/*
+ * mech.free()
+ *
+ * Frees the state allocated by mech.init(). This is called when the connection
+ * is dropped, not when the exchange is completed.
+ *
+ * Input parameters:
+ *
+ *   state:    the opaque mechanism state returned by mech.init()
+ */
+typedef void  (*pg_fe_sasl_free_func)(void *);
+
+typedef struct
+{
+       pg_fe_sasl_init_func                    init;
+       pg_fe_sasl_exchange_func                exchange;
+       pg_fe_sasl_channel_bound_func   channel_bound;
+       pg_fe_sasl_free_func                    free;
+} pg_fe_sasl_mech;
+
+#endif /* FE_AUTH_SASL_H */
diff --git a/src/interfaces/libpq/fe-auth.c b/src/interfaces/libpq/fe-auth.c
index d5cbac108e..f299e72e7e 100644
--- a/src/interfaces/libpq/fe-auth.c
+++ b/src/interfaces/libpq/fe-auth.c
@@ -41,6 +41,7 @@
 #include "common/md5.h"
 #include "common/scram-common.h"
 #include "fe-auth.h"
+#include "fe-auth-sasl.h"
 #include "libpq-fe.h"
 
 #ifdef ENABLE_GSS
@@ -672,6 +673,11 @@ pg_SASL_continue(PGconn *conn, int payloadlen, bool final)
                                                         
libpq_gettext("AuthenticationSASLFinal received from server, but SASL 
authentication was not completed\n"));
                return STATUS_ERROR;
        }
+       /*
+        * TODO SASL requires us to accomodate zero-length responses.
+        * TODO is it legal for a client not to send a response to a server
+        * challenge, if the exchange isn't being aborted?
+        */
        if (outputlen != 0)
        {
                /*
diff --git a/src/interfaces/libpq/libpq-int.h b/src/interfaces/libpq/libpq-int.h
index 3ebf111158..e9f214b61b 100644
--- a/src/interfaces/libpq/libpq-int.h
+++ b/src/interfaces/libpq/libpq-int.h
@@ -41,6 +41,7 @@
 #include "getaddrinfo.h"
 #include "libpq/pqcomm.h"
 /* include stuff found in fe only */
+#include "fe-auth-sasl.h"
 #include "pqexpbuffer.h"
 
 #ifdef ENABLE_GSS
@@ -339,19 +340,6 @@ typedef struct pg_conn_host
                                                                 * found in 
password file. */
 } pg_conn_host;
 
-typedef void *(*pg_fe_sasl_init_func)(PGconn *, const char *, const char *);
-typedef void  (*pg_fe_sasl_exchange_func)(void *, char *, int, char **, int *, 
bool *, bool *);
-typedef bool  (*pg_fe_sasl_channel_bound_func)(void *);
-typedef void  (*pg_fe_sasl_free_func)(void *);
-
-typedef struct
-{
-       pg_fe_sasl_init_func                    init;
-       pg_fe_sasl_exchange_func                exchange;
-       pg_fe_sasl_channel_bound_func   channel_bound;
-       pg_fe_sasl_free_func                    free;
-} pg_fe_sasl_mech;
-
 /*
  * PGconn stores all the state data associated with a single connection
  * to a backend.
From 22cd26de5266880d2cc5419ce80428ec5c25bf5f Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchamp...@vmware.com>
Date: Tue, 13 Apr 2021 10:25:48 -0700
Subject: [PATCH v3 1/2] auth: generalize SASL mechanisms

Split the SASL logic out from the SCRAM implementation, so that it can
be reused by other mechanisms.  New implementations will implement both
a pg_fe_sasl_mech and a pg_be_sasl_mech.
---
 src/backend/libpq/auth-scram.c       |  48 ++++++----
 src/backend/libpq/auth.c             |  40 +++++---
 src/include/libpq/sasl.h             | 127 ++++++++++++++++++++++++++
 src/include/libpq/scram.h            |  13 +--
 src/interfaces/libpq/fe-auth-sasl.h  | 131 +++++++++++++++++++++++++++
 src/interfaces/libpq/fe-auth-scram.c |  40 +++++---
 src/interfaces/libpq/fe-auth.c       |  22 ++++-
 src/interfaces/libpq/fe-auth.h       |  11 +--
 src/interfaces/libpq/fe-connect.c    |   6 +-
 src/interfaces/libpq/libpq-int.h     |   2 +
 10 files changed, 367 insertions(+), 73 deletions(-)
 create mode 100644 src/include/libpq/sasl.h
 create mode 100644 src/interfaces/libpq/fe-auth-sasl.h

diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c
index f9e1026a12..2965ea2ddb 100644
--- a/src/backend/libpq/auth-scram.c
+++ b/src/backend/libpq/auth-scram.c
@@ -101,11 +101,25 @@
 #include "common/sha2.h"
 #include "libpq/auth.h"
 #include "libpq/crypt.h"
+#include "libpq/sasl.h"
 #include "libpq/scram.h"
 #include "miscadmin.h"
 #include "utils/builtins.h"
 #include "utils/timestamp.h"
 
+static void  scram_get_mechanisms(Port *port, StringInfo buf);
+static void *scram_init(Port *port, const char *selected_mech,
+						const char *shadow_pass);
+static int   scram_exchange(void *opaq, const char *input, int inputlen,
+							char **output, int *outputlen, char **logdetail);
+
+/* Mechanism declaration */
+const pg_be_sasl_mech pg_be_scram_mech = {
+	scram_get_mechanisms,
+	scram_init,
+	scram_exchange,
+};
+
 /*
  * Status data for a SCRAM authentication exchange.  This should be kept
  * internal to this file.
@@ -170,16 +184,14 @@ static char *sanitize_str(const char *s);
 static char *scram_mock_salt(const char *username);
 
 /*
- * pg_be_scram_get_mechanisms
- *
  * Get a list of SASL mechanisms that this module supports.
  *
  * For the convenience of building the FE/BE packet that lists the
  * mechanisms, the names are appended to the given StringInfo buffer,
  * separated by '\0' bytes.
  */
-void
-pg_be_scram_get_mechanisms(Port *port, StringInfo buf)
+static void
+scram_get_mechanisms(Port *port, StringInfo buf)
 {
 	/*
 	 * Advertise the mechanisms in decreasing order of importance.  So the
@@ -199,8 +211,6 @@ pg_be_scram_get_mechanisms(Port *port, StringInfo buf)
 }
 
 /*
- * pg_be_scram_init
- *
  * Initialize a new SCRAM authentication exchange status tracker.  This
  * needs to be called before doing any exchange.  It will be filled later
  * after the beginning of the exchange with authentication information.
@@ -215,10 +225,8 @@ pg_be_scram_get_mechanisms(Port *port, StringInfo buf)
  * an authentication exchange, but it will fail, as if an incorrect password
  * was given.
  */
-void *
-pg_be_scram_init(Port *port,
-				 const char *selected_mech,
-				 const char *shadow_pass)
+static void *
+scram_init(Port *port, const char *selected_mech, const char *shadow_pass)
 {
 	scram_state *state;
 	bool		got_secret;
@@ -325,9 +333,9 @@ pg_be_scram_init(Port *port,
  * string at *logdetail that will be sent to the postmaster log (but not
  * the client).
  */
-int
-pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
-					 char **output, int *outputlen, char **logdetail)
+static int
+scram_exchange(void *opaq, const char *input, int inputlen,
+			   char **output, int *outputlen, char **logdetail)
 {
 	scram_state *state = (scram_state *) opaq;
 	int			result;
@@ -346,7 +354,7 @@ pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
 
 		*output = pstrdup("");
 		*outputlen = 0;
-		return SASL_EXCHANGE_CONTINUE;
+		return PG_SASL_EXCHANGE_CONTINUE;
 	}
 
 	/*
@@ -379,7 +387,7 @@ pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
 			*output = build_server_first_message(state);
 
 			state->state = SCRAM_AUTH_SALT_SENT;
-			result = SASL_EXCHANGE_CONTINUE;
+			result = PG_SASL_EXCHANGE_CONTINUE;
 			break;
 
 		case SCRAM_AUTH_SALT_SENT:
@@ -408,7 +416,7 @@ pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
 			 * erroring out in an application-specific way.  We choose to do
 			 * the latter, so that the error message for invalid password is
 			 * the same for all authentication methods.  The caller will call
-			 * ereport(), when we return SASL_EXCHANGE_FAILURE with no output.
+			 * ereport(), when we return PG_SASL_EXCHANGE_FAILURE with no output.
 			 *
 			 * NB: the order of these checks is intentional.  We calculate the
 			 * client proof even in a mock authentication, even though it's
@@ -417,7 +425,7 @@ pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
 			 */
 			if (!verify_client_proof(state) || state->doomed)
 			{
-				result = SASL_EXCHANGE_FAILURE;
+				result = PG_SASL_EXCHANGE_FAILURE;
 				break;
 			}
 
@@ -425,16 +433,16 @@ pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
 			*output = build_server_final_message(state);
 
 			/* Success! */
-			result = SASL_EXCHANGE_SUCCESS;
+			result = PG_SASL_EXCHANGE_SUCCESS;
 			state->state = SCRAM_AUTH_FINISHED;
 			break;
 
 		default:
 			elog(ERROR, "invalid SCRAM exchange state");
-			result = SASL_EXCHANGE_FAILURE;
+			result = PG_SASL_EXCHANGE_FAILURE;
 	}
 
-	if (result == SASL_EXCHANGE_FAILURE && state->logdetail && logdetail)
+	if (result == PG_SASL_EXCHANGE_FAILURE && state->logdetail && logdetail)
 		*logdetail = state->logdetail;
 
 	if (*output)
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index 967b5ef73c..82f043a343 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -26,11 +26,11 @@
 #include "commands/user.h"
 #include "common/ip.h"
 #include "common/md5.h"
-#include "common/scram-common.h"
 #include "libpq/auth.h"
 #include "libpq/crypt.h"
 #include "libpq/libpq.h"
 #include "libpq/pqformat.h"
+#include "libpq/sasl.h"
 #include "libpq/scram.h"
 #include "miscadmin.h"
 #include "port/pg_bswap.h"
@@ -51,6 +51,13 @@ static void auth_failed(Port *port, int status, char *logdetail);
 static char *recv_password_packet(Port *port);
 static void set_authn_id(Port *port, const char *id);
 
+/*----------------------------------------------------------------
+ * SASL common authentication
+ *----------------------------------------------------------------
+ */
+static int	SASL_exchange(const pg_be_sasl_mech *mech, Port *port,
+						  char *shadow_pass, char **logdetail);
+
 
 /*----------------------------------------------------------------
  * Password-based authentication methods (password, md5, and scram-sha-256)
@@ -912,12 +919,13 @@ CheckMD5Auth(Port *port, char *shadow_pass, char **logdetail)
 }
 
 static int
-CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
+SASL_exchange(const pg_be_sasl_mech *mech, Port *port, char *shadow_pass,
+			  char **logdetail)
 {
 	StringInfoData sasl_mechs;
 	int			mtype;
 	StringInfoData buf;
-	void	   *scram_opaq = NULL;
+	void	   *opaq = NULL;
 	char	   *output = NULL;
 	int			outputlen = 0;
 	const char *input;
@@ -931,7 +939,7 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
 	 */
 	initStringInfo(&sasl_mechs);
 
-	pg_be_scram_get_mechanisms(port, &sasl_mechs);
+	mech->get_mechanisms(port, &sasl_mechs);
 	/* Put another '\0' to mark that list is finished. */
 	appendStringInfoChar(&sasl_mechs, '\0');
 
@@ -998,7 +1006,7 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
 			 * This is because we don't want to reveal to an attacker what
 			 * usernames are valid, nor which users have a valid password.
 			 */
-			scram_opaq = pg_be_scram_init(port, selected_mech, shadow_pass);
+			opaq = mech->init(port, selected_mech, shadow_pass);
 
 			inputlen = pq_getmsgint(&buf, 4);
 			if (inputlen == -1)
@@ -1022,12 +1030,11 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
 		Assert(input == NULL || input[inputlen] == '\0');
 
 		/*
-		 * we pass 'logdetail' as NULL when doing a mock authentication,
-		 * because we should already have a better error message in that case
+		 * Hand the incoming message to the mechanism implementation.
 		 */
-		result = pg_be_scram_exchange(scram_opaq, input, inputlen,
-									  &output, &outputlen,
-									  logdetail);
+		result = mech->exchange(opaq, input, inputlen,
+								&output, &outputlen,
+								logdetail);
 
 		/* input buffer no longer used */
 		pfree(buf.data);
@@ -1039,17 +1046,18 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
 			 */
 			elog(DEBUG4, "sending SASL challenge of length %u", outputlen);
 
-			if (result == SASL_EXCHANGE_SUCCESS)
+			/* TODO: PG_SASL_EXCHANGE_FAILURE with output is forbidden in SASL */
+			if (result == PG_SASL_EXCHANGE_SUCCESS)
 				sendAuthRequest(port, AUTH_REQ_SASL_FIN, output, outputlen);
 			else
 				sendAuthRequest(port, AUTH_REQ_SASL_CONT, output, outputlen);
 
 			pfree(output);
 		}
-	} while (result == SASL_EXCHANGE_CONTINUE);
+	} while (result == PG_SASL_EXCHANGE_CONTINUE);
 
 	/* Oops, Something bad happened */
-	if (result != SASL_EXCHANGE_SUCCESS)
+	if (result != PG_SASL_EXCHANGE_SUCCESS)
 	{
 		return STATUS_ERROR;
 	}
@@ -1057,6 +1065,12 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
 	return STATUS_OK;
 }
 
+static int
+CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
+{
+	return SASL_exchange(&pg_be_scram_mech, port, shadow_pass, logdetail);
+}
+
 
 /*----------------------------------------------------------------
  * GSSAPI authentication system
diff --git a/src/include/libpq/sasl.h b/src/include/libpq/sasl.h
new file mode 100644
index 0000000000..c732f35564
--- /dev/null
+++ b/src/include/libpq/sasl.h
@@ -0,0 +1,127 @@
+/*-------------------------------------------------------------------------
+ *
+ * sasl.h
+ *     Defines the SASL mechanism interface for the libpq backend. Each SASL
+ *     mechanism defines a frontend and a backend callback structure.
+ *
+ *     See src/interfaces/libpq/fe-auth-sasl.h for the frontend counterpart.
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * src/include/libpq/sasl.h
+ *
+ *-------------------------------------------------------------------------
+ */
+#ifndef PG_SASL_H
+#define PG_SASL_H
+
+#include "libpq/libpq-be.h"
+
+/* Status codes for message exchange */
+#define PG_SASL_EXCHANGE_CONTINUE		0
+#define PG_SASL_EXCHANGE_SUCCESS		1
+#define PG_SASL_EXCHANGE_FAILURE		2
+
+/*
+ * Backend mechanism API
+ *
+ * To implement a backend mechanism, declare a pg_be_sasl_mech struct with
+ * appropriate callback implementations. Then pass the mechanism to
+ * CheckSASLAuth() during ClientAuthentication(), once the server has decided
+ * which authentication method to use.
+ */
+
+/*
+ * mech.get_mechanisms()
+ *
+ * Retrieves the list of SASL mechanism names supported by this implementation.
+ * The names are appended into the provided buffer.
+ *
+ * Input parameters:
+ *
+ *   port: the client Port
+ *
+ * Output parameters:
+ *
+ *   buf: a StringInfo buffer that the callback should populate with supported
+ *        mechanism names. Null-terminated names should be printed to the buffer
+ *        using appendStringInfo*().
+ */
+typedef void  (*pg_be_sasl_mechanism_func)(Port *port, StringInfo buf);
+
+/*
+ * mech.init()
+ *
+ * Initializes mechanism-specific state for a connection. This callback must
+ * return a pointer to its allocated state, which will be passed as-is as the
+ * first argument to the other callbacks.
+ *
+ * Input paramters:
+ *
+ *   port:        the client Port
+ *
+ *	 mech:        the actual mechanism name in use by the client
+ *
+ *	 shadow_pass: the shadow entry for the user being authenticated, or NULL if
+ *	              one does not exist. Mechanisms that do not use shadow entries
+ *	              may ignore this parameter. If a mechanism uses shadow entries
+ *	              but shadow_pass is NULL, the implementation must continue the
+ *	              exchange as if the user existed and the password did not
+ *	              match, to avoid disclosing valid user names.
+ */
+typedef void *(*pg_be_sasl_init_func)(Port *port, const char *mech,
+									  const char *shadow_pass);
+
+/*
+ * mech.exchange()
+ *
+ * Produces a server challenge to be sent to the client. The callback must
+ * return one of the PG_SASL_EXCHANGE_* values, depending on whether the
+ * exchange must continue, has finished successfully, or has failed.
+ *
+ * Input parameters:
+ *
+ *   state:    the opaque mechanism state returned by mech.init()
+ *
+ *   input:    the response data sent by the client, or NULL if the mechanism is
+ *             client-first but the client did not send an initial response.
+ *             (This can only happen during the first message from the client.)
+ *             This is guaranteed to be null-terminated for safety, but SASL
+ *             allows embedded nulls in responses, so mechanisms must be careful
+ *             to check inputlen.
+ *
+ *   inputlen: the length of the challenge data sent by the server, or -1 if the
+ *             client did not send an initial response
+ *
+ * Output parameters, to be set by the callback function:
+ *
+ *   output:    a palloc'd buffer containing either the server's next challenge
+ *              (if PG_SASL_EXCHANGE_CONTINUE is returned) or the server's
+ *              outcome data (if PG_SASL_EXCHANGE_SUCCESS is returned and the
+ *              mechanism requires data to be sent during a successful outcome).
+ *              The callback should set this to NULL if the exchange is over and
+ *              no output should be sent, which should correspond to either
+ *              PG_SASL_EXCHANGE_FAILURE or a PG_SASL_EXCHANGE_SUCCESS with no
+ *              outcome data.
+ *
+ *   outputlen: the length of the challenge data. Ignored if *output is NULL.
+ *
+ *   logdetail: set to an optional DETAIL message to be printed to the server
+ *              log, to disambiguate failure modes. (The client will only ever
+ *              see the same generic authentication failure message.) Ignored if
+ *              the exchange is completed with PG_SASL_EXCHANGE_SUCCESS.
+ */
+typedef int   (*pg_be_sasl_exchange_func)(void *state,
+										  const char *input, int inputlen,
+										  char **output, int *outputlen,
+										  char **logdetail);
+
+typedef struct
+{
+	pg_be_sasl_mechanism_func	get_mechanisms;
+	pg_be_sasl_init_func		init;
+	pg_be_sasl_exchange_func	exchange;
+} pg_be_sasl_mech;
+
+#endif /* PG_SASL_H */
diff --git a/src/include/libpq/scram.h b/src/include/libpq/scram.h
index 2c879150da..9e4540bde3 100644
--- a/src/include/libpq/scram.h
+++ b/src/include/libpq/scram.h
@@ -15,17 +15,10 @@
 
 #include "lib/stringinfo.h"
 #include "libpq/libpq-be.h"
+#include "libpq/sasl.h"
 
-/* Status codes for message exchange */
-#define SASL_EXCHANGE_CONTINUE		0
-#define SASL_EXCHANGE_SUCCESS		1
-#define SASL_EXCHANGE_FAILURE		2
-
-/* Routines dedicated to authentication */
-extern void pg_be_scram_get_mechanisms(Port *port, StringInfo buf);
-extern void *pg_be_scram_init(Port *port, const char *selected_mech, const char *shadow_pass);
-extern int	pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
-								 char **output, int *outputlen, char **logdetail);
+/* Implementation */
+extern const pg_be_sasl_mech pg_be_scram_mech;
 
 /* Routines to handle and check SCRAM-SHA-256 secret */
 extern char *pg_be_scram_build_secret(const char *password);
diff --git a/src/interfaces/libpq/fe-auth-sasl.h b/src/interfaces/libpq/fe-auth-sasl.h
new file mode 100644
index 0000000000..1409e51287
--- /dev/null
+++ b/src/interfaces/libpq/fe-auth-sasl.h
@@ -0,0 +1,131 @@
+/*-------------------------------------------------------------------------
+ *
+ * fe-auth-sasl.h
+ *    Defines the SASL mechanism interface for the libpq frontend. Each SASL
+ *    mechanism defines a frontend and a backend callback structure. This is not
+ *    part of the public API for applications.
+ *
+ *    See src/include/libpq/sasl.h for the backend counterpart.
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * src/interfaces/libpq/fe-auth-sasl.h
+ *
+ *-------------------------------------------------------------------------
+ */
+
+#ifndef FE_AUTH_SASL_H
+#define FE_AUTH_SASL_H
+
+#include "libpq-fe.h"
+
+/*
+ * Frontend mechanism API
+ *
+ * To implement a frontend mechanism, declare a pg_be_sasl_mech struct with
+ * appropriate callback implementations, then hook it into conn->sasl during
+ * pg_SASL_init()'s mechanism negotiation.
+ */
+
+/*
+ * mech.init()
+ *
+ * Initializes mechanism-specific state for a connection. This callback must
+ * return a pointer to its allocated state, which will be passed as-is as the
+ * first argument to the other callbacks. mech.free() will be called to release
+ * any state resources.
+ *
+ * If state allocation fails, the implementation should return NULL to fail the
+ * authentication exchange.
+ *
+ * Input parameters:
+ *
+ *   conn:     the connection to the server
+ *
+ *   password: the user's supplied password for the current connection
+ *
+ *   mech:     the mechanism name in use, for implementations that may advertise
+ *             more than one name (such as *-PLUS variants)
+ */
+typedef void *(*pg_fe_sasl_init_func)(PGconn *conn, const char *password,
+									  const char *mech);
+
+/*
+ * mech.exchange()
+ *
+ * Produces a client response to a server challenge. As a special case for
+ * client-first SASL mechanisms, exchange() is called with a NULL server
+ * response once at the start of the authentication exchange to generate an
+ * initial response.
+ *
+ * Input parameters:
+ *
+ *   state:    the opaque mechanism state returned by mech.init()
+ *
+ *   input:    the challenge data sent by the server, or NULL when generating a
+ *             client-first initial response (that is, when the server expects
+ *             the client to send a message to start the exchange). This is
+ *             guaranteed to be null-terminated for safety, but SASL allows
+ *             embedded nulls in challenges, so mechanisms must be careful to
+ *             check inputlen.
+ *
+ *   inputlen: the length of the challenge data sent by the server, or -1
+ *             during client-first initial response generation.
+ *
+ * Output parameters, to be set by the callback function:
+ *
+ *   output:    a malloc'd buffer containing the client's response to the
+ *              server, or NULL if the exchange should be aborted. (*success
+ *              should be set to false in the latter case.)
+ *
+ *   outputlen: the length of the client response buffer, or zero if no data
+ *              should be sent due to an exchange failure
+ *
+ *   done:      set to true if the SASL exchange should not continue, because
+ *              the exchange is either complete or failed
+ *
+ *   success:   set to true if the SASL exchange completed successfully. Ignored
+ *              if *done is false.
+ */
+typedef void  (*pg_fe_sasl_exchange_func)(void *state,
+										  char *input, int inputlen,
+										  char **output, int *outputlen,
+										  bool *done, bool *success);
+
+/*
+ * mech.channel_bound()
+ *
+ * Returns true if the connection has an established channel binding. A
+ * mechanism implementation must ensure that a SASL exchange has actually been
+ * completed, in addition to checking that channel binding is in use.
+ *
+ * Mechanisms that do not implement channel binding may simply return false.
+ *
+ * Input parameters:
+ *
+ *   state:    the opaque mechanism state returned by mech.init()
+ */
+typedef bool  (*pg_fe_sasl_channel_bound_func)(void *);
+
+/*
+ * mech.free()
+ *
+ * Frees the state allocated by mech.init(). This is called when the connection
+ * is dropped, not when the exchange is completed.
+ *
+ * Input parameters:
+ *
+ *   state:    the opaque mechanism state returned by mech.init()
+ */
+typedef void  (*pg_fe_sasl_free_func)(void *);
+
+typedef struct
+{
+	pg_fe_sasl_init_func			init;
+	pg_fe_sasl_exchange_func		exchange;
+	pg_fe_sasl_channel_bound_func	channel_bound;
+	pg_fe_sasl_free_func			free;
+} pg_fe_sasl_mech;
+
+#endif /* FE_AUTH_SASL_H */
diff --git a/src/interfaces/libpq/fe-auth-scram.c b/src/interfaces/libpq/fe-auth-scram.c
index 5881386e37..515ef66f37 100644
--- a/src/interfaces/libpq/fe-auth-scram.c
+++ b/src/interfaces/libpq/fe-auth-scram.c
@@ -21,6 +21,22 @@
 #include "fe-auth.h"
 
 
+/* The exported SCRAM callback mechanism. */
+static void *scram_init(PGconn *conn, const char *password,
+						const char *sasl_mechanism);
+static void scram_exchange(void *opaq, char *input, int inputlen,
+						   char **output, int *outputlen,
+						   bool *done, bool *success);
+static bool scram_channel_bound(void *opaq);
+static void scram_free(void *opaq);
+
+const pg_fe_sasl_mech pg_scram_mech = {
+	scram_init,
+	scram_exchange,
+	scram_channel_bound,
+	scram_free,
+};
+
 /*
  * Status of exchange messages used for SCRAM authentication via the
  * SASL protocol.
@@ -72,10 +88,10 @@ static bool calculate_client_proof(fe_scram_state *state,
 /*
  * Initialize SCRAM exchange status.
  */
-void *
-pg_fe_scram_init(PGconn *conn,
-				 const char *password,
-				 const char *sasl_mechanism)
+static void *
+scram_init(PGconn *conn,
+		   const char *password,
+		   const char *sasl_mechanism)
 {
 	fe_scram_state *state;
 	char	   *prep_password;
@@ -128,8 +144,8 @@ pg_fe_scram_init(PGconn *conn,
  * Note that the caller must also ensure that the exchange was actually
  * successful.
  */
-bool
-pg_fe_scram_channel_bound(void *opaq)
+static bool
+scram_channel_bound(void *opaq)
 {
 	fe_scram_state *state = (fe_scram_state *) opaq;
 
@@ -152,8 +168,8 @@ pg_fe_scram_channel_bound(void *opaq)
 /*
  * Free SCRAM exchange status
  */
-void
-pg_fe_scram_free(void *opaq)
+static void
+scram_free(void *opaq)
 {
 	fe_scram_state *state = (fe_scram_state *) opaq;
 
@@ -188,10 +204,10 @@ pg_fe_scram_free(void *opaq)
 /*
  * Exchange a SCRAM message with backend.
  */
-void
-pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
-					 char **output, int *outputlen,
-					 bool *done, bool *success)
+static void
+scram_exchange(void *opaq, char *input, int inputlen,
+			   char **output, int *outputlen,
+			   bool *done, bool *success)
 {
 	fe_scram_state *state = (fe_scram_state *) opaq;
 	PGconn	   *conn = state->conn;
diff --git a/src/interfaces/libpq/fe-auth.c b/src/interfaces/libpq/fe-auth.c
index e8062647e6..f299e72e7e 100644
--- a/src/interfaces/libpq/fe-auth.c
+++ b/src/interfaces/libpq/fe-auth.c
@@ -41,6 +41,7 @@
 #include "common/md5.h"
 #include "common/scram-common.h"
 #include "fe-auth.h"
+#include "fe-auth-sasl.h"
 #include "libpq-fe.h"
 
 #ifdef ENABLE_GSS
@@ -482,7 +483,10 @@ pg_SASL_init(PGconn *conn, int payloadlen)
 				 * channel_binding is not disabled.
 				 */
 				if (conn->channel_binding[0] != 'd')	/* disable */
+				{
 					selected_mechanism = SCRAM_SHA_256_PLUS_NAME;
+					conn->sasl = &pg_scram_mech;
+				}
 #else
 				/*
 				 * The client does not support channel binding.  If it is
@@ -516,7 +520,10 @@ pg_SASL_init(PGconn *conn, int payloadlen)
 		}
 		else if (strcmp(mechanism_buf.data, SCRAM_SHA_256_NAME) == 0 &&
 				 !selected_mechanism)
+		{
 			selected_mechanism = SCRAM_SHA_256_NAME;
+			conn->sasl = &pg_scram_mech;
+		}
 	}
 
 	if (!selected_mechanism)
@@ -555,20 +562,22 @@ pg_SASL_init(PGconn *conn, int payloadlen)
 		goto error;
 	}
 
+	Assert(conn->sasl);
+
 	/*
 	 * Initialize the SASL state information with all the information gathered
 	 * during the initial exchange.
 	 *
 	 * Note: Only tls-unique is supported for the moment.
 	 */
-	conn->sasl_state = pg_fe_scram_init(conn,
+	conn->sasl_state = conn->sasl->init(conn,
 										password,
 										selected_mechanism);
 	if (!conn->sasl_state)
 		goto oom_error;
 
 	/* Get the mechanism-specific Initial Client Response, if any */
-	pg_fe_scram_exchange(conn->sasl_state,
+	conn->sasl->exchange(conn->sasl_state,
 						 NULL, -1,
 						 &initialresponse, &initialresponselen,
 						 &done, &success);
@@ -649,7 +658,7 @@ pg_SASL_continue(PGconn *conn, int payloadlen, bool final)
 	/* For safety and convenience, ensure the buffer is NULL-terminated. */
 	challenge[payloadlen] = '\0';
 
-	pg_fe_scram_exchange(conn->sasl_state,
+	conn->sasl->exchange(conn->sasl_state,
 						 challenge, payloadlen,
 						 &output, &outputlen,
 						 &done, &success);
@@ -664,6 +673,11 @@ pg_SASL_continue(PGconn *conn, int payloadlen, bool final)
 							 libpq_gettext("AuthenticationSASLFinal received from server, but SASL authentication was not completed\n"));
 		return STATUS_ERROR;
 	}
+	/*
+	 * TODO SASL requires us to accomodate zero-length responses.
+	 * TODO is it legal for a client not to send a response to a server
+	 * challenge, if the exchange isn't being aborted?
+	 */
 	if (outputlen != 0)
 	{
 		/*
@@ -830,7 +844,7 @@ check_expected_areq(AuthRequest areq, PGconn *conn)
 			case AUTH_REQ_SASL_FIN:
 				break;
 			case AUTH_REQ_OK:
-				if (!pg_fe_scram_channel_bound(conn->sasl_state))
+				if (!conn->sasl || !conn->sasl->channel_bound(conn->sasl_state))
 				{
 					appendPQExpBufferStr(&conn->errorMessage,
 										 libpq_gettext("channel binding required, but server authenticated client without channel binding\n"));
diff --git a/src/interfaces/libpq/fe-auth.h b/src/interfaces/libpq/fe-auth.h
index 7877dcbd09..63927480ee 100644
--- a/src/interfaces/libpq/fe-auth.h
+++ b/src/interfaces/libpq/fe-auth.h
@@ -22,15 +22,8 @@
 extern int	pg_fe_sendauth(AuthRequest areq, int payloadlen, PGconn *conn);
 extern char *pg_fe_getauthname(PQExpBuffer errorMessage);
 
-/* Prototypes for functions in fe-auth-scram.c */
-extern void *pg_fe_scram_init(PGconn *conn,
-							  const char *password,
-							  const char *sasl_mechanism);
-extern bool pg_fe_scram_channel_bound(void *opaq);
-extern void pg_fe_scram_free(void *opaq);
-extern void pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
-								 char **output, int *outputlen,
-								 bool *done, bool *success);
+/* Mechanisms in fe-auth-scram.c */
+extern const pg_fe_sasl_mech pg_scram_mech;
 extern char *pg_fe_scram_build_secret(const char *password);
 
 #endif							/* FE_AUTH_H */
diff --git a/src/interfaces/libpq/fe-connect.c b/src/interfaces/libpq/fe-connect.c
index fc65e490ef..e950b41374 100644
--- a/src/interfaces/libpq/fe-connect.c
+++ b/src/interfaces/libpq/fe-connect.c
@@ -516,11 +516,7 @@ pqDropConnection(PGconn *conn, bool flushInput)
 #endif
 	if (conn->sasl_state)
 	{
-		/*
-		 * XXX: if support for more authentication mechanisms is added, this
-		 * needs to call the right 'free' function.
-		 */
-		pg_fe_scram_free(conn->sasl_state);
+		conn->sasl->free(conn->sasl_state);
 		conn->sasl_state = NULL;
 	}
 }
diff --git a/src/interfaces/libpq/libpq-int.h b/src/interfaces/libpq/libpq-int.h
index 6b7fd2c267..e9f214b61b 100644
--- a/src/interfaces/libpq/libpq-int.h
+++ b/src/interfaces/libpq/libpq-int.h
@@ -41,6 +41,7 @@
 #include "getaddrinfo.h"
 #include "libpq/pqcomm.h"
 /* include stuff found in fe only */
+#include "fe-auth-sasl.h"
 #include "pqexpbuffer.h"
 
 #ifdef ENABLE_GSS
@@ -500,6 +501,7 @@ struct pg_conn
 	PGresult   *next_result;	/* next result (used in single-row mode) */
 
 	/* Assorted state for SASL, SSL, GSS, etc */
+	const pg_fe_sasl_mech *sasl;
 	void	   *sasl_state;
 
 	/* SSL structures */
-- 
2.25.1

From fbd17c7b77251ed66eed00d80efc58abb5eeb84a Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchamp...@vmware.com>
Date: Wed, 30 Jun 2021 09:27:40 -0700
Subject: [PATCH v3 2/2] auth: pull backend SASL exchange into its own file

This code motion is pulled into a separate commit to ease review.

Move SASL_exchange to its own file and rename it to CheckSASLAuth, which
is now called directly from ClientAuthentication(). This replaces the
CheckSCRAMAuth() and CheckOAuthBearer() wrappers.
---
 src/backend/libpq/Makefile    |   1 +
 src/backend/libpq/auth-sasl.c | 187 ++++++++++++++++++++++++++++++++++
 src/backend/libpq/auth.c      | 178 +-------------------------------
 src/include/libpq/auth.h      |   2 +
 src/include/libpq/sasl.h      |  13 +++
 5 files changed, 207 insertions(+), 174 deletions(-)
 create mode 100644 src/backend/libpq/auth-sasl.c

diff --git a/src/backend/libpq/Makefile b/src/backend/libpq/Makefile
index 8d1d16b0fc..6d385fd6a4 100644
--- a/src/backend/libpq/Makefile
+++ b/src/backend/libpq/Makefile
@@ -15,6 +15,7 @@ include $(top_builddir)/src/Makefile.global
 # be-fsstubs is here for historical reasons, probably belongs elsewhere
 
 OBJS = \
+	auth-sasl.o \
 	auth-scram.o \
 	auth.o \
 	be-fsstubs.o \
diff --git a/src/backend/libpq/auth-sasl.c b/src/backend/libpq/auth-sasl.c
new file mode 100644
index 0000000000..b7cdb2ecf6
--- /dev/null
+++ b/src/backend/libpq/auth-sasl.c
@@ -0,0 +1,187 @@
+/*-------------------------------------------------------------------------
+ *
+ * auth-sasl.c
+ *	  Routines to handle network authentication via SASL
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ *
+ * IDENTIFICATION
+ *	  src/backend/libpq/auth-sasl.c
+ *
+ *-------------------------------------------------------------------------
+ */
+
+#include "postgres.h"
+
+#include "libpq/auth.h"
+#include "libpq/libpq.h"
+#include "libpq/pqformat.h"
+#include "libpq/sasl.h"
+
+/*
+ * Perform a SASL exchange with a libpq client, using a specific mechanism
+ * implementation.
+ *
+ * shadow_pass is an optional pointer to the shadow entry for the client's
+ * presented user name. For mechanisms that use shadowed passwords, a NULL
+ * pointer here means that an entry could not be found for the user (or the user
+ * does not exist), and the mechanism should fail the authentication exchange.
+ *
+ * Mechanisms must take care not to reveal to the client that a user entry does
+ * not exist; ideally, the external failure mode is identical to that of an
+ * incorrect password. Mechanisms may instead use the logdetail output parameter
+ * to internally differentiate between failure cases and assist debugging by the
+ * server admin.
+ *
+ * A mechanism is not required to utilize a shadow entry, or even a password
+ * system at all; for these cases, shadow_pass may be ignored and the caller
+ * should just pass NULL.
+ */
+int
+CheckSASLAuth(const pg_be_sasl_mech *mech, Port *port, char *shadow_pass,
+			  char **logdetail)
+{
+	StringInfoData sasl_mechs;
+	int			mtype;
+	StringInfoData buf;
+	void	   *opaq = NULL;
+	char	   *output = NULL;
+	int			outputlen = 0;
+	const char *input;
+	int			inputlen;
+	int			result;
+	bool		initial;
+
+	/*
+	 * Send the SASL authentication request to user.  It includes the list of
+	 * authentication mechanisms that are supported.
+	 */
+	initStringInfo(&sasl_mechs);
+
+	mech->get_mechanisms(port, &sasl_mechs);
+	/* Put another '\0' to mark that list is finished. */
+	appendStringInfoChar(&sasl_mechs, '\0');
+
+	sendAuthRequest(port, AUTH_REQ_SASL, sasl_mechs.data, sasl_mechs.len);
+	pfree(sasl_mechs.data);
+
+	/*
+	 * Loop through SASL message exchange.  This exchange can consist of
+	 * multiple messages sent in both directions.  First message is always
+	 * from the client.  All messages from client to server are password
+	 * packets (type 'p').
+	 */
+	initial = true;
+	do
+	{
+		pq_startmsgread();
+		mtype = pq_getbyte();
+		if (mtype != 'p')
+		{
+			/* Only log error if client didn't disconnect. */
+			if (mtype != EOF)
+			{
+				ereport(ERROR,
+						(errcode(ERRCODE_PROTOCOL_VIOLATION),
+						 errmsg("expected SASL response, got message type %d",
+								mtype)));
+			}
+			else
+				return STATUS_EOF;
+		}
+
+		/* Get the actual SASL message */
+		initStringInfo(&buf);
+		if (pq_getmessage(&buf, PG_MAX_SASL_MESSAGE_LENGTH))
+		{
+			/* EOF - pq_getmessage already logged error */
+			pfree(buf.data);
+			return STATUS_ERROR;
+		}
+
+		elog(DEBUG4, "processing received SASL response of length %d", buf.len);
+
+		/*
+		 * The first SASLInitialResponse message is different from the others.
+		 * It indicates which SASL mechanism the client selected, and contains
+		 * an optional Initial Client Response payload.  The subsequent
+		 * SASLResponse messages contain just the SASL payload.
+		 */
+		if (initial)
+		{
+			const char *selected_mech;
+
+			selected_mech = pq_getmsgrawstring(&buf);
+
+			/*
+			 * Initialize the status tracker for message exchanges.
+			 *
+			 * If the user doesn't exist, or doesn't have a valid password, or
+			 * it's expired, we still go through the motions of SASL
+			 * authentication, but tell the authentication method that the
+			 * authentication is "doomed". That is, it's going to fail, no
+			 * matter what.
+			 *
+			 * This is because we don't want to reveal to an attacker what
+			 * usernames are valid, nor which users have a valid password.
+			 */
+			opaq = mech->init(port, selected_mech, shadow_pass);
+
+			inputlen = pq_getmsgint(&buf, 4);
+			if (inputlen == -1)
+				input = NULL;
+			else
+				input = pq_getmsgbytes(&buf, inputlen);
+
+			initial = false;
+		}
+		else
+		{
+			inputlen = buf.len;
+			input = pq_getmsgbytes(&buf, buf.len);
+		}
+		pq_getmsgend(&buf);
+
+		/*
+		 * The StringInfo guarantees that there's a \0 byte after the
+		 * response.
+		 */
+		Assert(input == NULL || input[inputlen] == '\0');
+
+		/*
+		 * Hand the incoming message to the mechanism implementation.
+		 */
+		result = mech->exchange(opaq, input, inputlen,
+								&output, &outputlen,
+								logdetail);
+
+		/* input buffer no longer used */
+		pfree(buf.data);
+
+		if (output)
+		{
+			/*
+			 * Negotiation generated data to be sent to the client.
+			 */
+			elog(DEBUG4, "sending SASL challenge of length %u", outputlen);
+
+			/* TODO: PG_SASL_EXCHANGE_FAILURE with output is forbidden in SASL */
+			if (result == PG_SASL_EXCHANGE_SUCCESS)
+				sendAuthRequest(port, AUTH_REQ_SASL_FIN, output, outputlen);
+			else
+				sendAuthRequest(port, AUTH_REQ_SASL_CONT, output, outputlen);
+
+			pfree(output);
+		}
+	} while (result == PG_SASL_EXCHANGE_CONTINUE);
+
+	/* Oops, Something bad happened */
+	if (result != PG_SASL_EXCHANGE_SUCCESS)
+	{
+		return STATUS_ERROR;
+	}
+
+	return STATUS_OK;
+}
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index 82f043a343..ac6fe4a747 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -45,19 +45,10 @@
  * Global authentication functions
  *----------------------------------------------------------------
  */
-static void sendAuthRequest(Port *port, AuthRequest areq, const char *extradata,
-							int extralen);
 static void auth_failed(Port *port, int status, char *logdetail);
 static char *recv_password_packet(Port *port);
 static void set_authn_id(Port *port, const char *id);
 
-/*----------------------------------------------------------------
- * SASL common authentication
- *----------------------------------------------------------------
- */
-static int	SASL_exchange(const pg_be_sasl_mech *mech, Port *port,
-						  char *shadow_pass, char **logdetail);
-
 
 /*----------------------------------------------------------------
  * Password-based authentication methods (password, md5, and scram-sha-256)
@@ -67,7 +58,6 @@ static int	CheckPasswordAuth(Port *port, char **logdetail);
 static int	CheckPWChallengeAuth(Port *port, char **logdetail);
 
 static int	CheckMD5Auth(Port *port, char *shadow_pass, char **logdetail);
-static int	CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail);
 
 
 /*----------------------------------------------------------------
@@ -231,14 +221,6 @@ static int	PerformRadiusTransaction(const char *server, const char *secret, cons
  */
 #define PG_MAX_AUTH_TOKEN_LENGTH	65535
 
-/*
- * Maximum accepted size of SASL messages.
- *
- * The messages that the server or libpq generate are much smaller than this,
- * but have some headroom.
- */
-#define PG_MAX_SASL_MESSAGE_LENGTH	1024
-
 /*----------------------------------------------------------------
  * Global authentication functions
  *----------------------------------------------------------------
@@ -675,7 +657,7 @@ ClientAuthentication(Port *port)
 /*
  * Send an authentication request packet to the frontend.
  */
-static void
+void
 sendAuthRequest(Port *port, AuthRequest areq, const char *extradata, int extralen)
 {
 	StringInfoData buf;
@@ -855,12 +837,13 @@ CheckPWChallengeAuth(Port *port, char **logdetail)
 	 * SCRAM secret, we must do SCRAM authentication.
 	 *
 	 * If MD5 authentication is not allowed, always use SCRAM.  If the user
-	 * had an MD5 password, CheckSCRAMAuth() will fail.
+	 * had an MD5 password, the SCRAM mechanism will fail.
 	 */
 	if (port->hba->auth_method == uaMD5 && pwtype == PASSWORD_TYPE_MD5)
 		auth_result = CheckMD5Auth(port, shadow_pass, logdetail);
 	else
-		auth_result = CheckSCRAMAuth(port, shadow_pass, logdetail);
+		auth_result = CheckSASLAuth(&pg_be_scram_mech, port, shadow_pass,
+									logdetail);
 
 	if (shadow_pass)
 		pfree(shadow_pass);
@@ -918,159 +901,6 @@ CheckMD5Auth(Port *port, char *shadow_pass, char **logdetail)
 	return result;
 }
 
-static int
-SASL_exchange(const pg_be_sasl_mech *mech, Port *port, char *shadow_pass,
-			  char **logdetail)
-{
-	StringInfoData sasl_mechs;
-	int			mtype;
-	StringInfoData buf;
-	void	   *opaq = NULL;
-	char	   *output = NULL;
-	int			outputlen = 0;
-	const char *input;
-	int			inputlen;
-	int			result;
-	bool		initial;
-
-	/*
-	 * Send the SASL authentication request to user.  It includes the list of
-	 * authentication mechanisms that are supported.
-	 */
-	initStringInfo(&sasl_mechs);
-
-	mech->get_mechanisms(port, &sasl_mechs);
-	/* Put another '\0' to mark that list is finished. */
-	appendStringInfoChar(&sasl_mechs, '\0');
-
-	sendAuthRequest(port, AUTH_REQ_SASL, sasl_mechs.data, sasl_mechs.len);
-	pfree(sasl_mechs.data);
-
-	/*
-	 * Loop through SASL message exchange.  This exchange can consist of
-	 * multiple messages sent in both directions.  First message is always
-	 * from the client.  All messages from client to server are password
-	 * packets (type 'p').
-	 */
-	initial = true;
-	do
-	{
-		pq_startmsgread();
-		mtype = pq_getbyte();
-		if (mtype != 'p')
-		{
-			/* Only log error if client didn't disconnect. */
-			if (mtype != EOF)
-			{
-				ereport(ERROR,
-						(errcode(ERRCODE_PROTOCOL_VIOLATION),
-						 errmsg("expected SASL response, got message type %d",
-								mtype)));
-			}
-			else
-				return STATUS_EOF;
-		}
-
-		/* Get the actual SASL message */
-		initStringInfo(&buf);
-		if (pq_getmessage(&buf, PG_MAX_SASL_MESSAGE_LENGTH))
-		{
-			/* EOF - pq_getmessage already logged error */
-			pfree(buf.data);
-			return STATUS_ERROR;
-		}
-
-		elog(DEBUG4, "processing received SASL response of length %d", buf.len);
-
-		/*
-		 * The first SASLInitialResponse message is different from the others.
-		 * It indicates which SASL mechanism the client selected, and contains
-		 * an optional Initial Client Response payload.  The subsequent
-		 * SASLResponse messages contain just the SASL payload.
-		 */
-		if (initial)
-		{
-			const char *selected_mech;
-
-			selected_mech = pq_getmsgrawstring(&buf);
-
-			/*
-			 * Initialize the status tracker for message exchanges.
-			 *
-			 * If the user doesn't exist, or doesn't have a valid password, or
-			 * it's expired, we still go through the motions of SASL
-			 * authentication, but tell the authentication method that the
-			 * authentication is "doomed". That is, it's going to fail, no
-			 * matter what.
-			 *
-			 * This is because we don't want to reveal to an attacker what
-			 * usernames are valid, nor which users have a valid password.
-			 */
-			opaq = mech->init(port, selected_mech, shadow_pass);
-
-			inputlen = pq_getmsgint(&buf, 4);
-			if (inputlen == -1)
-				input = NULL;
-			else
-				input = pq_getmsgbytes(&buf, inputlen);
-
-			initial = false;
-		}
-		else
-		{
-			inputlen = buf.len;
-			input = pq_getmsgbytes(&buf, buf.len);
-		}
-		pq_getmsgend(&buf);
-
-		/*
-		 * The StringInfo guarantees that there's a \0 byte after the
-		 * response.
-		 */
-		Assert(input == NULL || input[inputlen] == '\0');
-
-		/*
-		 * Hand the incoming message to the mechanism implementation.
-		 */
-		result = mech->exchange(opaq, input, inputlen,
-								&output, &outputlen,
-								logdetail);
-
-		/* input buffer no longer used */
-		pfree(buf.data);
-
-		if (output)
-		{
-			/*
-			 * Negotiation generated data to be sent to the client.
-			 */
-			elog(DEBUG4, "sending SASL challenge of length %u", outputlen);
-
-			/* TODO: PG_SASL_EXCHANGE_FAILURE with output is forbidden in SASL */
-			if (result == PG_SASL_EXCHANGE_SUCCESS)
-				sendAuthRequest(port, AUTH_REQ_SASL_FIN, output, outputlen);
-			else
-				sendAuthRequest(port, AUTH_REQ_SASL_CONT, output, outputlen);
-
-			pfree(output);
-		}
-	} while (result == PG_SASL_EXCHANGE_CONTINUE);
-
-	/* Oops, Something bad happened */
-	if (result != PG_SASL_EXCHANGE_SUCCESS)
-	{
-		return STATUS_ERROR;
-	}
-
-	return STATUS_OK;
-}
-
-static int
-CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
-{
-	return SASL_exchange(&pg_be_scram_mech, port, shadow_pass, logdetail);
-}
-
 
 /*----------------------------------------------------------------
  * GSSAPI authentication system
diff --git a/src/include/libpq/auth.h b/src/include/libpq/auth.h
index 3610fae3ff..3d6734f253 100644
--- a/src/include/libpq/auth.h
+++ b/src/include/libpq/auth.h
@@ -21,6 +21,8 @@ extern bool pg_krb_caseins_users;
 extern char *pg_krb_realm;
 
 extern void ClientAuthentication(Port *port);
+extern void sendAuthRequest(Port *port, AuthRequest areq, const char *extradata,
+							int extralen);
 
 /* Hook for plugins to get control in ClientAuthentication() */
 typedef void (*ClientAuthentication_hook_type) (Port *, int);
diff --git a/src/include/libpq/sasl.h b/src/include/libpq/sasl.h
index c732f35564..dad04d8ecd 100644
--- a/src/include/libpq/sasl.h
+++ b/src/include/libpq/sasl.h
@@ -16,6 +16,7 @@
 #ifndef PG_SASL_H
 #define PG_SASL_H
 
+#include "lib/stringinfo.h"
 #include "libpq/libpq-be.h"
 
 /* Status codes for message exchange */
@@ -23,6 +24,14 @@
 #define PG_SASL_EXCHANGE_SUCCESS		1
 #define PG_SASL_EXCHANGE_FAILURE		2
 
+/*
+ * Maximum accepted size of SASL messages.
+ *
+ * The messages that the server or libpq generate are much smaller than this,
+ * but have some headroom.
+ */
+#define PG_MAX_SASL_MESSAGE_LENGTH	1024
+
 /*
  * Backend mechanism API
  *
@@ -124,4 +133,8 @@ typedef struct
 	pg_be_sasl_exchange_func	exchange;
 } pg_be_sasl_mech;
 
+/* Common implementation for auth.c */
+extern int CheckSASLAuth(const pg_be_sasl_mech *mech, Port *port,
+						 char *shadow_pass, char **logdetail);
+
 #endif /* PG_SASL_H */
-- 
2.25.1

Reply via email to