Hi,

The attached a patch enables SCRAM authentication for postgres_fdw connections without requiring plain-text password on user mapping properties.

This is achieved by storing the SCRAM ClientKey and ServerKey obtained during client authentication with the backend. These keys are then used to complete the SCRAM exchange between the backend and the fdw server, eliminating the need to derive them from a stored plain-text password.

I think that some documentation updates may be necessary for this change. If so, I plan to submit an updated patch with the relevant documentation changes in the coming days.

This patch is based on a previous WIP patch from Peter Eisentraut [1]

[1] https://github.com/petere/postgresql/commit/90009ccd736e99d65c59b9078d14d76fffc2426a

--
Matheus Alcantara
EDB: https://www.enterprisedb.com
From 65fcb8c9565c7f4ba5c204af775c29c76d474a57 Mon Sep 17 00:00:00 2001
From: Matheus Alcantara <mths....@pm.me>
Date: Tue, 19 Nov 2024 15:37:57 -0300
Subject: [PATCH v1] postgres_fdw: SCRAM authentication pass-through

This commit enable SCRAM authentication for postgres_fdw when connecting
to a fdw server without having to store a plain-text password on user
mapping options.

This is done by saving the SCRAM ClientKey and ServeryKey from the
client authentication and using those instead of the plain-text password
for the server-side SCRAM exchange.
---
 contrib/postgres_fdw/Makefile            |  1 +
 contrib/postgres_fdw/connection.c        | 67 ++++++++++++++++++++++--
 contrib/postgres_fdw/meson.build         |  5 ++
 contrib/postgres_fdw/option.c            |  3 ++
 contrib/postgres_fdw/t/001_auth_scram.pl | 62 ++++++++++++++++++++++
 src/backend/libpq/auth-scram.c           | 16 ++++--
 src/include/libpq/libpq-be.h             |  9 ++++
 src/interfaces/libpq/fe-auth-scram.c     | 29 ++++++++--
 src/interfaces/libpq/fe-auth.c           |  2 +-
 src/interfaces/libpq/fe-connect.c        | 31 +++++++++++
 src/interfaces/libpq/libpq-int.h         |  6 +++
 11 files changed, 217 insertions(+), 14 deletions(-)
 create mode 100644 contrib/postgres_fdw/t/001_auth_scram.pl

diff --git a/contrib/postgres_fdw/Makefile b/contrib/postgres_fdw/Makefile
index 88fdce40d6..6c12c8e925 100644
--- a/contrib/postgres_fdw/Makefile
+++ b/contrib/postgres_fdw/Makefile
@@ -8,6 +8,7 @@ OBJS = \
        option.o \
        postgres_fdw.o \
        shippable.o
+TAP_TESTS = 1
 PGFILEDESC = "postgres_fdw - foreign data wrapper for PostgreSQL"
 
 PG_CPPFLAGS = -I$(libpq_srcdir)
diff --git a/contrib/postgres_fdw/connection.c 
b/contrib/postgres_fdw/connection.c
index 2326f391d3..e0e1ebe0d4 100644
--- a/contrib/postgres_fdw/connection.c
+++ b/contrib/postgres_fdw/connection.c
@@ -19,6 +19,7 @@
 #include "access/xact.h"
 #include "catalog/pg_user_mapping.h"
 #include "commands/defrem.h"
+#include "common/base64.h"
 #include "funcapi.h"
 #include "libpq/libpq-be.h"
 #include "libpq/libpq-be-fe-helpers.h"
