cfbot's Windows build wasn't happy with a couple of casts.  I applied a
fix similar to c6a43c2 in v2.  The patch is still very much a work in
progress.

-- 
Nathan Bossart
Amazon Web Services: https://aws.amazon.com
>From 055717233c47518ae48119938ebd203cc55f7f3c Mon Sep 17 00:00:00 2001
From: Nathan Bossart <nathandboss...@gmail.com>
Date: Sat, 4 Mar 2023 23:16:07 -0800
Subject: [PATCH v2 1/1] speed up several list functions with SIMD

---
 src/backend/nodes/list.c    | 262 +++++++++++++++++++++++++++++++-----
 src/include/port/pg_lfind.h | 189 ++++++++++++++++++++++++++
 src/include/port/simd.h     | 103 ++++++++++++++
 3 files changed, 521 insertions(+), 33 deletions(-)

diff --git a/src/backend/nodes/list.c b/src/backend/nodes/list.c
index a709d23ef1..02ddbeb3f2 100644
--- a/src/backend/nodes/list.c
+++ b/src/backend/nodes/list.c
@@ -19,6 +19,7 @@
 
 #include "nodes/pg_list.h"
 #include "port/pg_bitutils.h"
+#include "port/pg_lfind.h"
 #include "utils/memdebug.h"
 #include "utils/memutils.h"
 
@@ -680,11 +681,25 @@ list_member(const List *list, const void *datum)
 bool
 list_member_ptr(const List *list, const void *datum)
 {
+#ifdef USE_NO_SIMD
 	const ListCell *cell;
+#endif
 
 	Assert(IsPointerList(list));
 	check_list_invariants(list);
 
+#ifndef USE_NO_SIMD
+
+	Assert(sizeof(ListCell) == 8);
+	Assert(sizeof(void *) == 8);
+
+	if (list == NIL)
+		return false;
+
+	return pg_lfind64((uint64) datum, (uint64 *) list->elements, list->length);
+
+#else
+
 	foreach(cell, list)
 	{
 		if (lfirst(cell) == datum)
@@ -692,46 +707,125 @@ list_member_ptr(const List *list, const void *datum)
 	}
 
 	return false;
+
+#endif
 }
 
 /*
- * Return true iff the integer 'datum' is a member of the list.
+ * Optimized linear search routine (using SIMD intrinsics where available) for
+ * lists with inline 32-bit data.
  */
-bool
-list_member_int(const List *list, int datum)
+static inline bool
+list_member_inline_internal(const List *list, uint32 datum)
 {
+	uint32		i = 0;
 	const ListCell *cell;
 
-	Assert(IsIntegerList(list));
-	check_list_invariants(list);
+#ifndef USE_NO_SIMD
 
-	foreach(cell, list)
+	/*
+	 * For better instruction-level parallelism, each loop iteration operates
+	 * on a block of four registers.
+	 */
+	const Vector32 keys = vector32_broadcast(datum);	/* load copies of key */
+	const uint32 nelem_per_vector = sizeof(Vector32) / sizeof(uint32);
+	const uint32 nelem_per_iteration = 4 * nelem_per_vector;
+#ifdef USE_NEON
+	const Vector32 mask = (Vector32) vector64_broadcast(UINT64CONST(0xFFFFFFFF));
+#else
+	const Vector32 mask = vector64_broadcast(UINT64CONST(0xFFFFFFFF));
+#endif
+	const uint32 *elements = (const uint32 *) list->elements;
+
+	/* round down to multiple of elements per iteration */
+	const uint32 tail_idx = (list->length * 2) & ~(nelem_per_iteration - 1);
+
+	/*
+	 * The SIMD-optimized portion of this routine is written with the
+	 * expectation that the 32-bit datum we are searching for only takes up
+	 * half of a ListCell.  If that changes, this routine must change, too.
+	 */
+	Assert(sizeof(ListCell) == 8);
+
+	for (i = 0; i < tail_idx; i += nelem_per_iteration)
+	{
+		Vector32	vals1,
+					vals2,
+					vals3,
+					vals4,
+					result1,
+					result2,
+					result3,
+					result4,
+					tmp1,
+					tmp2,
+					result,
+					masked;
+
+		/* load the next block into 4 registers */
+		vector32_load(&vals1, &elements[i]);
+		vector32_load(&vals2, &elements[i + nelem_per_vector]);
+		vector32_load(&vals3, &elements[i + nelem_per_vector * 2]);
+		vector32_load(&vals4, &elements[i + nelem_per_vector * 3]);
+
+		/* compare each value to the key */
+		result1 = vector32_eq(keys, vals1);
+		result2 = vector32_eq(keys, vals2);
+		result3 = vector32_eq(keys, vals3);
+		result4 = vector32_eq(keys, vals4);
+
+		/* combine the results into a single variable */
+		tmp1 = vector32_or(result1, result2);
+		tmp2 = vector32_or(result3, result4);
+		result = vector32_or(tmp1, tmp2);
+
+		/* filter out matches in space between data */
+		masked = vector32_and(result, mask);
+
+		/* see if there was a match */
+		if (vector32_is_highbit_set(masked))
+			return true;
+	}
+
+#endif							/* ! USE_NO_SIMD */
+
+	for_each_from(cell, list, i / 2)
 	{
-		if (lfirst_int(cell) == datum)
+		if (lfirst_int(cell) == (int) datum)
 			return true;
 	}
 
 	return false;
 }
 
