From b25fa01a5a9624c44574d2f00e2795f1e1a7058e Mon Sep 17 00:00:00 2001
From: Florents Tselai <florents.tselai@gmail.com>
Date: Wed, 4 Jun 2025 10:40:31 +0300
Subject: [PATCH v4 2/3] Extract pg_base64_{en,de}code_internal with an
 additional bool url param, to be used by other functions.

---
 src/backend/utils/adt/encode.c | 206 +++++++++++++++------------------
 1 file changed, 91 insertions(+), 115 deletions(-)

diff --git a/src/backend/utils/adt/encode.c b/src/backend/utils/adt/encode.c
index 9522eecd4be..3f2dd448e2a 100644
--- a/src/backend/utils/adt/encode.c
+++ b/src/backend/utils/adt/encode.c
@@ -273,6 +273,9 @@ hex_dec_len(const char *src, size_t srclen)
 static const char _base64[] =
 "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
 
+static const char _base64url[] =
+"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
+
 static const int8 b64lookup[128] = {
 	-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
 	-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
@@ -285,17 +288,15 @@ static const int8 b64lookup[128] = {
 };
 
 static uint64
-pg_base64_encode(const char *src, size_t len, char *dst)
+pg_base64_encode_internal(const char *src, size_t len, char *dst, bool url)
 {
-	char	   *p,
-			   *lend = dst + 76;
-	const char *s,
-			   *end = src + len;
-	int			pos = 2;
-	uint32		buf = 0;
-
-	s = src;
-	p = dst;
+	const char *alphabet = url ? _base64url : _base64;
+	const char *end = src + len;
+	const char *s = src;
+	char *p = dst;
+	int pos = 2;
+	uint32 buf = 0;
+	char *lend = dst + 76;
 
 	while (s < end)
 	{
@@ -306,53 +307,84 @@ pg_base64_encode(const char *src, size_t len, char *dst)
 		/* write it out */
 		if (pos < 0)
 		{
-			*p++ = _base64[(buf >> 18) & 0x3f];
-			*p++ = _base64[(buf >> 12) & 0x3f];
-			*p++ = _base64[(buf >> 6) & 0x3f];
-			*p++ = _base64[buf & 0x3f];
+			*p++ = alphabet[(buf >> 18) & 0x3f];
+			*p++ = alphabet[(buf >> 12) & 0x3f];
+			*p++ = alphabet[(buf >> 6) & 0x3f];
+			*p++ = alphabet[buf & 0x3f];
 
 			pos = 2;
 			buf = 0;
-		}
-		if (p >= lend)
-		{
-			*p++ = '\n';
-			lend = p + 76;
+
+			if (!url && p >= lend)
+			{
+				*p++ = '\n';
+				lend = p + 76;
+			}
 		}
 	}
+
+	/* handle remainder */
 	if (pos != 2)
 	{
-		*p++ = _base64[(buf >> 18) & 0x3f];
-		*p++ = _base64[(buf >> 12) & 0x3f];
-		*p++ = (pos == 0) ? _base64[(buf >> 6) & 0x3f] : '=';
-		*p++ = '=';
+		*p++ = alphabet[(buf >> 18) & 0x3f];
+		*p++ = alphabet[(buf >> 12) & 0x3f];
+
+		if (pos == 0)
+		{
+			*p++ = alphabet[(buf >> 6) & 0x3f];
+			if (!url)
+				*p++ = '=';
+		}
+		else
+		{
+			if (!url)
+			{
+				*p++ = '=';
+				*p++ = '=';
+			}
+		}
 	}
 
 	return p - dst;
 }
 
 static uint64
-pg_base64_decode(const char *src, size_t len, char *dst)
+pg_base64_encode(const char *src, size_t len, char *dst)
+{
+	return pg_base64_encode_internal(src, len, dst, false);
+}
+
+static uint64
+pg_base64_decode_internal(const char *src, size_t len, char *dst, bool url)
 {
-	const char *srcend = src + len,
-			   *s = src;
-	char	   *p = dst;
-	char		c;
-	int			b = 0;
-	uint32		buf = 0;
-	int			pos = 0,
-				end = 0;
+	const char *srcend = src + len;
+	const char *s = src;
+	char *p = dst;
+	char c;
+	int b = 0;
+	uint32 buf = 0;
+	int pos = 0;
+	int end = 0;
 
 	while (s < srcend)
 	{
 		c = *s++;
 
+		/* skip whitespace */
 		if (c == ' ' || c == '\t' || c == '\n' || c == '\r')
 			continue;
 
+		/* convert Base64URL to Base64 if needed */
+		if (url)
+		{
+			if (c == '-')
+				c = '+';
+			else if (c == '_')
+				c = '/';
+		}
+
 		if (c == '=')
 		{
-			/* end sequence */
 			if (!end)
 			{
 				if (pos == 2)
@@ -377,30 +409,49 @@ pg_base64_decode(const char *src, size_t len, char *dst)
 						 errmsg("invalid symbol \"%.*s\" found while decoding base64 sequence",
 								pg_mblen(s - 1), s - 1)));
 		}
-		/* add it to buffer */
+
 		buf = (buf << 6) + b;
 		pos++;
+
 		if (pos == 4)
 		{
-			*p++ = (buf >> 16) & 255;
+			*p++ = (buf >> 16) & 0xFF;
 			if (end == 0 || end > 1)
-				*p++ = (buf >> 8) & 255;
+				*p++ = (buf >> 8) & 0xFF;
 			if (end == 0 || end > 2)
-				*p++ = buf & 255;
+				*p++ = buf & 0xFF;
 			buf = 0;
 			pos = 0;
 		}
 	}
 
-	if (pos != 0)
+	if (pos == 2)
+	{
+		buf <<= 12;  /* 2 * 6 = 12 bits, pad remaining to 24 */
+		*p++ = (buf >> 16) & 0xFF;
+	}
+	else if (pos == 3)
+	{
+		buf <<= 6;  /* 3 * 6 = 18 bits */
+		*p++ = (buf >> 16) & 0xFF;
+		*p++ = (buf >> 8) & 0xFF;
+	}
+	else if (pos != 0)
+	{
 		ereport(ERROR,
 				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
 				 errmsg("invalid base64 end sequence"),
 				 errhint("Input data is missing padding, is truncated, or is otherwise corrupted.")));
+	}
 
 	return p - dst;
 }
 
