On Fri, 5 Jul 2024 at 18:37, Joel Jacobson <j...@compiler.org> wrote:
>
> On Fri, Jul 5, 2024, at 18:42, Joel Jacobson wrote:
> > Very nice, v7-optimize-numeric-mul_var-small-var1-arbitrary-var2.patch
> > is now the winner on all my CPUs:
>
> I thought it would be interesting to also measure the isolated effect
> on just numeric_mul() without the query overhead.
>
> Impressive speed-up, between 25% - 81%.
>

Cool. I think we should go with the mul_var_small() patch then, since
it's more generally applicable.

I also did some testing with much larger var2 values, and saw similar
speed-ups. One high-level function that benefits from that is
factorial(), which accepts inputs up to 32177, and so uses both the
1-digit and 2-digit code with very large var2 values. I doubt anyone
actually uses it with such large inputs, but it's interesting
nonetheless:

SELECT factorial(32177);
Time: 923.117 ms  -- HEAD
Time: 534.375 ms  -- mul_var_small() patch

I did one more round of (mostly cosmetic) copy-editing. Aside from
improving some of the comments, it occurred to me that there's no need
to pass rscale to mul_var_small(), or for it to call round_var(),
since it's always computing the exact result. That shaves off a few
more cycles.

Additionally, I didn't like how res_weight and res_ndigits were being
set 1 higher than they needed to be. That makes sense in mul_var()
because it may round the result, causing a non-zero carry to propagate
into the next digit up, but it's just confusing in mul_var_small(). So
I've reduced those by 1, which makes the look much more logical. To be
clear, this doesn't change how many digits we're calculating. But now
res_ndigits is actually the number of digits being calculated, whereas
before, res_ndigits was 1 larger and we were calculating res_ndigits -
1 digits, which was confusing.

I think this is good to go, so unless there are any further comments,
I plan to commit it soon.

Possible future work would be to try extending it to larger var1
values. I have a feeling that might work quite well for 5 or 6 digits,
but at some point, we'll start seeing diminishing returns, and the
code bloat won't be worth it.

