On Wed, Aug 10, 2022 at 10:48 PM Drouvot, Bertrand <bdrou...@amazon.com> wrote:
> What do you think about adding a second field in ClientConnectionInfo
> for the auth method (as suggested by Michael upthread)?

Sure -- without a followup patch, it's not really tested, though.

v2 adjusts set_authn_id() to copy the auth_method over as well. It
"passes tests" but is otherwise unexercised.

Thanks,
--Jacob
commit 69cacd5e0869b18d64ff4233ef6a73123c513496
Author: Jacob Champion <jchamp...@timescale.com>
Date:   Thu Aug 11 15:16:15 2022 -0700

    squash! Allow parallel workers to read authn_id
    
    Add a copy of hba->auth_method to ClientConnectionInfo when
    set_authn_id() is called.

diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index 313a6ea701..9113f04189 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -333,9 +333,9 @@ auth_failed(Port *port, int status, const char *logdetail)
 
 
 /*
- * Sets the authenticated identity for the current user.  The provided string
- * will be copied into the TopMemoryContext.  The ID will be logged if
- * log_connections is enabled.
+ * Sets the authenticated identity for the current user. The provided string
+ * will be stored into MyClientConnectionInfo, alongside the current HBA method
+ * in use. The ID will be logged if log_connections is enabled.
  *
  * Auth methods should call this routine exactly once, as soon as the user is
  * successfully authenticated, even if they have reasons to know that
@@ -365,6 +365,7 @@ set_authn_id(Port *port, const char *id)
        }
 
        MyClientConnectionInfo.authn_id = MemoryContextStrdup(TopMemoryContext, 
id);
+       MyClientConnectionInfo.auth_method = port->hba->auth_method;
 
        if (Log_connections)
        {
@@ -372,8 +373,8 @@ set_authn_id(Port *port, const char *id)
                                errmsg("connection authenticated: 
identity=\"%s\" method=%s "
                                           "(%s:%d)",
                                           MyClientConnectionInfo.authn_id,
-                                          
hba_authname(port->hba->auth_method), HbaFileName,
-                                          port->hba->linenumber));
+                                          
hba_authname(MyClientConnectionInfo.auth_method),
+                                          HbaFileName, port->hba->linenumber));
        }
 }
 
diff --git a/src/backend/utils/init/miscinit.c 
b/src/backend/utils/init/miscinit.c
index 973103374b..155ba92c67 100644
--- a/src/backend/utils/init/miscinit.c
+++ b/src/backend/utils/init/miscinit.c
@@ -954,6 +954,8 @@ EstimateClientConnectionInfoSpace(void)
        if (MyClientConnectionInfo.authn_id)
                size = add_size(size, strlen(MyClientConnectionInfo.authn_id) + 
1);
 
+       size = add_size(size, sizeof(UserAuth));
+
        return size;
 }
 
@@ -981,6 +983,15 @@ SerializeClientConnectionInfo(Size maxsize, char 
*start_address)
                maxsize -= len;
                start_address += len;
        }
+
+       {
+               UserAuth           *auth_method = (UserAuth*) start_address;
+
+               Assert(sizeof(*auth_method) <= maxsize);
+               *auth_method = MyClientConnectionInfo.auth_method;
+               maxsize -= sizeof(*auth_method);
+               start_address += sizeof(*auth_method);
+       }
 }
 
 /*
@@ -1001,6 +1012,13 @@ RestoreClientConnectionInfo(char *conninfo)
                                                                                
                                          conninfo);
                conninfo += strlen(conninfo) + 1;
        }
+
+       {
+               UserAuth           *auth_method = (UserAuth*) conninfo;
+
+               MyClientConnectionInfo.auth_method = *auth_method;
+               conninfo += sizeof(*auth_method);
+       }
 }
 
 
diff --git a/src/include/libpq/libpq-be.h b/src/include/libpq/libpq-be.h
index c900411fdd..0643733765 100644
--- a/src/include/libpq/libpq-be.h
+++ b/src/include/libpq/libpq-be.h
@@ -111,7 +111,7 @@ typedef struct
 {
        /*
         * Authenticated identity.  The meaning of this identifier is dependent 
on
-        * hba->auth_method; it is the identity (if any) that the user presented
+        * auth_method; it is the identity (if any) that the user presented
         * during the authentication cycle, before they were assigned a database
         * role.  (It is effectively the "SYSTEM-USERNAME" of a pg_ident usermap
         * -- though the exact string in use may be different, depending on 
pg_hba
@@ -121,6 +121,12 @@ typedef struct
         * example if the "trust" auth method is in use.
         */
        const char *authn_id;
