On Mon, Jul 01, 2019 at 11:11:43PM +0200, Daniel Gustafsson wrote:
> I very much agree that functions operating on a buffer like this should have
> the size of the buffer in order to safeguard against overflow, so +1 on the
> general concept.

Thanks for the review!

> A few small comments:
> 
> In src/common/scram-common.c there are a few instances like this.  Shouldn’t 
> we
> also free the result buffer in these cases?
> 
> +#ifdef FRONTEND
> +               return NULL;
> +#else

Fixed.

> In the below passage, we leave the input buffer with a non-complete
> encoded string.  Should we memset the buffer to zero to avoid the
> risk that code which fails to check the return value believes it has
> an encoded string?

Hmm.  Good point.  I have not thought of that, and your suggestion
makes sense.

Another question is if we'd want to actually use explicit_bzero()
here, but that could be a discussion on this other thread, except if
the patch discussed there is merged first:
https://www.postgresql.org/message-id/42d26bde-5d5b-c90d-87ae-6cab875f7...@2ndquadrant.com

Attached is an updated patch.
--
Michael
diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c
index 6b60abe1dd..3a31afc7b7 100644
--- a/src/backend/libpq/auth-scram.c
+++ b/src/backend/libpq/auth-scram.c
@@ -510,9 +510,11 @@ scram_verify_plain_password(const char *username, const char *password,
 		return false;
 	}
 
-	salt = palloc(pg_b64_dec_len(strlen(encoded_salt)));
-	saltlen = pg_b64_decode(encoded_salt, strlen(encoded_salt), salt);
-	if (saltlen == -1)
+	saltlen = pg_b64_dec_len(strlen(encoded_salt));
+	salt = palloc(saltlen);
+	saltlen = pg_b64_decode(encoded_salt, strlen(encoded_salt), salt,
+							saltlen);
+	if (saltlen < 0)
 	{
 		ereport(LOG,
 				(errmsg("invalid SCRAM verifier for user \"%s\"", username)));
@@ -596,9 +598,10 @@ parse_scram_verifier(const char *verifier, int *iterations, char **salt,
 	 * Verify that the salt is in Base64-encoded format, by decoding it,
 	 * although we return the encoded version to the caller.
 	 */
-	decoded_salt_buf = palloc(pg_b64_dec_len(strlen(salt_str)));
+	decoded_len = pg_b64_dec_len(strlen(salt_str));
+	decoded_salt_buf = palloc(decoded_len);
 	decoded_len = pg_b64_decode(salt_str, strlen(salt_str),
-								decoded_salt_buf);
+								decoded_salt_buf, decoded_len);
 	if (decoded_len < 0)
 		goto invalid_verifier;
 	*salt = pstrdup(salt_str);
@@ -606,16 +609,18 @@ parse_scram_verifier(const char *verifier, int *iterations, char **salt,
 	/*
 	 * Decode StoredKey and ServerKey.
 	 */
-	decoded_stored_buf = palloc(pg_b64_dec_len(strlen(storedkey_str)));
+	decoded_len = pg_b64_dec_len(strlen(storedkey_str));
+	decoded_stored_buf = palloc(decoded_len);
 	decoded_len = pg_b64_decode(storedkey_str, strlen(storedkey_str),
-								decoded_stored_buf);
+								decoded_stored_buf, decoded_len);
 	if (decoded_len != SCRAM_KEY_LEN)
 		goto invalid_verifier;
 	memcpy(stored_key, decoded_stored_buf, SCRAM_KEY_LEN);
 
-	decoded_server_buf = palloc(pg_b64_dec_len(strlen(serverkey_str)));
+	decoded_len = pg_b64_dec_len(strlen(serverkey_str));
+	decoded_server_buf = palloc(decoded_len);
 	decoded_len = pg_b64_decode(serverkey_str, strlen(serverkey_str),
-								decoded_server_buf);
+								decoded_server_buf, decoded_len);
 	if (decoded_len != SCRAM_KEY_LEN)
 		goto invalid_verifier;
 	memcpy(server_key, decoded_server_buf, SCRAM_KEY_LEN);
