Hi all,

After the issues behind CVE-2019-10164, it seems that we can do much
better with the current interface of decoding and encoding functions
for base64 in src/common/.

The root issue is that the callers of pg_b64_decode() and
pg_b64_encode() provide a buffer where the result gets stored which is
allocated using respectively pg_b64_dec_len() and pg_b64_dec_enc()
(those routines overestimate the allocation on purpose) but we don't
allow callers to provide the length of the buffer allocated and hence
those routines lack sanity checks to make sure that what is in input
does not cause an overflow within the result buffer.

One thing I have noticed is that many projects on the net include this
code for their own purpose, and I have suspicions that many other
projects link to the code from Postgres and make use of it.  So that's
rather scary.

Attached is a refactoring patch for those interfaces, which introduces
a set of overflow checks so as we cannot repeat errors of the past.
This adds one argument to pg_b64_decode() and pg_b64_encode() as the
size of the result buffer, and we make use of it in the code to make
sure that an error is reported in case of an overflow.  That's the
status code -1 which is used for other errors for simplicity.  One
thing to note is that the decoding path can already complain on some
errors, basically an incorrectly shaped encoded string, but the
encoding path does not have any errors yet, so we need to make sure
that all the existing callers of pg_b64_encode() complain correctly
with the new interface.

I am adding that to the next CF for v13.

Any thoughts?
--
Michael
diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c
index 6b60abe1dd..91ed71391d 100644
--- a/src/backend/libpq/auth-scram.c
+++ b/src/backend/libpq/auth-scram.c
@@ -510,9 +510,11 @@ scram_verify_plain_password(const char *username, const char *password,
 		return false;
 	}
 
-	salt = palloc(pg_b64_dec_len(strlen(encoded_salt)));
-	saltlen = pg_b64_decode(encoded_salt, strlen(encoded_salt), salt);
-	if (saltlen == -1)
+	saltlen = pg_b64_dec_len(strlen(encoded_salt));
+	salt = palloc(saltlen);
+	saltlen = pg_b64_decode(encoded_salt, strlen(encoded_salt), salt,
+							saltlen);
+	if (saltlen < 0)
 	{
 		ereport(LOG,
 				(errmsg("invalid SCRAM verifier for user \"%s\"", username)));
@@ -596,9 +598,10 @@ parse_scram_verifier(const char *verifier, int *iterations, char **salt,
 	 * Verify that the salt is in Base64-encoded format, by decoding it,
 	 * although we return the encoded version to the caller.
 	 */
-	decoded_salt_buf = palloc(pg_b64_dec_len(strlen(salt_str)));
+	decoded_len = pg_b64_dec_len(strlen(salt_str));
+	decoded_salt_buf = palloc(decoded_len);
 	decoded_len = pg_b64_decode(salt_str, strlen(salt_str),
-								decoded_salt_buf);
+								decoded_salt_buf, decoded_len);
 	if (decoded_len < 0)
 		goto invalid_verifier;
 	*salt = pstrdup(salt_str);
@@ -606,16 +609,18 @@ parse_scram_verifier(const char *verifier, int *iterations, char **salt,
 	/*
 	 * Decode StoredKey and ServerKey.
 	 */
-	decoded_stored_buf = palloc(pg_b64_dec_len(strlen(storedkey_str)));
+	decoded_len = pg_b64_dec_len(strlen(storedkey_str));
+	decoded_stored_buf = palloc(decoded_len);
 	decoded_len = pg_b64_decode(storedkey_str, strlen(storedkey_str),
-								decoded_stored_buf);
+								decoded_stored_buf, decoded_len);
 	if (decoded_len != SCRAM_KEY_LEN)
 		goto invalid_verifier;
 	memcpy(stored_key, decoded_stored_buf, SCRAM_KEY_LEN);
 
-	decoded_server_buf = palloc(pg_b64_dec_len(strlen(serverkey_str)));
+	decoded_len = pg_b64_dec_len(strlen(serverkey_str));
+	decoded_server_buf = palloc(decoded_len);
 	decoded_len = pg_b64_decode(serverkey_str, strlen(serverkey_str),
-								decoded_server_buf);
+								decoded_server_buf, decoded_len);
 	if (decoded_len != SCRAM_KEY_LEN)
 		goto invalid_verifier;
 	memcpy(server_key, decoded_server_buf, SCRAM_KEY_LEN);
