Hi  Hackers,

We are trying to implement AAD(Azure AD) support in PostgreSQL and it
can be achieved with support of the OAuth method. To support AAD on
top of OAuth in a generic fashion (i.e for all other OAuth providers),
we are proposing this patch. It basically exposes two new hooks (one
for error reporting and one for OAuth provider specific token
validation) and passing OAuth bearer token to backend. It also adds
support for client credentials flow of OAuth additional to device code
flow which Jacob has proposed.

The changes for each component are summarized below.

1.     Provider-specific extension:
        Each OAuth provider implements their own token validator as an
extension. Extension registers an OAuth provider hook which is matched
to a line in the HBA file.

2.     Add support to pass on the OAuth bearer token. In this
obtaining the bearer token is left to 3rd party application or user.

        ./psql -U <username> -d 'dbname=postgres
oauth_client_id=<client_id> oauth_bearer_token=<token>

3.     HBA: An additional param ‘provider’ is added for the oauth method.
        Defining "oauth" as method + passing provider, issuer endpoint
and expected audience

        * * * * oauth   provider=<token validation extension>
issuer=.... scope=....

4.     Engine Backend:
        Support for generic OAUTHBEARER type, requesting client to
provide token and passing to token for provider-specific extension.

5.     Engine Frontend: Two-tiered approach.
           a)      libpq transparently passes on the token received
from 3rd party client as is to the backend.
           b)      libpq optionally compiled for the clients which
explicitly need libpq to orchestrate OAuth communication with the
issuer (it depends heavily on 3rd party library iddawc as Jacob
already pointed out. The library seems to be supporting all the OAuth
flows.)

Please let us know your thoughts as the proposed method supports
different OAuth flows with the use of provider specific hooks. We
think that the proposal would be useful for various OAuth providers.

Thanks,
Mahendrakar.


On Tue, 20 Sept 2022 at 10:18, Jacob Champion <pchamp...@vmware.com> wrote:
>
> On Tue, 2021-06-22 at 23:22 +0000, Jacob Champion wrote:
> > On Fri, 2021-06-18 at 11:31 +0300, Heikki Linnakangas wrote:
> > >
> > > A few small things caught my eye in the backend oauth_exchange function:
> > >
> > > > +       /* Handle the client's initial message. */
> > > > +       p = strdup(input);
> > >
> > > this strdup() should be pstrdup().
> >
> > Thanks, I'll fix that in the next re-roll.
> >
> > > In the same function, there are a bunch of reports like this:
> > >
> > > >                    ereport(ERROR,
> > > > +                          (errcode(ERRCODE_PROTOCOL_VIOLATION),
> > > > +                           errmsg("malformed OAUTHBEARER message"),
> > > > +                           errdetail("Comma expected, but found 
> > > > character \"%s\".",
> > > > +                                     sanitize_char(*p))));
> > >
> > > I don't think the double quotes are needed here, because sanitize_char
> > > will return quotes if it's a single character. So it would end up
> > > looking like this: ... found character "'x'".
> >
> > I'll fix this too. Thanks!
>
> v2, attached, incorporates Heikki's suggested fixes and also rebases on
> top of latest HEAD, which had the SASL refactoring changes committed
> last month.
>
> The biggest change from the last patchset is 0001, an attempt at
> enabling jsonapi in the frontend without the use of palloc(), based on
> suggestions by Michael and Tom from last commitfest. I've also made
> some improvements to the pytest suite. No major changes to the OAuth
> implementation yet.
>
> --Jacob
diff --git a/src/backend/libpq/auth-oauth.c b/src/backend/libpq/auth-oauth.c
index c47211132c..86f820482b 100644
--- a/src/backend/libpq/auth-oauth.c
+++ b/src/backend/libpq/auth-oauth.c
@@ -24,7 +24,9 @@
 #include "libpq/hba.h"
 #include "libpq/oauth.h"
 #include "libpq/sasl.h"
+#include "miscadmin.h"
 #include "storage/fd.h"
+#include "utils/memutils.h"
 
 /* GUC */
 char *oauth_validator_command;
