Here is a new patch set. I've split it into two patches: one for the 64-bit functions, and one for the 32-bit functions. I've also added tests for pg_lfind64/pg_lfind64_idx and deduplicated the code a bit.
-- Nathan Bossart Amazon Web Services: https://aws.amazon.com
>From 7dbbf61e8a2546a73413c3768a49da318f7de7f5 Mon Sep 17 00:00:00 2001 From: Nathan Bossart <nathandboss...@gmail.com> Date: Mon, 10 Apr 2023 09:10:47 -0700 Subject: [PATCH v3 1/2] speed up list_member_ptr and list_delete_ptr with SIMD --- src/backend/nodes/list.c | 35 ++++++ src/include/port/pg_lfind.h | 112 ++++++++++++++++++ src/include/port/simd.h | 96 ++++++++++++++- .../test_lfind/expected/test_lfind.out | 12 ++ .../modules/test_lfind/sql/test_lfind.sql | 2 + .../modules/test_lfind/test_lfind--1.0.sql | 8 ++ src/test/modules/test_lfind/test_lfind.c | 62 ++++++++++ 7 files changed, 324 insertions(+), 3 deletions(-) diff --git a/src/backend/nodes/list.c b/src/backend/nodes/list.c index a709d23ef1..92bc48de17 100644 --- a/src/backend/nodes/list.c +++ b/src/backend/nodes/list.c @@ -19,6 +19,7 @@ #include "nodes/pg_list.h" #include "port/pg_bitutils.h" +#include "port/pg_lfind.h" #include "utils/memdebug.h" #include "utils/memutils.h" @@ -680,11 +681,15 @@ list_member(const List *list, const void *datum) bool list_member_ptr(const List *list, const void *datum) { +#ifdef USE_NO_SIMD const ListCell *cell; +#endif Assert(IsPointerList(list)); check_list_invariants(list); +#ifdef USE_NO_SIMD + foreach(cell, list) { if (lfirst(cell) == datum) @@ -692,6 +697,18 @@ list_member_ptr(const List *list, const void *datum) } return false; + +#else + + Assert(sizeof(ListCell) == 8); + Assert(sizeof(void *) == 8); + + if (list == NIL) + return false; + + return pg_lfind64((uint64) datum, (uint64 *) list->elements, list->length); + +#endif } /* @@ -875,12 +892,30 @@ list_delete_ptr(List *list, void *datum) Assert(IsPointerList(list)); check_list_invariants(list); +#ifdef USE_NO_SIMD + foreach(cell, list) { if (lfirst(cell) == datum) return list_delete_cell(list, cell); } +#else + + Assert(sizeof(ListCell) == 8); + Assert(sizeof(void *) == 8); + + if (list == NIL) + return NIL; + + cell = (ListCell *) pg_lfind64_idx((uint64) datum, + (uint64 *) list->elements, + list->length); + if (cell != NULL) + return list_delete_cell(list, cell); + +#endif + /* Didn't find a match: return the list unmodified */ return list; } diff --git a/src/include/port/pg_lfind.h b/src/include/port/pg_lfind.h index 59aa8245ed..d08ad91ae8 100644 --- a/src/include/port/pg_lfind.h +++ b/src/include/port/pg_lfind.h @@ -177,4 +177,116 @@ pg_lfind32(uint32 key, uint32 *base, uint32 nelem) return false; } +/* + * pg_lfind64_internal + * + * Workhorse for pg_lfind64 and pg_lfind64_idx. + */ +static inline bool +pg_lfind64_internal(uint64 key, uint64 *base, uint32 nelem, uint32 *i) +{ +#ifdef USE_NO_SIMD + + *i = 0; + +#else + + /* + * For better instruction-level parallelism, each loop iteration operates + * on a block of four registers. + */ + const Vector64 keys = vector64_broadcast(key); /* load copies of key */ + const uint32 nelem_per_vector = sizeof(Vector64) / sizeof(uint64); + const uint32 nelem_per_iteration = 4 * nelem_per_vector; + + /* round down to multiple of elements per iteration */ + const uint32 tail_idx = nelem & ~(nelem_per_iteration - 1); + + for (*i = 0; *i < tail_idx; *i += nelem_per_iteration) + { + Vector64 vals1, + vals2, + vals3, + vals4, + result1, + result2, + result3, + result4, + tmp1, + tmp2, + result; + + /* load the next block into 4 registers */ + vector64_load(&vals1, &base[*i]); + vector64_load(&vals2, &base[*i + nelem_per_vector]); + vector64_load(&vals3, &base[*i + nelem_per_vector * 2]); + vector64_load(&vals4, &base[*i + nelem_per_vector * 3]); + + /* compare each value to the key */ + result1 = vector64_eq(keys, vals1); + result2 = vector64_eq(keys, vals2); + result3 = vector64_eq(keys, vals3); + result4 = vector64_eq(keys, vals4); + + /* combine the results into a single variable */ + tmp1 = vector64_or(result1, result2); + tmp2 = vector64_or(result3, result4); + result = vector64_or(tmp1, tmp2); + + /* see if there was a match */ + if (vector64_is_highbit_set(result)) + return true; + } +#endif /* ! USE_NO_SIMD */ + + return false; +} + +/* + * pg_lfind64 + * + * Return true if there is an element in 'base' that equals 'key', otherwise + * return false. + */ +static inline bool +pg_lfind64(uint64 key, uint64 *base, uint32 nelem) +{ + uint32 i = 0; + + if (pg_lfind64_internal(key, base, nelem, &i)) + return true; + + /* Process the remaining elements one at a time. */ + for (; i < nelem; i++) + { + if (key == base[i]) + return true; + } + + return false; +} + +/* + * pg_lfind64_idx + * + * Return the address of the first element in 'base' that equals 'key', or NULL + * if no match is found. + */ +static inline uint64 * +pg_lfind64_idx(uint64 key, uint64 *base, uint32 nelem) +{ + uint32 i = 0; + + (void) pg_lfind64_internal(key, base, nelem, &i); + + /* Process the remaining elements one at a time. */ + for (; i < nelem; i++) + { + if (key == base[i]) + return &base[i]; + } + + return NULL; +} + #endif /* PG_LFIND_H */ diff --git a/src/include/port/simd.h b/src/include/port/simd.h index 1fa6c3bc6c..b851d3d89d 100644 --- a/src/include/port/simd.h +++ b/src/include/port/simd.h @@ -32,6 +32,7 @@ #define USE_SSE2 typedef __m128i Vector8; typedef __m128i Vector32; +typedef __m128i Vector64; #elif defined(__aarch64__) && defined(__ARM_NEON) /* @@ -46,15 +47,16 @@ typedef __m128i Vector32; #define USE_NEON typedef uint8x16_t Vector8; typedef uint32x4_t Vector32; +typedef uint64x2_t Vector64; #else /* * If no SIMD instructions are available, we can in some cases emulate vector * operations using bitwise operations on unsigned integers. Note that many * of the functions in this file presently do not have non-SIMD - * implementations. In particular, none of the functions involving Vector32 - * are implemented without SIMD since it's likely not worthwhile to represent - * two 32-bit integers using a uint64. + * implementations. For example, none of the functions involving Vector32 are + * implemented without SIMD since it's likely not worthwhile to represent two + * 32-bit integers using a uint64. */ #define USE_NO_SIMD typedef uint64 Vector8; @@ -64,12 +66,14 @@ typedef uint64 Vector8; static inline void vector8_load(Vector8 *v, const uint8 *s); #ifndef USE_NO_SIMD static inline void vector32_load(Vector32 *v, const uint32 *s); +static inline void vector64_load(Vector64 *v, const uint64 *s); #endif /* assignment operations */ static inline Vector8 vector8_broadcast(const uint8 c); #ifndef USE_NO_SIMD static inline Vector32 vector32_broadcast(const uint32 c); +static inline Vector64 vector64_broadcast(const uint64 c); #endif /* element-wise comparisons to a scalar */ @@ -79,12 +83,15 @@ static inline bool vector8_has_le(const Vector8 v, const uint8 c); static inline bool vector8_is_highbit_set(const Vector8 v); #ifndef USE_NO_SIMD static inline bool vector32_is_highbit_set(const Vector32 v); +static inline bool vector64_is_highbit_set(const Vector64 v); #endif /* arithmetic operations */ static inline Vector8 vector8_or(const Vector8 v1, const Vector8 v2); #ifndef USE_NO_SIMD static inline Vector32 vector32_or(const Vector32 v1, const Vector32 v2); +static inline Vector64 vector64_or(const Vector64 v1, const Vector64 v2); +static inline Vector32 vector32_and(const Vector32 v1, const Vector32 v2); static inline Vector8 vector8_ssub(const Vector8 v1, const Vector8 v2); #endif @@ -97,6 +104,7 @@ static inline Vector8 vector8_ssub(const Vector8 v1, const Vector8 v2); #ifndef USE_NO_SIMD static inline Vector8 vector8_eq(const Vector8 v1, const Vector8 v2); static inline Vector32 vector32_eq(const Vector32 v1, const Vector32 v2); +static inline Vector64 vector64_eq(const Vector64 v1, const Vector64 v2); #endif /* @@ -126,6 +134,18 @@ vector32_load(Vector32 *v, const uint32 *s) } #endif /* ! USE_NO_SIMD */ +#ifndef USE_NO_SIMD +static inline void +vector64_load(Vector64 *v, const uint64 *s) +{ +#ifdef USE_SSE2 + *v = _mm_loadu_si128((const __m128i *) s); +#elif defined(USE_NEON) + *v = vld1q_u64((const uint64_t *) s); +#endif +} +#endif /* ! USE_NO_SIMD */ + /* * Create a vector with all elements set to the same value. */ @@ -153,6 +173,18 @@ vector32_broadcast(const uint32 c) } #endif /* ! USE_NO_SIMD */ +#ifndef USE_NO_SIMD +static inline Vector64 +vector64_broadcast(const uint64 c) +{ +#ifdef USE_SSE2 + return _mm_set1_epi64x(c); +#elif defined(USE_NEON) + return vdupq_n_u64(c); +#endif +} +#endif /* ! USE_NO_SIMD */ + /* * Return true if any elements in the vector are equal to the given scalar. */ @@ -299,6 +331,23 @@ vector32_is_highbit_set(const Vector32 v) } #endif /* ! USE_NO_SIMD */ +/* + * Exactly like vector8_is_highbit_set except for the input type, so it + * looks at each byte separately. See the comment above + * vector32_is_highbit_set for more information. + */ +#ifndef USE_NO_SIMD +static inline bool +vector64_is_highbit_set(const Vector64 v) +{ +#if defined(USE_NEON) + return vector8_is_highbit_set((Vector8) v); +#else + return vector8_is_highbit_set(v); +#endif +} +#endif /* ! USE_NO_SIMD */ + /* * Return the bitwise OR of the inputs */ @@ -326,6 +375,30 @@ vector32_or(const Vector32 v1, const Vector32 v2) } #endif /* ! USE_NO_SIMD */ +#ifndef USE_NO_SIMD +static inline Vector64 +vector64_or(const Vector64 v1, const Vector64 v2) +{ +#ifdef USE_SSE2 + return _mm_or_si128(v1, v2); +#elif defined(USE_NEON) + return vorrq_u64(v1, v2); +#endif +} +#endif /* ! USE_NO_SIMD */ + +#ifndef USE_NO_SIMD +static inline Vector32 +vector32_and(const Vector32 v1, const Vector32 v2) +{ +#ifdef USE_SSE2 + return _mm_and_si128(v1, v2); +#elif defined(USE_NEON) + return vandq_u32(v1, v2); +#endif +} +#endif /* ! USE_NO_SIMD */ + /* * Return the result of subtracting the respective elements of the input * vectors using saturation (i.e., if the operation would yield a value less @@ -372,4 +445,21 @@ vector32_eq(const Vector32 v1, const Vector32 v2) } #endif /* ! USE_NO_SIMD */ +#ifndef USE_NO_SIMD +static inline Vector64 +vector64_eq(const Vector64 v1, const Vector64 v2) +{ +#ifdef USE_SSE2 + /* We have to work around SSE2's lack of _mm_cmpeq_epi64. */ + const Vector32 cmp = vector32_eq(v1, v2); + const Vector32 hi = _mm_shuffle_epi32(cmp, _MM_SHUFFLE(3, 3, 1, 1)); + const Vector32 lo = _mm_shuffle_epi32(cmp, _MM_SHUFFLE(2, 2, 0, 0)); + + return vector32_and(hi, lo); +#elif defined(USE_NEON) + return vceqq_u64(v1, v2); +#endif +} +#endif /* ! USE_NO_SIMD */ + #endif /* SIMD_H */ diff --git a/src/test/modules/test_lfind/expected/test_lfind.out b/src/test/modules/test_lfind/expected/test_lfind.out index 1d4b14e703..6488c7f700 100644 --- a/src/test/modules/test_lfind/expected/test_lfind.out +++ b/src/test/modules/test_lfind/expected/test_lfind.out @@ -22,3 +22,15 @@ SELECT test_lfind32(); (1 row) +SELECT test_lfind64(); + test_lfind64 +-------------- + +(1 row) + +SELECT test_lfind64_idx(); + test_lfind64_idx +------------------ + +(1 row) + diff --git a/src/test/modules/test_lfind/sql/test_lfind.sql b/src/test/modules/test_lfind/sql/test_lfind.sql index 766c640831..0b5b6eea9e 100644 --- a/src/test/modules/test_lfind/sql/test_lfind.sql +++ b/src/test/modules/test_lfind/sql/test_lfind.sql @@ -8,3 +8,5 @@ CREATE EXTENSION test_lfind; SELECT test_lfind8(); SELECT test_lfind8_le(); SELECT test_lfind32(); +SELECT test_lfind64(); +SELECT test_lfind64_idx(); diff --git a/src/test/modules/test_lfind/test_lfind--1.0.sql b/src/test/modules/test_lfind/test_lfind--1.0.sql index 81801926ae..55c5fd2087 100644 --- a/src/test/modules/test_lfind/test_lfind--1.0.sql +++ b/src/test/modules/test_lfind/test_lfind--1.0.sql @@ -3,6 +3,14 @@ -- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "CREATE EXTENSION test_lfind" to load this file. \quit +CREATE FUNCTION test_lfind64() + RETURNS pg_catalog.void + AS 'MODULE_PATHNAME' LANGUAGE C; + +CREATE FUNCTION test_lfind64_idx() + RETURNS pg_catalog.void + AS 'MODULE_PATHNAME' LANGUAGE C; + CREATE FUNCTION test_lfind32() RETURNS pg_catalog.void AS 'MODULE_PATHNAME' LANGUAGE C; diff --git a/src/test/modules/test_lfind/test_lfind.c b/src/test/modules/test_lfind/test_lfind.c index e2e8b7389f..906e40300e 100644 --- a/src/test/modules/test_lfind/test_lfind.c +++ b/src/test/modules/test_lfind/test_lfind.c @@ -146,3 +146,65 @@ test_lfind32(PG_FUNCTION_ARGS) PG_RETURN_VOID(); } + +PG_FUNCTION_INFO_V1(test_lfind64); +Datum +test_lfind64(PG_FUNCTION_ARGS) +{ + uint64 test_array[TEST_ARRAY_SIZE] = {0}; + + test_array[8] = 1; + test_array[64] = 2; + test_array[TEST_ARRAY_SIZE - 1] = 3; + + if (pg_lfind64(1, test_array, 4)) + elog(ERROR, "pg_lfind64() found nonexistent element"); + if (!pg_lfind64(1, test_array, TEST_ARRAY_SIZE)) + elog(ERROR, "pg_lfind64() did not find existing element"); + + if (pg_lfind64(2, test_array, 32)) + elog(ERROR, "pg_lfind64() found nonexistent element"); + if (!pg_lfind64(2, test_array, TEST_ARRAY_SIZE)) + elog(ERROR, "pg_lfind64() did not find existing element"); + + if (pg_lfind64(3, test_array, 96)) + elog(ERROR, "pg_lfind64() found nonexistent element"); + if (!pg_lfind64(3, test_array, TEST_ARRAY_SIZE)) + elog(ERROR, "pg_lfind64() did not find existing element"); + + if (pg_lfind64(4, test_array, TEST_ARRAY_SIZE)) + elog(ERROR, "pg_lfind64() found nonexistent element"); + + PG_RETURN_VOID(); +} + +PG_FUNCTION_INFO_V1(test_lfind64_idx); +Datum +test_lfind64_idx(PG_FUNCTION_ARGS) +{ + uint64 test_array[TEST_ARRAY_SIZE] = {0}; + + test_array[8] = 1; + test_array[64] = 2; + test_array[TEST_ARRAY_SIZE - 1] = 3; + + if (pg_lfind64_idx(1, test_array, 4) != NULL) + elog(ERROR, "pg_lfind64_idx() found nonexistent element"); + if (pg_lfind64_idx(1, test_array, TEST_ARRAY_SIZE) != &test_array[8]) + elog(ERROR, "pg_lfind64_idx() did not find existing element"); + + if (pg_lfind64_idx(2, test_array, 32) != NULL) + elog(ERROR, "pg_lfind6_idx4() found nonexistent element"); + if (pg_lfind64_idx(2, test_array, TEST_ARRAY_SIZE) != &test_array[64]) + elog(ERROR, "pg_lfind64_idx() did not find existing element"); + + if (pg_lfind64_idx(3, test_array, 96) != NULL) + elog(ERROR, "pg_lfind64_idx() found nonexistent element"); + if (pg_lfind64_idx(3, test_array, TEST_ARRAY_SIZE) != &test_array[TEST_ARRAY_SIZE - 1]) + elog(ERROR, "pg_lfind64_idx() did not find existing element"); + + if (pg_lfind64_idx(4, test_array, TEST_ARRAY_SIZE) != NULL) + elog(ERROR, "pg_lfind64_idx() found nonexistent element"); + + PG_RETURN_VOID(); +} -- 2.25.1
>From ccb3d6c072899063bc6e47388ef9a222a19be324 Mon Sep 17 00:00:00 2001 From: Nathan Bossart <nathandboss...@gmail.com> Date: Mon, 17 Apr 2023 12:29:10 -0700 Subject: [PATCH v3 2/2] speed up several functions for lists with inline 32-bit data using SIMD --- src/backend/nodes/list.c | 192 +++++++++++++++++++++++++++++++-------- 1 file changed, 155 insertions(+), 37 deletions(-) diff --git a/src/backend/nodes/list.c b/src/backend/nodes/list.c index 92bc48de17..fa06caf553 100644 --- a/src/backend/nodes/list.c +++ b/src/backend/nodes/list.c @@ -57,6 +57,9 @@ #define IsOidList(l) ((l) == NIL || IsA((l), OidList)) #define IsXidList(l) ((l) == NIL || IsA((l), XidList)) +static inline ListCell *list_member_inline_internal_idx(const List *list, uint32 datum); +static inline bool list_member_inline_internal(const List *list, uint32 datum); + #ifdef USE_ASSERT_CHECKING /* * Check that the specified List is valid (so far as we can tell). @@ -717,18 +720,10 @@ list_member_ptr(const List *list, const void *datum) bool list_member_int(const List *list, int datum) { - const ListCell *cell; - Assert(IsIntegerList(list)); check_list_invariants(list); - foreach(cell, list) - { - if (lfirst_int(cell) == datum) - return true; - } - - return false; + return list_member_inline_internal(list, datum); } /* @@ -737,18 +732,10 @@ list_member_int(const List *list, int datum) bool list_member_oid(const List *list, Oid datum) { - const ListCell *cell; - Assert(IsOidList(list)); check_list_invariants(list); - foreach(cell, list) - { - if (lfirst_oid(cell) == datum) - return true; - } - - return false; + return list_member_inline_internal(list, datum); } /* @@ -757,18 +744,10 @@ list_member_oid(const List *list, Oid datum) bool list_member_xid(const List *list, TransactionId datum) { - const ListCell *cell; - Assert(IsXidList(list)); check_list_invariants(list); - foreach(cell, list) - { - if (lfirst_xid(cell) == datum) - return true; - } - - return false; + return list_member_inline_internal(list, datum); } /* @@ -929,11 +908,9 @@ list_delete_int(List *list, int datum) Assert(IsIntegerList(list)); check_list_invariants(list); - foreach(cell, list) - { - if (lfirst_int(cell) == datum) - return list_delete_cell(list, cell); - } + cell = list_member_inline_internal_idx(list, datum); + if (cell != NULL) + return list_delete_cell(list, cell); /* Didn't find a match: return the list unmodified */ return list; @@ -948,11 +925,9 @@ list_delete_oid(List *list, Oid datum) Assert(IsOidList(list)); check_list_invariants(list); - foreach(cell, list) - { - if (lfirst_oid(cell) == datum) - return list_delete_cell(list, cell); - } + cell = list_member_inline_internal_idx(list, datum); + if (cell != NULL) + return list_delete_cell(list, cell); /* Didn't find a match: return the list unmodified */ return list; @@ -1749,3 +1724,146 @@ list_oid_cmp(const ListCell *p1, const ListCell *p2) return 1; return 0; } + +/* + * list_member_inline_helper + * + * Workhorse for list_member_inline_internal and + * list_member_inline_internal_idx. + */ +static inline bool +list_member_inline_helper(const List *list, uint32 datum, uint32 *i) +{ +#ifdef USE_NO_SIMD + + *i = 0; + +#else + + /* + * For better instruction-level parallelism, each loop iteration operates + * on a block of four registers. + */ + const Vector32 keys = vector32_broadcast(datum); /* load copies of key */ + const uint32 nelem_per_vector = sizeof(Vector32) / sizeof(uint32); + const uint32 nelem_per_iteration = 4 * nelem_per_vector; +#ifdef USE_NEON + const Vector32 mask = (Vector32) vector64_broadcast(UINT64CONST(0xFFFFFFFF)); +#else + const Vector32 mask = vector64_broadcast(UINT64CONST(0xFFFFFFFF)); +#endif + const uint32 *elements = (const uint32 *) list->elements; + + /* round down to multiple of elements per iteration */ + const uint32 tail_idx = (list->length * 2) & ~(nelem_per_iteration - 1); + + /* + * The SIMD optimized portion of this routine is written with the + * expectation that the 32-bit datum we are searching for only takes up + * half of a ListCell. If that changes, this routine must change, too. + */ + Assert(sizeof(ListCell) == 8); + + for (*i = 0; *i < tail_idx; *i += nelem_per_iteration) + { + Vector32 vals1, + vals2, + vals3, + vals4, + result1, + result2, + result3, + result4, + tmp1, + tmp2, + result, + masked; + + /* load the next block into 4 registers */ + vector32_load(&vals1, &elements[*i]); + vector32_load(&vals2, &elements[*i + nelem_per_vector]); + vector32_load(&vals3, &elements[*i + nelem_per_vector * 2]); + vector32_load(&vals4, &elements[*i + nelem_per_vector * 3]); + + /* compare each value to the key */ + result1 = vector32_eq(keys, vals1); + result2 = vector32_eq(keys, vals2); + result3 = vector32_eq(keys, vals3); + result4 = vector32_eq(keys, vals4); + + /* combine the results into a single variable */ + tmp1 = vector32_or(result1, result2); + tmp2 = vector32_or(result3, result4); + result = vector32_or(tmp1, tmp2); + + /* filter out matches in space between data */ + masked = vector32_and(result, mask); + + /* break out and find the exact element if there was a match */ + if (vector32_is_highbit_set(masked)) + { + *i /= 2; + return true; + } + } + +#endif /* ! USE_NO_SIMD */ + + *i /= 2; + return false; +} + +/* + * list_member_inline_internal + * + * Optimized linear search routine (using SIMD intrinsics where available) for + * lists with inline 32-bit data. + */ +static inline bool +list_member_inline_internal(const List *list, uint32 datum) +{ + uint32 i = 0; + const ListCell *cell; + + if (list == NIL) + return false; + + if (list_member_inline_helper(list, datum, &i)) + return true; + + /* Process the remaining elements one at a time. */ + for_each_from(cell, list, i) + { + if (lfirst_int(cell) == (int) datum) + return true; + } + + return false; +} + +/* + * list_member_inline_internal_idx + * + * Optimized linear search routine (using SIMD intrinsics where available) for + * lists with inline 32-bit data. + */ +static inline ListCell * +list_member_inline_internal_idx(const List *list, uint32 datum) +{ + uint32 i = 0; + ListCell *cell; + + if (list == NIL) + return NULL; + + (void) list_member_inline_helper(list, datum, &i); + + /* Process the remaining elements one at a time. */ + for_each_from(cell, list, i) + { + if (lfirst_int(cell) == (int) datum) + return cell; + } + + return NULL; +} -- 2.25.1