Refactors to isolate strcoll, wcscoll, and ucol_strcoll into
pg_locale.c which seems like a better place. Most of the buffer
manipulation and equality optimizations are still left in varlena.c.

Patch attached.

I'm not able to easily test on windows, but it should be close to
correct as I just moved some code around.

Is there a good way to look for regressions (perf or correctness) when
making changes in this area, especially on windows and/or with strange
collation rules? What I did doesn't look like a problem, but there's
always a chance of a regression.


-- 
Jeff Davis
PostgreSQL Contributor Team - AWS


From 644ee9ac28652884d91f17f4fc7755104538d9f1 Mon Sep 17 00:00:00 2001
From: Jeff Davis <j...@j-davis.com>
Date: Thu, 6 Oct 2022 10:46:36 -0700
Subject: [PATCH] Refactor: introduce pg_strcoll().

Isolate collation routines into pg_locale.c and simplify varlena.c.
---
 src/backend/utils/adt/pg_locale.c | 178 +++++++++++++++++++++++++
 src/backend/utils/adt/varlena.c   | 207 +-----------------------------
 src/include/utils/pg_locale.h     |   3 +-
 3 files changed, 182 insertions(+), 206 deletions(-)

diff --git a/src/backend/utils/adt/pg_locale.c b/src/backend/utils/adt/pg_locale.c
index 2b42d9ccd8..696ff75201 100644
--- a/src/backend/utils/adt/pg_locale.c
+++ b/src/backend/utils/adt/pg_locale.c
@@ -1639,6 +1639,184 @@ pg_newlocale_from_collation(Oid collid)
 	return cache_entry->locale;
 }
 