@@ -649,8 +654,18 @@ mock_scram_verifier(const char *username, int *iterations, char **salt,
 	/* Generate deterministic salt */
 	raw_salt = scram_mock_salt(username);
 
-	encoded_salt = (char *) palloc(pg_b64_enc_len(SCRAM_DEFAULT_SALT_LEN) + 1);
-	encoded_len = pg_b64_encode(raw_salt, SCRAM_DEFAULT_SALT_LEN, encoded_salt);
+	encoded_len = pg_b64_enc_len(SCRAM_DEFAULT_SALT_LEN);
+	/* don't forget the zero-terminator */
+	encoded_salt = (char *) palloc(encoded_len + 1);
+	encoded_len = pg_b64_encode(raw_salt, SCRAM_DEFAULT_SALT_LEN, encoded_salt,
+								encoded_len);
+
+	/*
+	 * Note that we cannot reveal any information to an attacker here so the
+	 * error message needs to remain generic.
+	 */
+	if (encoded_len < 0)
+		elog(ERROR, "could not encode salt");
 	encoded_salt[encoded_len] = '\0';
 
 	*salt = encoded_salt;
@@ -1144,8 +1159,15 @@ build_server_first_message(scram_state *state)
 				(errcode(ERRCODE_INTERNAL_ERROR),
 				 errmsg("could not generate random nonce")));
 
-	state->server_nonce = palloc(pg_b64_enc_len(SCRAM_RAW_NONCE_LEN) + 1);
-	encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN, state->server_nonce);
+	encoded_len = pg_b64_enc_len(SCRAM_RAW_NONCE_LEN);
+	/* don't forget the zero-terminator */
+	state->server_nonce = palloc(encoded_len + 1);
+	encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN,
+								state->server_nonce, encoded_len);
+	if (encoded_len < 0)
+		ereport(ERROR,
+				(errcode(ERRCODE_INTERNAL_ERROR),
+				 errmsg("could not encode random nonce")));
 	state->server_nonce[encoded_len] = '\0';
 
 	state->server_first_message =
@@ -1170,6 +1192,7 @@ read_client_final_message(scram_state *state, const char *input)
 			   *proof;
 	char	   *p;
 	char	   *client_proof;
+	int			client_proof_len;
 
 	begin = p = pstrdup(input);
 
@@ -1234,9 +1257,13 @@ read_client_final_message(scram_state *state, const char *input)
 		snprintf(cbind_input, cbind_input_len, "p=tls-server-end-point,,");
 		memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len);
 