+
+       /*
+        * The HBA method that determined the above authn_id. This only has 
meaning
+        * if authn_id is not NULL; otherwise it's undefined.
+        */
+       UserAuth        auth_method;
 } ClientConnectionInfo;
 
 /*
From 32d465527678ad6ef2f177287c797cd87feba585 Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchamp...@vmware.com>
Date: Wed, 23 Mar 2022 15:07:05 -0700
Subject: [PATCH v2] Allow parallel workers to read authn_id

Move authn_id into a new global, MyClientConnectionInfo, which is
intended to hold all the client information that needs to be shared
between the backend and any parallel workers. MyClientConnectionInfo is
serialized and restored using a new parallel key.

Additionally, make a copy of hba->auth_method in ClientConnectionInfo
when set_authn_id() is called, for use by SYSTEM_USER.
---
 src/backend/access/transam/parallel.c | 19 +++++-
 src/backend/libpq/auth.c              | 25 ++++----
 src/backend/utils/init/miscinit.c     | 90 +++++++++++++++++++++++++++
 src/include/libpq/libpq-be.h          | 45 ++++++++++----
 src/include/miscadmin.h               |  4 ++
 5 files changed, 158 insertions(+), 25 deletions(-)

diff --git a/src/backend/access/transam/parallel.c b/src/backend/access/transam/parallel.c
index df0cd77558..bc93101ff7 100644
--- a/src/backend/access/transam/parallel.c
+++ b/src/backend/access/transam/parallel.c
@@ -76,6 +76,7 @@
 #define PARALLEL_KEY_REINDEX_STATE			UINT64CONST(0xFFFFFFFFFFFF000C)
 #define PARALLEL_KEY_RELMAPPER_STATE		UINT64CONST(0xFFFFFFFFFFFF000D)
 #define PARALLEL_KEY_UNCOMMITTEDENUMS		UINT64CONST(0xFFFFFFFFFFFF000E)
+#define PARALLEL_KEY_CLIENTCONNINFO			UINT64CONST(0xFFFFFFFFFFFF000F)
 
 /* Fixed-size parallel state. */
 typedef struct FixedParallelState
@@ -212,6 +213,7 @@ InitializeParallelDSM(ParallelContext *pcxt)
 	Size		reindexlen = 0;
 	Size		relmapperlen = 0;
 	Size		uncommittedenumslen = 0;
+	Size		clientconninfolen = 0;
 	Size		segsize = 0;
 	int			i;
 	FixedParallelState *fps;
@@ -272,8 +274,10 @@ InitializeParallelDSM(ParallelContext *pcxt)
 		shm_toc_estimate_chunk(&pcxt->estimator, relmapperlen);
 		uncommittedenumslen = EstimateUncommittedEnumsSpace();
 		shm_toc_estimate_chunk(&pcxt->estimator, uncommittedenumslen);
+		clientconninfolen = EstimateClientConnectionInfoSpace();
+		shm_toc_estimate_chunk(&pcxt->estimator, clientconninfolen);
 		/* If you add more chunks here, you probably need to add keys. */
