On Tue, Jan 7, 2025 at 2:24 PM Jacob Champion
<jacob.champ...@enterprisedb.com> wrote:
> Along those lines, though, Michael Paquier suggested that maybe I
> could pull the require_auth prefactoring up to the front of the
> patchset. That might look a bit odd until OAuth support lands, since
> it won't be adding any new useful value, but I will give it a shot.

While I take a look at the async patch from upthread, here is my
attempt at pulling the require_auth change out.

Note that there's a dead branch that cannot be exercised until OAuth
lands. We're not going to process the SASL mechanism name at all if no
mechanisms are allowed to begin with, and right now SASL is synonymous
with SCRAM. I can change that by always allowing AuthenticationSASL
messages -- even if none of the allowed authentication types use SASL
-- but that approach didn't seem to generate excitement on- or
off-list the last time I proposed it [1].

Thanks,
--Jacob

[1] 
https://postgr.es/m/CAAWbhmg%2BGzNMK5Li182BKSbzoFVaKk_dDJ628NnuV80GqYgFFg%40mail.gmail.com
commit f8bbb0f2a2e0f4840bdeb5c9a7b9a35797280aaf
Author: Jacob Champion <jacob.champ...@enterprisedb.com>
Date:   Mon Dec 16 13:57:14 2024 -0800

    require_auth: prepare for multiple SASL mechanisms
    
    Prior to this patch, the require_auth implementation assumed that the
    AuthenticationSASL protocol message was synonymous with SCRAM-SHA-256.
    In preparation for the OAUTHBEARER SASL mechanism, split the
    implementation into two tiers: the first checks the acceptable
    AUTH_REQ_* codes, and the second checks acceptable mechanisms if
    AUTH_REQ_SASL et al are permitted.
    
    conn->allowed_sasl_mechs is the list of pointers to acceptable
    mechanisms. (Since we'll support only a small number of mechanisms, this
    is an array of static length to minimize bookkeeping.) pg_SASL_init()
    will bail if the selected mechanism isn't contained in this array.
    
    Since there's only one mechansism supported right now, one branch of the
    second tier cannot be exercised yet (it's marked with Assert(false)).
    This assertion will need to be removed when the next mechanism is added.

diff --git a/src/interfaces/libpq/fe-auth.c b/src/interfaces/libpq/fe-auth.c
index 14a9a862f51..722bb47ee14 100644
--- a/src/interfaces/libpq/fe-auth.c
+++ b/src/interfaces/libpq/fe-auth.c
@@ -543,6 +543,35 @@ pg_SASL_init(PGconn *conn, int payloadlen)
                goto error;
        }
 
+       /* Make sure require_auth is satisfied. */
+       if (conn->require_auth)
+       {
+               bool            allowed = false;
+
+               for (int i = 0; i < lengthof(conn->allowed_sasl_mechs); i++)
+               {
+                       if (conn->sasl == conn->allowed_sasl_mechs[i])
+                       {
+                               allowed = true;
+                               break;
+                       }
+               }
+
+               if (!allowed)
+               {
+                       /*
+                        * TODO: this is dead code until a second SASL 
mechanism is added;
+                        * the connection can't have proceeded past 
check_expected_areq()
+                        * if no SASL methods are allowed.
+                        */
+                       Assert(false);
+
+                       libpq_append_conn_error(conn, "authentication method 
requirement \"%s\" failed: server requested %s authentication",
+                                                                       
conn->require_auth, selected_mechanism);
+                       goto error;
+               }
+       }
+
        if (conn->channel_binding[0] == 'r' &&  /* require */
                strcmp(selected_mechanism, SCRAM_SHA_256_PLUS_NAME) != 0)
        {
diff --git a/src/interfaces/libpq/fe-connect.c 
b/src/interfaces/libpq/fe-connect.c
index 8f211821eb2..6f262706b0a 100644
--- a/src/interfaces/libpq/fe-connect.c
+++ b/src/interfaces/libpq/fe-connect.c
@@ -1110,6 +1110,56 @@ libpq_prng_init(PGconn *conn)
        pg_prng_seed(&conn->prng_state, rseed);
 }
 