+/*
+ * win32_utf8_wcscoll
+ *
+ * Convert UTF8 arguments to wide characters and invoke wcscoll() or
+ * wcscoll_l(). Will allocate on large input.
+ */
+#ifdef WIN32
+#define TEXTBUFLEN		1024
+static int
+win32_utf8_wcscoll(const char *arg1, int len1, const char *arg2, int len2,
+				   pg_locale_t locale)
+{
+	char		a1buf[TEXTBUFLEN];
+	char		a2buf[TEXTBUFLEN];
+	char	   *a1p,
+			   *a2p;
+	int			a1len;
+	int			a2len;
+	int			r;
+	int			result;
+
+	if (len1 >= TEXTBUFLEN / 2)
+	{
+		a1len = len1 * 2 + 2;
+		a1p = palloc(a1len);
+	}
+	else
+	{
+		a1len = TEXTBUFLEN;
+		a1p = a1buf;
+	}
+	if (len2 >= TEXTBUFLEN / 2)
+	{
+		a2len = len2 * 2 + 2;
+		a2p = palloc(a2len);
+	}
+	else
+	{
+		a2len = TEXTBUFLEN;
+		a2p = a2buf;
+	}
+
+	/* API does not work for zero-length input */
+	if (len1 == 0)
+		r = 0;
+	else
+	{
+		r = MultiByteToWideChar(CP_UTF8, 0, arg1, len1,
+								(LPWSTR) a1p, a1len / 2);
+		if (!r)
+			ereport(ERROR,
+					(errmsg("could not convert string to UTF-16: error code %lu",
+							GetLastError())));
+	}
+	((LPWSTR) a1p)[r] = 0;
+
+	if (len2 == 0)
+		r = 0;
+	else
+	{
+		r = MultiByteToWideChar(CP_UTF8, 0, arg2, len2,
+								(LPWSTR) a2p, a2len / 2);
+		if (!r)
+			ereport(ERROR,
+					(errmsg("could not convert string to UTF-16: error code %lu",
+							GetLastError())));
+	}
+	((LPWSTR) a2p)[r] = 0;
+
+	errno = 0;
+#ifdef HAVE_LOCALE_T
+	if (locale)
+		result = wcscoll_l((LPWSTR) a1p, (LPWSTR) a2p, locale->info.lt);
+	else
+#endif
+		result = wcscoll((LPWSTR) a1p, (LPWSTR) a2p);
+	if (result == 2147483647)	/* _NLSCMPERROR; missing from mingw
+								 * headers */
+		ereport(ERROR,
+				(errmsg("could not compare Unicode strings: %m")));
+
+	if (a1p != a1buf)
+		pfree(a1p);
+	if (a2p != a2buf)
+		pfree(a2p);
+
+	return result;
+}
+#endif
+
+/*
+ * pg_strcoll
+ *
+ * Call ucol_strcollUTF8(), ucol_strcoll(), strcoll(), strcoll_l(), wcscoll(),
+ * or wcscoll_l() as appropriate for the given locale, platform, and database
+ * encoding. Arguments must be NUL-terminated. If the locale is not specified,
+ * use the database collation.
+ *
+ * If the collation is deterministic, break ties with memcmp(), and then with
+ * the string length.
+ */
+int
+pg_strcoll(const char *arg1, int len1, const char *arg2, int len2,
+		   pg_locale_t locale)
+{
+	int result;
+
+#ifdef WIN32
+	/* Win32 does not have UTF-8, so we need to map to UTF-16 */
+	if (GetDatabaseEncoding() == PG_UTF8
+		&& (!locale || locale->provider == COLLPROVIDER_LIBC))
+	{
+		result = win32_utf8_wcscoll(arg1, len1, arg2, len2, locale);
+	}
+	else
+#endif							/* WIN32 */
+	if (locale)
+	{
+		if (locale->provider == COLLPROVIDER_ICU)
+		{
+#ifdef USE_ICU
+#ifdef HAVE_UCOL_STRCOLLUTF8
+			if (GetDatabaseEncoding() == PG_UTF8)
+			{
+				UErrorCode	status;
+
+				status = U_ZERO_ERROR;
+				result = ucol_strcollUTF8(locale->info.icu.ucol,
+										  arg1, len1,
+										  arg2, len2,
+										  &status);
+				if (U_FAILURE(status))
+					ereport(ERROR,
+							(errmsg("collation failed: %s", u_errorName(status))));
+			}
+			else
+#endif
+			{
+				int32_t		ulen1,
+							ulen2;
+				UChar	   *uchar1,
+						   *uchar2;
+
+				ulen1 = icu_to_uchar(&uchar1, arg1, len1);
+				ulen2 = icu_to_uchar(&uchar2, arg2, len2);
+
+				result = ucol_strcoll(locale->info.icu.ucol,
+									  uchar1, ulen1,
+									  uchar2, ulen2);
+
+				pfree(uchar1);
+				pfree(uchar2);
+			}
+#else							/* not USE_ICU */
+			/* shouldn't happen */
+			elog(ERROR, "unsupported collprovider: %c", locale->provider);
+#endif							/* not USE_ICU */
+		}
+		else
+		{
+#ifdef HAVE_LOCALE_T
+			result = strcoll_l(arg1, arg2, locale->info.lt);
+#else
+			/* shouldn't happen */
+			elog(ERROR, "unsupported collprovider: %c", locale->provider);
+#endif
+		}
+	}
+	else
+		result = strcoll(arg1, arg2);
+
+	/* Break tie if necessary. */
+	if (result == 0 && (!locale || locale->deterministic))
+		result = strcmp(arg1, arg2);
+
+	return result;
+}
+
 /*
  * Get provider-specific collation version string for the given collation from
  * the operating system/library.
diff --git a/src/backend/utils/adt/varlena.c b/src/backend/utils/adt/varlena.c
index 1f6e090821..5be7eaee9f 100644
--- a/src/backend/utils/adt/varlena.c
+++ b/src/backend/utils/adt/varlena.c
@@ -1555,93 +1555,6 @@ varstr_cmp(const char *arg1, int len1, const char *arg2, int len2, Oid collid)
 		if (len1 == len2 && memcmp(arg1, arg2, len1) == 0)
 			return 0;
 
-#ifdef WIN32
-		/* Win32 does not have UTF-8, so we need to map to UTF-16 */
-		if (GetDatabaseEncoding() == PG_UTF8
-			&& (!mylocale || mylocale->provider == COLLPROVIDER_LIBC))
-		{
-			int			a1len;
-			int			a2len;
-			int			r;
-
-			if (len1 >= TEXTBUFLEN / 2)
-			{
-				a1len = len1 * 2 + 2;
-				a1p = palloc(a1len);
-			}
-			else
-			{
-				a1len = TEXTBUFLEN;
-				a1p = a1buf;
-			}
-			if (len2 >= TEXTBUFLEN / 2)
-			{
-				a2len = len2 * 2 + 2;
-				a2p = palloc(a2len);
-			}
-			else
-			{
-				a2len = TEXTBUFLEN;
-				a2p = a2buf;
-			}
-
-			/* stupid Microsloth API does not work for zero-length input */
-			if (len1 == 0)
-				r = 0;
-			else
-			{
-				r = MultiByteToWideChar(CP_UTF8, 0, arg1, len1,
-										(LPWSTR) a1p, a1len / 2);
-				if (!r)
-					ereport(ERROR,
-							(errmsg("could not convert string to UTF-16: error code %lu",
-									GetLastError())));
-			}
-			((LPWSTR) a1p)[r] = 0;
-
-			if (len2 == 0)
-				r = 0;
-			else
-			{
-				r = MultiByteToWideChar(CP_UTF8, 0, arg2, len2,
-										(LPWSTR) a2p, a2len / 2);
-				if (!r)
-					ereport(ERROR,
-							(errmsg("could not convert string to UTF-16: error code %lu",
-									GetLastError())));
-			}
-			((LPWSTR) a2p)[r] = 0;
-
-			errno = 0;
-#ifdef HAVE_LOCALE_T
-			if (mylocale)
-				result = wcscoll_l((LPWSTR) a1p, (LPWSTR) a2p, mylocale->info.lt);
-			else
-#endif
-				result = wcscoll((LPWSTR) a1p, (LPWSTR) a2p);
-			if (result == 2147483647)	/* _NLSCMPERROR; missing from mingw
-										 * headers */
-				ereport(ERROR,
-						(errmsg("could not compare Unicode strings: %m")));
-
-			/* Break tie if necessary. */
-			if (result == 0 &&
-				(!mylocale || mylocale->deterministic))
-			{
-				result = memcmp(arg1, arg2, Min(len1, len2));
-				if ((result == 0) && (len1 != len2))
-					result = (len1 < len2) ? -1 : 1;
-			}
-
-			if (a1p != a1buf)
-				pfree(a1p);
-			if (a2p != a2buf)
-				pfree(a2p);
-
-			return result;
-		}
-#endif							/* WIN32 */
-
 		if (len1 >= TEXTBUFLEN)
 			a1p = (char *) palloc(len1 + 1);
 		else