-		shm_toc_estimate_keys(&pcxt->estimator, 11);
+		shm_toc_estimate_keys(&pcxt->estimator, 12);
 
 		/* Estimate space need for error queues. */
 		StaticAssertStmt(BUFFERALIGN(PARALLEL_ERROR_QUEUE_SIZE) ==
@@ -352,6 +356,7 @@ InitializeParallelDSM(ParallelContext *pcxt)
 		char	   *session_dsm_handle_space;
 		char	   *entrypointstate;
 		char	   *uncommittedenumsspace;
+		char	   *clientconninfospace;
 		Size		lnamelen;
 
 		/* Serialize shared libraries we have loaded. */
@@ -422,6 +427,12 @@ InitializeParallelDSM(ParallelContext *pcxt)
 		shm_toc_insert(pcxt->toc, PARALLEL_KEY_UNCOMMITTEDENUMS,
 					   uncommittedenumsspace);
 
+		/* Serialize our ClientConnectionInfo. */
+		clientconninfospace = shm_toc_allocate(pcxt->toc, clientconninfolen);
+		SerializeClientConnectionInfo(clientconninfolen, clientconninfospace);
+		shm_toc_insert(pcxt->toc, PARALLEL_KEY_CLIENTCONNINFO,
+					   clientconninfospace);
+
 		/* Allocate space for worker information. */
 		pcxt->worker = palloc0(sizeof(ParallelWorkerInfo) * pcxt->nworkers);
 
@@ -1270,6 +1281,7 @@ ParallelWorkerMain(Datum main_arg)
 	char	   *reindexspace;
 	char	   *relmapperspace;
 	char	   *uncommittedenumsspace;
+	char	   *clientconninfospace;
 	StringInfoData msgbuf;
 	char	   *session_dsm_handle_space;
 	Snapshot	tsnapshot;
@@ -1479,6 +1491,11 @@ ParallelWorkerMain(Datum main_arg)
 										   false);
 	RestoreUncommittedEnums(uncommittedenumsspace);
 