+/*
+ * Return true iff the integer 'datum' is a member of the list.
+ */
+bool
+list_member_int(const List *list, int datum)
+{
+	Assert(IsIntegerList(list));
+	check_list_invariants(list);
+
+	if (list == NIL)
+		return false;
+
+	return list_member_inline_internal(list, datum);
+}
+
 /*
  * Return true iff the OID 'datum' is a member of the list.
  */
 bool
 list_member_oid(const List *list, Oid datum)
 {
-	const ListCell *cell;
-
 	Assert(IsOidList(list));
 	check_list_invariants(list);
 
-	foreach(cell, list)
-	{
-		if (lfirst_oid(cell) == datum)
-			return true;
-	}
+	if (list == NIL)
+		return false;
 
-	return false;
+	return list_member_inline_internal(list, datum);
 }
 
 /*
@@ -740,18 +834,13 @@ list_member_oid(const List *list, Oid datum)
 bool
 list_member_xid(const List *list, TransactionId datum)
 {
-	const ListCell *cell;
-
 	Assert(IsXidList(list));
 	check_list_invariants(list);
 
-	foreach(cell, list)
-	{
-		if (lfirst_xid(cell) == datum)
-			return true;
-	}
+	if (list == NIL)
+		return false;
 
-	return false;
+	return list_member_inline_internal(list, datum);
 }
 
 /*
@@ -875,16 +964,121 @@ list_delete_ptr(List *list, void *datum)
 	Assert(IsPointerList(list));
 	check_list_invariants(list);
 
+#ifndef USE_NO_SIMD
+
+	Assert(sizeof(ListCell) == 8);
+	Assert(sizeof(void *) == 8);
+
+	if (list == NIL)
+		return NIL;
+
+	cell = (ListCell *) pg_lfind64_idx((uint64) datum,
+									   (uint64 *) list->elements,
+									   list->length);
+	if (cell != NULL)
+		return list_delete_cell(list, cell);
+
+#else
+
 	foreach(cell, list)
 	{
 		if (lfirst(cell) == datum)
 			return list_delete_cell(list, cell);
 	}
 
+#endif
+
 	/* Didn't find a match: return the list unmodified */
 	return list;
 }
 