@@ -34,6 +36,13 @@ static void *oauth_init(Port *port, const char *selected_mech, const char *shado
 static int   oauth_exchange(void *opaq, const char *input, int inputlen,
 							char **output, int *outputlen, char **logdetail);
 
+/*----------------------------------------------------------------
+ * OAuth Authentication
+ *----------------------------------------------------------------
+ */
+static List *oauth_providers = NIL;
+static OAuthProvider* oauth_provider = NULL;
+
 /* Mechanism declaration */
 const pg_be_sasl_mech pg_be_oauth_mech = {
 	oauth_get_mechanisms,
@@ -63,15 +72,90 @@ static char *sanitize_char(char c);
 static char *parse_kvpairs_for_auth(char **input);
 static void generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen);
 static bool validate(Port *port, const char *auth, char **logdetail);
-static bool run_validator_command(Port *port, const char *token);
+static const char* run_validator_command(Port *port, const char *token);
 static bool check_exit(FILE **fh, const char *command);
 static bool unset_cloexec(int fd);
-static bool username_ok_for_shell(const char *username);
 
 #define KVSEP 0x01
 #define AUTH_KEY "auth"
 #define BEARER_SCHEME "Bearer "
 
+/*----------------------------------------------------------------
+ * OAuth Token Validator
+ *----------------------------------------------------------------
+ */
+
+/*
+ * RegisterOAuthProvider registers a OAuth Token Validator to be
+ * used for oauth token validation. It validates the token and adds the valiator
+ * name and it's hooks to a list of loaded token validator. The right validator's
+ * hooks can then be called based on the validator name specified in
+ * pg_hba.conf.
+ *
+ * This function should be called in _PG_init() by any extension looking to
+ * add a custom authentication method.
+ */
+void
+RegisterOAuthProvider(
+	const char *provider_name,
+	OAuthProviderCheck_hook_type OAuthProviderCheck_hook,
+	OAuthProviderError_hook_type OAuthProviderError_hook	
+)
+{
+	if (!process_shared_preload_libraries_in_progress)
+	{
+		ereport(ERROR,
+			(errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE),
+			errmsg("RegisterOAuthProvider can only be called by a shared_preload_library")));
+		return;
+	}
+
+	MemoryContext oldcxt;
+	if (oauth_provider == NULL)
+	{
+		oldcxt = MemoryContextSwitchTo(TopMemoryContext);
+		oauth_provider = palloc(sizeof(OAuthProvider));
+		oauth_provider->name = pstrdup(provider_name);
+		oauth_provider->oauth_provider_hook = OAuthProviderCheck_hook;
+		oauth_provider->oauth_error_hook = OAuthProviderError_hook;		
+		oauth_providers = lappend(oauth_providers, oauth_provider);
+		MemoryContextSwitchTo(oldcxt);
+	}
+	else
+	{
+		if (oauth_provider && oauth_provider->name)
+		{
+			ereport(ERROR,
+				(errmsg("OAuth provider \"%s\" is already loaded.",
+					oauth_provider->name)));
+		}
+		else
+		{
+			ereport(ERROR,
+				(errmsg("OAuth provider is already loaded.")));
+		}
+	}
+}
+
+/*
+ * Returns the oauth provider (which includes it's
+ * callback functions) based on name specified.
+ */
+OAuthProvider *get_provider_by_name(const char *name)
+{
+	ListCell *lc;
+	foreach(lc, oauth_providers)
+	{
+		OAuthProvider *provider = (OAuthProvider *) lfirst(lc);
+		if (strcmp(provider->name, name) == 0)
+		{
+			return provider;
+		}
+	}
+
+	return NULL;
+}
+
 static void
 oauth_get_mechanisms(Port *port, StringInfo buf)
 {
@@ -494,17 +578,17 @@ validate(Port *port, const char *auth, char **logdetail)
 	}
 
 	/* Have the validator check the token. */
-	if (!run_validator_command(port, token))
+	if (run_validator_command(port, token) == NULL)
 		return false;