+static uint64
+pg_base64_decode(const char *src, size_t len, char *dst)
+{
+	return pg_base64_decode_internal(src, len, dst, false);
+}
 
 static uint64
 pg_base64_enc_len(const char *src, size_t srclen)
@@ -436,7 +487,6 @@ pg_base64url_enc_len(const char *src, size_t srclen)
 	return result;
 }
 
-
 static uint64
 pg_base64url_dec_len(const char *src, size_t srclen)
 {
@@ -452,87 +502,13 @@ pg_base64url_dec_len(const char *src, size_t srclen)
 static uint64
 pg_base64url_encode(const char *src, size_t len, char *dst)
 {
-	uint64 encoded_len;
-	if (len == 0)
-		return 0;
-
-	encoded_len = pg_base64_encode(src, len, dst);
-
-	/* Convert Base64 to Base64URL */
-	for (uint64 i = 0; i < encoded_len; i++) {
-		if (dst[i] == '+')
-			dst[i] = '-';
-		else if (dst[i] == '/')
-			dst[i] = '_';
-	}
-
-	/* Trim '=' padding */
-	while (encoded_len > 0 && dst[encoded_len - 1] == '=')
-		encoded_len--;
-
-	return encoded_len;
+	return pg_base64_encode_internal(src, len, dst, true);
 }
 
 static uint64
 pg_base64url_decode(const char *src, size_t len, char *dst)
 {
-	size_t i, pad_len, base64_len;
-	uint64 decoded_len;
-	char *base64;
-
-	/* Handle empty input specially */
-	if (len == 0)
-		return 0;
-
-	/* Calculate padding needed for standard base64 */
-	pad_len = 0;
-	if (len % 4 != 0)
-		pad_len = 4 - (len % 4);
-
-	/* Allocate memory for converted string */
-	base64_len = len + pad_len;
-	base64 = palloc(base64_len + 1); /* +1 for null terminator */
-
-	/* Convert Base64URL to Base64 */
-	for (i = 0; i < len; i++)
-	{
-		char c = src[i];
-		if (c == '-')
-			base64[i] = '+';  /* Convert '-' to '+' */
-		else if (c == '_')
-			base64[i] = '/';  /* Convert '_' to '/' */
-		else if ((c >= 'A' && c <= 'Z') ||
-				 (c >= 'a' && c <= 'z') ||
-				 (c >= '0' && c <= '9'))
-			base64[i] = c;    /* Keep alphanumeric chars unchanged */
-		else if (c == '=')
-			ereport(ERROR,
-				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
-				 errmsg("invalid base64url input"),
-				 errhint("Base64URL encoding should not contain padding '='.")));
-		else if (c == '+' || c == '/')
-			ereport(ERROR,
-				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
-				 errmsg("invalid base64url character: '%c'", c),
-				 errhint("Base64URL should use '-' instead of '+' and '_' instead of '/'.")));
-		else
-			ereport(ERROR,
-				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
-				 errmsg("invalid base64url character: '%c'", c)));
-	}
-
-	/* Add padding if necessary */
-	for (i = 0; i < pad_len; i++)
-		base64[len + i] = '=';
-
-	base64[base64_len] = '\0';  /* Null-terminate for safety */
-
-	/* Decode using the standard Base64 decoder */
-	decoded_len = pg_base64_decode(base64, base64_len, dst);
-
-	/* Free allocated memory */
-	pfree(base64);
-	return decoded_len;
+	return pg_base64_decode_internal(src, len, dst, true);
 }
 
 /*
-- 
2.49.0