-		b64_message = palloc(pg_b64_enc_len(cbind_input_len) + 1);
+		b64_message_len = pg_b64_enc_len(cbind_input_len);
+		/* don't forget the zero-terminator */
+		b64_message = palloc(b64_message_len + 1);
 		b64_message_len = pg_b64_encode(cbind_input, cbind_input_len,
-										b64_message);
+										b64_message, b64_message_len);
+		if (b64_message_len < 0)
+			elog(ERROR, "could not encode channel binding data");
 		b64_message[b64_message_len] = '\0';
 
 		/*
@@ -1276,8 +1303,10 @@ read_client_final_message(scram_state *state, const char *input)
 		value = read_any_attr(&p, &attr);
 	} while (attr != 'p');
 
-	client_proof = palloc(pg_b64_dec_len(strlen(value)));
-	if (pg_b64_decode(value, strlen(value), client_proof) != SCRAM_KEY_LEN)
+	client_proof_len = pg_b64_dec_len(strlen(value));
+	client_proof = palloc(client_proof_len);
+	if (pg_b64_decode(value, strlen(value), client_proof,
+					  client_proof_len) != SCRAM_KEY_LEN)
 		ereport(ERROR,
 				(errcode(ERRCODE_PROTOCOL_VIOLATION),
 				 errmsg("malformed SCRAM message"),
@@ -1322,9 +1351,14 @@ build_server_final_message(scram_state *state)
 					  strlen(state->client_final_message_without_proof));
 	scram_HMAC_final(ServerSignature, &ctx);
 
-	server_signature_base64 = palloc(pg_b64_enc_len(SCRAM_KEY_LEN) + 1);
+	siglen = pg_b64_enc_len(SCRAM_KEY_LEN);
+	/* don't forget the zero-terminator */
+	server_signature_base64 = palloc(siglen + 1);
 	siglen = pg_b64_encode((const char *) ServerSignature,
-						   SCRAM_KEY_LEN, server_signature_base64);
+						   SCRAM_KEY_LEN, server_signature_base64,
+						   siglen);
+	if (siglen < 0)
+		elog(ERROR, "could not encode server signature");
 	server_signature_base64[siglen] = '\0';
 
 	/*------
diff --git a/src/common/base64.c b/src/common/base64.c
index 55c8983f97..57ec06c3a9 100644
--- a/src/common/base64.c
+++ b/src/common/base64.c
@@ -42,10 +42,11 @@ static const int8 b64lookup[128] = {
  * pg_b64_encode
  *
  * Encode into base64 the given string.  Returns the length of the encoded
- * string.
+ * string, and -1 in the event of an error with the result buffer zeroed
+ * for safety.
  */
 int
-pg_b64_encode(const char *src, int len, char *dst)
+pg_b64_encode(const char *src, int len, char *dst, int dstlen)
 {
 	char	   *p;
 	const char *s,
@@ -65,6 +66,13 @@ pg_b64_encode(const char *src, int len, char *dst)
 		/* write it out */
 		if (pos < 0)
 		{
+			/*
+			 * Leave if there is an overflow in the area allocated for the
+			 * encoded string.
+			 */
+			if ((p - dst + 4) > dstlen)
+				goto error;
+
 			*p++ = _base64[(buf >> 18) & 0x3f];
 			*p++ = _base64[(buf >> 12) & 0x3f];
 			*p++ = _base64[(buf >> 6) & 0x3f];
@@ -76,23 +84,36 @@ pg_b64_encode(const char *src, int len, char *dst)
 	}
 	if (pos != 2)
 	{
+		/*
+		 * Leave if there is an overflow in the area allocated for the encoded
+		 * string.
+		 */
+		if ((p - dst + 4) > dstlen)
+			goto error;
+
 		*p++ = _base64[(buf >> 18) & 0x3f];
 		*p++ = _base64[(buf >> 12) & 0x3f];
 		*p++ = (pos == 0) ? _base64[(buf >> 6) & 0x3f] : '=';
 		*p++ = '=';
 	}
 
+	Assert((p - dst) <= dstlen);
 	return p - dst;
+
+error:
+	memset(dst, 0, dstlen);
+	return -1;
 }
 
 /*
  * pg_b64_decode
  *
  * Decode the given base64 string.  Returns the length of the decoded
- * string on success, and -1 in the event of an error.
+ * string on success, and -1 in the event of an error with the result
+ * buffer zeroed for safety.
  */
 int
-pg_b64_decode(const char *src, int len, char *dst)
+pg_b64_decode(const char *src, int len, char *dst, int dstlen)
 {
 	const char *srcend = src + len,
 			   *s = src;
@@ -109,7 +130,7 @@ pg_b64_decode(const char *src, int len, char *dst)
 
 		/* Leave if a whitespace is found */
 		if (c == ' ' || c == '\t' || c == '\n' || c == '\r')
-			return -1;
+			goto error;
 
 		if (c == '=')
 		{
@@ -126,7 +147,7 @@ pg_b64_decode(const char *src, int len, char *dst)
 					 * Unexpected "=" character found while decoding base64
 					 * sequence.
 					 */
-					return -1;
+					goto error;
 				}
 			}
 			b = 0;