-
+	
 	if (port->hba->oauth_skip_usermap)
 	{
 		/*
-		 * If the validator is our authorization authority, we're done.
-		 * Authentication may or may not have been performed depending on the
-		 * validator implementation; all that matters is that the validator says
-		 * the user can log in with the target role.
-		 */
+	 	* If the validator is our authorization authority, we're done.
+	 	* Authentication may or may not have been performed depending on the
+	 	* validator implementation; all that matters is that the validator says
+	 	* the user can log in with the target role.
+	 	*/
 		return true;
 	}
 
@@ -524,193 +608,26 @@ validate(Port *port, const char *auth, char **logdetail)
 	return (ret == STATUS_OK);
 }
 
-static bool
+static const char*
 run_validator_command(Port *port, const char *token)
 {
-	bool		success = false;
-	int			rc;
-	int			pipefd[2];
-	int			rfd = -1;
-	int			wfd = -1;
-
-	StringInfoData command = { 0 };
-	char	   *p;
-	FILE	   *fh = NULL;
-
-	ssize_t		written;
-	char	   *line = NULL;
-	size_t		size = 0;
-	ssize_t		len;
-
-	Assert(oauth_validator_command);
-
-	if (!oauth_validator_command[0])
-	{
-		ereport(COMMERROR,
-				(errmsg("oauth_validator_command is not set"),
-				 errhint("To allow OAuth authenticated connections, set "
-						 "oauth_validator_command in postgresql.conf.")));
-		return false;
-	}
-
-	/*
-	 * Since popen() is unidirectional, open up a pipe for the other direction.
-	 * Use CLOEXEC to ensure that our write end doesn't accidentally get copied
-	 * into child processes, which would prevent us from closing it cleanly.
-	 *
-	 * XXX this is ugly. We should just read from the child process's stdout,
-	 * but that's a lot more code.
-	 * XXX by bypassing the popen API, we open the potential of process
-	 * deadlock. Clearly document child process requirements (i.e. the child
-	 * MUST read all data off of the pipe before writing anything).
-	 * TODO: port to Windows using _pipe().
-	 */
-	rc = pipe2(pipefd, O_CLOEXEC);
-	if (rc < 0)
+	if(oauth_provider->oauth_provider_hook == NULL)
 	{
-		ereport(COMMERROR,
-				(errcode_for_file_access(),
-				 errmsg("could not create child pipe: %m")));
 		return false;
 	}
 
-	rfd = pipefd[0];
-	wfd = pipefd[1];
-
-	/* Allow the read pipe be passed to the child. */
-	if (!unset_cloexec(rfd))
+	char *id = oauth_provider->
+			   oauth_provider_hook(port, token);
+	if(id == NULL)
 	{
-		/* error message was already logged */
-		goto cleanup;
-	}
-
-	/*
-	 * Construct the command, substituting any recognized %-specifiers:
-	 *
-	 *   %f: the file descriptor of the input pipe
-	 *   %r: the role that the client wants to assume (port->user_name)
-	 *   %%: a literal '%'
-	 */
-	initStringInfo(&command);
-
-	for (p = oauth_validator_command; *p; p++)
-	{
-		if (p[0] == '%')
-		{
-			switch (p[1])
-			{
-				case 'f':
-					appendStringInfo(&command, "%d", rfd);
-					p++;
-					break;
-				case 'r':
-					/*
-					 * TODO: decide how this string should be escaped. The role
-					 * is controlled by the client, so if we don't escape it,
-					 * command injections are inevitable.
-					 *
-					 * This is probably an indication that the role name needs
-					 * to be communicated to the validator process in some other
-					 * way. For this proof of concept, just be incredibly strict
-					 * about the characters that are allowed in user names.
-					 */
-					if (!username_ok_for_shell(port->user_name))
-						goto cleanup;
-
-					appendStringInfoString(&command, port->user_name);
-					p++;
-					break;
-				case '%':
-					appendStringInfoChar(&command, '%');
-					p++;
-					break;
-				default:
-					appendStringInfoChar(&command, p[0]);
-			}
-		}
-		else
-			appendStringInfoChar(&command, p[0]);
-	}
-
-	/* Execute the command. */
-	fh = OpenPipeStream(command.data, "re");
-	/* TODO: handle failures */
-
-	/* We don't need the read end of the pipe anymore. */
-	close(rfd);
-	rfd = -1;
-
-	/* Give the command the token to validate. */
-	written = write(wfd, token, strlen(token));
-	if (written != strlen(token))
-	{
-		/* TODO must loop for short writes, EINTR et al */
-		ereport(COMMERROR,
-				(errcode_for_file_access(),
-				 errmsg("could not write token to child pipe: %m")));
-		goto cleanup;
-	}
-
-	close(wfd);
-	wfd = -1;
-
-	/*
-	 * Read the command's response.
-	 *
-	 * TODO: getline() is probably too new to use, unfortunately.
-	 * TODO: loop over all lines
-	 */
-	if ((len = getline(&line, &size, fh)) >= 0)
-	{
-		/* TODO: fail if the authn_id doesn't end with a newline */
-		if (len > 0)
-			line[len - 1] = '\0';
-
-		set_authn_id(port, line);
-	}
-	else if (ferror(fh))
-	{
-		ereport(COMMERROR,
-				(errcode_for_file_access(),
-				 errmsg("could not read from command \"%s\": %m",
-						command.data)));
-		goto cleanup;
-	}
-
-	/* Make sure the command exits cleanly. */
-	if (!check_exit(&fh, command.data))
-	{
-		/* error message already logged */
-		goto cleanup;
-	}
-
-	/* Done. */
-	success = true;
-
-cleanup:
-	if (line)
-		free(line);
-
-	/*
-	 * In the successful case, the pipe fds are already closed. For the error
-	 * case, always close out the pipe before waiting for the command, to
-	 * prevent deadlock.
-	 */
-	if (rfd >= 0)
-		close(rfd);
-	if (wfd >= 0)
-		close(wfd);
-
-	if (fh)
-	{
-		Assert(!success);
-		check_exit(&fh, command.data);
+		ereport(LOG,
+				(errmsg("OAuth bearer token validation failed" )));
+		return NULL;
 	}
 
-	if (command.data)
-		pfree(command.data);
-
-	return success;
+	set_authn_id(port, id);
+	
+	return id;
 }
 
 static bool