+/*
+ * Fills the connection's allowed_sasl_mechs list with all supported SASL
+ * mechanisms.
+ */
+static inline void
+fill_allowed_sasl_mechs(PGconn *conn)
+{
+       /*---
+        * We only support one mechanism at the moment, so rather than deal 
with a
+        * linked list, conn->allowed_sasl_mechs is an array of static length. 
We
+        * rely on the compile-time assertion here to keep us honest.
+        *
+        * To add a new mechanism to require_auth,
+        * - update the length of conn->allowed_sasl_mechs,
+        * - add the new pg_fe_sasl_mech pointer to this function, and
+        * - handle the new mechanism name in the require_auth portion of
+        *   pqConnectOptions2(), below.
+        */
+       StaticAssertDecl(lengthof(conn->allowed_sasl_mechs) == 1,
+                                        "fill_allowed_sasl_mechs() must be 
updated when resizing conn->allowed_sasl_mechs[]");
+
+       conn->allowed_sasl_mechs[0] = &pg_scram_mech;
+}
+
+/*
+ * Clears the connection's allowed_sasl_mechs list.
+ */
+static inline void
+clear_allowed_sasl_mechs(PGconn *conn)
+{
+       for (int i = 0; i < lengthof(conn->allowed_sasl_mechs); i++)
+               conn->allowed_sasl_mechs[i] = NULL;
+}
+
+/*
+ * Helper routine that searches the static allowed_sasl_mechs list for a
+ * specific mechanism.
+ */
+static inline int
+index_of_allowed_sasl_mech(PGconn *conn, const pg_fe_sasl_mech *mech)
+{
+       for (int i = 0; i < lengthof(conn->allowed_sasl_mechs); i++)
+       {
+               if (conn->allowed_sasl_mechs[i] == mech)
+                       return i;
+       }
+
+       return -1;
+}
+
 /*
  *             pqConnectOptions2
  *
@@ -1351,17 +1401,19 @@ pqConnectOptions2(PGconn *conn)
                bool            negated = false;
 
                /*
-                * By default, start from an empty set of allowed options and 
add to
-                * it.
+                * By default, start from an empty set of allowed methods and
+                * mechanisms, and add to it.
                 */
                conn->auth_required = true;
                conn->allowed_auth_methods = 0;
+               clear_allowed_sasl_mechs(conn);
 
                for (first = true, more = true; more; first = false)
                {
                        char       *method,
                                           *part;
-                       uint32          bits;
+                       uint32          bits = 0;
+                       const pg_fe_sasl_mech *mech = NULL;
 
                        part = parse_comma_separated_list(&s, &more);
                        if (part == NULL)
@@ -1377,11 +1429,12 @@ pqConnectOptions2(PGconn *conn)
                                if (first)
                                {
                                        /*
-                                        * Switch to a permissive set of 
allowed options, and
-                                        * subtract from it.
+                                        * Switch to a permissive set of 
allowed methods and
+                                        * mechanisms, and subtract from it.
                                         */
                                        conn->auth_required = false;
                                        conn->allowed_auth_methods = -1;
+                                       fill_allowed_sasl_mechs(conn);
                                }
                                else if (!negated)
                                {
@@ -1406,6 +1459,10 @@ pqConnectOptions2(PGconn *conn)
                                return false;
                        }
 
+                       /*
+                        * First group: methods that can be handled solely with 
the
+                        * authentication request codes.
+                        */
                        if (strcmp(method, "password") == 0)
                        {
                                bits = (1 << AUTH_REQ_PASSWORD);
@@ -1424,13 +1481,22 @@ pqConnectOptions2(PGconn *conn)
                                bits = (1 << AUTH_REQ_SSPI);
                                bits |= (1 << AUTH_REQ_GSS_CONT);
                        }
+
+                       /*
+                        * Next group: SASL mechanisms. All of these use the 
same request
+                        * codes, so the list of allowed mechanisms is tracked 
separately.
+                        *
+                        * fill_allowed_sasl_mechs() must be updated when 
adding a new
+                        * mechanism here!
+                        */
                        else if (strcmp(method, "scram-sha-256") == 0)
                        {
-                               /* This currently assumes that SCRAM is the 
only SASL method. */
-                               bits = (1 << AUTH_REQ_SASL);
-                               bits |= (1 << AUTH_REQ_SASL_CONT);
-                               bits |= (1 << AUTH_REQ_SASL_FIN);
+                               mech = &pg_scram_mech;
                        }
+
+                       /*
+                        * Final group: meta-options.
+                        */
                        else if (strcmp(method, "none") == 0)
                        {
                                /*
@@ -1466,20 +1532,68 @@ pqConnectOptions2(PGconn *conn)
                                return false;
                        }
 
-                       /* Update the bitmask. */
-                       if (negated)
+                       if (mech)
                        {
-                               if ((conn->allowed_auth_methods & bits) == 0)
-                                       goto duplicate;
+                               /*
+                                * Update the mechanism set only. The method 
bitmask will be
+                                * updated for SASL further down.
+                                */
+                               Assert(!bits);
+
+                               if (negated)
+                               {
+                                       /* Remove the existing mechanism from 
the list. */
+                                       i = index_of_allowed_sasl_mech(conn, 
mech);
+                                       if (i < 0)
+                                               goto duplicate;
 
-                               conn->allowed_auth_methods &= ~bits;
+                                       conn->allowed_sasl_mechs[i] = NULL;
+                               }
+                               else
+                               {
+                                       /*
+                                        * Find a space to put the new 
mechanism (after making
+                                        * sure it's not already there).
+                                        */
+                                       i = index_of_allowed_sasl_mech(conn, 
mech);
+                                       if (i >= 0)
+                                               goto duplicate;
+
+                                       i = index_of_allowed_sasl_mech(conn, 
NULL);
+                                       if (i < 0)
+                                       {
+                                               /* Should not happen; the 
pointer list is corrupted. */
+                                               Assert(false);
+
+                                               conn->status = CONNECTION_BAD;
+                                               libpq_append_conn_error(conn,
+                                                                               
                "internal error: no space in allowed_sasl_mechs");
+                                               free(part);
+                                               return false;
+                                       }
+
+                                       conn->allowed_sasl_mechs[i] = mech;
+                               }
                        }
                        else
                        {
-                               if ((conn->allowed_auth_methods & bits) == bits)
-                                       goto duplicate;
+                               /* Update the method bitmask. */
+                               Assert(bits);
+
+                               if (negated)
+                               {
+                                       if ((conn->allowed_auth_methods & bits) 
== 0)
+                                               goto duplicate;
+
+                                       conn->allowed_auth_methods &= ~bits;
+                               }
+                               else
+                               {
+                                       if ((conn->allowed_auth_methods & bits) 
== bits)
+                                               goto duplicate;
 
-                               conn->allowed_auth_methods |= bits;
+                                       conn->allowed_auth_methods |= bits;
+                               }
                        }
 
                        free(part);
@@ -1498,6 +1612,36 @@ pqConnectOptions2(PGconn *conn)
                        free(part);
                        return false;
                }
+
+               /*
+                * Finally, allow SASL authentication requests if (and only if) 
we've
+                * allowed any mechanisms.
+                */
+               {
+                       bool            allowed = false;
+                       const uint32 sasl_bits =
+                               (1 << AUTH_REQ_SASL)
+                               | (1 << AUTH_REQ_SASL_CONT)
+                               | (1 << AUTH_REQ_SASL_FIN);
+
+                       for (i = 0; i < lengthof(conn->allowed_sasl_mechs); i++)
+                       {
+                               if (conn->allowed_sasl_mechs[i])
+                               {
+                                       allowed = true;
+                                       break;
+                               }
+                       }
+
+                       /*
+                        * For the standard case, add the SASL bits to the 
(default-empty)
+                        * set if needed. For the negated case, remove them.
+                        */
+                       if (!negated && allowed)
+                               conn->allowed_auth_methods |= sasl_bits;
+                       else if (negated && !allowed)
+                               conn->allowed_auth_methods &= ~sasl_bits;
+               }
        }
 
        /*
diff --git a/src/interfaces/libpq/libpq-int.h b/src/interfaces/libpq/libpq-int.h
index 4a5a7c8b5e3..d372276c486 100644
--- a/src/interfaces/libpq/libpq-int.h
+++ b/src/interfaces/libpq/libpq-int.h
@@ -501,6 +501,8 @@ struct pg_conn
                                                                 * the server? 
*/
        uint32          allowed_auth_methods;   /* bitmask of acceptable 
AuthRequest
                                                                                
 * codes */
+       const pg_fe_sasl_mech *allowed_sasl_mechs[1];   /* and acceptable SASL
+                                                                               
                         * mechanisms */
        bool            client_finished_auth;   /* have we finished our half of 
the
                                                                                
 * authentication exchange? */
        char            current_auth_response;  /* used by pqTraceOutputMessage 
to
diff --git a/src/test/authentication/t/001_password.pl 
b/src/test/authentication/t/001_password.pl
index 773238b76fd..1357f806b6f 100644
--- a/src/test/authentication/t/001_password.pl
+++ b/src/test/authentication/t/001_password.pl
@@ -277,6 +277,16 @@ $node->connect_fails(
        "require_auth methods cannot be duplicated, !none case",
        expected_stderr =>
          qr/require_auth method "!none" is specified more than once/);
+$node->connect_fails(
+       "user=scram_role require_auth=scram-sha-256,scram-sha-256",
+       "require_auth methods cannot be duplicated, scram-sha-256 case",
+       expected_stderr =>
+         qr/require_auth method "scram-sha-256" is specified more than once/);
+$node->connect_fails(
+       "user=scram_role require_auth=!scram-sha-256,!scram-sha-256",
+       "require_auth methods cannot be duplicated, !scram-sha-256 case",
+       expected_stderr =>
+         qr/require_auth method "!scram-sha-256" is specified more than once/);
 
 # Unknown value defined in require_auth.
 $node->connect_fails(

Reply via email to