+/*
+ * Optimized linear search routine (using SIMD intrinsics where available) for
+ * lists with inline 32-bit data.
+ */
+static inline ListCell *
+list_member_inline_interal_idx(const List *list, uint32 datum)
+{
+	uint32		i = 0;
+	ListCell   *cell;
+
+#ifndef USE_NO_SIMD
+
+	/*
+	 * For better instruction-level parallelism, each loop iteration operates
+	 * on a block of four registers.
+	 */
+	const Vector32 keys = vector32_broadcast(datum);	/* load copies of key */
+	const uint32 nelem_per_vector = sizeof(Vector32) / sizeof(uint32);
+	const uint32 nelem_per_iteration = 4 * nelem_per_vector;
+#ifdef USE_NEON
+	const Vector32 mask = (Vector32) vector64_broadcast(UINT64CONST(0xFFFFFFFF));
+#else
+	const Vector32 mask = vector64_broadcast(UINT64CONST(0xFFFFFFFF));
+#endif
+	const uint32 *elements = (const uint32 *) list->elements;
+
+	/* round down to multiple of elements per iteration */
+	const uint32 tail_idx = (list->length * 2) & ~(nelem_per_iteration - 1);
+
+	/*
+	 * The SIMD-optimized portion of this routine is written with the
+	 * expectation that the 32-bit datum we are searching for only takes up
+	 * half of a ListCell.  If that changes, this routine must change, too.
+	 */
+	Assert(sizeof(ListCell) == 8);
+
+	for (i = 0; i < tail_idx; i += nelem_per_iteration)
+	{
+		Vector32	vals1,
+					vals2,
+					vals3,
+					vals4,
+					result1,
+					result2,
+					result3,
+					result4,
+					tmp1,
+					tmp2,
+					result,
+					masked;
+
+		/* load the next block into 4 registers */
+		vector32_load(&vals1, &elements[i]);
+		vector32_load(&vals2, &elements[i + nelem_per_vector]);
+		vector32_load(&vals3, &elements[i + nelem_per_vector * 2]);
+		vector32_load(&vals4, &elements[i + nelem_per_vector * 3]);
+
+		/* compare each value to the key */
+		result1 = vector32_eq(keys, vals1);
+		result2 = vector32_eq(keys, vals2);
+		result3 = vector32_eq(keys, vals3);
+		result4 = vector32_eq(keys, vals4);
+
+		/* combine the results into a single variable */
+		tmp1 = vector32_or(result1, result2);
+		tmp2 = vector32_or(result3, result4);
+		result = vector32_or(tmp1, tmp2);
+
+		/* filter out matches in space between data */
+		masked = vector32_and(result, mask);
+
+		/* break out and find the exact element if there was a match */
+		if (vector32_is_highbit_set(masked))
+			break;
+	}
+
+#endif							/* ! USE_NO_SIMD */
+
+	for_each_from(cell, list, i / 2)
+	{
+		if (lfirst_int(cell) == (int) datum)
+			return cell;
+	}
+
+	return NULL;
+}
+
 /* As above, but for integers */
 List *
 list_delete_int(List *list, int datum)
@@ -894,11 +1088,12 @@ list_delete_int(List *list, int datum)
 	Assert(IsIntegerList(list));
 	check_list_invariants(list);
 
-	foreach(cell, list)
-	{
-		if (lfirst_int(cell) == datum)
-			return list_delete_cell(list, cell);
-	}
+	if (list == NIL)
+		return NIL;
+
+	cell = list_member_inline_interal_idx(list, datum);
+	if (cell)
+		return list_delete_cell(list, cell);
 
 	/* Didn't find a match: return the list unmodified */
 	return list;
@@ -913,11 +1108,12 @@ list_delete_oid(List *list, Oid datum)
 	Assert(IsOidList(list));
 	check_list_invariants(list);
 
-	foreach(cell, list)
-	{
-		if (lfirst_oid(cell) == datum)
-			return list_delete_cell(list, cell);
-	}
+	if (list == NIL)
+		return NIL;
+
+	cell = list_member_inline_interal_idx(list, datum);
+	if (cell)
+		return list_delete_cell(list, cell);
 
 	/* Didn't find a match: return the list unmodified */
 	return list;
diff --git a/src/include/port/pg_lfind.h b/src/include/port/pg_lfind.h
index 59aa8245ed..05d0b12463 100644
--- a/src/include/port/pg_lfind.h
+++ b/src/include/port/pg_lfind.h
@@ -177,4 +177,193 @@ pg_lfind32(uint32 key, uint32 *base, uint32 nelem)
 	return false;
 }
 