+	/* Restore the ClientConnectionInfo. */
+	clientconninfospace = shm_toc_lookup(toc, PARALLEL_KEY_CLIENTCONNINFO,
+										 false);
+	RestoreClientConnectionInfo(clientconninfospace);
+
 	/* Attach to the leader's serializable transaction, if SERIALIZABLE. */
 	AttachSerializableXact(fps->serializable_xact_handle);
 
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index 2d9ab7edce..9113f04189 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -333,24 +333,24 @@ auth_failed(Port *port, int status, const char *logdetail)
 
 
 /*
- * Sets the authenticated identity for the current user.  The provided string
- * will be copied into the TopMemoryContext.  The ID will be logged if
- * log_connections is enabled.
+ * Sets the authenticated identity for the current user. The provided string
+ * will be stored into MyClientConnectionInfo, alongside the current HBA method
+ * in use. The ID will be logged if log_connections is enabled.
  *
  * Auth methods should call this routine exactly once, as soon as the user is
  * successfully authenticated, even if they have reasons to know that
  * authorization will fail later.
  *
  * The provided string will be copied into TopMemoryContext, to match the
- * lifetime of the Port, so it is safe to pass a string that is managed by an
- * external library.
+ * lifetime of MyClientConnectionInfo, so it is safe to pass a string that is
+ * managed by an external library.
  */
 static void
 set_authn_id(Port *port, const char *id)
 {
 	Assert(id);
 
-	if (port->authn_id)
+	if (MyClientConnectionInfo.authn_id)
 	{
 		/*
 		 * An existing authn_id should never be overwritten; that means two
@@ -361,18 +361,20 @@ set_authn_id(Port *port, const char *id)
 		ereport(FATAL,
 				(errmsg("authentication identifier set more than once"),
 				 errdetail_log("previous identifier: \"%s\"; new identifier: \"%s\"",
-							   port->authn_id, id)));
+							   MyClientConnectionInfo.authn_id, id)));
 	}
 
-	port->authn_id = MemoryContextStrdup(TopMemoryContext, id);
+	MyClientConnectionInfo.authn_id = MemoryContextStrdup(TopMemoryContext, id);
+	MyClientConnectionInfo.auth_method = port->hba->auth_method;
 
 	if (Log_connections)
 	{
 		ereport(LOG,
 				errmsg("connection authenticated: identity=\"%s\" method=%s "
 					   "(%s:%d)",
-					   port->authn_id, hba_authname(port->hba->auth_method), HbaFileName,
-					   port->hba->linenumber));
+					   MyClientConnectionInfo.authn_id,
+					   hba_authname(MyClientConnectionInfo.auth_method),
+					   HbaFileName, port->hba->linenumber));
 	}
 }
 
@@ -1908,7 +1910,8 @@ auth_peer(hbaPort *port)
 	 */
 	set_authn_id(port, pw->pw_name);
 
-	ret = check_usermap(port->hba->usermap, port->user_name, port->authn_id, false);
+	ret = check_usermap(port->hba->usermap, port->user_name,
+						MyClientConnectionInfo.authn_id, false);
 
 	return ret;
 #else
diff --git a/src/backend/utils/init/miscinit.c b/src/backend/utils/init/miscinit.c
index eb43b2c5e5..155ba92c67 100644
--- a/src/backend/utils/init/miscinit.c
+++ b/src/backend/utils/init/miscinit.c
@@ -931,6 +931,96 @@ GetUserNameFromId(Oid roleid, bool noerr)
 	return result;
 }
 
+/* ------------------------------------------------------------------------
+ *				Parallel connection state
+ *
+ * ClientConnectionInfo contains pieces of information about the client that
+ * need to be synced to parallel workers when they initialize. Over time, this
+ * list will probably grow, and may subsume some of the "user state" variables
+ * above.
+ *-------------------------------------------------------------------------
+ */
+
+ClientConnectionInfo MyClientConnectionInfo;
+
+/*
+ * Calculate the space needed to serialize MyClientConnectionInfo.
+ */
+Size
+EstimateClientConnectionInfoSpace(void)
+{
+	Size		size = 1;
+
+	if (MyClientConnectionInfo.authn_id)
+		size = add_size(size, strlen(MyClientConnectionInfo.authn_id) + 1);
+
+	size = add_size(size, sizeof(UserAuth));
+
+	return size;
+}
+
+/*
+ * Serialize MyClientConnectionInfo for use by parallel workers.
+ */
+void
+SerializeClientConnectionInfo(Size maxsize, char *start_address)
+{
+	/*
+	 * First byte is an indication of whether or not authn_id has been set to
+	 * non-NULL, to differentiate that case from the empty string.
+	 */
+	Assert(maxsize > 0);
+	start_address[0] = MyClientConnectionInfo.authn_id ? 1 : 0;
+	start_address++;
+	maxsize--;
+
+	if (MyClientConnectionInfo.authn_id)
+	{
+		Size len;
+
+		len = strlcpy(start_address, MyClientConnectionInfo.authn_id, maxsize) + 1;
+		Assert(len <= maxsize);
+		maxsize -= len;
+		start_address += len;
+	}
+
+	{
+		UserAuth	   *auth_method = (UserAuth*) start_address;
+
+		Assert(sizeof(*auth_method) <= maxsize);
+		*auth_method = MyClientConnectionInfo.auth_method;
+		maxsize -= sizeof(*auth_method);
+		start_address += sizeof(*auth_method);
+	}
+}
+
+/*
+ * Restore MyClientConnectionInfo from its serialized representation.
+ */
+void
+RestoreClientConnectionInfo(char *conninfo)
+{
+	if (conninfo[0] == 0)
+	{
+		MyClientConnectionInfo.authn_id = NULL;
+		conninfo++;
+	}
+	else
+	{
+		conninfo++;
+		MyClientConnectionInfo.authn_id = MemoryContextStrdup(TopMemoryContext,
+															  conninfo);
+		conninfo += strlen(conninfo) + 1;
+	}
+
+	{
+		UserAuth	   *auth_method = (UserAuth*) conninfo;
+
+		MyClientConnectionInfo.auth_method = *auth_method;
+		conninfo += sizeof(*auth_method);
+	}
+}
+
 
 /*-------------------------------------------------------------------------
  *				Interlock-file support
diff --git a/src/include/libpq/libpq-be.h b/src/include/libpq/libpq-be.h
index 90c20da22b..0643733765 100644
--- a/src/include/libpq/libpq-be.h
+++ b/src/include/libpq/libpq-be.h
@@ -98,6 +98,37 @@ typedef struct
 } pg_gssinfo;
 #endif
 
+/*
+ * Fields describing the client connection, that also need to be copied over to
+ * parallel workers, go into the ClientConnectionInfo rather than Port. The same
+ * rules apply for allocations here as for Port (must be malloc'd or palloc'd in
+ * TopMemoryContext).
+ *
+ * If you add a struct member here, remember to also handle serialization in
+ * SerializeClientConnectionInfo() et al.
+ */
+typedef struct
+{
+	/*
+	 * Authenticated identity.  The meaning of this identifier is dependent on
+	 * auth_method; it is the identity (if any) that the user presented
+	 * during the authentication cycle, before they were assigned a database
+	 * role.  (It is effectively the "SYSTEM-USERNAME" of a pg_ident usermap
+	 * -- though the exact string in use may be different, depending on pg_hba
+	 * options.)
+	 *
+	 * authn_id is NULL if the user has not actually been authenticated, for
+	 * example if the "trust" auth method is in use.
+	 */
+	const char *authn_id;
+
+	/*
+	 * The HBA method that determined the above authn_id. This only has meaning
+	 * if authn_id is not NULL; otherwise it's undefined.
+	 */
+	UserAuth	auth_method;
+} ClientConnectionInfo;
+
 /*
  * This is used by the postmaster in its communication with frontends.  It
  * contains all state information needed during this communication before the
@@ -158,19 +189,6 @@ typedef struct Port
 	 */
 	HbaLine    *hba;
 
-	/*
-	 * Authenticated identity.  The meaning of this identifier is dependent on
-	 * hba->auth_method; it is the identity (if any) that the user presented
-	 * during the authentication cycle, before they were assigned a database
-	 * role.  (It is effectively the "SYSTEM-USERNAME" of a pg_ident usermap
-	 * -- though the exact string in use may be different, depending on pg_hba
-	 * options.)
-	 *
-	 * authn_id is NULL if the user has not actually been authenticated, for
-	 * example if the "trust" auth method is in use.
-	 */
-	const char *authn_id;
-
 	/*
 	 * TCP keepalive and user timeout settings.
 	 *
@@ -327,6 +345,7 @@ extern ssize_t be_gssapi_write(Port *port, void *ptr, size_t len);
 #endif							/* ENABLE_GSS */
 
 extern PGDLLIMPORT ProtocolVersion FrontendProtocol;
+extern PGDLLIMPORT ClientConnectionInfo MyClientConnectionInfo;
 
 /* TCP keepalives configuration. These are no-ops on an AF_UNIX socket. */
 
diff --git a/src/include/miscadmin.h b/src/include/miscadmin.h
index 067b729d5a..3e9297e399 100644
--- a/src/include/miscadmin.h
+++ b/src/include/miscadmin.h
@@ -481,6 +481,10 @@ extern bool has_rolreplication(Oid roleid);
 typedef void (*shmem_request_hook_type) (void);
 extern PGDLLIMPORT shmem_request_hook_type shmem_request_hook;
 
+extern Size EstimateClientConnectionInfoSpace(void);
+extern void SerializeClientConnectionInfo(Size maxsize, char *start_address);
+extern void RestoreClientConnectionInfo(char *procinfo);
+
 /* in executor/nodeHash.c */
 extern size_t get_hash_memory_limit(void);
 
-- 
2.25.1

Reply via email to