On Thu, 2022-10-13 at 10:57 +0200, Peter Eisentraut wrote:
> It's a bit confusing that arguments must be NUL-terminated, but the 
> length is still specified.  Maybe another sentence to explain that
> would 
> be helpful.

Added a comment. It was a little frustrating to get a perfectly clean
API, because the callers do some buffer manipulation and optimizations
of their own. I think this is an improvement, but suggestions welcome.

If win32 is used with UTF-8 and wcscoll, it ends up allocating some
extra stack space for the temporary buffers, whereas previously it used
the buffers on the stack of varstr_cmp(). I'm not sure if that's a
problem or not.

> The length arguments ought to be of type size_t, I think.

Changed.

Thank you.


-- 
Jeff Davis
PostgreSQL Contributor Team - AWS


From 4d5664552a8a86418a94c37fd4ab8ca3a665c1cd Mon Sep 17 00:00:00 2001
From: Jeff Davis <j...@j-davis.com>
Date: Thu, 6 Oct 2022 10:46:36 -0700
Subject: [PATCH] Refactor: introduce pg_strcoll().

Isolate collation routines into pg_locale.c and simplify varlena.c.
---
 src/backend/utils/adt/pg_locale.c | 180 ++++++++++++++++++++++++++
 src/backend/utils/adt/varlena.c   | 207 +-----------------------------
 src/include/utils/pg_locale.h     |   3 +-
 3 files changed, 184 insertions(+), 206 deletions(-)

diff --git a/src/backend/utils/adt/pg_locale.c b/src/backend/utils/adt/pg_locale.c
index 2b42d9ccd8..3eb6a67bdc 100644
--- a/src/backend/utils/adt/pg_locale.c
+++ b/src/backend/utils/adt/pg_locale.c
@@ -1639,6 +1639,186 @@ pg_newlocale_from_collation(Oid collid)
 	return cache_entry->locale;
 }
 