@@ -139,7 +160,7 @@ pg_b64_decode(const char *src, int len, char *dst)
 			if (b < 0)
 			{
 				/* invalid symbol found */
-				return -1;
+				goto error;
 			}
 		}
 		/* add it to buffer */
@@ -147,11 +168,28 @@ pg_b64_decode(const char *src, int len, char *dst)
 		pos++;
 		if (pos == 4)
 		{
+			/*
+			 * Leave if there is an overflow in the area allocated for the
+			 * decoded string.
+			 */
+			if ((p - dst + 1) > dstlen)
+				goto error;
 			*p++ = (buf >> 16) & 255;
+
 			if (end == 0 || end > 1)
+			{
+				/* overflow check */
+				if ((p - dst + 1) > dstlen)
+					goto error;
 				*p++ = (buf >> 8) & 255;
+			}
 			if (end == 0 || end > 2)
+			{
+				/* overflow check */
+				if ((p - dst + 1) > dstlen)
+					goto error;
 				*p++ = buf & 255;
+			}
 			buf = 0;
 			pos = 0;
 		}
@@ -163,10 +201,15 @@ pg_b64_decode(const char *src, int len, char *dst)
 		 * base64 end sequence is invalid.  Input data is missing padding, is
 		 * truncated or is otherwise corrupted.
 		 */
-		return -1;
+		goto error;
 	}
 
+	Assert((p - dst) <= dstlen);
 	return p - dst;
+
+error:
+	memset(dst, 0, dstlen);
+	return -1;
 }
 
 /*
diff --git a/src/common/scram-common.c b/src/common/scram-common.c
index c30dfc97dc..dff9723e67 100644
--- a/src/common/scram-common.c
+++ b/src/common/scram-common.c
@@ -198,6 +198,10 @@ scram_build_verifier(const char *salt, int saltlen, int iterations,
 	char	   *result;
 	char	   *p;
 	int			maxlen;
+	int			encoded_salt_len;
+	int			encoded_stored_len;
+	int			encoded_server_len;
+	int			encoded_result;
 
 	if (iterations <= 0)
 		iterations = SCRAM_DEFAULT_ITERATIONS;
@@ -215,11 +219,15 @@ scram_build_verifier(const char *salt, int saltlen, int iterations,
 	 * SCRAM-SHA-256$<iteration count>:<salt>$<StoredKey>:<ServerKey>
 	 *----------
 	 */
+	encoded_salt_len = pg_b64_enc_len(saltlen);
+	encoded_stored_len = pg_b64_enc_len(SCRAM_KEY_LEN);
+	encoded_server_len = pg_b64_enc_len(SCRAM_KEY_LEN);
+
 	maxlen = strlen("SCRAM-SHA-256") + 1
 		+ 10 + 1				/* iteration count */
-		+ pg_b64_enc_len(saltlen) + 1	/* Base64-encoded salt */
-		+ pg_b64_enc_len(SCRAM_KEY_LEN) + 1 /* Base64-encoded StoredKey */
-		+ pg_b64_enc_len(SCRAM_KEY_LEN) + 1;	/* Base64-encoded ServerKey */
+		+ encoded_salt_len + 1	/* Base64-encoded salt */
+		+ encoded_stored_len + 1	/* Base64-encoded StoredKey */
+		+ encoded_server_len + 1;	/* Base64-encoded ServerKey */
 
 #ifdef FRONTEND
 	result = malloc(maxlen);