@@ -649,8 +654,18 @@ mock_scram_verifier(const char *username, int *iterations, char **salt,
 	/* Generate deterministic salt */
 	raw_salt = scram_mock_salt(username);
 
-	encoded_salt = (char *) palloc(pg_b64_enc_len(SCRAM_DEFAULT_SALT_LEN) + 1);
-	encoded_len = pg_b64_encode(raw_salt, SCRAM_DEFAULT_SALT_LEN, encoded_salt);
+	encoded_len = pg_b64_enc_len(SCRAM_DEFAULT_SALT_LEN);
+	/* don't forget the zero-terminator */
+	encoded_salt = (char *) palloc(encoded_len + 1);
+	encoded_len = pg_b64_encode(raw_salt, SCRAM_DEFAULT_SALT_LEN, encoded_salt,
+								encoded_len);
+
+	/*
+	 * Note that we cannot reveal any information to an attacker here
+	 * so the error message needs to remain generic.
+	 */
+	if (encoded_len < 0)
+		elog(ERROR, "could not encode salt");
 	encoded_salt[encoded_len] = '\0';
 
 	*salt = encoded_salt;
@@ -1144,8 +1159,15 @@ build_server_first_message(scram_state *state)
 				(errcode(ERRCODE_INTERNAL_ERROR),
 				 errmsg("could not generate random nonce")));
 
-	state->server_nonce = palloc(pg_b64_enc_len(SCRAM_RAW_NONCE_LEN) + 1);
-	encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN, state->server_nonce);
+	encoded_len = pg_b64_enc_len(SCRAM_RAW_NONCE_LEN);
+	/* don't forget the zero-terminator */
+	state->server_nonce = palloc(encoded_len + 1);
+	encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN,
+								state->server_nonce, encoded_len);
+	if (encoded_len < 0)
+		ereport(ERROR,
+				(errcode(ERRCODE_INTERNAL_ERROR),
+				 errmsg("could not encode random nonce")));
 	state->server_nonce[encoded_len] = '\0';
 
 	state->server_first_message =
@@ -1170,6 +1192,7 @@ read_client_final_message(scram_state *state, const char *input)
 			   *proof;
 	char	   *p;
 	char	   *client_proof;
+	int			client_proof_len;
 
 	begin = p = pstrdup(input);
 
@@ -1234,9 +1257,13 @@ read_client_final_message(scram_state *state, const char *input)
 		snprintf(cbind_input, cbind_input_len, "p=tls-server-end-point,,");
 		memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len);
 
-		b64_message = palloc(pg_b64_enc_len(cbind_input_len) + 1);
+		b64_message_len = pg_b64_enc_len(cbind_input_len);
+		/* don't forget the zero-terminator */
+		b64_message = palloc(b64_message_len + 1);
 		b64_message_len = pg_b64_encode(cbind_input, cbind_input_len,
-										b64_message);
+										b64_message, b64_message_len);
+		if (b64_message_len < 0)
+			elog(ERROR, "could not encode channel binding data");
 		b64_message[b64_message_len] = '\0';
 
 		/*
@@ -1276,8 +1303,10 @@ read_client_final_message(scram_state *state, const char *input)
 		value = read_any_attr(&p, &attr);
 	} while (attr != 'p');
 
-	client_proof = palloc(pg_b64_dec_len(strlen(value)));
-	if (pg_b64_decode(value, strlen(value), client_proof) != SCRAM_KEY_LEN)
+	client_proof_len = pg_b64_dec_len(strlen(value));
+	client_proof = palloc(client_proof_len);
+	if (pg_b64_decode(value, strlen(value), client_proof,
+					  client_proof_len) != SCRAM_KEY_LEN)
 		ereport(ERROR,
 				(errcode(ERRCODE_PROTOCOL_VIOLATION),
 				 errmsg("malformed SCRAM message"),
@@ -1322,9 +1351,14 @@ build_server_final_message(scram_state *state)
 					  strlen(state->client_final_message_without_proof));
 	scram_HMAC_final(ServerSignature, &ctx);
 
-	server_signature_base64 = palloc(pg_b64_enc_len(SCRAM_KEY_LEN) + 1);
+	siglen = pg_b64_enc_len(SCRAM_KEY_LEN);
+	/* don't forget the zero-terminator */
+	server_signature_base64 = palloc(siglen + 1);
 	siglen = pg_b64_encode((const char *) ServerSignature,
-						   SCRAM_KEY_LEN, server_signature_base64);
+						   SCRAM_KEY_LEN, server_signature_base64,
+						   siglen);
+	if (siglen < 0)
+		elog(ERROR, "could not encode server signature");
 	server_signature_base64[siglen] = '\0';
 
 	/*------
diff --git a/src/common/base64.c b/src/common/base64.c
index 55c8983f97..e0042a7d92 100644
--- a/src/common/base64.c
+++ b/src/common/base64.c
@@ -42,10 +42,10 @@ static const int8 b64lookup[128] = {
  * pg_b64_encode
  *
  * Encode into base64 the given string.  Returns the length of the encoded
- * string.
+ * string, and -1 in the event of an error.
  */
 int