@@ -769,29 +686,3 @@ unset_cloexec(int fd)
 
 	return true;
 }
-
-/*
- * XXX This should go away eventually and be replaced with either a proper
- * escape or a different strategy for communication with the validator command.
- */
-static bool
-username_ok_for_shell(const char *username)
-{
-	/* This set is borrowed from fe_utils' appendShellStringNoError(). */
-	static const char * const allowed = "abcdefghijklmnopqrstuvwxyz"
-										"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
-										"0123456789-_./:";
-	size_t	span;
-
-	Assert(username && username[0]); /* should have already been checked */
-
-	span = strspn(username, allowed);
-	if (username[span] != '\0')
-	{
-		ereport(COMMERROR,
-				(errmsg("PostgreSQL user name contains unsafe characters and cannot be passed to the OAuth validator")));
-		return false;
-	}
-
-	return true;
-}
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index 333051ad3c..0bbcf231d2 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -296,8 +296,14 @@ auth_failed(Port *port, int status, const char *logdetail)
 			errstr = gettext_noop("RADIUS authentication failed for user \"%s\"");
 			break;
 		case uaOAuth:
-			errstr = gettext_noop("OAuth bearer authentication failed for user \"%s\"");
-			break;
+			{
+				OAuthProvider *provider = get_provider_by_name(port->hba->oauth_provider);
+				if(provider->oauth_error_hook)
+					errstr = provider->oauth_error_hook(port);
+				else
+					errstr = gettext_noop("OAuth bearer authentication failed for user \"%s\"");
+				break;
+			}
 		default:
 			errstr = gettext_noop("authentication failed for user \"%s\": invalid authentication method");
 			break;
diff --git a/src/backend/libpq/hba.c b/src/backend/libpq/hba.c
index 943e78ddff..94fb5d434d 100644
--- a/src/backend/libpq/hba.c
+++ b/src/backend/libpq/hba.c
@@ -1663,6 +1663,14 @@ parse_hba_line(TokenizedAuthLine *tok_line, int elevel)
 		parsedline->clientcert = clientCertFull;
 	}
 
