Hi all,
(Adding Daniel and Jonathan per recent threads)

While investigating on what it would take to extend SCRAM to use new
hash methods (say like the RFC draft for SCRAM-SHA-512), I have been
quickly reminded of the limitations created by SCRAM_KEY_LEN, which is
the key length that we use in the HMAC and hash computations when
creating a SCRAM verifier or when doing a SASL exchange.

Back in v10 when SCRAM was implemented, we have kept the
implementation simple and took the decision to rely heavily on buffers
with a static size of SCRAM_KEY_LEN during the exchanges.  This was a
good choice back then, because we did not really have a way to handle
errors and there were no need to worry about things like OOMs or even
SHA computations errors.  This was also incorrect in its own ways,
because we failed to go through OpenSSL for the hash/HMAC computations
which would be an issue with FIPS certifications.  This has led to the
introduction of the new cryptohash and HMAC APIs, and the SCRAM code
has been changed so as it is able to know and pass through its layers
any errors (OOM, hash/MAC computation, OpenSSL issue), as of 87ae969
and e6bdfd9.

Taking advantage of all the error stack and logic introduced
previously, it becomes rather straight-forward to remove the
hardcoded assumptions behind SHA-256 in the SCRAM code path.  Attached
is a patch that does so:
- SCRAM_KEY_LEN is removed from all the internal SCRAM routines, this
is replaced by a logic where the hash type and the key length are
stored in fe_scram_state for the frontend and scram_state for the
backend.
- The frontend assigns the hash type and the key length depending on
its choice of SASL mechanism in scram_init()@fe-auth-scram.c.
- The backend assigns the hash type and the key length based on the
parsed SCRAM entry from pg_authid, which works nicely when we need to
handle a raw password for a SASL exchange, for example.  That's
basically what we do now, but scram_state is just changed to work
through it.

We have currently on HEAD 68 references to SCRAM_KEY_LEN.  This brings
down these references to 6, that cannot really be avoided because we
still need to handle SCRAM-SHA-256 one way or another:
- When parsing a SCRAM password from pg_authid (to get a password type
or fill in scram_state in the backend).
- For the mock authentication, where SHA-256 is forced.  We are going
to fail anyway, so any hash would be fine as long as we just let the
user know about the failure transparently.
- When initializing the exchange in libpq based on the SASL mechanism
choice.
- scram-common.h itself.
- When building a verifier in the be-fe interfaces.  These could be
changed as well but I did not see a point in doing so yet.
- SCRAM_KEY_LEN is renamed to SCRAM_SHA_256_KEY_LEN to reflect its
dependency to SHA-256.

With this patch in place, the internals of SCRAM for SASL are
basically able to handle any hash methods.  There is much more to
consider like how we'd need to treat uaSCRAM (separate code for more
hash methods or use the same) or HBA entries, but this removes what I
consider to be 70%~80% of the pain in terms of extensibility with the
current code, and this is enough to be a submission on its own to move
towards more methods.  I am planning to tackle more things in terms of
pluggability after what's done here, btw :)

This patch passes check-world and the CI is green.  I have tested as
well the patch with SCRAM verifiers coming from a server initially on
HEAD, so it looks pretty solid seen from here, being careful of memory
leaks in the frontend, mainly.

Thoughts or comments?
--
Michael
From 378c86619933d9c712730e3d6a105a79854660cf Mon Sep 17 00:00:00 2001
From: Michael Paquier <mich...@paquier.xyz>
Date: Wed, 14 Dec 2022 11:35:33 +0900
Subject: [PATCH] Remove dependency to hash type and key length in internal
 SCRAM code

SCRAM_KEY_LEN had a hard dependency on SHA-256, making difficult the
addition of more hash methods in SCRAM with many statically-sized
buffers, as one problem.
---
 src/include/common/scram-common.h    |  25 ++--
 src/include/libpq/scram.h            |   8 +-
 src/backend/libpq/auth-scram.c       | 165 ++++++++++++++---------
 src/backend/libpq/crypt.c            |  10 +-
 src/common/scram-common.c            | 189 ++++++++++++++++++---------
 src/interfaces/libpq/fe-auth-scram.c | 175 +++++++++++++++++--------
 6 files changed, 380 insertions(+), 192 deletions(-)

diff --git a/src/include/common/scram-common.h b/src/include/common/scram-common.h
index 4acf2a78ad..5b647e4b81 100644
--- a/src/include/common/scram-common.h
+++ b/src/include/common/scram-common.h
@@ -21,7 +21,7 @@
 #define SCRAM_SHA_256_PLUS_NAME "SCRAM-SHA-256-PLUS"	/* with channel binding */
 
 /* Length of SCRAM keys (client and server) */
-#define SCRAM_KEY_LEN				PG_SHA256_DIGEST_LENGTH
+#define SCRAM_SHA_256_KEY_LEN				PG_SHA256_DIGEST_LENGTH
 
 /*
  * Size of random nonce generated in the authentication exchange.  This
@@ -43,17 +43,22 @@
  */
 #define SCRAM_DEFAULT_ITERATIONS	4096
 