-pg_b64_encode(const char *src, int len, char *dst)
+pg_b64_encode(const char *src, int len, char *dst, int dstlen)
 {
 	char	   *p;
 	const char *s,
@@ -65,6 +65,13 @@ pg_b64_encode(const char *src, int len, char *dst)
 		/* write it out */
 		if (pos < 0)
 		{
+			/*
+			 * Leave if there is an overflow in the area allocated for
+			 * the encoded string.
+			 */
+			if ((p - dst + 4) > dstlen)
+				return -1;
+
 			*p++ = _base64[(buf >> 18) & 0x3f];
 			*p++ = _base64[(buf >> 12) & 0x3f];
 			*p++ = _base64[(buf >> 6) & 0x3f];
@@ -76,12 +83,20 @@ pg_b64_encode(const char *src, int len, char *dst)
 	}
 	if (pos != 2)
 	{
+		/*
+		 * Leave if there is an overflow in the area allocated for
+		 * the encoded string.
+		 */
+		if ((p - dst + 4) > dstlen)
+			return -1;
+
 		*p++ = _base64[(buf >> 18) & 0x3f];
 		*p++ = _base64[(buf >> 12) & 0x3f];
 		*p++ = (pos == 0) ? _base64[(buf >> 6) & 0x3f] : '=';
 		*p++ = '=';
 	}
 
+	Assert((p - dst) <= dstlen);
 	return p - dst;
 }
 
@@ -92,7 +107,7 @@ pg_b64_encode(const char *src, int len, char *dst)
  * string on success, and -1 in the event of an error.
  */
 int
-pg_b64_decode(const char *src, int len, char *dst)
+pg_b64_decode(const char *src, int len, char *dst, int dstlen)
 {
 	const char *srcend = src + len,
 			   *s = src;
@@ -147,11 +162,28 @@ pg_b64_decode(const char *src, int len, char *dst)
 		pos++;
 		if (pos == 4)
 		{
+			/*
+			 * Leave if there is an overflow in the area allocated for
+			 * the decoded string.
+			 */
+			if ((p - dst + 1) > dstlen)
+				return -1;
 			*p++ = (buf >> 16) & 255;
+
 			if (end == 0 || end > 1)
+			{
+				/* overflow check */
+				if ((p - dst + 1) > dstlen)
+					return -1;
 				*p++ = (buf >> 8) & 255;
+			}
 			if (end == 0 || end > 2)
+			{
+				/* overflow check */
+				if ((p - dst + 1) > dstlen)
+					return -1;
 				*p++ = buf & 255;
+			}
 			buf = 0;
 			pos = 0;
 		}
@@ -166,6 +198,7 @@ pg_b64_decode(const char *src, int len, char *dst)
 		return -1;
 	}
 
+	Assert((p - dst) <= dstlen);
 	return p - dst;
 }
 
