cfbot's Windows build wasn't happy with a couple of casts. I applied a fix similar to c6a43c2 in v2. The patch is still very much a work in progress.
-- Nathan Bossart Amazon Web Services: https://aws.amazon.com
>From 055717233c47518ae48119938ebd203cc55f7f3c Mon Sep 17 00:00:00 2001 From: Nathan Bossart <nathandboss...@gmail.com> Date: Sat, 4 Mar 2023 23:16:07 -0800 Subject: [PATCH v2 1/1] speed up several list functions with SIMD --- src/backend/nodes/list.c | 262 +++++++++++++++++++++++++++++++----- src/include/port/pg_lfind.h | 189 ++++++++++++++++++++++++++ src/include/port/simd.h | 103 ++++++++++++++ 3 files changed, 521 insertions(+), 33 deletions(-) diff --git a/src/backend/nodes/list.c b/src/backend/nodes/list.c index a709d23ef1..02ddbeb3f2 100644 --- a/src/backend/nodes/list.c +++ b/src/backend/nodes/list.c @@ -19,6 +19,7 @@ #include "nodes/pg_list.h" #include "port/pg_bitutils.h" +#include "port/pg_lfind.h" #include "utils/memdebug.h" #include "utils/memutils.h" @@ -680,11 +681,25 @@ list_member(const List *list, const void *datum) bool list_member_ptr(const List *list, const void *datum) { +#ifdef USE_NO_SIMD const ListCell *cell; +#endif Assert(IsPointerList(list)); check_list_invariants(list); +#ifndef USE_NO_SIMD + + Assert(sizeof(ListCell) == 8); + Assert(sizeof(void *) == 8); + + if (list == NIL) + return false; + + return pg_lfind64((uint64) datum, (uint64 *) list->elements, list->length); + +#else + foreach(cell, list) { if (lfirst(cell) == datum) @@ -692,46 +707,125 @@ list_member_ptr(const List *list, const void *datum) } return false; + +#endif } /* - * Return true iff the integer 'datum' is a member of the list. + * Optimized linear search routine (using SIMD intrinsics where available) for + * lists with inline 32-bit data. */ -bool -list_member_int(const List *list, int datum) +static inline bool +list_member_inline_internal(const List *list, uint32 datum) { + uint32 i = 0; const ListCell *cell; - Assert(IsIntegerList(list)); - check_list_invariants(list); +#ifndef USE_NO_SIMD - foreach(cell, list) + /* + * For better instruction-level parallelism, each loop iteration operates + * on a block of four registers. + */ + const Vector32 keys = vector32_broadcast(datum); /* load copies of key */ + const uint32 nelem_per_vector = sizeof(Vector32) / sizeof(uint32); + const uint32 nelem_per_iteration = 4 * nelem_per_vector; +#ifdef USE_NEON + const Vector32 mask = (Vector32) vector64_broadcast(UINT64CONST(0xFFFFFFFF)); +#else + const Vector32 mask = vector64_broadcast(UINT64CONST(0xFFFFFFFF)); +#endif + const uint32 *elements = (const uint32 *) list->elements; + + /* round down to multiple of elements per iteration */ + const uint32 tail_idx = (list->length * 2) & ~(nelem_per_iteration - 1); + + /* + * The SIMD-optimized portion of this routine is written with the + * expectation that the 32-bit datum we are searching for only takes up + * half of a ListCell. If that changes, this routine must change, too. + */ + Assert(sizeof(ListCell) == 8); + + for (i = 0; i < tail_idx; i += nelem_per_iteration) + { + Vector32 vals1, + vals2, + vals3, + vals4, + result1, + result2, + result3, + result4, + tmp1, + tmp2, + result, + masked; + + /* load the next block into 4 registers */ + vector32_load(&vals1, &elements[i]); + vector32_load(&vals2, &elements[i + nelem_per_vector]); + vector32_load(&vals3, &elements[i + nelem_per_vector * 2]); + vector32_load(&vals4, &elements[i + nelem_per_vector * 3]); + + /* compare each value to the key */ + result1 = vector32_eq(keys, vals1); + result2 = vector32_eq(keys, vals2); + result3 = vector32_eq(keys, vals3); + result4 = vector32_eq(keys, vals4); + + /* combine the results into a single variable */ + tmp1 = vector32_or(result1, result2); + tmp2 = vector32_or(result3, result4); + result = vector32_or(tmp1, tmp2); + + /* filter out matches in space between data */ + masked = vector32_and(result, mask); + + /* see if there was a match */ + if (vector32_is_highbit_set(masked)) + return true; + } + +#endif /* ! USE_NO_SIMD */ + + for_each_from(cell, list, i / 2) { - if (lfirst_int(cell) == datum) + if (lfirst_int(cell) == (int) datum) return true; } return false; } +/* + * Return true iff the integer 'datum' is a member of the list. + */ +bool +list_member_int(const List *list, int datum) +{ + Assert(IsIntegerList(list)); + check_list_invariants(list); + + if (list == NIL) + return false; + + return list_member_inline_internal(list, datum); +} + /* * Return true iff the OID 'datum' is a member of the list. */ bool list_member_oid(const List *list, Oid datum) { - const ListCell *cell; - Assert(IsOidList(list)); check_list_invariants(list); - foreach(cell, list) - { - if (lfirst_oid(cell) == datum) - return true; - } + if (list == NIL) + return false; - return false; + return list_member_inline_internal(list, datum); } /* @@ -740,18 +834,13 @@ list_member_oid(const List *list, Oid datum) bool list_member_xid(const List *list, TransactionId datum) { - const ListCell *cell; - Assert(IsXidList(list)); check_list_invariants(list); - foreach(cell, list) - { - if (lfirst_xid(cell) == datum) - return true; - } + if (list == NIL) + return false; - return false; + return list_member_inline_internal(list, datum); } /* @@ -875,16 +964,121 @@ list_delete_ptr(List *list, void *datum) Assert(IsPointerList(list)); check_list_invariants(list); +#ifndef USE_NO_SIMD + + Assert(sizeof(ListCell) == 8); + Assert(sizeof(void *) == 8); + + if (list == NIL) + return NIL; + + cell = (ListCell *) pg_lfind64_idx((uint64) datum, + (uint64 *) list->elements, + list->length); + if (cell != NULL) + return list_delete_cell(list, cell); + +#else + foreach(cell, list) { if (lfirst(cell) == datum) return list_delete_cell(list, cell); } +#endif + /* Didn't find a match: return the list unmodified */ return list; } +/* + * Optimized linear search routine (using SIMD intrinsics where available) for + * lists with inline 32-bit data. + */ +static inline ListCell * +list_member_inline_interal_idx(const List *list, uint32 datum) +{ + uint32 i = 0; + ListCell *cell; + +#ifndef USE_NO_SIMD + + /* + * For better instruction-level parallelism, each loop iteration operates + * on a block of four registers. + */ + const Vector32 keys = vector32_broadcast(datum); /* load copies of key */ + const uint32 nelem_per_vector = sizeof(Vector32) / sizeof(uint32); + const uint32 nelem_per_iteration = 4 * nelem_per_vector; +#ifdef USE_NEON + const Vector32 mask = (Vector32) vector64_broadcast(UINT64CONST(0xFFFFFFFF)); +#else + const Vector32 mask = vector64_broadcast(UINT64CONST(0xFFFFFFFF)); +#endif + const uint32 *elements = (const uint32 *) list->elements; + + /* round down to multiple of elements per iteration */ + const uint32 tail_idx = (list->length * 2) & ~(nelem_per_iteration - 1); + + /* + * The SIMD-optimized portion of this routine is written with the + * expectation that the 32-bit datum we are searching for only takes up + * half of a ListCell. If that changes, this routine must change, too. + */ + Assert(sizeof(ListCell) == 8); + + for (i = 0; i < tail_idx; i += nelem_per_iteration) + { + Vector32 vals1, + vals2, + vals3, + vals4, + result1, + result2, + result3, + result4, + tmp1, + tmp2, + result, + masked; + + /* load the next block into 4 registers */ + vector32_load(&vals1, &elements[i]); + vector32_load(&vals2, &elements[i + nelem_per_vector]); + vector32_load(&vals3, &elements[i + nelem_per_vector * 2]); + vector32_load(&vals4, &elements[i + nelem_per_vector * 3]); + + /* compare each value to the key */ + result1 = vector32_eq(keys, vals1); + result2 = vector32_eq(keys, vals2); + result3 = vector32_eq(keys, vals3); + result4 = vector32_eq(keys, vals4); + + /* combine the results into a single variable */ + tmp1 = vector32_or(result1, result2); + tmp2 = vector32_or(result3, result4); + result = vector32_or(tmp1, tmp2); + + /* filter out matches in space between data */ + masked = vector32_and(result, mask); + + /* break out and find the exact element if there was a match */ + if (vector32_is_highbit_set(masked)) + break; + } + +#endif /* ! USE_NO_SIMD */ + + for_each_from(cell, list, i / 2) + { + if (lfirst_int(cell) == (int) datum) + return cell; + } + + return NULL; +} + /* As above, but for integers */ List * list_delete_int(List *list, int datum) @@ -894,11 +1088,12 @@ list_delete_int(List *list, int datum) Assert(IsIntegerList(list)); check_list_invariants(list); - foreach(cell, list) - { - if (lfirst_int(cell) == datum) - return list_delete_cell(list, cell); - } + if (list == NIL) + return NIL; + + cell = list_member_inline_interal_idx(list, datum); + if (cell) + return list_delete_cell(list, cell); /* Didn't find a match: return the list unmodified */ return list; @@ -913,11 +1108,12 @@ list_delete_oid(List *list, Oid datum) Assert(IsOidList(list)); check_list_invariants(list); - foreach(cell, list) - { - if (lfirst_oid(cell) == datum) - return list_delete_cell(list, cell); - } + if (list == NIL) + return NIL; + + cell = list_member_inline_interal_idx(list, datum); + if (cell) + return list_delete_cell(list, cell); /* Didn't find a match: return the list unmodified */ return list; diff --git a/src/include/port/pg_lfind.h b/src/include/port/pg_lfind.h index 59aa8245ed..05d0b12463 100644 --- a/src/include/port/pg_lfind.h +++ b/src/include/port/pg_lfind.h @@ -177,4 +177,193 @@ pg_lfind32(uint32 key, uint32 *base, uint32 nelem) return false; } +/* + * pg_lfind64 + * + * Return true if there is an element in 'base' that equals 'key', otherwise + * return false. + */ +static inline bool +pg_lfind64(uint64 key, uint64 *base, uint32 nelem) +{ + uint32 i = 0; + +#ifndef USE_NO_SIMD + + /* + * For better instruction-level parallelism, each loop iteration operates + * on a block of four registers. + */ + const Vector64 keys = vector64_broadcast(key); /* load copies of key */ + const uint32 nelem_per_vector = sizeof(Vector64) / sizeof(uint64); + const uint32 nelem_per_iteration = 4 * nelem_per_vector; + + /* round down to multiple of elements per iteration */ + const uint32 tail_idx = nelem & ~(nelem_per_iteration - 1); + +#if defined(USE_ASSERT_CHECKING) + bool assert_result = false; + + /* pre-compute the result for assert checking */ + for (i = 0; i < nelem; i++) + { + if (key == base[i]) + { + assert_result = true; + break; + } + } +#endif + + for (i = 0; i < tail_idx; i += nelem_per_iteration) + { + Vector64 vals1, + vals2, + vals3, + vals4, + result1, + result2, + result3, + result4, + tmp1, + tmp2, + result; + + /* load the next block into 4 registers */ + vector64_load(&vals1, &base[i]); + vector64_load(&vals2, &base[i + nelem_per_vector]); + vector64_load(&vals3, &base[i + nelem_per_vector * 2]); + vector64_load(&vals4, &base[i + nelem_per_vector * 3]); + + /* compare each value to the key */ + result1 = vector64_eq(keys, vals1); + result2 = vector64_eq(keys, vals2); + result3 = vector64_eq(keys, vals3); + result4 = vector64_eq(keys, vals4); + + /* combine the results into a single variable */ + tmp1 = vector64_or(result1, result2); + tmp2 = vector64_or(result3, result4); + result = vector64_or(tmp1, tmp2); + + /* see if there was a match */ + if (vector64_is_highbit_set(result)) + { + Assert(assert_result == true); + return true; + } + } +#endif /* ! USE_NO_SIMD */ + + /* Process the remaining elements one at a time. */ + for (; i < nelem; i++) + { + if (key == base[i]) + { +#ifndef USE_NO_SIMD + Assert(assert_result == true); +#endif + return true; + } + } + +#ifndef USE_NO_SIMD + Assert(assert_result == false); +#endif + return false; +} + +/* + * pg_lfind64_idx + * + * Return the address of the first element in 'base' that equals 'key', or NULL + * if no match is found. + */ +static inline uint64 * +pg_lfind64_idx(uint64 key, uint64 *base, uint32 nelem) +{ + uint32 i = 0; + +#ifndef USE_NO_SIMD + + /* + * For better instruction-level parallelism, each loop iteration operates + * on a block of four registers. + */ + const Vector64 keys = vector64_broadcast(key); /* load copies of key */ + const uint32 nelem_per_vector = sizeof(Vector64) / sizeof(uint64); + const uint32 nelem_per_iteration = 4 * nelem_per_vector; + + /* round down to multiple of elements per iteration */ + const uint32 tail_idx = nelem & ~(nelem_per_iteration - 1); + +#if defined(USE_ASSERT_CHECKING) + uint64 *assert_result = NULL; + + /* pre-compute the result for assert checking */ + for (i = 0; i < nelem; i++) + { + if (key == base[i]) + { + assert_result = &base[i]; + break; + } + } +#endif + + for (i = 0; i < tail_idx; i += nelem_per_iteration) + { + Vector64 vals1, + vals2, + vals3, + vals4, + result1, + result2, + result3, + result4, + tmp1, + tmp2, + result; + + /* load the next block into 4 registers */ + vector64_load(&vals1, &base[i]); + vector64_load(&vals2, &base[i + nelem_per_vector]); + vector64_load(&vals3, &base[i + nelem_per_vector * 2]); + vector64_load(&vals4, &base[i + nelem_per_vector * 3]); + + /* compare each value to the key */ + result1 = vector64_eq(keys, vals1); + result2 = vector64_eq(keys, vals2); + result3 = vector64_eq(keys, vals3); + result4 = vector64_eq(keys, vals4); + + /* combine the results into a single variable */ + tmp1 = vector64_or(result1, result2); + tmp2 = vector64_or(result3, result4); + result = vector64_or(tmp1, tmp2); + + /* if there was a match, break out to find the matching element */ + if (vector64_is_highbit_set(result)) + break; + } +#endif /* ! USE_NO_SIMD */ + + /* Process the remaining elements one at a time. */ + for (; i < nelem; i++) + { + if (key == base[i]) + { +#ifndef USE_NO_SIMD + Assert(assert_result == &base[i]); +#endif + return &base[i]; + } + } + +#ifndef USE_NO_SIMD + Assert(assert_result == NULL); +#endif + return NULL; +} + #endif /* PG_LFIND_H */ diff --git a/src/include/port/simd.h b/src/include/port/simd.h index 1fa6c3bc6c..805b99cd61 100644 --- a/src/include/port/simd.h +++ b/src/include/port/simd.h @@ -32,6 +32,7 @@ #define USE_SSE2 typedef __m128i Vector8; typedef __m128i Vector32; +typedef __m128i Vector64; #elif defined(__aarch64__) && defined(__ARM_NEON) /* @@ -46,6 +47,7 @@ typedef __m128i Vector32; #define USE_NEON typedef uint8x16_t Vector8; typedef uint32x4_t Vector32; +typedef uint64x2_t Vector64; #else /* @@ -64,12 +66,14 @@ typedef uint64 Vector8; static inline void vector8_load(Vector8 *v, const uint8 *s); #ifndef USE_NO_SIMD static inline void vector32_load(Vector32 *v, const uint32 *s); +static inline void vector64_load(Vector64 *v, const uint64 *s); #endif /* assignment operations */ static inline Vector8 vector8_broadcast(const uint8 c); #ifndef USE_NO_SIMD static inline Vector32 vector32_broadcast(const uint32 c); +static inline Vector64 vector64_broadcast(const uint64 c); #endif /* element-wise comparisons to a scalar */ @@ -79,12 +83,16 @@ static inline bool vector8_has_le(const Vector8 v, const uint8 c); static inline bool vector8_is_highbit_set(const Vector8 v); #ifndef USE_NO_SIMD static inline bool vector32_is_highbit_set(const Vector32 v); +static inline bool vector64_is_highbit_set(const Vector64 v); #endif /* arithmetic operations */ static inline Vector8 vector8_or(const Vector8 v1, const Vector8 v2); #ifndef USE_NO_SIMD static inline Vector32 vector32_or(const Vector32 v1, const Vector32 v2); +static inline Vector64 vector64_or(const Vector64 v1, const Vector64 v2); +static inline Vector32 vector32_and(const Vector32 v1, const Vector32 v2); +static inline Vector64 vector64_and(const Vector64 v1, const Vector64 v2); static inline Vector8 vector8_ssub(const Vector8 v1, const Vector8 v2); #endif @@ -97,6 +105,7 @@ static inline Vector8 vector8_ssub(const Vector8 v1, const Vector8 v2); #ifndef USE_NO_SIMD static inline Vector8 vector8_eq(const Vector8 v1, const Vector8 v2); static inline Vector32 vector32_eq(const Vector32 v1, const Vector32 v2); +static inline Vector64 vector64_eq(const Vector64 v1, const Vector64 v2); #endif /* @@ -126,6 +135,18 @@ vector32_load(Vector32 *v, const uint32 *s) } #endif /* ! USE_NO_SIMD */ +#ifndef USE_NO_SIMD +static inline void +vector64_load(Vector64 *v, const uint64 *s) +{ +#ifdef USE_SSE2 + *v = _mm_loadu_si128((const __m128i *) s); +#elif defined(USE_NEON) + *v = vld1q_u64((const uint64_t *) s); +#endif +} +#endif /* ! USE_NO_SIMD */ + /* * Create a vector with all elements set to the same value. */ @@ -153,6 +174,18 @@ vector32_broadcast(const uint32 c) } #endif /* ! USE_NO_SIMD */ +#ifndef USE_NO_SIMD +static inline Vector64 +vector64_broadcast(const uint64 c) +{ +#ifdef USE_SSE2 + return _mm_set1_epi64x(c); +#elif defined(USE_NEON) + return vdupq_n_u64(c); +#endif +} +#endif /* ! USE_NO_SIMD */ + /* * Return true if any elements in the vector are equal to the given scalar. */ @@ -299,6 +332,23 @@ vector32_is_highbit_set(const Vector32 v) } #endif /* ! USE_NO_SIMD */ +/* + * Exactly like vector8_is_highbit_set except for the input type, so it + * looks at each byte separately. See the comment above + * vector32_is_highbit_set for more information. + */ +#ifndef USE_NO_SIMD +static inline bool +vector64_is_highbit_set(const Vector64 v) +{ +#if defined(USE_NEON) + return vector8_is_highbit_set((Vector8) v); +#else + return vector8_is_highbit_set(v); +#endif +} +#endif /* ! USE_NO_SIMD */ + /* * Return the bitwise OR of the inputs */ @@ -326,6 +376,42 @@ vector32_or(const Vector32 v1, const Vector32 v2) } #endif /* ! USE_NO_SIMD */ +#ifndef USE_NO_SIMD +static inline Vector64 +vector64_or(const Vector64 v1, const Vector64 v2) +{ +#ifdef USE_SSE2 + return _mm_or_si128(v1, v2); +#elif defined(USE_NEON) + return vorrq_u64(v1, v2); +#endif +} +#endif /* ! USE_NO_SIMD */ + +#ifndef USE_NO_SIMD +static inline Vector32 +vector32_and(const Vector32 v1, const Vector32 v2) +{ +#ifdef USE_SSE2 + return _mm_and_si128(v1, v2); +#elif defined(USE_NEON) + return vandq_u32(v1, v2); +#endif +} +#endif /* ! USE_NO_SIMD */ + +#ifndef USE_NO_SIMD +static inline Vector64 +vector64_and(const Vector64 v1, const Vector64 v2) +{ +#ifdef USE_SSE2 + return _mm_and_si128(v1, v2); +#elif defined(USE_NEON) + return vandq_u64(v1, v2); +#endif +} +#endif /* ! USE_NO_SIMD */ + /* * Return the result of subtracting the respective elements of the input * vectors using saturation (i.e., if the operation would yield a value less @@ -372,4 +458,21 @@ vector32_eq(const Vector32 v1, const Vector32 v2) } #endif /* ! USE_NO_SIMD */ +#ifndef USE_NO_SIMD +static inline Vector64 +vector64_eq(const Vector64 v1, const Vector64 v2) +{ +#ifdef USE_SSE2 + /* We have to work around SSE2's lack of _mm_cmpeq_epi64. */ + const Vector32 cmp = vector32_eq(v1, v2); + const Vector32 hi = _mm_shuffle_epi32(cmp, _MM_SHUFFLE(3, 3, 1, 1)); + const Vector32 lo = _mm_shuffle_epi32(cmp, _MM_SHUFFLE(2, 2, 0, 0)); + + return vector32_and(hi, lo); +#elif defined(USE_NEON) + return vceqq_u64(v1, v2); +#endif +} +#endif /* ! USE_NO_SIMD */ + #endif /* SIMD_H */ -- 2.25.1