-extern int	scram_SaltedPassword(const char *password, const char *salt,
-								 int saltlen, int iterations, uint8 *result,
-								 const char **errstr);
-extern int	scram_H(const uint8 *input, int len, uint8 *result,
+extern int	scram_SaltedPassword(const char *password,
+								 pg_cryptohash_type hash_type, int key_length,
+								 const char *salt, int saltlen, int iterations,
+								 uint8 *result, const char **errstr);
+extern int	scram_H(const uint8 *input, pg_cryptohash_type hash_type,
+					int key_length, uint8 *result,
 					const char **errstr);
-extern int	scram_ClientKey(const uint8 *salted_password, uint8 *result,
-							const char **errstr);
-extern int	scram_ServerKey(const uint8 *salted_password, uint8 *result,
-							const char **errstr);
+extern int	scram_ClientKey(const uint8 *salted_password,
+							pg_cryptohash_type hash_type, int key_length,
+							uint8 *result, const char **errstr);
+extern int	scram_ServerKey(const uint8 *salted_password,
+							pg_cryptohash_type hash_type, int key_length,
+							uint8 *result, const char **errstr);
 
-extern char *scram_build_secret(const char *salt, int saltlen, int iterations,
+extern char *scram_build_secret(pg_cryptohash_type hash_type, int key_length,
+								const char *salt, int saltlen, int iterations,
 								const char *password, const char **errstr);
 
 #endif							/* SCRAM_COMMON_H */
diff --git a/src/include/libpq/scram.h b/src/include/libpq/scram.h
index c51e848c24..2662c9d703 100644
--- a/src/include/libpq/scram.h
+++ b/src/include/libpq/scram.h
@@ -13,6 +13,7 @@
 #ifndef PG_SCRAM_H
 #define PG_SCRAM_H
 
+#include "common/cryptohash.h"
 #include "lib/stringinfo.h"
 #include "libpq/libpq-be.h"
 #include "libpq/sasl.h"
@@ -22,8 +23,11 @@ extern PGDLLIMPORT const pg_be_sasl_mech pg_be_scram_mech;
 
 /* Routines to handle and check SCRAM-SHA-256 secret */
 extern char *pg_be_scram_build_secret(const char *password);
-extern bool parse_scram_secret(const char *secret, int *iterations, char **salt,
-							   uint8 *stored_key, uint8 *server_key);
+extern bool parse_scram_secret(const char *secret,
+							   int *iterations,
+							   pg_cryptohash_type *hash_type,
+							   int *key_length, char **salt,
+							   uint8 **stored_key, uint8 **server_key);
 extern bool scram_verify_plain_password(const char *username,
 										const char *password, const char *secret);
 
diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c
index ee7f52218a..d083439c13 100644
--- a/src/backend/libpq/auth-scram.c
+++ b/src/backend/libpq/auth-scram.c
@@ -141,10 +141,14 @@ typedef struct
 	Port	   *port;
 	bool		channel_binding_in_use;
 
+	/* State data depending on the hash type */
+	pg_cryptohash_type	hash_type;
+	int			key_length;
+
 	int			iterations;
 	char	   *salt;			/* base64-encoded */
-	uint8		StoredKey[SCRAM_KEY_LEN];
-	uint8		ServerKey[SCRAM_KEY_LEN];
+	uint8	   *stored_key;		/* size of key_length */
+	uint8	   *server_key;		/* size of key_length */
 
 	/* Fields of the first message from client */
 	char		cbind_flag;
@@ -155,7 +159,7 @@ typedef struct
 	/* Fields from the last message from client */
 	char	   *client_final_message_without_proof;
 	char	   *client_final_nonce;
-	char		ClientProof[SCRAM_KEY_LEN];
+	char	   *client_proof;	/* size of key_length */
 
 	/* Fields generated in the server */
 	char	   *server_first_message;
@@ -177,12 +181,15 @@ static char *build_server_first_message(scram_state *state);
 static char *build_server_final_message(scram_state *state);
 static bool verify_client_proof(scram_state *state);
 static bool verify_final_nonce(scram_state *state);
-static void mock_scram_secret(const char *username, int *iterations,
-							  char **salt, uint8 *stored_key, uint8 *server_key);
+static void mock_scram_secret(const char *username, pg_cryptohash_type *hash_type,
+							  int *iterations, int *key_length, char **salt,
+							  uint8 **stored_key, uint8 **server_key);
 static bool is_scram_printable(char *p);
 static char *sanitize_char(char c);
 static char *sanitize_str(const char *s);
-static char *scram_mock_salt(const char *username);
+static char *scram_mock_salt(const char *username,
+							 pg_cryptohash_type hash_type,
+							 int key_length);
 
 /*
  * Get a list of SASL mechanisms that this module supports.
@@ -266,8 +273,11 @@ scram_init(Port *port, const char *selected_mech, const char *shadow_pass)
 
 		if (password_type == PASSWORD_TYPE_SCRAM_SHA_256)
 		{
-			if (parse_scram_secret(shadow_pass, &state->iterations, &state->salt,
-								   state->StoredKey, state->ServerKey))
+			if (parse_scram_secret(shadow_pass, &state->iterations,
+								   &state->hash_type, &state->key_length,
+								   &state->salt,
+								   &state->stored_key,
+								   &state->server_key))
 				got_secret = true;
 			else
 			{
@@ -310,8 +320,10 @@ scram_init(Port *port, const char *selected_mech, const char *shadow_pass)
 	 */
 	if (!got_secret)
 	{
-		mock_scram_secret(state->port->user_name, &state->iterations,
-						  &state->salt, state->StoredKey, state->ServerKey);
+		mock_scram_secret(state->port->user_name, &state->hash_type,
+						  &state->iterations, &state->key_length,
+						  &state->salt,
+						  &state->stored_key, &state->server_key);
 		state->doomed = true;
 	}
 
@@ -482,7 +494,8 @@ pg_be_scram_build_secret(const char *password)
 				(errcode(ERRCODE_INTERNAL_ERROR),
 				 errmsg("could not generate random salt")));
 
-	result = scram_build_secret(saltbuf, SCRAM_DEFAULT_SALT_LEN,
+	result = scram_build_secret(PG_SHA256, SCRAM_SHA_256_KEY_LEN,
+								saltbuf, SCRAM_DEFAULT_SALT_LEN,
 								SCRAM_DEFAULT_ITERATIONS, password,
 								&errstr);
 
@@ -505,16 +518,18 @@ scram_verify_plain_password(const char *username, const char *password,
 	char	   *salt;
 	int			saltlen;
 	int			iterations;
-	uint8		salted_password[SCRAM_KEY_LEN];
-	uint8		stored_key[SCRAM_KEY_LEN];
-	uint8		server_key[SCRAM_KEY_LEN];
-	uint8		computed_key[SCRAM_KEY_LEN];
+	int			key_length = 0;
+	pg_cryptohash_type hash_type;
+	uint8	   *salted_password = NULL;	/* size of key_length */
+	uint8	   *stored_key = NULL;		/* size of key_length */
+	uint8	   *server_key = NULL;		/* size of key_length */
+	uint8	   *computed_key = NULL;	/* size of key_length */
 	char	   *prep_password;
 	pg_saslprep_rc rc;
 	const char *errstr = NULL;
 
-	if (!parse_scram_secret(secret, &iterations, &encoded_salt,
-							stored_key, server_key))
+	if (!parse_scram_secret(secret, &iterations, &hash_type, &key_length,
+							&encoded_salt, &stored_key, &server_key))
 	{
 		/*
 		 * The password looked like a SCRAM secret, but could not be parsed.
@@ -524,6 +539,11 @@ scram_verify_plain_password(const char *username, const char *password,
 		return false;
 	}
 
+	/* allocated by parse_scram_secret() */
+	Assert(stored_key && server_key);
+	salted_password = (uint8 *) palloc(key_length * sizeof(uint8));
+	computed_key = (uint8 *) palloc(key_length * sizeof(uint8));
+
 	saltlen = pg_b64_dec_len(strlen(encoded_salt));
 	salt = palloc(saltlen);
 	saltlen = pg_b64_decode(encoded_salt, strlen(encoded_salt), salt,
@@ -541,9 +561,11 @@ scram_verify_plain_password(const char *username, const char *password,
 		password = prep_password;
 
 	/* Compute Server Key based on the user-supplied plaintext password */
-	if (scram_SaltedPassword(password, salt, saltlen, iterations,
+	if (scram_SaltedPassword(password, hash_type, key_length,
+							 salt, saltlen, iterations,
 							 salted_password, &errstr) < 0 ||
-		scram_ServerKey(salted_password, computed_key, &errstr) < 0)
+		scram_ServerKey(salted_password, hash_type, key_length,
+						computed_key, &errstr) < 0)
 	{
 		elog(ERROR, "could not compute server key: %s", errstr);
 	}
@@ -555,24 +577,25 @@ scram_verify_plain_password(const char *username, const char *password,
 	 * Compare the secret's Server Key with the one computed from the
 	 * user-supplied password.
 	 */
-	return memcmp(computed_key, server_key, SCRAM_KEY_LEN) == 0;
+	return memcmp(computed_key, server_key, key_length) == 0;
 }
 
 
 /*
  * Parse and validate format of given SCRAM secret.
  *
- * On success, the iteration count, salt, stored key, and server key are
- * extracted from the secret, and returned to the caller.  For 'stored_key'
- * and 'server_key', the caller must pass pre-allocated buffers of size
- * SCRAM_KEY_LEN.  Salt is returned as a base64-encoded, null-terminated
- * string.  The buffer for the salt is palloc'd by this function.
+ * On success, the iteration count, salt, key length, stored key, and
+ * server key are extracted from the secret, and returned to the caller.
+ * 'stored_key' and 'server_key' are palloc'd with a size of 'key_length'.
+ * Salt is returned as a base64-encoded, null-terminated string.  The buffer
+ * for the salt is palloc'd by this function.
  *
  * Returns true if the SCRAM secret has been parsed, and false otherwise.
  */
 bool
-parse_scram_secret(const char *secret, int *iterations, char **salt,
-				   uint8 *stored_key, uint8 *server_key)
+parse_scram_secret(const char *secret, int *iterations,
+				   pg_cryptohash_type *hash_type, int *key_length,
+				   char **salt, uint8 **stored_key, uint8 **server_key)
 {
 	char	   *v;
 	char	   *p;
@@ -606,6 +629,8 @@ parse_scram_secret(const char *secret, int *iterations, char **salt,
 	/* Parse the fields */
 	if (strcmp(scheme_str, "SCRAM-SHA-256") != 0)
 		goto invalid_secret;
+	*hash_type = PG_SHA256;
+	*key_length = SCRAM_SHA_256_KEY_LEN;
 
 	errno = 0;
 	*iterations = strtol(iterations_str, &p, 10);
@@ -631,17 +656,19 @@ parse_scram_secret(const char *secret, int *iterations, char **salt,
 	decoded_stored_buf = palloc(decoded_len);
 	decoded_len = pg_b64_decode(storedkey_str, strlen(storedkey_str),
 								decoded_stored_buf, decoded_len);
-	if (decoded_len != SCRAM_KEY_LEN)
+	if (decoded_len != *key_length)
 		goto invalid_secret;
-	memcpy(stored_key, decoded_stored_buf, SCRAM_KEY_LEN);
+	*stored_key = (uint8 *) palloc(*key_length * sizeof(uint8));
+	memcpy(*stored_key, decoded_stored_buf, *key_length * sizeof(uint8));
 
 	decoded_len = pg_b64_dec_len(strlen(serverkey_str));
 	decoded_server_buf = palloc(decoded_len);
 	decoded_len = pg_b64_decode(serverkey_str, strlen(serverkey_str),
 								decoded_server_buf, decoded_len);
-	if (decoded_len != SCRAM_KEY_LEN)
+	if (decoded_len != *key_length)
 		goto invalid_secret;
-	memcpy(server_key, decoded_server_buf, SCRAM_KEY_LEN);
+	*server_key = (uint8 *) palloc(*key_length * sizeof(uint8));
+	memcpy(*server_key, decoded_server_buf, *key_length * sizeof(uint8));
 
 	return true;
 
@@ -655,20 +682,25 @@ invalid_secret:
  *
  * In a normal authentication, these are extracted from the secret
  * stored in the server.  This function generates values that look
- * realistic, for when there is no stored secret.
+ * realistic, for when there is no stored secret, using SCRAM-SHA-256.
  *
- * Like in parse_scram_secret(), for 'stored_key' and 'server_key', the
- * caller must pass pre-allocated buffers of size SCRAM_KEY_LEN, and
- * the buffer for the salt is palloc'd by this function.
+ * 'stored_key' and 'server_key' are palloc'd by this function with
+ * an arbitrary key length guessed from the hash type, and the buffer
+ * for the salt is palloc'd by this function.
  */
 static void
-mock_scram_secret(const char *username, int *iterations, char **salt,
-				  uint8 *stored_key, uint8 *server_key)
+mock_scram_secret(const char *username, pg_cryptohash_type *hash_type,
+				  int *iterations, int *key_length, char **salt,
+				  uint8 **stored_key, uint8 **server_key)
 {
 	char	   *raw_salt;
 	char	   *encoded_salt;
 	int			encoded_len;
 
+	/* Enforce the use of SHA-256, which would be realistic enough */
+	*hash_type = PG_SHA256;
+	*key_length = SCRAM_SHA_256_KEY_LEN;
+
 	/*
 	 * Generate deterministic salt.
 	 *
@@ -677,7 +709,7 @@ mock_scram_secret(const char *username, int *iterations, char **salt,
 	 * as the salt generated for mock authentication uses the cluster's nonce
 	 * value.
 	 */
-	raw_salt = scram_mock_salt(username);
+	raw_salt = scram_mock_salt(username, *hash_type, *key_length);
 	if (raw_salt == NULL)
 		elog(ERROR, "could not encode salt");
 
@@ -695,8 +727,8 @@ mock_scram_secret(const char *username, int *iterations, char **salt,
 	*iterations = SCRAM_DEFAULT_ITERATIONS;
 
 	/* StoredKey and ServerKey are not used in a doomed authentication */
-	memset(stored_key, 0, SCRAM_KEY_LEN);
-	memset(server_key, 0, SCRAM_KEY_LEN);
+	*stored_key = (uint8 *) palloc0(*key_length * sizeof(uint8));
+	*server_key = (uint8 *) palloc0(*key_length * sizeof(uint8));
 }
 
 /*
@@ -1111,10 +1143,13 @@ verify_final_nonce(scram_state *state)
 static bool
 verify_client_proof(scram_state *state)
 {
-	uint8		ClientSignature[SCRAM_KEY_LEN];
-	uint8		ClientKey[SCRAM_KEY_LEN];
-	uint8		client_StoredKey[SCRAM_KEY_LEN];
-	pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
+	uint8	   *ClientSignature = (uint8 *) palloc(state->key_length *
+												   sizeof(uint8));
+	uint8	   *ClientKey = (uint8 *) palloc(state->key_length *
+											 sizeof(uint8));
+	uint8	   *client_StoredKey = (uint8 *) palloc(state->key_length *
+													sizeof(uint8));
+	pg_hmac_ctx *ctx = pg_hmac_create(state->hash_type);
 	int			i;
 	const char *errstr = NULL;
 
@@ -1123,7 +1158,7 @@ verify_client_proof(scram_state *state)
 	 * here even when processing the calculations as this could involve a mock
 	 * authentication.
 	 */
-	if (pg_hmac_init(ctx, state->StoredKey, SCRAM_KEY_LEN) < 0 ||
+	if (pg_hmac_init(ctx, state->stored_key, state->key_length) < 0 ||
 		pg_hmac_update(ctx,
 					   (uint8 *) state->client_first_message_bare,
 					   strlen(state->client_first_message_bare)) < 0 ||
@@ -1135,7 +1170,7 @@ verify_client_proof(scram_state *state)
 		pg_hmac_update(ctx,
 					   (uint8 *) state->client_final_message_without_proof,
 					   strlen(state->client_final_message_without_proof)) < 0 ||
-		pg_hmac_final(ctx, ClientSignature, sizeof(ClientSignature)) < 0)
+		pg_hmac_final(ctx, ClientSignature, state->key_length) < 0)
 	{
 		elog(ERROR, "could not calculate client signature: %s",
 			 pg_hmac_error(ctx));
@@ -1144,14 +1179,15 @@ verify_client_proof(scram_state *state)
 	pg_hmac_free(ctx);
 
 	/* Extract the ClientKey that the client calculated from the proof */
-	for (i = 0; i < SCRAM_KEY_LEN; i++)
-		ClientKey[i] = state->ClientProof[i] ^ ClientSignature[i];
+	for (i = 0; i < state->key_length; i++)
+		ClientKey[i] = state->client_proof[i] ^ ClientSignature[i];
 
 	/* Hash it one more time, and compare with StoredKey */
-	if (scram_H(ClientKey, SCRAM_KEY_LEN, client_StoredKey, &errstr) < 0)
+	if (scram_H(ClientKey, state->hash_type, state->key_length,
+				client_StoredKey, &errstr) < 0)
 		elog(ERROR, "could not hash stored key: %s", errstr);
 
-	if (memcmp(client_StoredKey, state->StoredKey, SCRAM_KEY_LEN) != 0)
+	if (memcmp(client_StoredKey, state->stored_key, state->key_length) != 0)
 		return false;
 
 	return true;
@@ -1349,12 +1385,13 @@ read_client_final_message(scram_state *state, const char *input)
 	client_proof_len = pg_b64_dec_len(strlen(value));
 	client_proof = palloc(client_proof_len);
 	if (pg_b64_decode(value, strlen(value), client_proof,
-					  client_proof_len) != SCRAM_KEY_LEN)
+					  client_proof_len) != state->key_length)
 		ereport(ERROR,
 				(errcode(ERRCODE_PROTOCOL_VIOLATION),
 				 errmsg("malformed SCRAM message"),
 				 errdetail("Malformed proof in client-final-message.")));
-	memcpy(state->ClientProof, client_proof, SCRAM_KEY_LEN);
+	state->client_proof = palloc(state->key_length);
+	memcpy(state->client_proof, client_proof, state->key_length);
 	pfree(client_proof);
 
 	if (*p != '\0')
@@ -1374,13 +1411,14 @@ read_client_final_message(scram_state *state, const char *input)
 static char *
 build_server_final_message(scram_state *state)
 {
-	uint8		ServerSignature[SCRAM_KEY_LEN];
+	uint8	   *ServerSignature = (uint8 *) palloc(state->key_length *
+												   sizeof(uint8));
 	char	   *server_signature_base64;
 	int			siglen;
-	pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
+	pg_hmac_ctx *ctx = pg_hmac_create(state->hash_type);
 
 	/* calculate ServerSignature */
-	if (pg_hmac_init(ctx, state->ServerKey, SCRAM_KEY_LEN) < 0 ||
+	if (pg_hmac_init(ctx, state->server_key, state->key_length) < 0 ||
 		pg_hmac_update(ctx,
 					   (uint8 *) state->client_first_message_bare,
 					   strlen(state->client_first_message_bare)) < 0 ||
@@ -1392,7 +1430,7 @@ build_server_final_message(scram_state *state)
 		pg_hmac_update(ctx,
 					   (uint8 *) state->client_final_message_without_proof,
 					   strlen(state->client_final_message_without_proof)) < 0 ||
-		pg_hmac_final(ctx, ServerSignature, sizeof(ServerSignature)) < 0)
+		pg_hmac_final(ctx, ServerSignature, state->key_length) < 0)
 	{
 		elog(ERROR, "could not calculate server signature: %s",
 			 pg_hmac_error(ctx));
@@ -1400,11 +1438,11 @@ build_server_final_message(scram_state *state)
 
 	pg_hmac_free(ctx);
 
-	siglen = pg_b64_enc_len(SCRAM_KEY_LEN);
+	siglen = pg_b64_enc_len(state->key_length);
 	/* don't forget the zero-terminator */
 	server_signature_base64 = palloc(siglen + 1);
 	siglen = pg_b64_encode((const char *) ServerSignature,
-						   SCRAM_KEY_LEN, server_signature_base64,
+						   state->key_length, server_signature_base64,
 						   siglen);
 	if (siglen < 0)
 		elog(ERROR, "could not encode server signature");
@@ -1431,12 +1469,15 @@ build_server_final_message(scram_state *state)
  * pointer to a static buffer of size SCRAM_DEFAULT_SALT_LEN, or NULL.
  */
 static char *
-scram_mock_salt(const char *username)
+scram_mock_salt(const char *username, pg_cryptohash_type hash_type,
+				int key_length)
 {
 	pg_cryptohash_ctx *ctx;
-	static uint8 sha_digest[PG_SHA256_DIGEST_LENGTH];
+	uint8	   *sha_digest = (uint8 *) palloc(key_length * sizeof(uint8));
 	char	   *mock_auth_nonce = GetMockAuthenticationNonce();
 
+	Assert(hash_type == PG_SHA256);
+
 	/*
 	 * Generate salt using a SHA256 hash of the username and the cluster's
 	 * mock authentication nonce.  (This works as long as the salt length is
@@ -1446,11 +1487,11 @@ scram_mock_salt(const char *username)
 	StaticAssertStmt(PG_SHA256_DIGEST_LENGTH >= SCRAM_DEFAULT_SALT_LEN,
 					 "salt length greater than SHA256 digest length");
 
-	ctx = pg_cryptohash_create(PG_SHA256);
+	ctx = pg_cryptohash_create(hash_type);
 	if (pg_cryptohash_init(ctx) < 0 ||
 		pg_cryptohash_update(ctx, (uint8 *) username, strlen(username)) < 0 ||
 		pg_cryptohash_update(ctx, (uint8 *) mock_auth_nonce, MOCK_AUTH_NONCE_LEN) < 0 ||
-		pg_cryptohash_final(ctx, sha_digest, sizeof(sha_digest)) < 0)
+		pg_cryptohash_final(ctx, sha_digest, key_length) < 0)
 	{
 		pg_cryptohash_free(ctx);
 		return NULL;
diff --git a/src/backend/libpq/crypt.c b/src/backend/libpq/crypt.c
index 1ff8b0507d..4e2b7c99fe 100644
--- a/src/backend/libpq/crypt.c
+++ b/src/backend/libpq/crypt.c
@@ -90,15 +90,17 @@ get_password_type(const char *shadow_pass)
 {
 	char	   *encoded_salt;
 	int			iterations;
-	uint8		stored_key[SCRAM_KEY_LEN];
-	uint8		server_key[SCRAM_KEY_LEN];
+	int			key_length = 0;
+	pg_cryptohash_type hash_type;
+	uint8	   *stored_key;		/* size of key_length */
+	uint8	   *server_key;		/* size of key_length */
 
 	if (strncmp(shadow_pass, "md5", 3) == 0 &&
 		strlen(shadow_pass) == MD5_PASSWD_LEN &&
 		strspn(shadow_pass + 3, MD5_PASSWD_CHARSET) == MD5_PASSWD_LEN - 3)
 		return PASSWORD_TYPE_MD5;
-	if (parse_scram_secret(shadow_pass, &iterations, &encoded_salt,
-						   stored_key, server_key))
+	if (parse_scram_secret(shadow_pass, &iterations, &hash_type, &key_length,
+						   &encoded_salt, &stored_key, &server_key))
 		return PASSWORD_TYPE_SCRAM_SHA_256;
 	return PASSWORD_TYPE_PLAINTEXT;
 }
diff --git a/src/common/scram-common.c b/src/common/scram-common.c
index 1268625929..d41be27ca4 100644
--- a/src/common/scram-common.c
+++ b/src/common/scram-common.c
@@ -33,6 +33,7 @@
  */
 int
 scram_SaltedPassword(const char *password,
+					 pg_cryptohash_type hash_type, int key_length,
 					 const char *salt, int saltlen, int iterations,
 					 uint8 *result, const char **errstr)
 {
@@ -40,9 +41,9 @@ scram_SaltedPassword(const char *password,
 	uint32		one = pg_hton32(1);
 	int			i,
 				j;
-	uint8		Ui[SCRAM_KEY_LEN];
-	uint8		Ui_prev[SCRAM_KEY_LEN];
-	pg_hmac_ctx *hmac_ctx = pg_hmac_create(PG_SHA256);
+	uint8	   *Ui;			/* size of key_length */
+	uint8	   *Ui_prev;	/* size of key_length */
+	pg_hmac_ctx *hmac_ctx = pg_hmac_create(hash_type);
 
 	if (hmac_ctx == NULL)
 	{
@@ -50,6 +51,19 @@ scram_SaltedPassword(const char *password,
 		return -1;
 	}
 
+#ifdef FRONTEND
+	Ui = (uint8 *) malloc(key_length * sizeof(uint8));
+	Ui_prev = (uint8 *) malloc(key_length * sizeof(uint8));
+	if (Ui == NULL || Ui_prev == NULL)
+	{
+		*errstr = _("out of memory");
+		goto error;
+	}
+#else
+	Ui = (uint8 *) palloc(key_length * sizeof(uint8));
+	Ui_prev = (uint8 *) palloc(key_length * sizeof(uint8));
+#endif
+
 	/*
 	 * Iterate hash calculation of HMAC entry using given salt.  This is
 	 * essentially PBKDF2 (see RFC2898) with HMAC() as the pseudorandom
@@ -60,48 +74,70 @@ scram_SaltedPassword(const char *password,
 	if (pg_hmac_init(hmac_ctx, (uint8 *) password, password_len) < 0 ||
 		pg_hmac_update(hmac_ctx, (uint8 *) salt, saltlen) < 0 ||
 		pg_hmac_update(hmac_ctx, (uint8 *) &one, sizeof(uint32)) < 0 ||
-		pg_hmac_final(hmac_ctx, Ui_prev, sizeof(Ui_prev)) < 0)
+		pg_hmac_final(hmac_ctx, Ui_prev, key_length) < 0)
 	{
 		*errstr = pg_hmac_error(hmac_ctx);
-		pg_hmac_free(hmac_ctx);
-		return -1;
+		goto error;
 	}
 
-	memcpy(result, Ui_prev, SCRAM_KEY_LEN);
+	memcpy(result, Ui_prev, key_length);
 
 	/* Subsequent iterations */
 	for (i = 2; i <= iterations; i++)
 	{
 		if (pg_hmac_init(hmac_ctx, (uint8 *) password, password_len) < 0 ||
-			pg_hmac_update(hmac_ctx, (uint8 *) Ui_prev, SCRAM_KEY_LEN) < 0 ||
-			pg_hmac_final(hmac_ctx, Ui, sizeof(Ui)) < 0)
+			pg_hmac_update(hmac_ctx, (uint8 *) Ui_prev, key_length) < 0 ||
+			pg_hmac_final(hmac_ctx, Ui, key_length) < 0)
 		{
 			*errstr = pg_hmac_error(hmac_ctx);
-			pg_hmac_free(hmac_ctx);
-			return -1;
+			goto error;
 		}
 
-		for (j = 0; j < SCRAM_KEY_LEN; j++)
+		for (j = 0; j < key_length; j++)
 			result[j] ^= Ui[j];
-		memcpy(Ui_prev, Ui, SCRAM_KEY_LEN);
+		memcpy(Ui_prev, Ui, key_length);
 	}
 
+#ifdef FRONTEND
+	free(Ui);
+	free(Ui_prev);
+#else
+	pfree(Ui);
+	pfree(Ui_prev);
+#endif
 	pg_hmac_free(hmac_ctx);
 	return 0;
+
+error:
+#ifdef FRONTEND
+	if (Ui)
+		free(Ui);
+	if (Ui_prev)
+		free(Ui_prev);
+#else
+	if (Ui)
+		pfree(Ui);
+	if (Ui_prev)
+		pfree(Ui_prev);
+#endif
+
+	pg_hmac_free(hmac_ctx);
+	return -1;
 }
 
 
 /*
- * Calculate SHA-256 hash for a NULL-terminated string. (The NULL terminator is
+ * Calculate hash for a NULL-terminated string. (The NULL terminator is
  * not included in the hash).  Returns 0 on success, -1 on failure with *errstr
  * pointing to a message about the error details.
  */
 int
-scram_H(const uint8 *input, int len, uint8 *result, const char **errstr)
+scram_H(const uint8 *input, pg_cryptohash_type hash_type, int key_length,
+		uint8 *result, const char **errstr)
 {
 	pg_cryptohash_ctx *ctx;
 
-	ctx = pg_cryptohash_create(PG_SHA256);
+	ctx = pg_cryptohash_create(hash_type);
 	if (ctx == NULL)
 	{
 		*errstr = pg_cryptohash_error(NULL);	/* returns OOM */
@@ -109,8 +145,8 @@ scram_H(const uint8 *input, int len, uint8 *result, const char **errstr)
 	}
 
 	if (pg_cryptohash_init(ctx) < 0 ||
-		pg_cryptohash_update(ctx, input, len) < 0 ||
-		pg_cryptohash_final(ctx, result, SCRAM_KEY_LEN) < 0)
+		pg_cryptohash_update(ctx, input, key_length) < 0 ||
+		pg_cryptohash_final(ctx, result, key_length) < 0)
 	{
 		*errstr = pg_cryptohash_error(ctx);
 		pg_cryptohash_free(ctx);
@@ -126,10 +162,11 @@ scram_H(const uint8 *input, int len, uint8 *result, const char **errstr)
  * pointing to a message about the error details.
  */
 int
-scram_ClientKey(const uint8 *salted_password, uint8 *result,
-				const char **errstr)
+scram_ClientKey(const uint8 *salted_password,
+				pg_cryptohash_type hash_type, int key_length,
+				uint8 *result, const char **errstr)
 {
-	pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
+	pg_hmac_ctx *ctx = pg_hmac_create(hash_type);
 
 	if (ctx == NULL)
 	{
@@ -137,9 +174,9 @@ scram_ClientKey(const uint8 *salted_password, uint8 *result,
 		return -1;
 	}
 
-	if (pg_hmac_init(ctx, salted_password, SCRAM_KEY_LEN) < 0 ||
+	if (pg_hmac_init(ctx, salted_password, key_length) < 0 ||
 		pg_hmac_update(ctx, (uint8 *) "Client Key", strlen("Client Key")) < 0 ||
-		pg_hmac_final(ctx, result, SCRAM_KEY_LEN) < 0)
+		pg_hmac_final(ctx, result, key_length) < 0)
 	{
 		*errstr = pg_hmac_error(ctx);
 		pg_hmac_free(ctx);
@@ -155,10 +192,11 @@ scram_ClientKey(const uint8 *salted_password, uint8 *result,
  * pointing to a message about the error details.
  */
 int
-scram_ServerKey(const uint8 *salted_password, uint8 *result,
-				const char **errstr)
+scram_ServerKey(const uint8 *salted_password,
+				pg_cryptohash_type hash_type, int key_length,
+				uint8 *result, const char **errstr)
 {
-	pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
+	pg_hmac_ctx *ctx = pg_hmac_create(hash_type);
 
 	if (ctx == NULL)
 	{
@@ -166,9 +204,9 @@ scram_ServerKey(const uint8 *salted_password, uint8 *result,
 		return -1;
 	}
 
-	if (pg_hmac_init(ctx, salted_password, SCRAM_KEY_LEN) < 0 ||
+	if (pg_hmac_init(ctx, salted_password, key_length) < 0 ||
 		pg_hmac_update(ctx, (uint8 *) "Server Key", strlen("Server Key")) < 0 ||
-		pg_hmac_final(ctx, result, SCRAM_KEY_LEN) < 0)
+		pg_hmac_final(ctx, result, key_length) < 0)
 	{
 		*errstr = pg_hmac_error(ctx);
 		pg_hmac_free(ctx);
@@ -192,13 +230,14 @@ scram_ServerKey(const uint8 *salted_password, uint8 *result,
  * error details.
  */
 char *
-scram_build_secret(const char *salt, int saltlen, int iterations,
+scram_build_secret(pg_cryptohash_type hash_type, int key_length,
+				   const char *salt, int saltlen, int iterations,
 				   const char *password, const char **errstr)
 {
-	uint8		salted_password[SCRAM_KEY_LEN];
-	uint8		stored_key[SCRAM_KEY_LEN];
-	uint8		server_key[SCRAM_KEY_LEN];
-	char	   *result;
+	uint8	   *salted_password = NULL;
+	uint8	   *stored_key = NULL;
+	uint8	   *server_key = NULL;
+	char	   *result = NULL;
 	char	   *p;
 	int			maxlen;
 	int			encoded_salt_len;
@@ -206,19 +245,42 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
 	int			encoded_server_len;
 	int			encoded_result;
 
+	Assert(hash_type == PG_SHA256);
+
 	if (iterations <= 0)
 		iterations = SCRAM_DEFAULT_ITERATIONS;
 
+#ifdef FRONTEND
+	salted_password = (uint8 *) malloc(key_length * sizeof(uint8));
+	stored_key = (uint8 *) malloc(key_length * sizeof(uint8));
+	server_key = (uint8 *) malloc(key_length * sizeof(uint8));
+	if (salted_password == NULL ||
+		stored_key == NULL ||
+		server_key == NULL)
+	{
+		*errstr = _("out of memory");
+		goto error;
+	}
+#else
+	salted_password = (uint8 *) palloc(key_length * sizeof(uint8));
+	stored_key = (uint8 *) palloc(key_length * sizeof(uint8));
+	server_key = (uint8 *) palloc(key_length * sizeof(uint8));
+#endif
+
 	/* Calculate StoredKey and ServerKey */
-	if (scram_SaltedPassword(password, salt, saltlen, iterations,
+	if (scram_SaltedPassword(password, hash_type, key_length,
+							 salt, saltlen, iterations,
 							 salted_password, errstr) < 0 ||
-		scram_ClientKey(salted_password, stored_key, errstr) < 0 ||
-		scram_H(stored_key, SCRAM_KEY_LEN, stored_key, errstr) < 0 ||
-		scram_ServerKey(salted_password, server_key, errstr) < 0)
+		scram_ClientKey(salted_password, hash_type, key_length,
+						stored_key, errstr) < 0 ||
+		scram_H(stored_key, hash_type, key_length,
+				stored_key, errstr) < 0 ||
+		scram_ServerKey(salted_password, hash_type, key_length,
+						server_key, errstr) < 0)
 	{
 		/* errstr is filled already here */
 #ifdef FRONTEND
-		return NULL;
+		goto error;
 #else
 		elog(ERROR, "could not calculate stored key and server key: %s",
 			 *errstr);
@@ -231,8 +293,8 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
 	 *----------
 	 */
 	encoded_salt_len = pg_b64_enc_len(saltlen);
-	encoded_stored_len = pg_b64_enc_len(SCRAM_KEY_LEN);
-	encoded_server_len = pg_b64_enc_len(SCRAM_KEY_LEN);
+	encoded_stored_len = pg_b64_enc_len(key_length);
+	encoded_server_len = pg_b64_enc_len(key_length);
 
 	maxlen = strlen("SCRAM-SHA-256") + 1
 		+ 10 + 1				/* iteration count */
@@ -245,7 +307,7 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
 	if (!result)
 	{
 		*errstr = _("out of memory");
-		return NULL;
+		goto error;
 	}
 #else
 	result = palloc(maxlen);
@@ -258,45 +320,30 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
 	if (encoded_result < 0)
 	{
 		*errstr = _("could not encode salt");
-#ifdef FRONTEND
-		free(result);
-		return NULL;
-#else
-		elog(ERROR, "%s", *errstr);
-#endif
+		goto error;
 	}
 	p += encoded_result;
 	*(p++) = '$';
 
 	/* stored key */
-	encoded_result = pg_b64_encode((char *) stored_key, SCRAM_KEY_LEN, p,
+	encoded_result = pg_b64_encode((char *) stored_key, key_length, p,
 								   encoded_stored_len);
 	if (encoded_result < 0)
 	{
 		*errstr = _("could not encode stored key");
-#ifdef FRONTEND
-		free(result);
-		return NULL;
-#else
-		elog(ERROR, "%s", *errstr);
-#endif
+		goto error;
 	}
 
 	p += encoded_result;
 	*(p++) = ':';
 
 	/* server key */
-	encoded_result = pg_b64_encode((char *) server_key, SCRAM_KEY_LEN, p,
+	encoded_result = pg_b64_encode((char *) server_key, key_length, p,
 								   encoded_server_len);
 	if (encoded_result < 0)
 	{
 		*errstr = _("could not encode server key");
-#ifdef FRONTEND
-		free(result);
-		return NULL;
-#else
-		elog(ERROR, "%s", *errstr);
-#endif
+		goto error;
 	}
 
 	p += encoded_result;
@@ -304,5 +351,25 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
 
 	Assert(p - result <= maxlen);
 
+#ifdef FRONTEND
+	free(salted_password);
+	free(stored_key);
+	free(server_key);
+#endif
 	return result;
+
+error:
+#ifdef FRONTEND
+	if (salted_password)
+		free(salted_password);
+	if (stored_key)
+		free(stored_key);
+	if (server_key)
+		free(server_key);
+	if (result)
+		free(result);
+#else
+	elog(ERROR, "%s", *errstr);
+#endif
+	return NULL;
 }
diff --git a/src/interfaces/libpq/fe-auth-scram.c b/src/interfaces/libpq/fe-auth-scram.c
index c500bea9e7..3b20062484 100644
--- a/src/interfaces/libpq/fe-auth-scram.c
+++ b/src/interfaces/libpq/fe-auth-scram.c
@@ -58,8 +58,12 @@ typedef struct
 	char	   *password;
 	char	   *sasl_mechanism;
 
+	/* State data depending on the hash type */
+	pg_cryptohash_type	hash_type;
+	int			key_length;
+
 	/* We construct these */
-	uint8		SaltedPassword[SCRAM_KEY_LEN];
+	uint8	   *salted_password;	/* size of key_length */
 	char	   *client_nonce;
 	char	   *client_first_message_bare;
 	char	   *client_final_message_without_proof;
@@ -73,7 +77,7 @@ typedef struct
 
 	/* These come from the server-final message */
 	char	   *server_final_message;
-	char		ServerSignature[SCRAM_KEY_LEN];
+	char	   *server_signature;	/* size of key_length */
 } fe_scram_state;
 
 static bool read_server_first_message(fe_scram_state *state, char *input);
@@ -106,35 +110,47 @@ scram_init(PGconn *conn,
 	memset(state, 0, sizeof(fe_scram_state));
 	state->conn = conn;
 	state->state = FE_SCRAM_INIT;
-	state->sasl_mechanism = strdup(sasl_mechanism);
+	state->key_length = SCRAM_SHA_256_KEY_LEN;
+	state->hash_type = PG_SHA256;
 
+	state->sasl_mechanism = strdup(sasl_mechanism);
 	if (!state->sasl_mechanism)
-	{
-		free(state);
-		return NULL;
-	}
+		goto oom_error;
+
+	state->salted_password = (uint8 *) malloc(state->key_length * sizeof(uint8));
+	if (state->salted_password == NULL)
+		goto oom_error;
+	state->server_signature = (char *) malloc(state->key_length * sizeof(char));
+	if (state->server_signature == NULL)
+		goto oom_error;
 
 	/* Normalize the password with SASLprep, if possible */
 	rc = pg_saslprep(password, &prep_password);
 	if (rc == SASLPREP_OOM)
-	{
-		free(state->sasl_mechanism);
-		free(state);
-		return NULL;
-	}
+		goto oom_error;
+
 	if (rc != SASLPREP_SUCCESS)
 	{
 		prep_password = strdup(password);
 		if (!prep_password)
-		{
-			free(state->sasl_mechanism);
-			free(state);
-			return NULL;
-		}
+			goto oom_error;
 	}
 	state->password = prep_password;
 
 	return state;
+
+oom_error:
+	if (state->salted_password)
+		free(state->salted_password);
+	if (state->server_signature)
+		free(state->server_signature);
+	if (state->password)
+		free(state->password);
+	if (state->sasl_mechanism)
+		free(state->sasl_mechanism);
+	if (state)
+		free(state);
+	return NULL;
 }
 
 /*
@@ -178,6 +194,7 @@ scram_free(void *opaq)
 	free(state->sasl_mechanism);
 
 	/* client messages */
+	free(state->salted_password);
 	free(state->client_nonce);
 	free(state->client_first_message_bare);
 	free(state->client_final_message_without_proof);
@@ -189,6 +206,7 @@ scram_free(void *opaq)
 
 	/* final message from server */
 	free(state->server_final_message);
+	free(state->server_signature);
 
 	free(state);
 }
@@ -450,13 +468,17 @@ build_client_final_message(fe_scram_state *state)
 {
 	PQExpBufferData buf;
 	PGconn	   *conn = state->conn;
-	uint8		client_proof[SCRAM_KEY_LEN];
+	uint8	   *client_proof;	/* size of key_length */
 	char	   *result;
 	int			encoded_len;
 	const char *errstr = NULL;
 
 	initPQExpBuffer(&buf);
 
+	client_proof = (uint8 *) malloc(state->key_length * sizeof(uint8));
+	if (client_proof == NULL)
+		goto oom_error;
+
 	/*
 	 * Construct client-final-message-without-proof.  We need to remember it
 	 * for verifying the server proof in the final step of authentication.
@@ -565,11 +587,11 @@ build_client_final_message(fe_scram_state *state)
 	}
 
 	appendPQExpBufferStr(&buf, ",p=");
-	encoded_len = pg_b64_enc_len(SCRAM_KEY_LEN);
+	encoded_len = pg_b64_enc_len(state->key_length);
 	if (!enlargePQExpBuffer(&buf, encoded_len))
 		goto oom_error;
 	encoded_len = pg_b64_encode((char *) client_proof,
-								SCRAM_KEY_LEN,
+								state->key_length,
 								buf.data + buf.len,
 								encoded_len);
 	if (encoded_len < 0)
@@ -590,6 +612,7 @@ build_client_final_message(fe_scram_state *state)
 
 oom_error:
 	termPQExpBuffer(&buf);
+	free(client_proof);
 	libpq_append_conn_error(conn, "out of memory");
 	return NULL;
 }
@@ -738,13 +761,14 @@ read_server_final_message(fe_scram_state *state, char *input)
 										 strlen(encoded_server_signature),
 										 decoded_server_signature,
 										 server_signature_len);
-	if (server_signature_len != SCRAM_KEY_LEN)
+	if (server_signature_len != state->key_length)
 	{
 		free(decoded_server_signature);
 		libpq_append_conn_error(conn, "malformed SCRAM message (invalid server signature)");
 		return false;
 	}
-	memcpy(state->ServerSignature, decoded_server_signature, SCRAM_KEY_LEN);
+	memcpy(state->server_signature, decoded_server_signature,
+		   state->key_length);
 	free(decoded_server_signature);
 
 	return true;
@@ -760,35 +784,48 @@ calculate_client_proof(fe_scram_state *state,
 					   const char *client_final_message_without_proof,
 					   uint8 *result, const char **errstr)
 {
-	uint8		StoredKey[SCRAM_KEY_LEN];
-	uint8		ClientKey[SCRAM_KEY_LEN];
-	uint8		ClientSignature[SCRAM_KEY_LEN];
+	uint8	   *StoredKey = NULL;
+	uint8	   *ClientKey = NULL;
+	uint8	   *ClientSignature = NULL;
 	int			i;
-	pg_hmac_ctx *ctx;
+	pg_hmac_ctx *ctx = NULL;
 
-	ctx = pg_hmac_create(PG_SHA256);
+	StoredKey = malloc(state->key_length * sizeof(uint8));
+	ClientKey = malloc(state->key_length * sizeof(uint8));
+	ClientSignature = malloc(state->key_length * sizeof(uint8));
+	if (StoredKey == NULL ||
+		ClientKey == NULL ||
+		ClientSignature == NULL)
+	{
+		*errstr = libpq_gettext("out of memory");
+		goto error;
+	}
+
+	ctx = pg_hmac_create(state->hash_type);
 	if (ctx == NULL)
 	{
 		*errstr = pg_hmac_error(NULL);	/* returns OOM */
-		return false;
+		goto error;
 	}
 
 	/*
 	 * Calculate SaltedPassword, and store it in 'state' so that we can reuse
 	 * it later in verify_server_signature.
 	 */
-	if (scram_SaltedPassword(state->password, state->salt, state->saltlen,
-							 state->iterations, state->SaltedPassword,
+	if (scram_SaltedPassword(state->password, state->hash_type,
+							 state->key_length, state->salt, state->saltlen,
+							 state->iterations, state->salted_password,
 							 errstr) < 0 ||
-		scram_ClientKey(state->SaltedPassword, ClientKey, errstr) < 0 ||
-		scram_H(ClientKey, SCRAM_KEY_LEN, StoredKey, errstr) < 0)
+		scram_ClientKey(state->salted_password, state->hash_type,
+						state->key_length, ClientKey, errstr) < 0 ||
+		scram_H(ClientKey, state->hash_type, state->key_length,
+				StoredKey, errstr) < 0)
 	{
 		/* errstr is already filled here */
-		pg_hmac_free(ctx);
-		return false;
+		goto error;
 	}
 
-	if (pg_hmac_init(ctx, StoredKey, SCRAM_KEY_LEN) < 0 ||
+	if (pg_hmac_init(ctx, StoredKey, state->key_length) < 0 ||
 		pg_hmac_update(ctx,
 					   (uint8 *) state->client_first_message_bare,
 					   strlen(state->client_first_message_bare)) < 0 ||
@@ -800,18 +837,30 @@ calculate_client_proof(fe_scram_state *state,
 		pg_hmac_update(ctx,
 					   (uint8 *) client_final_message_without_proof,
 					   strlen(client_final_message_without_proof)) < 0 ||
-		pg_hmac_final(ctx, ClientSignature, sizeof(ClientSignature)) < 0)
+		pg_hmac_final(ctx, ClientSignature, state->key_length) < 0)
 	{
 		*errstr = pg_hmac_error(ctx);
-		pg_hmac_free(ctx);
-		return false;
+		goto error;
 	}
 
-	for (i = 0; i < SCRAM_KEY_LEN; i++)
+	for (i = 0; i < state->key_length; i++)
 		result[i] = ClientKey[i] ^ ClientSignature[i];
 
+	free(StoredKey);
+	free(ClientKey);
+	free(ClientSignature);
 	pg_hmac_free(ctx);
 	return true;
+
+error:
+	if (StoredKey)
+		free(StoredKey);
+	if (ClientKey)
+		free(ClientKey);
+	if (ClientSignature)
+		free(ClientSignature);
+	pg_hmac_free(ctx);
+	return false;
 }
 
 /*
@@ -825,26 +874,35 @@ static bool
 verify_server_signature(fe_scram_state *state, bool *match,
 						const char **errstr)
 {
-	uint8		expected_ServerSignature[SCRAM_KEY_LEN];
-	uint8		ServerKey[SCRAM_KEY_LEN];
-	pg_hmac_ctx *ctx;
+	uint8	   *expected_ServerSignature = NULL;
+	uint8	   *ServerKey = NULL;
+	pg_hmac_ctx *ctx = NULL;
 
-	ctx = pg_hmac_create(PG_SHA256);
+	ServerKey = (uint8 *) malloc(state->key_length * sizeof(uint8));
+	expected_ServerSignature = (uint8 *) malloc(state->key_length * sizeof(uint8));
+
+	if (ServerKey == NULL || expected_ServerSignature == NULL)
+	{
+		*errstr = libpq_gettext("out of memory");
+		goto error;
+	}
+
+	ctx = pg_hmac_create(state->hash_type);
 	if (ctx == NULL)
 	{
 		*errstr = pg_hmac_error(NULL);	/* returns OOM */
-		return false;
+		goto error;
 	}
 
-	if (scram_ServerKey(state->SaltedPassword, ServerKey, errstr) < 0)
+	if (scram_ServerKey(state->salted_password, state->hash_type,
+						state->key_length, ServerKey, errstr) < 0)
 	{
 		/* errstr is filled already */
-		pg_hmac_free(ctx);
-		return false;
+		goto error;
 	}
 
 	/* calculate ServerSignature */
-	if (pg_hmac_init(ctx, ServerKey, SCRAM_KEY_LEN) < 0 ||
+	if (pg_hmac_init(ctx, ServerKey, state->key_length) < 0 ||
 		pg_hmac_update(ctx,
 					   (uint8 *) state->client_first_message_bare,
 					   strlen(state->client_first_message_bare)) < 0 ||
@@ -857,22 +915,32 @@ verify_server_signature(fe_scram_state *state, bool *match,
 					   (uint8 *) state->client_final_message_without_proof,
 					   strlen(state->client_final_message_without_proof)) < 0 ||
 		pg_hmac_final(ctx, expected_ServerSignature,
-					  sizeof(expected_ServerSignature)) < 0)
+					  state->key_length) < 0)
 	{
 		*errstr = pg_hmac_error(ctx);
-		pg_hmac_free(ctx);
-		return false;
+		goto error;
 	}
 
 	pg_hmac_free(ctx);
 
 	/* signature processed, so now check after it */
-	if (memcmp(expected_ServerSignature, state->ServerSignature, SCRAM_KEY_LEN) != 0)
+	if (memcmp(expected_ServerSignature, state->server_signature,
+			   state->key_length) != 0)
 		*match = false;
 	else
 		*match = true;
 
+	free(ServerKey);
+	free(expected_ServerSignature);
 	return true;
+
+error:
+	if (ServerKey)
+		free(ServerKey);
+	if (expected_ServerSignature)
+		free(expected_ServerSignature);
+	pg_hmac_free(ctx);
+	return false;
 }
 
 /*
@@ -912,7 +980,8 @@ pg_fe_scram_build_secret(const char *password, const char **errstr)
 		return NULL;
 	}
 
-	result = scram_build_secret(saltbuf, SCRAM_DEFAULT_SALT_LEN,
+	result = scram_build_secret(PG_SHA256, SCRAM_SHA_256_KEY_LEN, saltbuf,
+								SCRAM_DEFAULT_SALT_LEN,
 								SCRAM_DEFAULT_ITERATIONS, password,
 								errstr);
 
-- 
2.38.1

Attachment: signature.asc
Description: PGP signature

Reply via email to