diff --git a/src/common/scram-common.c b/src/common/scram-common.c
index c30dfc97dc..5b8f6b2143 100644
--- a/src/common/scram-common.c
+++ b/src/common/scram-common.c
@@ -198,6 +198,10 @@ scram_build_verifier(const char *salt, int saltlen, int iterations,
 	char	   *result;
 	char	   *p;
 	int			maxlen;
+	int			encoded_salt_len;
+	int			encoded_stored_len;
+	int			encoded_server_len;
+	int			encoded_result;
 
 	if (iterations <= 0)
 		iterations = SCRAM_DEFAULT_ITERATIONS;
@@ -215,11 +219,15 @@ scram_build_verifier(const char *salt, int saltlen, int iterations,
 	 * SCRAM-SHA-256$<iteration count>:<salt>$<StoredKey>:<ServerKey>
 	 *----------
 	 */
+	encoded_salt_len = pg_b64_enc_len(saltlen);
+	encoded_stored_len = pg_b64_enc_len(SCRAM_KEY_LEN);
+	encoded_server_len = pg_b64_enc_len(SCRAM_KEY_LEN);
+
 	maxlen = strlen("SCRAM-SHA-256") + 1
 		+ 10 + 1				/* iteration count */
-		+ pg_b64_enc_len(saltlen) + 1	/* Base64-encoded salt */
-		+ pg_b64_enc_len(SCRAM_KEY_LEN) + 1 /* Base64-encoded StoredKey */
-		+ pg_b64_enc_len(SCRAM_KEY_LEN) + 1;	/* Base64-encoded ServerKey */
+		+ encoded_salt_len + 1	/* Base64-encoded salt */
+		+ encoded_stored_len + 1	/* Base64-encoded StoredKey */
+		+ encoded_server_len + 1;	/* Base64-encoded ServerKey */
 
 #ifdef FRONTEND
 	result = malloc(maxlen);
@@ -231,11 +239,41 @@ scram_build_verifier(const char *salt, int saltlen, int iterations,
 
 	p = result + sprintf(result, "SCRAM-SHA-256$%d:", iterations);
 
-	p += pg_b64_encode(salt, saltlen, p);
+	/* salt */
+	encoded_result = pg_b64_encode(salt, saltlen, p, encoded_salt_len);
+	if (encoded_result < 0)
+	{
+#ifdef FRONTEND
+		return NULL;
+#else
+		elog(ERROR, "could not encode salt");
+#endif
+	}
+	p += encoded_result;
 	*(p++) = '$';
-	p += pg_b64_encode((char *) stored_key, SCRAM_KEY_LEN, p);
+
+	/* stored key */
+	encoded_result = pg_b64_encode((char *) stored_key, SCRAM_KEY_LEN, p,
+								   encoded_stored_len);
+	if (encoded_result < 0)
+#ifdef FRONTEND
+		return NULL;
+#else
+		elog(ERROR, "could not encode stored key");
+#endif
+	p += encoded_result;
 	*(p++) = ':';
-	p += pg_b64_encode((char *) server_key, SCRAM_KEY_LEN, p);
+
+	/* server key */
+	encoded_result = pg_b64_encode((char *) server_key, SCRAM_KEY_LEN, p,
+								   encoded_server_len);
+	if (encoded_result < 0)
+#ifdef FRONTEND
+		return NULL;
+#else
+		elog(ERROR, "could not encode server key");
+#endif
+	p += encoded_result;
 	*(p++) = '\0';
 
 	Assert(p - result <= maxlen);
diff --git a/src/include/common/base64.h b/src/include/common/base64.h
index 1bae5ec966..c30b173483 100644
--- a/src/include/common/base64.h
+++ b/src/include/common/base64.h
@@ -11,8 +11,8 @@
 #define BASE64_H
 
 /* base 64 */
-extern int	pg_b64_encode(const char *src, int len, char *dst);
-extern int	pg_b64_decode(const char *src, int len, char *dst);
+extern int	pg_b64_encode(const char *src, int len, char *dst, int dstlen);
+extern int	pg_b64_decode(const char *src, int len, char *dst, int dstlen);
 extern int	pg_b64_enc_len(int srclen);
 extern int	pg_b64_dec_len(int srclen);
 
diff --git a/src/interfaces/libpq/fe-auth-scram.c b/src/interfaces/libpq/fe-auth-scram.c
index babdc06198..249cb1901c 100644
--- a/src/interfaces/libpq/fe-auth-scram.c
+++ b/src/interfaces/libpq/fe-auth-scram.c
@@ -321,14 +321,23 @@ build_client_first_message(fe_scram_state *state)
 		return NULL;
 	}
 
-	state->client_nonce = malloc(pg_b64_enc_len(SCRAM_RAW_NONCE_LEN) + 1);
+	encoded_len = pg_b64_enc_len(SCRAM_RAW_NONCE_LEN);
+	/* don't forget the zero-terminator */
+	state->client_nonce = malloc(encoded_len + 1);
 	if (state->client_nonce == NULL)
 	{
 		printfPQExpBuffer(&conn->errorMessage,
 						  libpq_gettext("out of memory\n"));
 		return NULL;
 	}
-	encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN, state->client_nonce);
+	encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN,
+								state->client_nonce, encoded_len);
+	if (encoded_len < 0)
+	{
+		printfPQExpBuffer(&conn->errorMessage,
+						  libpq_gettext("could not encode nonce\n"));
+		return NULL;
+	}
 	state->client_nonce[encoded_len] = '\0';
 
 	/*
@@ -406,6 +415,7 @@ build_client_final_message(fe_scram_state *state)
 	PGconn	   *conn = state->conn;
 	uint8		client_proof[SCRAM_KEY_LEN];
 	char	   *result;
+	int			encoded_len;
 
 	initPQExpBuffer(&buf);
 
@@ -425,6 +435,7 @@ build_client_final_message(fe_scram_state *state)
 		size_t		cbind_header_len;
 		char	   *cbind_input;
 		size_t		cbind_input_len;
+		int			encoded_cbind_len;
 
 		/* Fetch hash data of server's SSL certificate */
 		cbind_data =