+	/*
+	 * Ensure that the token validation provider name is specified as provider for oauth method.
+	 */
+	if (parsedline->auth_method == uaOAuth)
+	{
+		MANDATORY_AUTH_ARG(parsedline->oauth_provider, "provider", "oauth");
+	}
+
 	return parsedline;
 }
 
@@ -2095,6 +2103,31 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
 		else
 			hbaline->oauth_skip_usermap = false;
 	}
+	else if (strcmp(name, "provider") == 0)
+	{
+		REQUIRE_AUTH_OPTION(uaOAuth, "provider", "oauth");
+		if (hbaline->auth_method != uaOAuth)
+			INVALID_AUTH_OPTION("provider", gettext_noop("oauth"));
+		/*
+		 * Verify that the token validation mentioned is loaded via shared_preload_libraries.
+		 */
+		if (get_provider_by_name(val) == NULL)
+		{
+			ereport(elevel,
+					(errcode(ERRCODE_CONFIG_FILE_ERROR),
+					 errmsg("cannot use oauth provider %s",val),
+					 errhint("Load provider token validation via shared_preload_libraries."),
+					 errcontext("line %d of configuration file \"%s\"",
+								line_num, HbaFileName)));
+			*err_msg = psprintf("cannot use oauth provider %s", val);
+
+			return false;
+		}
+		else
+		{
+			hbaline->oauth_provider = pstrdup(val);
+		}
+	}
 	else
 	{
 		ereport(elevel,
diff --git a/src/include/libpq/auth.h b/src/include/libpq/auth.h
index 485e48970e..938ac399dc 100644
--- a/src/include/libpq/auth.h
+++ b/src/include/libpq/auth.h
@@ -44,4 +44,29 @@ extern void set_authn_id(Port *port, const char *id);
 typedef void (*ClientAuthentication_hook_type) (Port *, int);
 extern PGDLLIMPORT ClientAuthentication_hook_type ClientAuthentication_hook;
 
+/* Declarations for oAuth authentication providers */
+typedef const char* (*OAuthProviderCheck_hook_type) (Port *, const char*);
+
+/* Hook for plugins to report error messages in validation_failed() */
+typedef const char * (*OAuthProviderError_hook_type) (Port *);
+
+/* Hook for plugins to validate oauth provider options */
+typedef bool (*OAuthProviderValidateOptions_hook_type)
+			 (char *, char *, HbaLine *, char **);
+
+typedef struct OAuthProvider
+{
+	const char *name;
+	OAuthProviderCheck_hook_type oauth_provider_hook;
+	OAuthProviderError_hook_type oauth_error_hook;	
+} OAuthProvider;
+
+extern void RegisterOAuthProvider
+		(const char *provider_name,
+		OAuthProviderCheck_hook_type OAuthProviderCheck_hook,
+		OAuthProviderError_hook_type OAuthProviderError_hook
+		);
+
+extern OAuthProvider *get_provider_by_name(const char *name);
+
 #endif							/* AUTH_H */
diff --git a/src/include/libpq/hba.h b/src/include/libpq/hba.h
index c1b1313989..d65395cc22 100644
--- a/src/include/libpq/hba.h
+++ b/src/include/libpq/hba.h
@@ -123,6 +123,7 @@ typedef struct HbaLine
 	char	   *radiusports_s;
 	char	   *oauth_issuer;
 	char	   *oauth_scope;
+	char       *oauth_provider;
 	bool		oauth_skip_usermap;
 } HbaLine;
 
diff --git a/src/interfaces/libpq/fe-auth-oauth.c b/src/interfaces/libpq/fe-auth-oauth.c
index 91d2c69f16..61a0b80b7e 100644
--- a/src/interfaces/libpq/fe-auth-oauth.c
+++ b/src/interfaces/libpq/fe-auth-oauth.c
@@ -174,6 +174,16 @@ get_auth_token(PGconn *conn)
 	if (!token_buf)
 		goto cleanup;
 
+	if(conn->oauth_bearer_token)
+	{
+		appendPQExpBufferStr(token_buf, "Bearer ");
+		appendPQExpBufferStr(token_buf, conn->oauth_bearer_token);
+		if (PQExpBufferBroken(token_buf))
+			goto cleanup;
+		token = strdup(token_buf->data);
+		goto cleanup;
+	}
+
 	err = i_set_str_parameter(&session, I_OPT_OPENID_CONFIG_ENDPOINT, conn->oauth_discovery_uri);
 	if (err)
 	{
@@ -201,18 +211,22 @@ get_auth_token(PGconn *conn)
 							 libpq_gettext("issuer does not support device authorization"));
 		goto cleanup;
 	}