@@ -168,6 +169,7 @@ static void pgfdw_finish_abort_cleanup(List 
*pending_entries,
 static void pgfdw_security_check(const char **keywords, const char **values,
                                                                 UserMapping 
*user, PGconn *conn);
 static bool UserMappingPasswordRequired(UserMapping *user);
+static bool UseScramPassthrough(ForeignServer *server, UserMapping *user);
 static bool disconnect_cached_connections(Oid serverid);
 static void postgres_fdw_get_connections_internal(FunctionCallInfo fcinfo,
                                                                                
                  enum pgfdwVersion api_version);
@@ -476,7 +478,7 @@ connect_pg_server(ForeignServer *server, UserMapping *user)
                 * for application_name, fallback_application_name, 
client_encoding,
                 * end marker.
                 */
-               n = list_length(server->options) + list_length(user->options) + 
4;
+               n = list_length(server->options) + list_length(user->options) + 
4 + 2;
                keywords = (const char **) palloc(n * sizeof(char *));
                values = (const char **) palloc(n * sizeof(char *));
 
@@ -545,10 +547,37 @@ connect_pg_server(ForeignServer *server, UserMapping 
*user)
                values[n] = GetDatabaseEncodingName();
                n++;
 
+               if (MyProcPort->has_scram_keys && UseScramPassthrough(server, 
user))
+               {
+                       int                     len;
+
+                       keywords[n] = "scram_client_key";
+                       len = 
pg_b64_enc_len(sizeof(MyProcPort->scram_ClientKey));
+                       /* don't forget the zero-terminator */
+                       values[n] = palloc0(len+1);
+                       pg_b64_encode((const char *) 
MyProcPort->scram_ClientKey,
+                                                 
sizeof(MyProcPort->scram_ClientKey),
+                                                 (char *) values[n], len);
+                       n++;
+
+                       keywords[n] = "scram_server_key";
+                       len = 
pg_b64_enc_len(sizeof(MyProcPort->scram_ServerKey));
+                       /* don't forget the zero-terminator */
+                       values[n] = palloc0(len+1);
+                       pg_b64_encode((const char *) 
MyProcPort->scram_ServerKey,
+                                                 
sizeof(MyProcPort->scram_ServerKey),
+                                                 (char *) values[n], len);
+                       n++;
+               }
+
                keywords[n] = values[n] = NULL;
 
-               /* verify the set of connection parameters */
-               check_conn_params(keywords, values, user);
+               /*
+                * Verify the set of connection parameters only if scram 
pass-through
+                * is not being used because the password is not necessary.
+                */
+               if (!(MyProcPort->has_scram_keys && UseScramPassthrough(server, 
user)))
+                       check_conn_params(keywords, values, user);
 
                /* first time, allocate or get the custom wait event */
                if (pgfdw_we_connect == 0)
@@ -566,8 +595,12 @@ connect_pg_server(ForeignServer *server, UserMapping *user)
                                                        server->servername),
                                         errdetail_internal("%s", 
pchomp(PQerrorMessage(conn)))));
 
-               /* Perform post-connection security checks */
-               pgfdw_security_check(keywords, values, user, conn);
+               /*
+                * Perform post-connection security checks only if scram 
pass-through
+                * is not being used because the password is not necessary.
+                */
+               if (!(MyProcPort->has_scram_keys && UseScramPassthrough(server, 
user)))
+                       pgfdw_security_check(keywords, values, user, conn);
 
                /* Prepare new session for use */
                configure_remote_session(conn);
@@ -620,6 +653,30 @@ UserMappingPasswordRequired(UserMapping *user)
        return true;
 }
 