@@ -451,13 +462,26 @@ build_client_final_message(fe_scram_state *state)
 		memcpy(cbind_input, "p=tls-server-end-point,,", cbind_header_len);
 		memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len);
 
-		if (!enlargePQExpBuffer(&buf, pg_b64_enc_len(cbind_input_len)))
+		encoded_cbind_len = pg_b64_enc_len(cbind_input_len);
+		if (!enlargePQExpBuffer(&buf, encoded_cbind_len))
 		{
 			free(cbind_data);
 			free(cbind_input);
 			goto oom_error;
 		}
-		buf.len += pg_b64_encode(cbind_input, cbind_input_len, buf.data + buf.len);
+		encoded_cbind_len = pg_b64_encode(cbind_input, cbind_input_len,
+										  buf.data + buf.len,
+										  encoded_cbind_len);
+		if (encoded_cbind_len < 0)
+		{
+			free(cbind_data);
+			free(cbind_input);
+			termPQExpBuffer(&buf);
+			printfPQExpBuffer(&conn->errorMessage,
+							  "could not encode cbind input for channel binding\n");
+			return NULL;
+		}
+		buf.len += encoded_cbind_len;
 		buf.data[buf.len] = '\0';
 
 		free(cbind_data);
@@ -497,11 +521,21 @@ build_client_final_message(fe_scram_state *state)
 						   client_proof);
 
 	appendPQExpBuffer(&buf, ",p=");
-	if (!enlargePQExpBuffer(&buf, pg_b64_enc_len(SCRAM_KEY_LEN)))
+	encoded_len = pg_b64_enc_len(SCRAM_KEY_LEN);
+	if (!enlargePQExpBuffer(&buf, encoded_len))
 		goto oom_error;
-	buf.len += pg_b64_encode((char *) client_proof,
-							 SCRAM_KEY_LEN,
-							 buf.data + buf.len);
+	encoded_len = pg_b64_encode((char *) client_proof,
+								SCRAM_KEY_LEN,
+								buf.data + buf.len,
+								encoded_len);
+	if (encoded_len < 0)
+	{
+		termPQExpBuffer(&buf);
+		printfPQExpBuffer(&conn->errorMessage,
+						  libpq_gettext("could not encode client proof\n"));
+		return NULL;
+	}
+	buf.len += encoded_len;
 	buf.data[buf.len] = '\0';
 
 	result = strdup(buf.data);
@@ -529,6 +563,7 @@ read_server_first_message(fe_scram_state *state, char *input)
 	char	   *endptr;
 	char	   *encoded_salt;
 	char	   *nonce;
+	int			decoded_salt_len;
 
 	state->server_first_message = strdup(input);
 	if (state->server_first_message == NULL)
@@ -570,7 +605,8 @@ read_server_first_message(fe_scram_state *state, char *input)
 		/* read_attr_value() has generated an error string */
 		return false;
 	}
-	state->salt = malloc(pg_b64_dec_len(strlen(encoded_salt)));
+	decoded_salt_len = pg_b64_dec_len(strlen(encoded_salt));
+	state->salt = malloc(decoded_salt_len);
 	if (state->salt == NULL)
 	{
 		printfPQExpBuffer(&conn->errorMessage,
@@ -579,7 +615,8 @@ read_server_first_message(fe_scram_state *state, char *input)
 	}
 	state->saltlen = pg_b64_decode(encoded_salt,
 								   strlen(encoded_salt),
-								   state->salt);
+								   state->salt,
+								   decoded_salt_len);
 	if (state->saltlen < 0)
 	{
 		printfPQExpBuffer(&conn->errorMessage,
@@ -663,7 +700,8 @@ read_server_final_message(fe_scram_state *state, char *input)
 
 	server_signature_len = pg_b64_decode(encoded_server_signature,
 										 strlen(encoded_server_signature),
-										 decoded_server_signature);
+										 decoded_server_signature,
+										 server_signature_len);
 	if (server_signature_len != SCRAM_KEY_LEN)
 	{
 		free(decoded_server_signature);

Attachment: signature.asc
Description: PGP signature

Reply via email to