+	
+	//default device flow
+	int session_response_type = I_RESPONSE_TYPE_DEVICE_CODE; 
+	auth_method = I_TOKEN_AUTH_METHOD_NONE;
+	if (conn->oauth_client_secret && *conn->oauth_client_secret)
+	{
+		auth_method = I_TOKEN_AUTH_METHOD_SECRET_BASIC;
+	}
 
-	err = i_set_response_type(&session, I_RESPONSE_TYPE_DEVICE_CODE);
+	err = i_set_response_type(&session, session_response_type);
 	if (err)
 	{
 		iddawc_error(conn, err, "failed to set device code response type");
 		goto cleanup;
 	}
 
-	auth_method = I_TOKEN_AUTH_METHOD_NONE;
-	if (conn->oauth_client_secret && *conn->oauth_client_secret)
-		auth_method = I_TOKEN_AUTH_METHOD_SECRET_BASIC;
-
 	err = i_set_parameter_list(&session,
 		I_OPT_CLIENT_ID, conn->oauth_client_id,
 		I_OPT_CLIENT_SECRET, conn->oauth_client_secret,
@@ -250,6 +264,18 @@ get_auth_token(PGconn *conn)
 		goto cleanup;
 	}
 
+	if (conn->oauth_client_secret && *conn->oauth_client_secret)
+	{
+		session_response_type = I_RESPONSE_TYPE_CLIENT_CREDENTIALS;
+	}
+	
+	err = i_set_response_type(&session, session_response_type);
+	if (err)
+	{
+		iddawc_error(conn, err, "failed to set session response type");
+		goto cleanup;
+	}
+
 	/*
 	 * Poll the token endpoint until either the user logs in and authorizes the
 	 * use of a token, or a hard failure occurs. We perform one ping _before_
diff --git a/src/interfaces/libpq/fe-connect.c b/src/interfaces/libpq/fe-connect.c
index 2ff450ce05..5d804c8c0d 100644
--- a/src/interfaces/libpq/fe-connect.c
+++ b/src/interfaces/libpq/fe-connect.c
@@ -361,6 +361,10 @@ static const internalPQconninfoOption PQconninfoOptions[] = {
 		"OAuth-Scope", "", 15,
 	offsetof(struct pg_conn, oauth_scope)},
 
+	{"oauth_bearer_token", NULL, NULL, NULL,
+		"OAuth-Bearer", "", 20,
+	offsetof(struct pg_conn, oauth_bearer_token)},
+
 	/* Terminating entry --- MUST BE LAST */
 	{NULL, NULL, NULL, NULL,
 	NULL, NULL, 0}
@@ -4200,6 +4204,8 @@ freePGconn(PGconn *conn)
 		free(conn->oauth_discovery_uri);
 	if (conn->oauth_client_id)
 		free(conn->oauth_client_id);
+	if(conn->oauth_bearer_token)
+		free(conn->oauth_bearer_token);
 	if (conn->oauth_client_secret)
 		free(conn->oauth_client_secret);
 	if (conn->oauth_scope)
diff --git a/src/interfaces/libpq/libpq-int.h b/src/interfaces/libpq/libpq-int.h
index 1b4de3dff0..91e71afe14 100644
--- a/src/interfaces/libpq/libpq-int.h
+++ b/src/interfaces/libpq/libpq-int.h
@@ -402,6 +402,7 @@ struct pg_conn
 	char	   *oauth_client_id;		/* client identifier */
 	char	   *oauth_client_secret;	/* client secret */
 	char	   *oauth_scope;			/* access token scope */
+	char       *oauth_bearer_token;		/* oauth token */
 	bool		oauth_want_retry;		/* should we retry on failure? */
 
 	/* Optional file to write trace info to */

Reply via email to