@@ -231,11 +239,50 @@ scram_build_verifier(const char *salt, int saltlen, int iterations,
 
 	p = result + sprintf(result, "SCRAM-SHA-256$%d:", iterations);
 
-	p += pg_b64_encode(salt, saltlen, p);
+	/* salt */
+	encoded_result = pg_b64_encode(salt, saltlen, p, encoded_salt_len);
+	if (encoded_result < 0)
+	{
+#ifdef FRONTEND
+		free(result);
+		return NULL;
+#else
+		elog(ERROR, "could not encode salt");
+#endif
+	}
+	p += encoded_result;
 	*(p++) = '$';
-	p += pg_b64_encode((char *) stored_key, SCRAM_KEY_LEN, p);
+
+	/* stored key */
+	encoded_result = pg_b64_encode((char *) stored_key, SCRAM_KEY_LEN, p,
+								   encoded_stored_len);
+	if (encoded_result < 0)
+	{
+#ifdef FRONTEND
+		free(result);
+		return NULL;
+#else
+		elog(ERROR, "could not encode stored key");
+#endif
+	}
+
+	p += encoded_result;
 	*(p++) = ':';
-	p += pg_b64_encode((char *) server_key, SCRAM_KEY_LEN, p);
+
+	/* server key */
+	encoded_result = pg_b64_encode((char *) server_key, SCRAM_KEY_LEN, p,
+								   encoded_server_len);
+	if (encoded_result < 0)
+	{
+#ifdef FRONTEND
+		free(result);
+		return NULL;
+#else
+		elog(ERROR, "could not encode server key");
+#endif
+	}
+
+	p += encoded_result;
 	*(p++) = '\0';
 
 	Assert(p - result <= maxlen);
diff --git a/src/include/common/base64.h b/src/include/common/base64.h
index 1bae5ec966..c30b173483 100644
--- a/src/include/common/base64.h
+++ b/src/include/common/base64.h
@@ -11,8 +11,8 @@
 #define BASE64_H
 
 /* base 64 */
-extern int	pg_b64_encode(const char *src, int len, char *dst);
-extern int	pg_b64_decode(const char *src, int len, char *dst);
+extern int	pg_b64_encode(const char *src, int len, char *dst, int dstlen);
+extern int	pg_b64_decode(const char *src, int len, char *dst, int dstlen);
 extern int	pg_b64_enc_len(int srclen);
 extern int	pg_b64_dec_len(int srclen);
 
diff --git a/src/interfaces/libpq/fe-auth-scram.c b/src/interfaces/libpq/fe-auth-scram.c
index babdc06198..249cb1901c 100644
--- a/src/interfaces/libpq/fe-auth-scram.c
+++ b/src/interfaces/libpq/fe-auth-scram.c
@@ -321,14 +321,23 @@ build_client_first_message(fe_scram_state *state)
 		return NULL;
 	}
 
-	state->client_nonce = malloc(pg_b64_enc_len(SCRAM_RAW_NONCE_LEN) + 1);
+	encoded_len = pg_b64_enc_len(SCRAM_RAW_NONCE_LEN);
+	/* don't forget the zero-terminator */
+	state->client_nonce = malloc(encoded_len + 1);
 	if (state->client_nonce == NULL)
 	{
 		printfPQExpBuffer(&conn->errorMessage,
 						  libpq_gettext("out of memory\n"));
 		return NULL;
 	}
-	encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN, state->client_nonce);
+	encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN,
+								state->client_nonce, encoded_len);
+	if (encoded_len < 0)
+	{
+		printfPQExpBuffer(&conn->errorMessage,
+						  libpq_gettext("could not encode nonce\n"));
+		return NULL;
+	}
 	state->client_nonce[encoded_len] = '\0';
 
 	/*
@@ -406,6 +415,7 @@ build_client_final_message(fe_scram_state *state)
 	PGconn	   *conn = state->conn;
 	uint8		client_proof[SCRAM_KEY_LEN];
 	char	   *result;
+	int			encoded_len;
 
 	initPQExpBuffer(&buf);
 
@@ -425,6 +435,7 @@ build_client_final_message(fe_scram_state *state)
 		size_t		cbind_header_len;
 		char	   *cbind_input;
 		size_t		cbind_input_len;
+		int			encoded_cbind_len;
 
 		/* Fetch hash data of server's SSL certificate */
 		cbind_data =
@@ -451,13 +462,26 @@ build_client_final_message(fe_scram_state *state)
 		memcpy(cbind_input, "p=tls-server-end-point,,", cbind_header_len);
 		memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len);
 