+/*
+ * pg_lfind64
+ *
+ * Return true if there is an element in 'base' that equals 'key', otherwise
+ * return false.
+ */
+static inline bool
+pg_lfind64(uint64 key, uint64 *base, uint32 nelem)
+{
+	uint32		i = 0;
+
+#ifndef USE_NO_SIMD
+
+	/*
+	 * For better instruction-level parallelism, each loop iteration operates
+	 * on a block of four registers.
+	 */
+	const Vector64 keys = vector64_broadcast(key);	/* load copies of key */
+	const uint32 nelem_per_vector = sizeof(Vector64) / sizeof(uint64);
+	const uint32 nelem_per_iteration = 4 * nelem_per_vector;
+
+	/* round down to multiple of elements per iteration */
+	const uint32 tail_idx = nelem & ~(nelem_per_iteration - 1);
+
+#if defined(USE_ASSERT_CHECKING)
+	bool		assert_result = false;
+
+	/* pre-compute the result for assert checking */
+	for (i = 0; i < nelem; i++)
+	{
+		if (key == base[i])
+		{
+			assert_result = true;
+			break;
+		}
+	}
+#endif
+
+	for (i = 0; i < tail_idx; i += nelem_per_iteration)
+	{
+		Vector64	vals1,
+					vals2,
+					vals3,
+					vals4,
+					result1,
+					result2,
+					result3,
+					result4,
+					tmp1,
+					tmp2,
+					result;
+
+		/* load the next block into 4 registers */
+		vector64_load(&vals1, &base[i]);
+		vector64_load(&vals2, &base[i + nelem_per_vector]);
+		vector64_load(&vals3, &base[i + nelem_per_vector * 2]);
+		vector64_load(&vals4, &base[i + nelem_per_vector * 3]);
+
+		/* compare each value to the key */
+		result1 = vector64_eq(keys, vals1);
+		result2 = vector64_eq(keys, vals2);
+		result3 = vector64_eq(keys, vals3);
+		result4 = vector64_eq(keys, vals4);
+
+		/* combine the results into a single variable */
+		tmp1 = vector64_or(result1, result2);
+		tmp2 = vector64_or(result3, result4);
+		result = vector64_or(tmp1, tmp2);
+
+		/* see if there was a match */
+		if (vector64_is_highbit_set(result))
+		{
+			Assert(assert_result == true);
+			return true;
+		}
+	}
+#endif							/* ! USE_NO_SIMD */
+
+	/* Process the remaining elements one at a time. */
+	for (; i < nelem; i++)
+	{
+		if (key == base[i])
+		{
+#ifndef USE_NO_SIMD
+			Assert(assert_result == true);
+#endif
+			return true;
+		}
+	}
+
+#ifndef USE_NO_SIMD
+	Assert(assert_result == false);
+#endif
+	return false;
+}
+
+/*
+ * pg_lfind64_idx
+ *
+ * Return the address of the first element in 'base' that equals 'key', or NULL
+ * if no match is found.
+ */
+static inline uint64 *
+pg_lfind64_idx(uint64 key, uint64 *base, uint32 nelem)
+{
+	uint32		i = 0;
+
+#ifndef USE_NO_SIMD
+
+	/*
+	 * For better instruction-level parallelism, each loop iteration operates
+	 * on a block of four registers.
+	 */
+	const Vector64 keys = vector64_broadcast(key);	/* load copies of key */
+	const uint32 nelem_per_vector = sizeof(Vector64) / sizeof(uint64);
+	const uint32 nelem_per_iteration = 4 * nelem_per_vector;
+
+	/* round down to multiple of elements per iteration */
+	const uint32 tail_idx = nelem & ~(nelem_per_iteration - 1);
+
+#if defined(USE_ASSERT_CHECKING)
+	uint64	   *assert_result = NULL;
+
+	/* pre-compute the result for assert checking */
+	for (i = 0; i < nelem; i++)
+	{
+		if (key == base[i])
+		{
+			assert_result = &base[i];
+			break;
+		}
+	}
+#endif
+
+	for (i = 0; i < tail_idx; i += nelem_per_iteration)
+	{
+		Vector64	vals1,
+					vals2,
+					vals3,
+					vals4,
+					result1,
+					result2,
+					result3,
+					result4,
+					tmp1,
+					tmp2,
+					result;
+
+		/* load the next block into 4 registers */
+		vector64_load(&vals1, &base[i]);
+		vector64_load(&vals2, &base[i + nelem_per_vector]);
+		vector64_load(&vals3, &base[i + nelem_per_vector * 2]);
+		vector64_load(&vals4, &base[i + nelem_per_vector * 3]);
+
+		/* compare each value to the key */
+		result1 = vector64_eq(keys, vals1);
+		result2 = vector64_eq(keys, vals2);
+		result3 = vector64_eq(keys, vals3);
+		result4 = vector64_eq(keys, vals4);
+
+		/* combine the results into a single variable */
+		tmp1 = vector64_or(result1, result2);
+		tmp2 = vector64_or(result3, result4);
+		result = vector64_or(tmp1, tmp2);
+
+		/* if there was a match, break out to find the matching element */
+		if (vector64_is_highbit_set(result))
+			break;
+	}
+#endif							/* ! USE_NO_SIMD */
+
+	/* Process the remaining elements one at a time. */
+	for (; i < nelem; i++)
+	{
+		if (key == base[i])
+		{
+#ifndef USE_NO_SIMD
+			Assert(assert_result == &base[i]);
+#endif
+			return &base[i];
+		}
+	}
+
+#ifndef USE_NO_SIMD
+	Assert(assert_result == NULL);
+#endif
+	return NULL;
+}
+
 #endif							/* PG_LFIND_H */
