I noticed that several of the List functions do simple linear searches that can be optimized with SIMD intrinsics (as was done for XidInMVCCSnapshot in 37a6e5d). The following table shows the time spent iterating over a list of n elements (via list_member_int) one billion times on my x86 laptop.
n | head (ms) | patched (ms) ------+-----------+-------------- 2 | 3884 | 3001 4 | 5506 | 4092 8 | 6209 | 3026 16 | 8797 | 4458 32 | 25051 | 7032 64 | 37611 | 12763 128 | 61886 | 22770 256 | 111170 | 59885 512 | 209612 | 103378 1024 | 407462 | 189484 I've attached a work-in-progress patch that implements these optimizations for both x86 and arm, and I will register this in the July commitfest. I'm posting this a little early in order to gauge interest. Presumably you shouldn't be using a List if you have many elements that must be routinely searched, but it might be nice to speed up these functions, anyway. This was mostly a fun weekend project, and I don't presently have any concrete examples of workloads where this might help. -- Nathan Bossart Amazon Web Services: https://aws.amazon.com
>From 4e04f84766d98f9ba6bb6fdd03bcb431c8aad1d3 Mon Sep 17 00:00:00 2001 From: Nathan Bossart <nathandboss...@gmail.com> Date: Sat, 4 Mar 2023 23:16:07 -0800 Subject: [PATCH v1 1/1] speed up several list functions with SIMD --- src/backend/nodes/list.c | 254 +++++++++++++++++++++++++++++++----- src/include/port/pg_lfind.h | 189 +++++++++++++++++++++++++++ src/include/port/simd.h | 103 +++++++++++++++ 3 files changed, 513 insertions(+), 33 deletions(-) diff --git a/src/backend/nodes/list.c b/src/backend/nodes/list.c index a709d23ef1..acc56dddb7 100644 --- a/src/backend/nodes/list.c +++ b/src/backend/nodes/list.c @@ -19,6 +19,7 @@ #include "nodes/pg_list.h" #include "port/pg_bitutils.h" +#include "port/pg_lfind.h" #include "utils/memdebug.h" #include "utils/memutils.h" @@ -680,11 +681,25 @@ list_member(const List *list, const void *datum) bool list_member_ptr(const List *list, const void *datum) { +#ifdef USE_NO_SIMD const ListCell *cell; +#endif Assert(IsPointerList(list)); check_list_invariants(list); +#ifndef USE_NO_SIMD + + Assert(sizeof(ListCell) == 8); + Assert(sizeof(void *) == 8); + + if (list == NIL) + return false; + + return pg_lfind64((uint64) datum, (uint64 *) list->elements, list->length); + +#else + foreach(cell, list) { if (lfirst(cell) == datum) @@ -692,46 +707,121 @@ list_member_ptr(const List *list, const void *datum) } return false; + +#endif } /* - * Return true iff the integer 'datum' is a member of the list. + * Optimized linear search routine (using SIMD intrinsics where available) for + * lists with inline 32-bit data. */ -bool -list_member_int(const List *list, int datum) +static inline bool +list_member_inline_internal(const List *list, uint32 datum) { + uint32 i = 0; const ListCell *cell; - Assert(IsIntegerList(list)); - check_list_invariants(list); +#ifndef USE_NO_SIMD - foreach(cell, list) + /* + * For better instruction-level parallelism, each loop iteration operates + * on a block of four registers. + */ + const Vector32 keys = vector32_broadcast(datum); /* load copies of key */ + const uint32 nelem_per_vector = sizeof(Vector32) / sizeof(uint32); + const uint32 nelem_per_iteration = 4 * nelem_per_vector; + const Vector32 mask = (Vector32) vector64_broadcast(UINT64CONST(0xFFFFFFFF)); + const uint32 *elements = (const uint32 *) list->elements; + + /* round down to multiple of elements per iteration */ + const uint32 tail_idx = (list->length * 2) & ~(nelem_per_iteration - 1); + + /* + * The SIMD-optimized portion of this routine is written with the + * expectation that the 32-bit datum we are searching for only takes up + * half of a ListCell. If that changes, this routine must change, too. + */ + Assert(sizeof(ListCell) == 8); + + for (i = 0; i < tail_idx; i += nelem_per_iteration) + { + Vector32 vals1, + vals2, + vals3, + vals4, + result1, + result2, + result3, + result4, + tmp1, + tmp2, + result, + masked; + + /* load the next block into 4 registers */ + vector32_load(&vals1, &elements[i]); + vector32_load(&vals2, &elements[i + nelem_per_vector]); + vector32_load(&vals3, &elements[i + nelem_per_vector * 2]); + vector32_load(&vals4, &elements[i + nelem_per_vector * 3]); + + /* compare each value to the key */ + result1 = vector32_eq(keys, vals1); + result2 = vector32_eq(keys, vals2); + result3 = vector32_eq(keys, vals3); + result4 = vector32_eq(keys, vals4); + + /* combine the results into a single variable */ + tmp1 = vector32_or(result1, result2); + tmp2 = vector32_or(result3, result4); + result = vector32_or(tmp1, tmp2); + + /* filter out matches in space between data */ + masked = vector32_and(result, mask); + + /* see if there was a match */ + if (vector32_is_highbit_set(masked)) + return true; + } + +#endif /* ! USE_NO_SIMD */ + + for_each_from(cell, list, i / 2) { - if (lfirst_int(cell) == datum) + if (lfirst_int(cell) == (int) datum) return true; } return false; } +/* + * Return true iff the integer 'datum' is a member of the list. + */ +bool +list_member_int(const List *list, int datum) +{ + Assert(IsIntegerList(list)); + check_list_invariants(list); + + if (list == NIL) + return false; + + return list_member_inline_internal(list, datum); +} + /* * Return true iff the OID 'datum' is a member of the list. */ bool list_member_oid(const List *list, Oid datum) { - const ListCell *cell; - Assert(IsOidList(list)); check_list_invariants(list); - foreach(cell, list) - { - if (lfirst_oid(cell) == datum) - return true; - } + if (list == NIL) + return false; - return false; + return list_member_inline_internal(list, datum); } /* @@ -740,18 +830,13 @@ list_member_oid(const List *list, Oid datum) bool list_member_xid(const List *list, TransactionId datum) { - const ListCell *cell; - Assert(IsXidList(list)); check_list_invariants(list); - foreach(cell, list) - { - if (lfirst_xid(cell) == datum) - return true; - } + if (list == NIL) + return false; - return false; + return list_member_inline_internal(list, datum); } /* @@ -875,16 +960,117 @@ list_delete_ptr(List *list, void *datum) Assert(IsPointerList(list)); check_list_invariants(list); +#ifndef USE_NO_SIMD + + Assert(sizeof(ListCell) == 8); + Assert(sizeof(void *) == 8); + + if (list == NIL) + return NIL; + + cell = (ListCell *) pg_lfind64_idx((uint64) datum, + (uint64 *) list->elements, + list->length); + if (cell != NULL) + return list_delete_cell(list, cell); + +#else + foreach(cell, list) { if (lfirst(cell) == datum) return list_delete_cell(list, cell); } +#endif + /* Didn't find a match: return the list unmodified */ return list; } +/* + * Optimized linear search routine (using SIMD intrinsics where available) for + * lists with inline 32-bit data. + */ +static inline ListCell * +list_member_inline_interal_idx(const List *list, uint32 datum) +{ + uint32 i = 0; + ListCell *cell; + +#ifndef USE_NO_SIMD + + /* + * For better instruction-level parallelism, each loop iteration operates + * on a block of four registers. + */ + const Vector32 keys = vector32_broadcast(datum); /* load copies of key */ + const uint32 nelem_per_vector = sizeof(Vector32) / sizeof(uint32); + const uint32 nelem_per_iteration = 4 * nelem_per_vector; + const Vector32 mask = (Vector32) vector64_broadcast(UINT64CONST(0xFFFFFFFF)); + const uint32 *elements = (const uint32 *) list->elements; + + /* round down to multiple of elements per iteration */ + const uint32 tail_idx = (list->length * 2) & ~(nelem_per_iteration - 1); + + /* + * The SIMD-optimized portion of this routine is written with the + * expectation that the 32-bit datum we are searching for only takes up + * half of a ListCell. If that changes, this routine must change, too. + */ + Assert(sizeof(ListCell) == 8); + + for (i = 0; i < tail_idx; i += nelem_per_iteration) + { + Vector32 vals1, + vals2, + vals3, + vals4, + result1, + result2, + result3, + result4, + tmp1, + tmp2, + result, + masked; + + /* load the next block into 4 registers */ + vector32_load(&vals1, &elements[i]); + vector32_load(&vals2, &elements[i + nelem_per_vector]); + vector32_load(&vals3, &elements[i + nelem_per_vector * 2]); + vector32_load(&vals4, &elements[i + nelem_per_vector * 3]); + + /* compare each value to the key */ + result1 = vector32_eq(keys, vals1); + result2 = vector32_eq(keys, vals2); + result3 = vector32_eq(keys, vals3); + result4 = vector32_eq(keys, vals4); + + /* combine the results into a single variable */ + tmp1 = vector32_or(result1, result2); + tmp2 = vector32_or(result3, result4); + result = vector32_or(tmp1, tmp2); + + /* filter out matches in space between data */ + masked = vector32_and(result, mask); + + /* break out and find the exact element if there was a match */ + if (vector32_is_highbit_set(masked)) + break; + } + +#endif /* ! USE_NO_SIMD */ + + for_each_from(cell, list, i / 2) + { + if (lfirst_int(cell) == (int) datum) + return cell; + } + + return NULL; +} + /* As above, but for integers */ List * list_delete_int(List *list, int datum) @@ -894,11 +1080,12 @@ list_delete_int(List *list, int datum) Assert(IsIntegerList(list)); check_list_invariants(list); - foreach(cell, list) - { - if (lfirst_int(cell) == datum) - return list_delete_cell(list, cell); - } + if (list == NIL) + return NIL; + + cell = list_member_inline_interal_idx(list, datum); + if (cell) + return list_delete_cell(list, cell); /* Didn't find a match: return the list unmodified */ return list; @@ -913,11 +1100,12 @@ list_delete_oid(List *list, Oid datum) Assert(IsOidList(list)); check_list_invariants(list); - foreach(cell, list) - { - if (lfirst_oid(cell) == datum) - return list_delete_cell(list, cell); - } + if (list == NIL) + return NIL; + + cell = list_member_inline_interal_idx(list, datum); + if (cell) + return list_delete_cell(list, cell); /* Didn't find a match: return the list unmodified */ return list; diff --git a/src/include/port/pg_lfind.h b/src/include/port/pg_lfind.h index 59aa8245ed..05d0b12463 100644 --- a/src/include/port/pg_lfind.h +++ b/src/include/port/pg_lfind.h @@ -177,4 +177,193 @@ pg_lfind32(uint32 key, uint32 *base, uint32 nelem) return false; } +/* + * pg_lfind64 + * + * Return true if there is an element in 'base' that equals 'key', otherwise + * return false. + */ +static inline bool +pg_lfind64(uint64 key, uint64 *base, uint32 nelem) +{ + uint32 i = 0; + +#ifndef USE_NO_SIMD + + /* + * For better instruction-level parallelism, each loop iteration operates + * on a block of four registers. + */ + const Vector64 keys = vector64_broadcast(key); /* load copies of key */ + const uint32 nelem_per_vector = sizeof(Vector64) / sizeof(uint64); + const uint32 nelem_per_iteration = 4 * nelem_per_vector; + + /* round down to multiple of elements per iteration */ + const uint32 tail_idx = nelem & ~(nelem_per_iteration - 1); + +#if defined(USE_ASSERT_CHECKING) + bool assert_result = false; + + /* pre-compute the result for assert checking */ + for (i = 0; i < nelem; i++) + { + if (key == base[i]) + { + assert_result = true; + break; + } + } +#endif + + for (i = 0; i < tail_idx; i += nelem_per_iteration) + { + Vector64 vals1, + vals2, + vals3, + vals4, + result1, + result2, + result3, + result4, + tmp1, + tmp2, + result; + + /* load the next block into 4 registers */ + vector64_load(&vals1, &base[i]); + vector64_load(&vals2, &base[i + nelem_per_vector]); + vector64_load(&vals3, &base[i + nelem_per_vector * 2]); + vector64_load(&vals4, &base[i + nelem_per_vector * 3]); + + /* compare each value to the key */ + result1 = vector64_eq(keys, vals1); + result2 = vector64_eq(keys, vals2); + result3 = vector64_eq(keys, vals3); + result4 = vector64_eq(keys, vals4); + + /* combine the results into a single variable */ + tmp1 = vector64_or(result1, result2); + tmp2 = vector64_or(result3, result4); + result = vector64_or(tmp1, tmp2); + + /* see if there was a match */ + if (vector64_is_highbit_set(result)) + { + Assert(assert_result == true); + return true; + } + } +#endif /* ! USE_NO_SIMD */ + + /* Process the remaining elements one at a time. */ + for (; i < nelem; i++) + { + if (key == base[i]) + { +#ifndef USE_NO_SIMD + Assert(assert_result == true); +#endif + return true; + } + } + +#ifndef USE_NO_SIMD + Assert(assert_result == false); +#endif + return false; +} + +/* + * pg_lfind64_idx + * + * Return the address of the first element in 'base' that equals 'key', or NULL + * if no match is found. + */ +static inline uint64 * +pg_lfind64_idx(uint64 key, uint64 *base, uint32 nelem) +{ + uint32 i = 0; + +#ifndef USE_NO_SIMD + + /* + * For better instruction-level parallelism, each loop iteration operates + * on a block of four registers. + */ + const Vector64 keys = vector64_broadcast(key); /* load copies of key */ + const uint32 nelem_per_vector = sizeof(Vector64) / sizeof(uint64); + const uint32 nelem_per_iteration = 4 * nelem_per_vector; + + /* round down to multiple of elements per iteration */ + const uint32 tail_idx = nelem & ~(nelem_per_iteration - 1); + +#if defined(USE_ASSERT_CHECKING) + uint64 *assert_result = NULL; + + /* pre-compute the result for assert checking */ + for (i = 0; i < nelem; i++) + { + if (key == base[i]) + { + assert_result = &base[i]; + break; + } + } +#endif + + for (i = 0; i < tail_idx; i += nelem_per_iteration) + { + Vector64 vals1, + vals2, + vals3, + vals4, + result1, + result2, + result3, + result4, + tmp1, + tmp2, + result; + + /* load the next block into 4 registers */ + vector64_load(&vals1, &base[i]); + vector64_load(&vals2, &base[i + nelem_per_vector]); + vector64_load(&vals3, &base[i + nelem_per_vector * 2]); + vector64_load(&vals4, &base[i + nelem_per_vector * 3]); + + /* compare each value to the key */ + result1 = vector64_eq(keys, vals1); + result2 = vector64_eq(keys, vals2); + result3 = vector64_eq(keys, vals3); + result4 = vector64_eq(keys, vals4); + + /* combine the results into a single variable */ + tmp1 = vector64_or(result1, result2); + tmp2 = vector64_or(result3, result4); + result = vector64_or(tmp1, tmp2); + + /* if there was a match, break out to find the matching element */ + if (vector64_is_highbit_set(result)) + break; + } +#endif /* ! USE_NO_SIMD */ + + /* Process the remaining elements one at a time. */ + for (; i < nelem; i++) + { + if (key == base[i]) + { +#ifndef USE_NO_SIMD + Assert(assert_result == &base[i]); +#endif + return &base[i]; + } + } + +#ifndef USE_NO_SIMD + Assert(assert_result == NULL); +#endif + return NULL; +} + #endif /* PG_LFIND_H */ diff --git a/src/include/port/simd.h b/src/include/port/simd.h index 1fa6c3bc6c..805b99cd61 100644 --- a/src/include/port/simd.h +++ b/src/include/port/simd.h @@ -32,6 +32,7 @@ #define USE_SSE2 typedef __m128i Vector8; typedef __m128i Vector32; +typedef __m128i Vector64; #elif defined(__aarch64__) && defined(__ARM_NEON) /* @@ -46,6 +47,7 @@ typedef __m128i Vector32; #define USE_NEON typedef uint8x16_t Vector8; typedef uint32x4_t Vector32; +typedef uint64x2_t Vector64; #else /* @@ -64,12 +66,14 @@ typedef uint64 Vector8; static inline void vector8_load(Vector8 *v, const uint8 *s); #ifndef USE_NO_SIMD static inline void vector32_load(Vector32 *v, const uint32 *s); +static inline void vector64_load(Vector64 *v, const uint64 *s); #endif /* assignment operations */ static inline Vector8 vector8_broadcast(const uint8 c); #ifndef USE_NO_SIMD static inline Vector32 vector32_broadcast(const uint32 c); +static inline Vector64 vector64_broadcast(const uint64 c); #endif /* element-wise comparisons to a scalar */ @@ -79,12 +83,16 @@ static inline bool vector8_has_le(const Vector8 v, const uint8 c); static inline bool vector8_is_highbit_set(const Vector8 v); #ifndef USE_NO_SIMD static inline bool vector32_is_highbit_set(const Vector32 v); +static inline bool vector64_is_highbit_set(const Vector64 v); #endif /* arithmetic operations */ static inline Vector8 vector8_or(const Vector8 v1, const Vector8 v2); #ifndef USE_NO_SIMD static inline Vector32 vector32_or(const Vector32 v1, const Vector32 v2); +static inline Vector64 vector64_or(const Vector64 v1, const Vector64 v2); +static inline Vector32 vector32_and(const Vector32 v1, const Vector32 v2); +static inline Vector64 vector64_and(const Vector64 v1, const Vector64 v2); static inline Vector8 vector8_ssub(const Vector8 v1, const Vector8 v2); #endif @@ -97,6 +105,7 @@ static inline Vector8 vector8_ssub(const Vector8 v1, const Vector8 v2); #ifndef USE_NO_SIMD static inline Vector8 vector8_eq(const Vector8 v1, const Vector8 v2); static inline Vector32 vector32_eq(const Vector32 v1, const Vector32 v2); +static inline Vector64 vector64_eq(const Vector64 v1, const Vector64 v2); #endif /* @@ -126,6 +135,18 @@ vector32_load(Vector32 *v, const uint32 *s) } #endif /* ! USE_NO_SIMD */ +#ifndef USE_NO_SIMD +static inline void +vector64_load(Vector64 *v, const uint64 *s) +{ +#ifdef USE_SSE2 + *v = _mm_loadu_si128((const __m128i *) s); +#elif defined(USE_NEON) + *v = vld1q_u64((const uint64_t *) s); +#endif +} +#endif /* ! USE_NO_SIMD */ + /* * Create a vector with all elements set to the same value. */ @@ -153,6 +174,18 @@ vector32_broadcast(const uint32 c) } #endif /* ! USE_NO_SIMD */ +#ifndef USE_NO_SIMD +static inline Vector64 +vector64_broadcast(const uint64 c) +{ +#ifdef USE_SSE2 + return _mm_set1_epi64x(c); +#elif defined(USE_NEON) + return vdupq_n_u64(c); +#endif +} +#endif /* ! USE_NO_SIMD */ + /* * Return true if any elements in the vector are equal to the given scalar. */ @@ -299,6 +332,23 @@ vector32_is_highbit_set(const Vector32 v) } #endif /* ! USE_NO_SIMD */ +/* + * Exactly like vector8_is_highbit_set except for the input type, so it + * looks at each byte separately. See the comment above + * vector32_is_highbit_set for more information. + */ +#ifndef USE_NO_SIMD +static inline bool +vector64_is_highbit_set(const Vector64 v) +{ +#if defined(USE_NEON) + return vector8_is_highbit_set((Vector8) v); +#else + return vector8_is_highbit_set(v); +#endif +} +#endif /* ! USE_NO_SIMD */ + /* * Return the bitwise OR of the inputs */ @@ -326,6 +376,42 @@ vector32_or(const Vector32 v1, const Vector32 v2) } #endif /* ! USE_NO_SIMD */ +#ifndef USE_NO_SIMD +static inline Vector64 +vector64_or(const Vector64 v1, const Vector64 v2) +{ +#ifdef USE_SSE2 + return _mm_or_si128(v1, v2); +#elif defined(USE_NEON) + return vorrq_u64(v1, v2); +#endif +} +#endif /* ! USE_NO_SIMD */ + +#ifndef USE_NO_SIMD +static inline Vector32 +vector32_and(const Vector32 v1, const Vector32 v2) +{ +#ifdef USE_SSE2 + return _mm_and_si128(v1, v2); +#elif defined(USE_NEON) + return vandq_u32(v1, v2); +#endif +} +#endif /* ! USE_NO_SIMD */ + +#ifndef USE_NO_SIMD +static inline Vector64 +vector64_and(const Vector64 v1, const Vector64 v2) +{ +#ifdef USE_SSE2 + return _mm_and_si128(v1, v2); +#elif defined(USE_NEON) + return vandq_u64(v1, v2); +#endif +} +#endif /* ! USE_NO_SIMD */ + /* * Return the result of subtracting the respective elements of the input * vectors using saturation (i.e., if the operation would yield a value less @@ -372,4 +458,21 @@ vector32_eq(const Vector32 v1, const Vector32 v2) } #endif /* ! USE_NO_SIMD */ +#ifndef USE_NO_SIMD +static inline Vector64 +vector64_eq(const Vector64 v1, const Vector64 v2) +{ +#ifdef USE_SSE2 + /* We have to work around SSE2's lack of _mm_cmpeq_epi64. */ + const Vector32 cmp = vector32_eq(v1, v2); + const Vector32 hi = _mm_shuffle_epi32(cmp, _MM_SHUFFLE(3, 3, 1, 1)); + const Vector32 lo = _mm_shuffle_epi32(cmp, _MM_SHUFFLE(2, 2, 0, 0)); + + return vector32_and(hi, lo); +#elif defined(USE_NEON) + return vceqq_u64(v1, v2); +#endif +} +#endif /* ! USE_NO_SIMD */ + #endif /* SIMD_H */ -- 2.25.1