+static bool
+UseScramPassthrough(ForeignServer *server, UserMapping *user)
+{
+       ListCell   *cell;
+
+       foreach(cell, server->options)
+       {
+               DefElem    *def = (DefElem *) lfirst(cell);
+
+               if (strcmp(def->defname, "use_scram_passthrough") == 0)
+                       return defGetBoolean(def);
+       }
+
+       foreach(cell, user->options)
+       {
+               DefElem    *def = (DefElem *) lfirst(cell);
+
+               if (strcmp(def->defname, "use_scram_passthrough") == 0)
+                       return defGetBoolean(def);
+       }
+
+       return false;
+}
+
 /*
  * For non-superusers, insist that the connstr specify a password or that the
  * user provided their own GSSAPI delegated credentials.  This
diff --git a/contrib/postgres_fdw/meson.build b/contrib/postgres_fdw/meson.build
index 3014086ba6..27d07188fc 100644
--- a/contrib/postgres_fdw/meson.build
+++ b/contrib/postgres_fdw/meson.build
@@ -41,4 +41,9 @@ tests += {
     ],
     'regress_args': ['--dlpath', meson.build_root() / 'src/test/regress'],
   },
+  'tap': {
+      'tests': [
+        't/001_auth_scram.pl',
+      ],
+  },
 }
diff --git a/contrib/postgres_fdw/option.c b/contrib/postgres_fdw/option.c
index 232d85354b..15abc64381 100644
--- a/contrib/postgres_fdw/option.c
+++ b/contrib/postgres_fdw/option.c
@@ -279,6 +279,9 @@ InitPgFdwOptions(void)
                {"analyze_sampling", ForeignServerRelationId, false},
                {"analyze_sampling", ForeignTableRelationId, false},
 
+               {"use_scram_passthrough", ForeignServerRelationId, false},
+               {"use_scram_passthrough", UserMappingRelationId, false},
+
                /*
                 * sslcert and sslkey are in fact libpq options, but we repeat 
them
                 * here to allow them to appear in both foreign server context 
(when
diff --git a/contrib/postgres_fdw/t/001_auth_scram.pl 
b/contrib/postgres_fdw/t/001_auth_scram.pl
new file mode 100644
index 0000000000..388d2179db
--- /dev/null
+++ b/contrib/postgres_fdw/t/001_auth_scram.pl
@@ -0,0 +1,62 @@
+# Copyright (c) 2021-2024, PostgreSQL Global Development Group
+
+# Test SCRAM authentication pass through the intermediary postgres_fdw to the 
server
+
+use strict;
+use warnings FATAL => 'all';
+use PostgreSQL::Test::Utils;
+use PostgreSQL::Test::Cluster;
+use Test::More;
+
+my $node = PostgreSQL::Test::Cluster->new('node');
+my $hostaddr = '127.0.0.1';
+my $user = "user01";
+my $db1 = "db1";
+my $db2 = "db2";
+my $fdw_server = "db2_fdw";
+my $host = $node->host;
+my $port = $node->port;
+my $connstr = $node->connstr($db1) . qq' user=$user';
+
+$node->init;
+$node->start;
+
+# Test setup
+
+$node->safe_psql('postgres', qq'CREATE USER $user WITH password \'pass\' ');
+$node->safe_psql('postgres', qq'CREATE DATABASE $db1');
+$node->safe_psql('postgres', qq'CREATE DATABASE $db2');
+
+$node->safe_psql($db2, 'CREATE TABLE t AS SELECT g,g+1 FROM 
generate_series(1,10) g(g)');
+$node->safe_psql($db2, qq'GRANT USAGE ON SCHEMA public to $user');
+$node->safe_psql($db2, qq'GRANT SELECT ON t to $user');
+
+$node->safe_psql($db1, 'CREATE EXTENSION IF NOT EXISTS postgres_fdw');
+$node->safe_psql($db1, qq'CREATE SERVER $fdw_server FOREIGN DATA WRAPPER 
postgres_fdw options (
+       host \'$host\', port \'$port\', dbname \'$db2\', use_scram_passthrough 
\'true\') ');
+# password not required
+$node->safe_psql($db1, qq'CREATE USER MAPPING FOR $user SERVER $fdw_server 
OPTIONS (user \'$user\');');
+$node->safe_psql($db1, qq'GRANT USAGE ON FOREIGN SERVER $fdw_server to 
$user;');
+$node->safe_psql($db1, qq'GRANT ALL ON SCHEMA public to $user');
+
+unlink($node->data_dir . '/pg_hba.conf');
+$node->append_conf(
+       'pg_hba.conf', qq{
+local   all             all                                     scram-sha-256
+host    all             all             $hostaddr/32            scram-sha-256
+});
+$node->restart;
+
+# End of test setup
+
+$ENV{PGPASSWORD} = "pass";
+
+$node->safe_psql($db1, qq'IMPORT FOREIGN SCHEMA public LIMIT TO(t) FROM SERVER 
$fdw_server INTO public ;',
+       connstr=>$connstr);
+
+my $ret = $node->safe_psql($db1, 'SELECT count(1) FROM t',
+       connstr=>$connstr);
+is($ret, '10', 'SELECT count from fdw server returns 10');
+
+
+done_testing();
diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c
index 8c5b6d9c67..88a15cc0e5 100644
--- a/src/backend/libpq/auth-scram.c
+++ b/src/backend/libpq/auth-scram.c
@@ -101,6 +101,7 @@
 #include "libpq/crypt.h"
 #include "libpq/sasl.h"
 #include "libpq/scram.h"
+#include "miscadmin.h"
 
 static void scram_get_mechanisms(Port *port, StringInfo buf);
 static void *scram_init(Port *port, const char *selected_mech,
@@ -144,6 +145,7 @@ typedef struct
 
        int                     iterations;
        char       *salt;                       /* base64-encoded */