diff --git a/src/include/port/simd.h b/src/include/port/simd.h
index 1fa6c3bc6c..805b99cd61 100644
--- a/src/include/port/simd.h
+++ b/src/include/port/simd.h
@@ -32,6 +32,7 @@
 #define USE_SSE2
 typedef __m128i Vector8;
 typedef __m128i Vector32;
+typedef __m128i Vector64;
 
 #elif defined(__aarch64__) && defined(__ARM_NEON)
 /*
@@ -46,6 +47,7 @@ typedef __m128i Vector32;
 #define USE_NEON
 typedef uint8x16_t Vector8;
 typedef uint32x4_t Vector32;
+typedef uint64x2_t Vector64;
 
 #else
 /*
@@ -64,12 +66,14 @@ typedef uint64 Vector8;
 static inline void vector8_load(Vector8 *v, const uint8 *s);
 #ifndef USE_NO_SIMD
 static inline void vector32_load(Vector32 *v, const uint32 *s);
+static inline void vector64_load(Vector64 *v, const uint64 *s);
 #endif
 
 /* assignment operations */
 static inline Vector8 vector8_broadcast(const uint8 c);
 #ifndef USE_NO_SIMD
 static inline Vector32 vector32_broadcast(const uint32 c);
+static inline Vector64 vector64_broadcast(const uint64 c);
 #endif
 
 /* element-wise comparisons to a scalar */
@@ -79,12 +83,16 @@ static inline bool vector8_has_le(const Vector8 v, const uint8 c);
 static inline bool vector8_is_highbit_set(const Vector8 v);
 #ifndef USE_NO_SIMD
 static inline bool vector32_is_highbit_set(const Vector32 v);
+static inline bool vector64_is_highbit_set(const Vector64 v);
 #endif
 
 /* arithmetic operations */
 static inline Vector8 vector8_or(const Vector8 v1, const Vector8 v2);
 #ifndef USE_NO_SIMD
 static inline Vector32 vector32_or(const Vector32 v1, const Vector32 v2);
+static inline Vector64 vector64_or(const Vector64 v1, const Vector64 v2);
+static inline Vector32 vector32_and(const Vector32 v1, const Vector32 v2);
+static inline Vector64 vector64_and(const Vector64 v1, const Vector64 v2);
 static inline Vector8 vector8_ssub(const Vector8 v1, const Vector8 v2);
 #endif
 
@@ -97,6 +105,7 @@ static inline Vector8 vector8_ssub(const Vector8 v1, const Vector8 v2);
 #ifndef USE_NO_SIMD
 static inline Vector8 vector8_eq(const Vector8 v1, const Vector8 v2);
 static inline Vector32 vector32_eq(const Vector32 v1, const Vector32 v2);
+static inline Vector64 vector64_eq(const Vector64 v1, const Vector64 v2);
 #endif
 
 /*
@@ -126,6 +135,18 @@ vector32_load(Vector32 *v, const uint32 *s)
 }
 #endif							/* ! USE_NO_SIMD */
 