+/*
+ * win32_utf8_wcscoll
+ *
+ * Convert UTF8 arguments to wide characters and invoke wcscoll() or
+ * wcscoll_l(). Will allocate on large input.
+ */
+#ifdef WIN32
+#define TEXTBUFLEN		1024
+static int
+win32_utf8_wcscoll(const char *arg1, size_t len1, const char *arg2,
+				   size_t len2, pg_locale_t locale)
+{
+	char		a1buf[TEXTBUFLEN];
+	char		a2buf[TEXTBUFLEN];
+	char	   *a1p,
+			   *a2p;
+	int			a1len;
+	int			a2len;
+	int			r;
+	int			result;
+
+	if (len1 >= TEXTBUFLEN / 2)
+	{
+		a1len = len1 * 2 + 2;
+		a1p = palloc(a1len);
+	}
+	else
+	{
+		a1len = TEXTBUFLEN;
+		a1p = a1buf;
+	}
+	if (len2 >= TEXTBUFLEN / 2)
+	{
+		a2len = len2 * 2 + 2;
+		a2p = palloc(a2len);
+	}
+	else
+	{
+		a2len = TEXTBUFLEN;
+		a2p = a2buf;
+	}
+
+	/* API does not work for zero-length input */
+	if (len1 == 0)
+		r = 0;
+	else
+	{
+		r = MultiByteToWideChar(CP_UTF8, 0, arg1, len1,
+								(LPWSTR) a1p, a1len / 2);
+		if (!r)
+			ereport(ERROR,
+					(errmsg("could not convert string to UTF-16: error code %lu",
+							GetLastError())));
+	}
+	((LPWSTR) a1p)[r] = 0;
+
+	if (len2 == 0)
+		r = 0;
+	else
+	{
+		r = MultiByteToWideChar(CP_UTF8, 0, arg2, len2,
+								(LPWSTR) a2p, a2len / 2);
+		if (!r)
+			ereport(ERROR,
+					(errmsg("could not convert string to UTF-16: error code %lu",
+							GetLastError())));
+	}
+	((LPWSTR) a2p)[r] = 0;
+
+	errno = 0;
+#ifdef HAVE_LOCALE_T
+	if (locale)
+		result = wcscoll_l((LPWSTR) a1p, (LPWSTR) a2p, locale->info.lt);
+	else
+#endif
+		result = wcscoll((LPWSTR) a1p, (LPWSTR) a2p);
+	if (result == 2147483647)	/* _NLSCMPERROR; missing from mingw
+								 * headers */
+		ereport(ERROR,
+				(errmsg("could not compare Unicode strings: %m")));
+
+	if (a1p != a1buf)
+		pfree(a1p);
+	if (a2p != a2buf)
+		pfree(a2p);
+
+	return result;
+}
+#endif
+
+/*
+ * pg_strcoll
+ *
+ * Call ucol_strcollUTF8(), ucol_strcoll(), strcoll(), strcoll_l(), wcscoll(),
+ * or wcscoll_l() as appropriate for the given locale, platform, and database
+ * encoding. If the locale is not specified, use the database collation.
+ *
+ * Arguments must be NUL-terminated so they can be passed directly to
+ * strcoll(); but we also need the lengths to pass to ucol_strcoll().
+ *
+ * If the collation is deterministic, break ties with memcmp(), and then with
+ * the string length.
+ */
+int
+pg_strcoll(const char *arg1, size_t len1, const char *arg2, size_t len2,
+		   pg_locale_t locale)
+{
+	int result;
+
+#ifdef WIN32
+	/* Win32 does not have UTF-8, so we need to map to UTF-16 */
+	if (GetDatabaseEncoding() == PG_UTF8
+		&& (!locale || locale->provider == COLLPROVIDER_LIBC))
+	{
+		result = win32_utf8_wcscoll(arg1, len1, arg2, len2, locale);
+	}
+	else
+#endif							/* WIN32 */
+	if (locale)
+	{
+		if (locale->provider == COLLPROVIDER_ICU)
+		{
+#ifdef USE_ICU
+#ifdef HAVE_UCOL_STRCOLLUTF8
+			if (GetDatabaseEncoding() == PG_UTF8)
+			{
+				UErrorCode	status;
+
+				status = U_ZERO_ERROR;
+				result = ucol_strcollUTF8(locale->info.icu.ucol,
+										  arg1, len1,
+										  arg2, len2,
+										  &status);
+				if (U_FAILURE(status))
+					ereport(ERROR,
+							(errmsg("collation failed: %s", u_errorName(status))));
+			}
+			else
+#endif
+			{
+				int32_t		ulen1,
+							ulen2;
+				UChar	   *uchar1,
+						   *uchar2;
+
+				ulen1 = icu_to_uchar(&uchar1, arg1, len1);
+				ulen2 = icu_to_uchar(&uchar2, arg2, len2);
+
+				result = ucol_strcoll(locale->info.icu.ucol,
+									  uchar1, ulen1,
+									  uchar2, ulen2);
+
+				pfree(uchar1);
+				pfree(uchar2);
+			}
+#else							/* not USE_ICU */
+			/* shouldn't happen */
+			elog(ERROR, "unsupported collprovider: %c", locale->provider);
+#endif							/* not USE_ICU */
+		}
+		else
+		{
+#ifdef HAVE_LOCALE_T
+			result = strcoll_l(arg1, arg2, locale->info.lt);
+#else
+			/* shouldn't happen */
+			elog(ERROR, "unsupported collprovider: %c", locale->provider);
+#endif
+		}
+	}
+	else
+		result = strcoll(arg1, arg2);
+
+	/* Break tie if necessary. */
+	if (result == 0 && (!locale || locale->deterministic))
+		result = strcmp(arg1, arg2);
+
+	return result;
+}
+
 /*
  * Get provider-specific collation version string for the given collation from
  * the operating system/library.
diff --git a/src/backend/utils/adt/varlena.c b/src/backend/utils/adt/varlena.c
index 1f6e090821..5be7eaee9f 100644
--- a/src/backend/utils/adt/varlena.c
+++ b/src/backend/utils/adt/varlena.c
@@ -1555,93 +1555,6 @@ varstr_cmp(const char *arg1, int len1, const char *arg2, int len2, Oid collid)
 		if (len1 == len2 && memcmp(arg1, arg2, len1) == 0)
 			return 0;
 
-#ifdef WIN32
-		/* Win32 does not have UTF-8, so we need to map to UTF-16 */
-		if (GetDatabaseEncoding() == PG_UTF8
-			&& (!mylocale || mylocale->provider == COLLPROVIDER_LIBC))
-		{
-			int			a1len;
-			int			a2len;
-			int			r;
-
-			if (len1 >= TEXTBUFLEN / 2)
-			{
-				a1len = len1 * 2 + 2;
-				a1p = palloc(a1len);
-			}
-			else
-			{
-				a1len = TEXTBUFLEN;
-				a1p = a1buf;
-			}
-			if (len2 >= TEXTBUFLEN / 2)
-			{
-				a2len = len2 * 2 + 2;
-				a2p = palloc(a2len);
-			}
-			else
-			{
-				a2len = TEXTBUFLEN;
-				a2p = a2buf;
-			}
-
-			/* stupid Microsloth API does not work for zero-length input */
-			if (len1 == 0)
-				r = 0;
-			else
-			{
-				r = MultiByteToWideChar(CP_UTF8, 0, arg1, len1,
-										(LPWSTR) a1p, a1len / 2);
-				if (!r)
-					ereport(ERROR,
-							(errmsg("could not convert string to UTF-16: error code %lu",
-									GetLastError())));
-			}
-			((LPWSTR) a1p)[r] = 0;
-
-			if (len2 == 0)
-				r = 0;
-			else
-			{
-				r = MultiByteToWideChar(CP_UTF8, 0, arg2, len2,
-										(LPWSTR) a2p, a2len / 2);
-				if (!r)
-					ereport(ERROR,
-							(errmsg("could not convert string to UTF-16: error code %lu",
-									GetLastError())));
-			}
-			((LPWSTR) a2p)[r] = 0;
-
-			errno = 0;
-#ifdef HAVE_LOCALE_T
-			if (mylocale)
-				result = wcscoll_l((LPWSTR) a1p, (LPWSTR) a2p, mylocale->info.lt);
-			else
-#endif
-				result = wcscoll((LPWSTR) a1p, (LPWSTR) a2p);
-			if (result == 2147483647)	/* _NLSCMPERROR; missing from mingw
-										 * headers */
-				ereport(ERROR,
-						(errmsg("could not compare Unicode strings: %m")));
-
-			/* Break tie if necessary. */
-			if (result == 0 &&
-				(!mylocale || mylocale->deterministic))
-			{
-				result = memcmp(arg1, arg2, Min(len1, len2));
-				if ((result == 0) && (len1 != len2))
-					result = (len1 < len2) ? -1 : 1;
-			}
-
-			if (a1p != a1buf)
-				pfree(a1p);
-			if (a2p != a2buf)
-				pfree(a2p);
-
-			return result;
-		}
-#endif							/* WIN32 */
-
 		if (len1 >= TEXTBUFLEN)
 			a1p = (char *) palloc(len1 + 1);
 		else