+       uint8           ClientKey[SCRAM_MAX_KEY_LEN];
        uint8           StoredKey[SCRAM_MAX_KEY_LEN];
        uint8           ServerKey[SCRAM_MAX_KEY_LEN];
 
@@ -462,6 +464,13 @@ scram_exchange(void *opaq, const char *input, int inputlen,
        if (*output)
                *outputlen = strlen(*output);
 
+       if (result == PG_SASL_EXCHANGE_SUCCESS && state->state == 
SCRAM_AUTH_FINISHED)
+       {
+               memcpy(MyProcPort->scram_ClientKey, state->ClientKey, 
sizeof(MyProcPort->scram_ClientKey));
+               memcpy(MyProcPort->scram_ServerKey, state->ServerKey, 
sizeof(MyProcPort->scram_ServerKey));
+               MyProcPort->has_scram_keys = true;
+       }
+
        return result;
 }
 
@@ -1140,9 +1149,8 @@ static bool
 verify_client_proof(scram_state *state)
 {
        uint8           ClientSignature[SCRAM_MAX_KEY_LEN];
-       uint8           ClientKey[SCRAM_MAX_KEY_LEN];
        uint8           client_StoredKey[SCRAM_MAX_KEY_LEN];
-       pg_hmac_ctx *ctx = pg_hmac_create(state->hash_type);
+       pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
        int                     i;
        const char *errstr = NULL;
 
@@ -1173,10 +1181,10 @@ verify_client_proof(scram_state *state)
 
        /* Extract the ClientKey that the client calculated from the proof */
        for (i = 0; i < state->key_length; i++)
-               ClientKey[i] = state->ClientProof[i] ^ ClientSignature[i];
+               state->ClientKey[i] = state->ClientProof[i] ^ 
ClientSignature[i];
 
        /* Hash it one more time, and compare with StoredKey */
-       if (scram_H(ClientKey, state->hash_type, state->key_length,
+       if (scram_H(state->ClientKey, state->hash_type, state->key_length,
                                client_StoredKey, &errstr) < 0)
                elog(ERROR, "could not hash stored key: %s", errstr);
 
diff --git a/src/include/libpq/libpq-be.h b/src/include/libpq/libpq-be.h
index 9109b2c334..4eb9e80523 100644
--- a/src/include/libpq/libpq-be.h
+++ b/src/include/libpq/libpq-be.h
@@ -18,6 +18,8 @@
 #ifndef LIBPQ_BE_H
 #define LIBPQ_BE_H
 
+#include "common/scram-common.h"
+
 #include <sys/time.h>
 #ifdef USE_OPENSSL
 #include <openssl/ssl.h>
@@ -181,6 +183,13 @@ typedef struct Port
        int                     keepalives_count;
        int                     tcp_user_timeout;
 
+       /*
+        * SCRAM structures.
+        */
+       uint8           scram_ClientKey[SCRAM_MAX_KEY_LEN];
+       uint8           scram_ServerKey[SCRAM_MAX_KEY_LEN];
+       bool            has_scram_keys; /* true if the above two are valid */
+
        /*
         * GSSAPI structures.
         */
diff --git a/src/interfaces/libpq/fe-auth-scram.c 
b/src/interfaces/libpq/fe-auth-scram.c
index 0bb820e0d9..7beb5a9d31 100644
--- a/src/interfaces/libpq/fe-auth-scram.c
+++ b/src/interfaces/libpq/fe-auth-scram.c
@@ -119,6 +119,8 @@ scram_init(PGconn *conn,
                return NULL;
        }
 
+       if (password)
+       {
        /* Normalize the password with SASLprep, if possible */
        rc = pg_saslprep(password, &prep_password);
        if (rc == SASLPREP_OOM)
@@ -138,6 +140,7 @@ scram_init(PGconn *conn,
                }
        }
        state->password = prep_password;
+       }
 
        return state;
 }
@@ -775,6 +778,12 @@ calculate_client_proof(fe_scram_state *state,
                return false;
        }
 
+       if (state->conn->scram_client_key_binary)
+       {
+               memcpy(ClientKey, state->conn->scram_client_key_binary, 
SCRAM_MAX_KEY_LEN);
+       }
+       else
+       {
        /*
         * Calculate SaltedPassword, and store it in 'state' so that we can 
reuse
         * it later in verify_server_signature.
@@ -783,15 +792,20 @@ calculate_client_proof(fe_scram_state *state,
                                                         state->key_length, 
state->salt, state->saltlen,
                                                         state->iterations, 
state->SaltedPassword,
                                                         errstr) < 0 ||
-               scram_ClientKey(state->SaltedPassword, state->hash_type,
-                                               state->key_length, ClientKey, 
errstr) < 0 ||
-               scram_H(ClientKey, state->hash_type, state->key_length,
-                               StoredKey, errstr) < 0)
+                       scram_ClientKey(state->SaltedPassword, state->hash_type,
+                                               state->key_length, ClientKey, 
errstr) < 0)
        {
                /* errstr is already filled here */
                pg_hmac_free(ctx);
                return false;
        }
+       }
+
+       if (scram_H(ClientKey, state->hash_type, state->key_length, StoredKey, 
errstr)  < 0)
+       {
+               pg_hmac_free(ctx);
+               return false;
+       }
 
        if (pg_hmac_init(ctx, StoredKey, state->key_length) < 0 ||
                pg_hmac_update(ctx,
@@ -841,6 +855,12 @@ verify_server_signature(fe_scram_state *state, bool *match,
                return false;
        }
 
+       if (state->conn->scram_server_key_binary)
+       {
+               memcpy(ServerKey, state->conn->scram_server_key_binary, 
SCRAM_MAX_KEY_LEN);
+       }
+       else
+       {
        if (scram_ServerKey(state->SaltedPassword, state->hash_type,
                                                state->key_length, ServerKey, 
errstr) < 0)
        {
@@ -848,6 +868,7 @@ verify_server_signature(fe_scram_state *state, bool *match,
                pg_hmac_free(ctx);
                return false;
        }
+       }
 
        /* calculate ServerSignature */
        if (pg_hmac_init(ctx, ServerKey, state->key_length) < 0 ||
diff --git a/src/interfaces/libpq/fe-auth.c b/src/interfaces/libpq/fe-auth.c
index 20d3427e94..ef1c965cd5 100644
--- a/src/interfaces/libpq/fe-auth.c
+++ b/src/interfaces/libpq/fe-auth.c
@@ -559,7 +559,7 @@ pg_SASL_init(PGconn *conn, int payloadlen)
         * First, select the password to use for the exchange, complaining if
         * there isn't one and the selected SASL mechanism needs it.
         */
-       if (conn->password_needed)
+       if (conn->password_needed && !conn->scram_client_key_binary)
        {
                password = conn->connhost[conn->whichhost].password;
                if (password == NULL)
diff --git a/src/interfaces/libpq/fe-connect.c 
b/src/interfaces/libpq/fe-connect.c
index aaf87e8e88..464cefd901 100644
--- a/src/interfaces/libpq/fe-connect.c
+++ b/src/interfaces/libpq/fe-connect.c
@@ -22,6 +22,7 @@
 #include <time.h>
 #include <unistd.h>
 
+#include "common/base64.h"
 #include "common/ip.h"
 #include "common/link-canary.h"
 #include "common/scram-common.h"
@@ -365,6 +366,12 @@ static const internalPQconninfoOption PQconninfoOptions[] 
= {
                "Load-Balance-Hosts", "", 8,    /* sizeof("disable") = 8 */
        offsetof(struct pg_conn, load_balance_hosts)},
 
+       {"scram_client_key", NULL, NULL, NULL, "SCRAM-Client-Key", "D", 
SCRAM_MAX_KEY_LEN * 2,
+       offsetof(struct pg_conn, scram_client_key)},
+
+       {"scram_server_key", NULL, NULL, NULL, "SCRAM-Server-Key", "D", 
SCRAM_MAX_KEY_LEN * 2,
+       offsetof(struct pg_conn, scram_server_key)},
+
        /* Terminating entry --- MUST BE LAST */
        {NULL, NULL, NULL, NULL,
        NULL, NULL, 0}
@@ -1792,6 +1799,28 @@ pqConnectOptions2(PGconn *conn)
        else
                conn->target_server_type = SERVER_TYPE_ANY;
 
+       if (conn->scram_client_key)
+       {
+               int                     len;
+
+               len = pg_b64_dec_len(strlen(conn->scram_client_key));
+               conn->scram_client_key_len = len;
+               conn->scram_client_key_binary = malloc(len);
+               pg_b64_decode(conn->scram_client_key, 
strlen(conn->scram_client_key),
+                                         conn->scram_client_key_binary, len);
+       }
+
+       if (conn->scram_server_key)
+       {
+               int                     len;
+
+               len = pg_b64_dec_len(strlen(conn->scram_server_key));
+               conn->scram_server_key_len = len;
+               conn->scram_server_key_binary = malloc(len);
+               pg_b64_decode(conn->scram_server_key, 
strlen(conn->scram_server_key),
+                                         conn->scram_server_key_binary, len);
+       }
+
        /*
         * validate load_balance_hosts option, and set load_balance_type
         */
@@ -4703,6 +4732,8 @@ freePGconn(PGconn *conn)
        free(conn->rowBuf);
        free(conn->target_session_attrs);
        free(conn->load_balance_hosts);
+       free(conn->scram_client_key);
+       free(conn->scram_server_key);
        termPQExpBuffer(&conn->errorMessage);
        termPQExpBuffer(&conn->workBuffer);
 
diff --git a/src/interfaces/libpq/libpq-int.h b/src/interfaces/libpq/libpq-int.h
index 08cc391cbd..17b81c81f4 100644
--- a/src/interfaces/libpq/libpq-int.h
+++ b/src/interfaces/libpq/libpq-int.h
@@ -427,6 +427,8 @@ struct pg_conn
        char       *target_session_attrs;       /* desired session properties */
        char       *require_auth;       /* name of the expected auth method */
        char       *load_balance_hosts; /* load balance over hosts */
+       char       *scram_client_key;
+       char       *scram_server_key;
 
        bool            cancelRequest;  /* true if this connection is used to 
send a
                                                                 * cancel 
request, instead of being a normal
@@ -517,6 +519,10 @@ struct pg_conn
        AddrInfo   *addr;                       /* the array of addresses for 
the currently
                                                                 * tried host */
        bool            send_appname;   /* okay to send application_name? */
+       size_t          scram_client_key_len;
+       char       *scram_client_key_binary;
+       size_t          scram_server_key_len;
+       char       *scram_server_key_binary;
 
        /* Miscellaneous stuff */
        int                     be_pid;                 /* PID of backend --- 
needed for cancels */
-- 
2.39.3 (Apple Git-146)

Reply via email to