Regards,
Dean
diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c
new file mode 100644
index 5510a20..b556861
--- a/src/backend/utils/adt/numeric.c
+++ b/src/backend/utils/adt/numeric.c
@@ -551,6 +551,8 @@ static void sub_var(const NumericVar *va
 static void mul_var(const NumericVar *var1, const NumericVar *var2,
 					NumericVar *result,
 					int rscale);
+static void mul_var_small(const NumericVar *var1, const NumericVar *var2,
+						  NumericVar *result);
 static void div_var(const NumericVar *var1, const NumericVar *var2,
 					NumericVar *result,
 					int rscale, bool round);
@@ -8707,7 +8709,7 @@ mul_var(const NumericVar *var1, const Nu
 	var1digits = var1->digits;
 	var2digits = var2->digits;
 
-	if (var1ndigits == 0 || var2ndigits == 0)
+	if (var1ndigits == 0)
 	{
 		/* one or both inputs is zero; so is result */
 		zero_var(result);
@@ -8715,6 +8717,16 @@ mul_var(const NumericVar *var1, const Nu
 		return;
 	}
 
+	/*
+	 * If var1 has 1-4 digits and the exact result was requested, delegate to
+	 * mul_var_small() which uses a faster direct multiplication algorithm.
+	 */
+	if (var1ndigits <= 4 && rscale == var1->dscale + var2->dscale)
+	{
+		mul_var_small(var1, var2, result);
+		return;
+	}
+
 	/* Determine result sign and (maximum possible) weight */
 	if (var1->sign == var2->sign)
 		res_sign = NUMERIC_POS;
@@ -8862,6 +8874,212 @@ mul_var(const NumericVar *var1, const Nu
 
 	/* Strip leading and trailing zeroes */
 	strip_var(result);
+}
+
+
+/*
+ * mul_var_small() -
+ *
+ *	Special-case multiplication function used when var1 has 1-4 digits, var2
+ *	has at least as many digits as var1, and the exact product var1 * var2 is
+ *	requested.
+ */
+static void
+mul_var_small(const NumericVar *var1, const NumericVar *var2,
+			  NumericVar *result)
+{
+	int			var1ndigits = var1->ndigits;
+	int			var2ndigits = var2->ndigits;
+	NumericDigit *var1digits = var1->digits;
+	NumericDigit *var2digits = var2->digits;
+	int			res_sign;
+	int			res_weight;
+	int			res_ndigits;
+	NumericDigit *res_buf;
+	NumericDigit *res_digits;
+	uint32		carry;
+	uint32		term;
+
+	/* Check preconditions */
+	Assert(var1ndigits >= 1);
+	Assert(var1ndigits <= 4);
+	Assert(var2ndigits >= var1ndigits);
+
+	/*
+	 * Determine the result sign, weight, and number of digits to calculate.
+	 * The weight figured here is correct if the product has no leading zero
+	 * digits; otherwise strip_var() will fix things up.  Note that, unlike
+	 * mul_var(), we do not need to allocate an extra output digit, because we
+	 * are not rounding here.
+	 */
+	if (var1->sign == var2->sign)
+		res_sign = NUMERIC_POS;
+	else
+		res_sign = NUMERIC_NEG;
+	res_weight = var1->weight + var2->weight + 1;
+	res_ndigits = var1ndigits + var2ndigits;
+
+	/* Allocate result digit array */
+	res_buf = digitbuf_alloc(res_ndigits + 1);
+	res_buf[0] = 0;				/* spare digit for later rounding */
+	res_digits = res_buf + 1;
+
+	/*
+	 * Compute the result digits in reverse, in one pass, propagating the
+	 * carry up as we go.  The i'th result digit consists of the sum of the
+	 * products var1digits[i1] * var2digits[i2] for which i = i1 + i2 + 1.
+	 */
+	switch (var1ndigits)
+	{
+		case 1:
+			/* ---------
+			 * 1-digit case:
+			 *		var1ndigits = 1
+			 *		var2ndigits >= 1
+			 *		res_ndigits = var2ndigits + 1
+			 * ----------
+			 */
+			carry = 0;
+			for (int i = res_ndigits - 2; i >= 0; i--)
+			{
+				term = (uint32) var1digits[0] * var2digits[i] + carry;
+				res_digits[i + 1] = (NumericDigit) (term % NBASE);
+				carry = term / NBASE;
+			}
+			res_digits[0] = (NumericDigit) carry;
+			break;
+
+		case 2:
+			/* ---------
+			 * 2-digit case:
+			 *		var1ndigits = 2
+			 *		var2ndigits >= 2
+			 *		res_ndigits = var2ndigits + 2
+			 * ----------
+			 */
+			/* last result digit and carry */
+			term = (uint32) var1digits[1] * var2digits[res_ndigits - 3];
+			res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE);
+			carry = term / NBASE;
+
+			/* remaining digits, except for the first two */
+			for (int i = res_ndigits - 3; i >= 1; i--)
+			{
+				term = (uint32) var1digits[0] * var2digits[i] +
+					(uint32) var1digits[1] * var2digits[i - 1] + carry;
+				res_digits[i + 1] = (NumericDigit) (term % NBASE);
+				carry = term / NBASE;
+			}
+
+			/* first two digits */
+			term = (uint32) var1digits[0] * var2digits[0] + carry;
+			res_digits[1] = (NumericDigit) (term % NBASE);
+			res_digits[0] = (NumericDigit) (term / NBASE);
+			break;
+
+		case 3:
+			/* ---------
+			 * 3-digit case:
+			 *		var1ndigits = 3
+			 *		var2ndigits >= 3
+			 *		res_ndigits = var2ndigits + 3
+			 * ----------
+			 */
+			/* last two result digits */
+			term = (uint32) var1digits[2] * var2digits[res_ndigits - 4];
+			res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE);
+			carry = term / NBASE;
+
+			term = (uint32) var1digits[1] * var2digits[res_ndigits - 4] +
+				(uint32) var1digits[2] * var2digits[res_ndigits - 5] + carry;
+			res_digits[res_ndigits - 2] = (NumericDigit) (term % NBASE);
+			carry = term / NBASE;
+
+			/* remaining digits, except for the first three */
+			for (int i = res_ndigits - 4; i >= 2; i--)
+			{
+				term = (uint32) var1digits[0] * var2digits[i] +
+					(uint32) var1digits[1] * var2digits[i - 1] +
+					(uint32) var1digits[2] * var2digits[i - 2] + carry;
+				res_digits[i + 1] = (NumericDigit) (term % NBASE);
+				carry = term / NBASE;
+			}
+
+			/* first three digits */
+			term = (uint32) var1digits[0] * var2digits[1] +
+				(uint32) var1digits[1] * var2digits[0] + carry;
+			res_digits[2] = (NumericDigit) (term % NBASE);
+			carry = term / NBASE;
+
+			term = (uint32) var1digits[0] * var2digits[0] + carry;
+			res_digits[1] = (NumericDigit) (term % NBASE);
+			res_digits[0] = (NumericDigit) (term / NBASE);
+			break;
+
+		case 4:
+			/* ---------
+			 * 4-digit case:
+			 *		var1ndigits = 4
+			 *		var2ndigits >= 4
+			 *		res_ndigits = var2ndigits + 4
+			 * ----------
+			 */
+			/* last three result digits */
+			term = (uint32) var1digits[3] * var2digits[res_ndigits - 5];
+			res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE);
+			carry = term / NBASE;
+
+			term = (uint32) var1digits[2] * var2digits[res_ndigits - 5] +
+				(uint32) var1digits[3] * var2digits[res_ndigits - 6] + carry;
+			res_digits[res_ndigits - 2] = (NumericDigit) (term % NBASE);
+			carry = term / NBASE;
+
+			term = (uint32) var1digits[1] * var2digits[res_ndigits - 5] +
+				(uint32) var1digits[2] * var2digits[res_ndigits - 6] +
+				(uint32) var1digits[3] * var2digits[res_ndigits - 7] + carry;
+			res_digits[res_ndigits - 3] = (NumericDigit) (term % NBASE);
+			carry = term / NBASE;
+
+			/* remaining digits, except for the first four */
+			for (int i = res_ndigits - 5; i >= 3; i--)
+			{
+				term = (uint32) var1digits[0] * var2digits[i] +
+					(uint32) var1digits[1] * var2digits[i - 1] +
+					(uint32) var1digits[2] * var2digits[i - 2] +
+					(uint32) var1digits[3] * var2digits[i - 3] + carry;
+				res_digits[i + 1] = (NumericDigit) (term % NBASE);
+				carry = term / NBASE;
+			}
+
+			/* first four digits */
+			term = (uint32) var1digits[0] * var2digits[2] +
+				(uint32) var1digits[1] * var2digits[1] +
+				(uint32) var1digits[2] * var2digits[0] + carry;
+			res_digits[3] = (NumericDigit) (term % NBASE);
+			carry = term / NBASE;
+
+			term = (uint32) var1digits[0] * var2digits[1] +
+				(uint32) var1digits[1] * var2digits[0] + carry;
+			res_digits[2] = (NumericDigit) (term % NBASE);
+			carry = term / NBASE;
+
+			term = (uint32) var1digits[0] * var2digits[0] + carry;
+			res_digits[1] = (NumericDigit) (term % NBASE);
+			res_digits[0] = (NumericDigit) (term / NBASE);
+			break;
+	}
+
+	/* Store the product in result */
+	digitbuf_free(result->buf);
+	result->ndigits = res_ndigits;
+	result->buf = res_buf;
+	result->digits = res_digits;
+	result->weight = res_weight;
+	result->sign = res_sign;
+	result->dscale = var1->dscale + var2->dscale;
+
+	/* Strip leading and trailing zeroes */
+	strip_var(result);
 }
 
 

Reply via email to