@@ -1656,65 +1569,7 @@ varstr_cmp(const char *arg1, int len1, const char *arg2, int len2, Oid collid)
 		memcpy(a2p, arg2, len2);
 		a2p[len2] = '\0';
 
-		if (mylocale)
-		{
-			if (mylocale->provider == COLLPROVIDER_ICU)
-			{
-#ifdef USE_ICU
-#ifdef HAVE_UCOL_STRCOLLUTF8
-				if (GetDatabaseEncoding() == PG_UTF8)
-				{
-					UErrorCode	status;
-
-					status = U_ZERO_ERROR;
-					result = ucol_strcollUTF8(mylocale->info.icu.ucol,
-											  arg1, len1,
-											  arg2, len2,
-											  &status);
-					if (U_FAILURE(status))
-						ereport(ERROR,
-								(errmsg("collation failed: %s", u_errorName(status))));
-				}
-				else
-#endif
-				{
-					int32_t		ulen1,
-								ulen2;
-					UChar	   *uchar1,
-							   *uchar2;
-
-					ulen1 = icu_to_uchar(&uchar1, arg1, len1);
-					ulen2 = icu_to_uchar(&uchar2, arg2, len2);
-
-					result = ucol_strcoll(mylocale->info.icu.ucol,
-										  uchar1, ulen1,
-										  uchar2, ulen2);
-
-					pfree(uchar1);
-					pfree(uchar2);
-				}
-#else							/* not USE_ICU */
-				/* shouldn't happen */
-				elog(ERROR, "unsupported collprovider: %c", mylocale->provider);
-#endif							/* not USE_ICU */
-			}
-			else
-			{
-#ifdef HAVE_LOCALE_T
-				result = strcoll_l(a1p, a2p, mylocale->info.lt);
-#else
-				/* shouldn't happen */
-				elog(ERROR, "unsupported collprovider: %c", mylocale->provider);
-#endif
-			}
-		}
-		else
-			result = strcoll(a1p, a2p);
-
-		/* Break tie if necessary. */
-		if (result == 0 &&
-			(!mylocale || mylocale->deterministic))
-			result = strcmp(a1p, a2p);
+		result = pg_strcoll(a1p, len1, a2p, len2, mylocale);
 
 		if (a1p != a1buf)
 			pfree(a1p);