+#ifndef USE_NO_SIMD
+static inline void
+vector64_load(Vector64 *v, const uint64 *s)
+{
+#ifdef USE_SSE2
+	*v = _mm_loadu_si128((const __m128i *) s);
+#elif defined(USE_NEON)
+	*v = vld1q_u64((const uint64_t *) s);
+#endif
+}
+#endif							/* ! USE_NO_SIMD */
+
 /*
  * Create a vector with all elements set to the same value.
  */
@@ -153,6 +174,18 @@ vector32_broadcast(const uint32 c)
 }
 #endif							/* ! USE_NO_SIMD */
 
+#ifndef USE_NO_SIMD
+static inline Vector64
+vector64_broadcast(const uint64 c)
+{
+#ifdef USE_SSE2
+	return _mm_set1_epi64x(c);
+#elif defined(USE_NEON)
+	return vdupq_n_u64(c);
+#endif
+}
+#endif							/* ! USE_NO_SIMD */
+
 /*
  * Return true if any elements in the vector are equal to the given scalar.
  */
@@ -299,6 +332,23 @@ vector32_is_highbit_set(const Vector32 v)
 }
 #endif							/* ! USE_NO_SIMD */
 
+/*
+ * Exactly like vector8_is_highbit_set except for the input type, so it
+ * looks at each byte separately.  See the comment above
+ * vector32_is_highbit_set for more information.
+ */
+#ifndef USE_NO_SIMD
+static inline bool
+vector64_is_highbit_set(const Vector64 v)
+{
+#if defined(USE_NEON)
+	return vector8_is_highbit_set((Vector8) v);
+#else
+	return vector8_is_highbit_set(v);
+#endif
+}
+#endif							/* ! USE_NO_SIMD */
+
 /*
  * Return the bitwise OR of the inputs
  */
@@ -326,6 +376,42 @@ vector32_or(const Vector32 v1, const Vector32 v2)
 }
 #endif							/* ! USE_NO_SIMD */
 
+#ifndef USE_NO_SIMD
+static inline Vector64
+vector64_or(const Vector64 v1, const Vector64 v2)
+{
+#ifdef USE_SSE2
+	return _mm_or_si128(v1, v2);
+#elif defined(USE_NEON)
+	return vorrq_u64(v1, v2);
+#endif
+}
+#endif							/* ! USE_NO_SIMD */
+
+#ifndef USE_NO_SIMD
+static inline Vector32
+vector32_and(const Vector32 v1, const Vector32 v2)
+{
+#ifdef USE_SSE2
+	return _mm_and_si128(v1, v2);
+#elif defined(USE_NEON)
+	return vandq_u32(v1, v2);
+#endif
+}
+#endif							/* ! USE_NO_SIMD */
+
+#ifndef USE_NO_SIMD
+static inline Vector64
+vector64_and(const Vector64 v1, const Vector64 v2)
+{
+#ifdef USE_SSE2
+	return _mm_and_si128(v1, v2);
+#elif defined(USE_NEON)
+	return vandq_u64(v1, v2);
+#endif
+}
+#endif							/* ! USE_NO_SIMD */
+
 /*
  * Return the result of subtracting the respective elements of the input
  * vectors using saturation (i.e., if the operation would yield a value less
@@ -372,4 +458,21 @@ vector32_eq(const Vector32 v1, const Vector32 v2)
 }
 #endif							/* ! USE_NO_SIMD */
 
+#ifndef USE_NO_SIMD
+static inline Vector64
+vector64_eq(const Vector64 v1, const Vector64 v2)
+{
+#ifdef USE_SSE2
+	/* We have to work around SSE2's lack of _mm_cmpeq_epi64. */
+	const Vector32 cmp = vector32_eq(v1, v2);
+	const Vector32 hi = _mm_shuffle_epi32(cmp, _MM_SHUFFLE(3, 3, 1, 1));
+	const Vector32 lo = _mm_shuffle_epi32(cmp, _MM_SHUFFLE(2, 2, 0, 0));
+
+	return vector32_and(hi, lo);
+#elif defined(USE_NEON)
+	return vceqq_u64(v1, v2);
+#endif
+}
+#endif							/* ! USE_NO_SIMD */
+
 #endif							/* SIMD_H */
-- 
2.25.1

Reply via email to