-		if (!enlargePQExpBuffer(&buf, pg_b64_enc_len(cbind_input_len)))
+		encoded_cbind_len = pg_b64_enc_len(cbind_input_len);
+		if (!enlargePQExpBuffer(&buf, encoded_cbind_len))
 		{
 			free(cbind_data);
 			free(cbind_input);
 			goto oom_error;
 		}
-		buf.len += pg_b64_encode(cbind_input, cbind_input_len, buf.data + buf.len);
+		encoded_cbind_len = pg_b64_encode(cbind_input, cbind_input_len,
+										  buf.data + buf.len,
+										  encoded_cbind_len);
+		if (encoded_cbind_len < 0)
+		{
+			free(cbind_data);
+			free(cbind_input);
+			termPQExpBuffer(&buf);
+			printfPQExpBuffer(&conn->errorMessage,
+							  "could not encode cbind input for channel binding\n");
+			return NULL;
+		}
+		buf.len += encoded_cbind_len;
 		buf.data[buf.len] = '\0';
 
 		free(cbind_data);
@@ -497,11 +521,21 @@ build_client_final_message(fe_scram_state *state)
 						   client_proof);
 
 	appendPQExpBuffer(&buf, ",p=");
-	if (!enlargePQExpBuffer(&buf, pg_b64_enc_len(SCRAM_KEY_LEN)))
+	encoded_len = pg_b64_enc_len(SCRAM_KEY_LEN);
+	if (!enlargePQExpBuffer(&buf, encoded_len))
 		goto oom_error;
-	buf.len += pg_b64_encode((char *) client_proof,
-							 SCRAM_KEY_LEN,
-							 buf.data + buf.len);
+	encoded_len = pg_b64_encode((char *) client_proof,
+								SCRAM_KEY_LEN,
+								buf.data + buf.len,
+								encoded_len);
+	if (encoded_len < 0)
+	{
+		termPQExpBuffer(&buf);
+		printfPQExpBuffer(&conn->errorMessage,
+						  libpq_gettext("could not encode client proof\n"));
+		return NULL;
+	}
+	buf.len += encoded_len;
 	buf.data[buf.len] = '\0';
 
 	result = strdup(buf.data);
@@ -529,6 +563,7 @@ read_server_first_message(fe_scram_state *state, char *input)
 	char	   *endptr;
 	char	   *encoded_salt;
 	char	   *nonce;
+	int			decoded_salt_len;
 
 	state->server_first_message = strdup(input);
 	if (state->server_first_message == NULL)
@@ -570,7 +605,8 @@ read_server_first_message(fe_scram_state *state, char *input)
 		/* read_attr_value() has generated an error string */
 		return false;
 	}
-	state->salt = malloc(pg_b64_dec_len(strlen(encoded_salt)));
+	decoded_salt_len = pg_b64_dec_len(strlen(encoded_salt));
+	state->salt = malloc(decoded_salt_len);
 	if (state->salt == NULL)
 	{
 		printfPQExpBuffer(&conn->errorMessage,
@@ -579,7 +615,8 @@ read_server_first_message(fe_scram_state *state, char *input)
 	}
 	state->saltlen = pg_b64_decode(encoded_salt,
 								   strlen(encoded_salt),
-								   state->salt);
+								   state->salt,
+								   decoded_salt_len);
 	if (state->saltlen < 0)
 	{
 		printfPQExpBuffer(&conn->errorMessage,
@@ -663,7 +700,8 @@ read_server_final_message(fe_scram_state *state, char *input)
 
 	server_signature_len = pg_b64_decode(encoded_server_signature,
 										 strlen(encoded_server_signature),
-										 decoded_server_signature);
+										 decoded_server_signature,
+										 server_signature_len);
 	if (server_signature_len != SCRAM_KEY_LEN)
 	{
 		free(decoded_server_signature);

Attachment: signature.asc
Description: PGP signature

Reply via email to