@@ -2377,65 +2232,7 @@ varstrfastcmp_locale(char *a1p, int len1, char *a2p, int len2, SortSupport ssup)
 		return sss->last_returned;
 	}
 
-	if (sss->locale)
-	{
-		if (sss->locale->provider == COLLPROVIDER_ICU)
-		{
-#ifdef USE_ICU
-#ifdef HAVE_UCOL_STRCOLLUTF8
-			if (GetDatabaseEncoding() == PG_UTF8)
-			{
-				UErrorCode	status;
-
-				status = U_ZERO_ERROR;
-				result = ucol_strcollUTF8(sss->locale->info.icu.ucol,
-										  a1p, len1,
-										  a2p, len2,
-										  &status);
-				if (U_FAILURE(status))
-					ereport(ERROR,
-							(errmsg("collation failed: %s", u_errorName(status))));
-			}
-			else
-#endif
-			{
-				int32_t		ulen1,
-							ulen2;
-				UChar	   *uchar1,
-						   *uchar2;
-
-				ulen1 = icu_to_uchar(&uchar1, a1p, len1);
-				ulen2 = icu_to_uchar(&uchar2, a2p, len2);
-
-				result = ucol_strcoll(sss->locale->info.icu.ucol,
-									  uchar1, ulen1,
-									  uchar2, ulen2);
-
-				pfree(uchar1);
-				pfree(uchar2);
-			}
-#else							/* not USE_ICU */
-			/* shouldn't happen */
-			elog(ERROR, "unsupported collprovider: %c", sss->locale->provider);
-#endif							/* not USE_ICU */
-		}
-		else
-		{
-#ifdef HAVE_LOCALE_T
-			result = strcoll_l(sss->buf1, sss->buf2, sss->locale->info.lt);
-#else
-			/* shouldn't happen */
-			elog(ERROR, "unsupported collprovider: %c", sss->locale->provider);
-#endif
-		}
-	}
-	else
-		result = strcoll(sss->buf1, sss->buf2);
-
-	/* Break tie if necessary. */
-	if (result == 0 &&
-		(!sss->locale || sss->locale->deterministic))
-		result = strcmp(sss->buf1, sss->buf2);
+	result = pg_strcoll(sss->buf1, len1, sss->buf2, len2, sss->locale);
 
 	/* Cache result, perhaps saving an expensive strcoll() call next time */
 	sss->cache_blob = false;
diff --git a/src/include/utils/pg_locale.h b/src/include/utils/pg_locale.h
index a875942123..59a4c9ad0d 100644
--- a/src/include/utils/pg_locale.h
+++ b/src/include/utils/pg_locale.h
@@ -98,7 +98,8 @@ extern void make_icu_collator(const char *iculocstr,
 							  struct pg_locale_struct *resultp);
 
 extern pg_locale_t pg_newlocale_from_collation(Oid collid);
-
+extern int pg_strcoll(const char *arg1, size_t len1, const char *arg2,
+					  size_t len2, pg_locale_t locale);
 extern char *get_collation_actual_version(char collprovider, const char *collcollate);
 
 #ifdef USE_ICU
-- 
2.34.1

Reply via email to