@@ -1656,65 +1569,7 @@ varstr_cmp(const char *arg1, int len1, const char *arg2, int len2, Oid collid)
 		memcpy(a2p, arg2, len2);
 		a2p[len2] = '\0';
 
-		if (mylocale)
-		{
-			if (mylocale->provider == COLLPROVIDER_ICU)
-			{
-#ifdef USE_ICU
-#ifdef HAVE_UCOL_STRCOLLUTF8
-				if (GetDatabaseEncoding() == PG_UTF8)
-				{
-					UErrorCode	status;
-
-					status = U_ZERO_ERROR;
-					result = ucol_strcollUTF8(mylocale->info.icu.ucol,
-											  arg1, len1,
-											  arg2, len2,
-											  &status);
-					if (U_FAILURE(status))
-						ereport(ERROR,
-								(errmsg("collation failed: %s", u_errorName(status))));
-				}
-				else
-#endif
-				{
-					int32_t		ulen1,
-								ulen2;
-					UChar	   *uchar1,
-							   *uchar2;
-
-					ulen1 = icu_to_uchar(&uchar1, arg1, len1);
-					ulen2 = icu_to_uchar(&uchar2, arg2, len2);
-
-					result = ucol_strcoll(mylocale->info.icu.ucol,
-										  uchar1, ulen1,
-										  uchar2, ulen2);
-
-					pfree(uchar1);
-					pfree(uchar2);
-				}
-#else							/* not USE_ICU */
-				/* shouldn't happen */
-				elog(ERROR, "unsupported collprovider: %c", mylocale->provider);
-#endif							/* not USE_ICU */
-			}
-			else
-			{
-#ifdef HAVE_LOCALE_T
-				result = strcoll_l(a1p, a2p, mylocale->info.lt);
-#else
-				/* shouldn't happen */
-				elog(ERROR, "unsupported collprovider: %c", mylocale->provider);
-#endif
-			}
-		}
-		else
-			result = strcoll(a1p, a2p);
-
-		/* Break tie if necessary. */
-		if (result == 0 &&
-			(!mylocale || mylocale->deterministic))
-			result = strcmp(a1p, a2p);
+		result = pg_strcoll(a1p, len1, a2p, len2, mylocale);
 
 		if (a1p != a1buf)
 			pfree(a1p);
