On Thu, 2022-10-13 at 10:57 +0200, Peter Eisentraut wrote: > It's a bit confusing that arguments must be NUL-terminated, but the > length is still specified. Maybe another sentence to explain that > would > be helpful.
Added a comment. It was a little frustrating to get a perfectly clean API, because the callers do some buffer manipulation and optimizations of their own. I think this is an improvement, but suggestions welcome. If win32 is used with UTF-8 and wcscoll, it ends up allocating some extra stack space for the temporary buffers, whereas previously it used the buffers on the stack of varstr_cmp(). I'm not sure if that's a problem or not. > The length arguments ought to be of type size_t, I think. Changed. Thank you. -- Jeff Davis PostgreSQL Contributor Team - AWS
From 4d5664552a8a86418a94c37fd4ab8ca3a665c1cd Mon Sep 17 00:00:00 2001 From: Jeff Davis <j...@j-davis.com> Date: Thu, 6 Oct 2022 10:46:36 -0700 Subject: [PATCH] Refactor: introduce pg_strcoll(). Isolate collation routines into pg_locale.c and simplify varlena.c. --- src/backend/utils/adt/pg_locale.c | 180 ++++++++++++++++++++++++++ src/backend/utils/adt/varlena.c | 207 +----------------------------- src/include/utils/pg_locale.h | 3 +- 3 files changed, 184 insertions(+), 206 deletions(-) diff --git a/src/backend/utils/adt/pg_locale.c b/src/backend/utils/adt/pg_locale.c index 2b42d9ccd8..3eb6a67bdc 100644 --- a/src/backend/utils/adt/pg_locale.c +++ b/src/backend/utils/adt/pg_locale.c @@ -1639,6 +1639,186 @@ pg_newlocale_from_collation(Oid collid) return cache_entry->locale; } +/* + * win32_utf8_wcscoll + * + * Convert UTF8 arguments to wide characters and invoke wcscoll() or + * wcscoll_l(). Will allocate on large input. + */ +#ifdef WIN32 +#define TEXTBUFLEN 1024 +static int +win32_utf8_wcscoll(const char *arg1, size_t len1, const char *arg2, + size_t len2, pg_locale_t locale) +{ + char a1buf[TEXTBUFLEN]; + char a2buf[TEXTBUFLEN]; + char *a1p, + *a2p; + int a1len; + int a2len; + int r; + int result; + + if (len1 >= TEXTBUFLEN / 2) + { + a1len = len1 * 2 + 2; + a1p = palloc(a1len); + } + else + { + a1len = TEXTBUFLEN; + a1p = a1buf; + } + if (len2 >= TEXTBUFLEN / 2) + { + a2len = len2 * 2 + 2; + a2p = palloc(a2len); + } + else + { + a2len = TEXTBUFLEN; + a2p = a2buf; + } + + /* API does not work for zero-length input */ + if (len1 == 0) + r = 0; + else + { + r = MultiByteToWideChar(CP_UTF8, 0, arg1, len1, + (LPWSTR) a1p, a1len / 2); + if (!r) + ereport(ERROR, + (errmsg("could not convert string to UTF-16: error code %lu", + GetLastError()))); + } + ((LPWSTR) a1p)[r] = 0; + + if (len2 == 0) + r = 0; + else + { + r = MultiByteToWideChar(CP_UTF8, 0, arg2, len2, + (LPWSTR) a2p, a2len / 2); + if (!r) + ereport(ERROR, + (errmsg("could not convert string to UTF-16: error code %lu", + GetLastError()))); + } + ((LPWSTR) a2p)[r] = 0; + + errno = 0; +#ifdef HAVE_LOCALE_T + if (locale) + result = wcscoll_l((LPWSTR) a1p, (LPWSTR) a2p, locale->info.lt); + else +#endif + result = wcscoll((LPWSTR) a1p, (LPWSTR) a2p); + if (result == 2147483647) /* _NLSCMPERROR; missing from mingw + * headers */ + ereport(ERROR, + (errmsg("could not compare Unicode strings: %m"))); + + if (a1p != a1buf) + pfree(a1p); + if (a2p != a2buf) + pfree(a2p); + + return result; +} +#endif + +/* + * pg_strcoll + * + * Call ucol_strcollUTF8(), ucol_strcoll(), strcoll(), strcoll_l(), wcscoll(), + * or wcscoll_l() as appropriate for the given locale, platform, and database + * encoding. If the locale is not specified, use the database collation. + * + * Arguments must be NUL-terminated so they can be passed directly to + * strcoll(); but we also need the lengths to pass to ucol_strcoll(). + * + * If the collation is deterministic, break ties with memcmp(), and then with + * the string length. + */ +int +pg_strcoll(const char *arg1, size_t len1, const char *arg2, size_t len2, + pg_locale_t locale) +{ + int result; + +#ifdef WIN32 + /* Win32 does not have UTF-8, so we need to map to UTF-16 */ + if (GetDatabaseEncoding() == PG_UTF8 + && (!locale || locale->provider == COLLPROVIDER_LIBC)) + { + result = win32_utf8_wcscoll(arg1, len1, arg2, len2, locale); + } + else +#endif /* WIN32 */ + if (locale) + { + if (locale->provider == COLLPROVIDER_ICU) + { +#ifdef USE_ICU +#ifdef HAVE_UCOL_STRCOLLUTF8 + if (GetDatabaseEncoding() == PG_UTF8) + { + UErrorCode status; + + status = U_ZERO_ERROR; + result = ucol_strcollUTF8(locale->info.icu.ucol, + arg1, len1, + arg2, len2, + &status); + if (U_FAILURE(status)) + ereport(ERROR, + (errmsg("collation failed: %s", u_errorName(status)))); + } + else +#endif + { + int32_t ulen1, + ulen2; + UChar *uchar1, + *uchar2; + + ulen1 = icu_to_uchar(&uchar1, arg1, len1); + ulen2 = icu_to_uchar(&uchar2, arg2, len2); + + result = ucol_strcoll(locale->info.icu.ucol, + uchar1, ulen1, + uchar2, ulen2); + + pfree(uchar1); + pfree(uchar2); + } +#else /* not USE_ICU */ + /* shouldn't happen */ + elog(ERROR, "unsupported collprovider: %c", locale->provider); +#endif /* not USE_ICU */ + } + else + { +#ifdef HAVE_LOCALE_T + result = strcoll_l(arg1, arg2, locale->info.lt); +#else + /* shouldn't happen */ + elog(ERROR, "unsupported collprovider: %c", locale->provider); +#endif + } + } + else + result = strcoll(arg1, arg2); + + /* Break tie if necessary. */ + if (result == 0 && (!locale || locale->deterministic)) + result = strcmp(arg1, arg2); + + return result; +} + /* * Get provider-specific collation version string for the given collation from * the operating system/library. diff --git a/src/backend/utils/adt/varlena.c b/src/backend/utils/adt/varlena.c index 1f6e090821..5be7eaee9f 100644 --- a/src/backend/utils/adt/varlena.c +++ b/src/backend/utils/adt/varlena.c @@ -1555,93 +1555,6 @@ varstr_cmp(const char *arg1, int len1, const char *arg2, int len2, Oid collid) if (len1 == len2 && memcmp(arg1, arg2, len1) == 0) return 0; -#ifdef WIN32 - /* Win32 does not have UTF-8, so we need to map to UTF-16 */ - if (GetDatabaseEncoding() == PG_UTF8 - && (!mylocale || mylocale->provider == COLLPROVIDER_LIBC)) - { - int a1len; - int a2len; - int r; - - if (len1 >= TEXTBUFLEN / 2) - { - a1len = len1 * 2 + 2; - a1p = palloc(a1len); - } - else - { - a1len = TEXTBUFLEN; - a1p = a1buf; - } - if (len2 >= TEXTBUFLEN / 2) - { - a2len = len2 * 2 + 2; - a2p = palloc(a2len); - } - else - { - a2len = TEXTBUFLEN; - a2p = a2buf; - } - - /* stupid Microsloth API does not work for zero-length input */ - if (len1 == 0) - r = 0; - else - { - r = MultiByteToWideChar(CP_UTF8, 0, arg1, len1, - (LPWSTR) a1p, a1len / 2); - if (!r) - ereport(ERROR, - (errmsg("could not convert string to UTF-16: error code %lu", - GetLastError()))); - } - ((LPWSTR) a1p)[r] = 0; - - if (len2 == 0) - r = 0; - else - { - r = MultiByteToWideChar(CP_UTF8, 0, arg2, len2, - (LPWSTR) a2p, a2len / 2); - if (!r) - ereport(ERROR, - (errmsg("could not convert string to UTF-16: error code %lu", - GetLastError()))); - } - ((LPWSTR) a2p)[r] = 0; - - errno = 0; -#ifdef HAVE_LOCALE_T - if (mylocale) - result = wcscoll_l((LPWSTR) a1p, (LPWSTR) a2p, mylocale->info.lt); - else -#endif - result = wcscoll((LPWSTR) a1p, (LPWSTR) a2p); - if (result == 2147483647) /* _NLSCMPERROR; missing from mingw - * headers */ - ereport(ERROR, - (errmsg("could not compare Unicode strings: %m"))); - - /* Break tie if necessary. */ - if (result == 0 && - (!mylocale || mylocale->deterministic)) - { - result = memcmp(arg1, arg2, Min(len1, len2)); - if ((result == 0) && (len1 != len2)) - result = (len1 < len2) ? -1 : 1; - } - - if (a1p != a1buf) - pfree(a1p); - if (a2p != a2buf) - pfree(a2p); - - return result; - } -#endif /* WIN32 */ - if (len1 >= TEXTBUFLEN) a1p = (char *) palloc(len1 + 1); else @@ -1656,65 +1569,7 @@ varstr_cmp(const char *arg1, int len1, const char *arg2, int len2, Oid collid) memcpy(a2p, arg2, len2); a2p[len2] = '\0'; - if (mylocale) - { - if (mylocale->provider == COLLPROVIDER_ICU) - { -#ifdef USE_ICU -#ifdef HAVE_UCOL_STRCOLLUTF8 - if (GetDatabaseEncoding() == PG_UTF8) - { - UErrorCode status; - - status = U_ZERO_ERROR; - result = ucol_strcollUTF8(mylocale->info.icu.ucol, - arg1, len1, - arg2, len2, - &status); - if (U_FAILURE(status)) - ereport(ERROR, - (errmsg("collation failed: %s", u_errorName(status)))); - } - else -#endif - { - int32_t ulen1, - ulen2; - UChar *uchar1, - *uchar2; - - ulen1 = icu_to_uchar(&uchar1, arg1, len1); - ulen2 = icu_to_uchar(&uchar2, arg2, len2); - - result = ucol_strcoll(mylocale->info.icu.ucol, - uchar1, ulen1, - uchar2, ulen2); - - pfree(uchar1); - pfree(uchar2); - } -#else /* not USE_ICU */ - /* shouldn't happen */ - elog(ERROR, "unsupported collprovider: %c", mylocale->provider); -#endif /* not USE_ICU */ - } - else - { -#ifdef HAVE_LOCALE_T - result = strcoll_l(a1p, a2p, mylocale->info.lt); -#else - /* shouldn't happen */ - elog(ERROR, "unsupported collprovider: %c", mylocale->provider); -#endif - } - } - else - result = strcoll(a1p, a2p); - - /* Break tie if necessary. */ - if (result == 0 && - (!mylocale || mylocale->deterministic)) - result = strcmp(a1p, a2p); + result = pg_strcoll(a1p, len1, a2p, len2, mylocale); if (a1p != a1buf) pfree(a1p); @@ -2377,65 +2232,7 @@ varstrfastcmp_locale(char *a1p, int len1, char *a2p, int len2, SortSupport ssup) return sss->last_returned; } - if (sss->locale) - { - if (sss->locale->provider == COLLPROVIDER_ICU) - { -#ifdef USE_ICU -#ifdef HAVE_UCOL_STRCOLLUTF8 - if (GetDatabaseEncoding() == PG_UTF8) - { - UErrorCode status; - - status = U_ZERO_ERROR; - result = ucol_strcollUTF8(sss->locale->info.icu.ucol, - a1p, len1, - a2p, len2, - &status); - if (U_FAILURE(status)) - ereport(ERROR, - (errmsg("collation failed: %s", u_errorName(status)))); - } - else -#endif - { - int32_t ulen1, - ulen2; - UChar *uchar1, - *uchar2; - - ulen1 = icu_to_uchar(&uchar1, a1p, len1); - ulen2 = icu_to_uchar(&uchar2, a2p, len2); - - result = ucol_strcoll(sss->locale->info.icu.ucol, - uchar1, ulen1, - uchar2, ulen2); - - pfree(uchar1); - pfree(uchar2); - } -#else /* not USE_ICU */ - /* shouldn't happen */ - elog(ERROR, "unsupported collprovider: %c", sss->locale->provider); -#endif /* not USE_ICU */ - } - else - { -#ifdef HAVE_LOCALE_T - result = strcoll_l(sss->buf1, sss->buf2, sss->locale->info.lt); -#else - /* shouldn't happen */ - elog(ERROR, "unsupported collprovider: %c", sss->locale->provider); -#endif - } - } - else - result = strcoll(sss->buf1, sss->buf2); - - /* Break tie if necessary. */ - if (result == 0 && - (!sss->locale || sss->locale->deterministic)) - result = strcmp(sss->buf1, sss->buf2); + result = pg_strcoll(sss->buf1, len1, sss->buf2, len2, sss->locale); /* Cache result, perhaps saving an expensive strcoll() call next time */ sss->cache_blob = false; diff --git a/src/include/utils/pg_locale.h b/src/include/utils/pg_locale.h index a875942123..59a4c9ad0d 100644 --- a/src/include/utils/pg_locale.h +++ b/src/include/utils/pg_locale.h @@ -98,7 +98,8 @@ extern void make_icu_collator(const char *iculocstr, struct pg_locale_struct *resultp); extern pg_locale_t pg_newlocale_from_collation(Oid collid); - +extern int pg_strcoll(const char *arg1, size_t len1, const char *arg2, + size_t len2, pg_locale_t locale); extern char *get_collation_actual_version(char collprovider, const char *collcollate); #ifdef USE_ICU -- 2.34.1