@@ -2377,65 +2232,7 @@ varstrfastcmp_locale(char *a1p, int len1, char *a2p, int len2, SortSupport ssup)
 		return sss->last_returned;
 	}
 
-	if (sss->locale)
-	{
-		if (sss->locale->provider == COLLPROVIDER_ICU)
-		{
-#ifdef USE_ICU
-#ifdef HAVE_UCOL_STRCOLLUTF8
-			if (GetDatabaseEncoding() == PG_UTF8)
-			{
-				UErrorCode	status;
-
-				status = U_ZERO_ERROR;
-				result = ucol_strcollUTF8(sss->locale->info.icu.ucol,
-										  a1p, len1,
-										  a2p, len2,
-										  &status);
-				if (U_FAILURE(status))
-					ereport(ERROR,
-							(errmsg("collation failed: %s", u_errorName(status))));
-			}
-			else
-#endif
-			{
-				int32_t		ulen1,
-							ulen2;
-				UChar	   *uchar1,
-						   *uchar2;
-
-				ulen1 = icu_to_uchar(&uchar1, a1p, len1);
-				ulen2 = icu_to_uchar(&uchar2, a2p, len2);
-
-				result = ucol_strcoll(sss->locale->info.icu.ucol,
-									  uchar1, ulen1,
-									  uchar2, ulen2);
-
-				pfree(uchar1);
-				pfree(uchar2);
-			}
-#else							/* not USE_ICU */
-			/* shouldn't happen */
-			elog(ERROR, "unsupported collprovider: %c", sss->locale->provider);
-#endif							/* not USE_ICU */
-		}
-		else
-		{
-#ifdef HAVE_LOCALE_T
-			result = strcoll_l(sss->buf1, sss->buf2, sss->locale->info.lt);
-#else
-			/* shouldn't happen */
-			elog(ERROR, "unsupported collprovider: %c", sss->locale->provider);
-#endif
-		}
-	}
-	else
-		result = strcoll(sss->buf1, sss->buf2);
-
-	/* Break tie if necessary. */
-	if (result == 0 &&
-		(!sss->locale || sss->locale->deterministic))
-		result = strcmp(sss->buf1, sss->buf2);
+	result = pg_strcoll(sss->buf1, len1, sss->buf2, len2, sss->locale);
 
 	/* Cache result, perhaps saving an expensive strcoll() call next time */
 	sss->cache_blob = false;
diff --git a/src/include/utils/pg_locale.h b/src/include/utils/pg_locale.h
index a875942123..d75b16a27f 100644
--- a/src/include/utils/pg_locale.h
+++ b/src/include/utils/pg_locale.h
@@ -98,7 +98,8 @@ extern void make_icu_collator(const char *iculocstr,
 							  struct pg_locale_struct *resultp);
 
 extern pg_locale_t pg_newlocale_from_collation(Oid collid);
-
+extern int pg_strcoll(const char *arg1, int len1, const char *arg2, int len2,
+					  pg_locale_t locale);
 extern char *get_collation_actual_version(char collprovider, const char *collcollate);
 
 #ifdef USE_ICU
-- 
2.34.1

Reply via email to