On Mon, 12 Aug 2024 at 16:17, Joel Jacobson <j...@compiler.org> wrote:
>
> On Mon, Aug 12, 2024, at 17:14, Joel Jacobson wrote:
> > The case found with the smallest rscale adjustment was this one:
> > -[ RECORD 1 ]------+--------------------------------
> > var1               | 0.0000000000009873307197037692
> > var2               | 0.426697279270850
> > rscale_adjustment  | -15
> > expected           | 0.0000000000004212913318381285
> > numeric_mul_rscale | 0.0000000000004212913318381284
> > diff               | -0.0000000000000000000000000001
>

Hmm, interesting example. There will of course always be cases where
the result isn't exact, but HEAD does produce the expected result in
this case, and the intention is to always produce a result at least as
accurate as HEAD, so it isn't working as expected.

Looking more closely, the problem is that to fully compute the
required guard digits, it is necessary to compute at least one extra
output base-NBASE digit, because the product of base-NBASE^2 digits
contributes to the next base-NBASE digit up. So instead of

    maxdigitpairs = (maxdigits + 1) / 2;

we should do

    maxdigitpairs = maxdigits / 2 + 1;

Additionally, since maxdigits is based on res_weight, we should
actually do the res_weight adjustments for odd-length inputs before
computing maxdigits. (Otherwise we're actually computing more digits
than strictly necessary for odd-length inputs, so this is a minor
optimisation.)

Updated patch attached, which fixes the above example and all the
other differences produced by your test. I think, with a little
thought, it ought to be possible to produce examples that round
incorrectly in a more systematic (less brute-force) way. It should
then be possible to construct examples where the patch differs from
HEAD, but hopefully only by being more accurate, not less.

Regards,
Dean
From 6c1820257997facfe8e74fac8b574c8f683bbebc Mon Sep 17 00:00:00 2001
From: Dean Rasheed <dean.a.rash...@gmail.com>
Date: Thu, 18 Jul 2024 17:38:59 +0100
Subject: [PATCH v5 1/2] Extend mul_var_short() to 5 and 6-digit inputs.

Commit ca481d3c9a introduced mul_var_short(), which is used by
mul_var() whenever the shorter input has 1-4 NBASE digits and the
exact product is requested. As speculated on in that commit, it can be
extended to work for more digits in the shorter input. This commit
extends it up to 6 NBASE digits (21-24 decimal digits), for which it
also gives a significant speedup.

To avoid excessive code bloat and duplication, refactor it a bit using
macros and exploiting the fact that some portions of the code are
shared between the different cases.
---
 src/backend/utils/adt/numeric.c | 175 ++++++++++++++++++++++----------
 1 file changed, 123 insertions(+), 52 deletions(-)

diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c
index d0f0923710..ca28d0e3b3 100644
--- a/src/backend/utils/adt/numeric.c
+++ b/src/backend/utils/adt/numeric.c
@@ -8720,10 +8720,10 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
 	}
 
 	/*
-	 * If var1 has 1-4 digits and the exact result was requested, delegate to
+	 * If var1 has 1-6 digits and the exact result was requested, delegate to
 	 * mul_var_short() which uses a faster direct multiplication algorithm.
 	 */
-	if (var1ndigits <= 4 && rscale == var1->dscale + var2->dscale)
+	if (var1ndigits <= 6 && rscale == var1->dscale + var2->dscale)
 	{
 		mul_var_short(var1, var2, result);
 		return;
@@ -8882,7 +8882,7 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
 /*
  * mul_var_short() -
  *
- *	Special-case multiplication function used when var1 has 1-4 digits, var2
+ *	Special-case multiplication function used when var1 has 1-6 digits, var2
  *	has at least as many digits as var1, and the exact product var1 * var2 is
  *	requested.
  */
@@ -8904,7 +8904,7 @@ mul_var_short(const NumericVar *var1, const NumericVar *var2,
 
 	/* Check preconditions */
 	Assert(var1ndigits >= 1);
-	Assert(var1ndigits <= 4);
+	Assert(var1ndigits <= 6);
 	Assert(var2ndigits >= var1ndigits);
 
 	/*
@@ -8931,6 +8931,13 @@ mul_var_short(const NumericVar *var1, const NumericVar *var2,
 	 * carry up as we go.  The i'th result digit consists of the sum of the
 	 * products var1digits[i1] * var2digits[i2] for which i = i1 + i2 + 1.
 	 */
+#define PRODSUM1(v1,i1,v2,i2) ((v1)[i1] * (v2)[i2])
+#define PRODSUM2(v1,i1,v2,i2) (PRODSUM1(v1,i1,v2,i2) + (v1)[i1+1] * (v2)[i2-1])
+#define PRODSUM3(v1,i1,v2,i2) (PRODSUM2(v1,i1,v2,i2) + (v1)[i1+2] * (v2)[i2-2])
+#define PRODSUM4(v1,i1,v2,i2) (PRODSUM3(v1,i1,v2,i2) + (v1)[i1+3] * (v2)[i2-3])
+#define PRODSUM5(v1,i1,v2,i2) (PRODSUM4(v1,i1,v2,i2) + (v1)[i1+4] * (v2)[i2-4])
+#define PRODSUM6(v1,i1,v2,i2) (PRODSUM5(v1,i1,v2,i2) + (v1)[i1+5] * (v2)[i2-5])
+
 	switch (var1ndigits)
 	{
 		case 1:
@@ -8942,9 +8949,9 @@ mul_var_short(const NumericVar *var1, const NumericVar *var2,
 			 * ----------
 			 */
 			carry = 0;
-			for (int i = res_ndigits - 2; i >= 0; i--)
+			for (int i = var2ndigits - 1; i >= 0; i--)
 			{
-				term = (uint32) var1digits[0] * var2digits[i] + carry;
+				term = PRODSUM1(var1digits, 0, var2digits, i) + carry;
 				res_digits[i + 1] = (NumericDigit) (term % NBASE);
 				carry = term / NBASE;
 			}
@@ -8960,23 +8967,17 @@ mul_var_short(const NumericVar *var1, const NumericVar *var2,
 			 * ----------
 			 */
 			/* last result digit and carry */
-			term = (uint32) var1digits[1] * var2digits[res_ndigits - 3];
+			term = PRODSUM1(var1digits, 1, var2digits, var2ndigits - 1);
 			res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE);
 			carry = term / NBASE;
 
 			/* remaining digits, except for the first two */
-			for (int i = res_ndigits - 3; i >= 1; i--)
+			for (int i = var2ndigits - 1; i >= 1; i--)
 			{
-				term = (uint32) var1digits[0] * var2digits[i] +
-					(uint32) var1digits[1] * var2digits[i - 1] + carry;
+				term = PRODSUM2(var1digits, 0, var2digits, i) + carry;
 				res_digits[i + 1] = (NumericDigit) (term % NBASE);
 				carry = term / NBASE;
 			}
-
-			/* first two digits */
-			term = (uint32) var1digits[0] * var2digits[0] + carry;
-			res_digits[1] = (NumericDigit) (term % NBASE);
-			res_digits[0] = (NumericDigit) (term / NBASE);
 			break;
 
 		case 3:
@@ -8988,34 +8989,21 @@ mul_var_short(const NumericVar *var1, const NumericVar *var2,
 			 * ----------
 			 */
 			/* last two result digits */
-			term = (uint32) var1digits[2] * var2digits[res_ndigits - 4];
+			term = PRODSUM1(var1digits, 2, var2digits, var2ndigits - 1);
 			res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE);
 			carry = term / NBASE;
 
-			term = (uint32) var1digits[1] * var2digits[res_ndigits - 4] +
-				(uint32) var1digits[2] * var2digits[res_ndigits - 5] + carry;
+			term = PRODSUM2(var1digits, 1, var2digits, var2ndigits - 1) + carry;
 			res_digits[res_ndigits - 2] = (NumericDigit) (term % NBASE);
 			carry = term / NBASE;
 
 			/* remaining digits, except for the first three */
-			for (int i = res_ndigits - 4; i >= 2; i--)
+			for (int i = var2ndigits - 1; i >= 2; i--)
 			{
-				term = (uint32) var1digits[0] * var2digits[i] +
-					(uint32) var1digits[1] * var2digits[i - 1] +
-					(uint32) var1digits[2] * var2digits[i - 2] + carry;
+				term = PRODSUM3(var1digits, 0, var2digits, i) + carry;
 				res_digits[i + 1] = (NumericDigit) (term % NBASE);
 				carry = term / NBASE;
 			}
-
-			/* first three digits */
-			term = (uint32) var1digits[0] * var2digits[1] +
-				(uint32) var1digits[1] * var2digits[0] + carry;
-			res_digits[2] = (NumericDigit) (term % NBASE);
-			carry = term / NBASE;
-
-			term = (uint32) var1digits[0] * var2digits[0] + carry;
-			res_digits[1] = (NumericDigit) (term % NBASE);
-			res_digits[0] = (NumericDigit) (term / NBASE);
 			break;
 
 		case 4:
@@ -9027,45 +9015,128 @@ mul_var_short(const NumericVar *var1, const NumericVar *var2,
 			 * ----------
 			 */
 			/* last three result digits */
-			term = (uint32) var1digits[3] * var2digits[res_ndigits - 5];
+			term = PRODSUM1(var1digits, 3, var2digits, var2ndigits - 1);
 			res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE);
 			carry = term / NBASE;
 
-			term = (uint32) var1digits[2] * var2digits[res_ndigits - 5] +
-				(uint32) var1digits[3] * var2digits[res_ndigits - 6] + carry;
+			term = PRODSUM2(var1digits, 2, var2digits, var2ndigits - 1) + carry;
 			res_digits[res_ndigits - 2] = (NumericDigit) (term % NBASE);
 			carry = term / NBASE;
 
-			term = (uint32) var1digits[1] * var2digits[res_ndigits - 5] +
-				(uint32) var1digits[2] * var2digits[res_ndigits - 6] +
-				(uint32) var1digits[3] * var2digits[res_ndigits - 7] + carry;
+			term = PRODSUM3(var1digits, 1, var2digits, var2ndigits - 1) + carry;
 			res_digits[res_ndigits - 3] = (NumericDigit) (term % NBASE);
 			carry = term / NBASE;
 
 			/* remaining digits, except for the first four */
-			for (int i = res_ndigits - 5; i >= 3; i--)
+			for (int i = var2ndigits - 1; i >= 3; i--)
 			{
-				term = (uint32) var1digits[0] * var2digits[i] +
-					(uint32) var1digits[1] * var2digits[i - 1] +
-					(uint32) var1digits[2] * var2digits[i - 2] +
-					(uint32) var1digits[3] * var2digits[i - 3] + carry;
+				term = PRODSUM4(var1digits, 0, var2digits, i) + carry;
 				res_digits[i + 1] = (NumericDigit) (term % NBASE);
 				carry = term / NBASE;
 			}
+			break;
 
-			/* first four digits */
-			term = (uint32) var1digits[0] * var2digits[2] +
-				(uint32) var1digits[1] * var2digits[1] +
-				(uint32) var1digits[2] * var2digits[0] + carry;
-			res_digits[3] = (NumericDigit) (term % NBASE);
+		case 5:
+			/* ---------
+			 * 5-digit case:
+			 *		var1ndigits = 5
+			 *		var2ndigits >= 5
+			 *		res_ndigits = var2ndigits + 5
+			 * ----------
+			 */
+			/* last four result digits */
+			term = PRODSUM1(var1digits, 4, var2digits, var2ndigits - 1);
+			res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE);
 			carry = term / NBASE;
 
-			term = (uint32) var1digits[0] * var2digits[1] +
-				(uint32) var1digits[1] * var2digits[0] + carry;
-			res_digits[2] = (NumericDigit) (term % NBASE);
+			term = PRODSUM2(var1digits, 3, var2digits, var2ndigits - 1) + carry;
+			res_digits[res_ndigits - 2] = (NumericDigit) (term % NBASE);
+			carry = term / NBASE;
+
+			term = PRODSUM3(var1digits, 2, var2digits, var2ndigits - 1) + carry;
+			res_digits[res_ndigits - 3] = (NumericDigit) (term % NBASE);
 			carry = term / NBASE;
 
-			term = (uint32) var1digits[0] * var2digits[0] + carry;
+			term = PRODSUM4(var1digits, 1, var2digits, var2ndigits - 1) + carry;
+			res_digits[res_ndigits - 4] = (NumericDigit) (term % NBASE);
+			carry = term / NBASE;
+
+			/* remaining digits, except for the first five */
+			for (int i = var2ndigits - 1; i >= 4; i--)
+			{
+				term = PRODSUM5(var1digits, 0, var2digits, i) + carry;
+				res_digits[i + 1] = (NumericDigit) (term % NBASE);
+				carry = term / NBASE;
+			}
+			break;
+
+		case 6:
+			/* ---------
+			 * 6-digit case:
+			 *		var1ndigits = 6
+			 *		var2ndigits >= 6
+			 *		res_ndigits = var2ndigits + 6
+			 * ----------
+			 */
+			/* last five result digits */
+			term = PRODSUM1(var1digits, 5, var2digits, var2ndigits - 1);
+			res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE);
+			carry = term / NBASE;
+
+			term = PRODSUM2(var1digits, 4, var2digits, var2ndigits - 1) + carry;
+			res_digits[res_ndigits - 2] = (NumericDigit) (term % NBASE);
+			carry = term / NBASE;
+
+			term = PRODSUM3(var1digits, 3, var2digits, var2ndigits - 1) + carry;
+			res_digits[res_ndigits - 3] = (NumericDigit) (term % NBASE);
+			carry = term / NBASE;
+
+			term = PRODSUM4(var1digits, 2, var2digits, var2ndigits - 1) + carry;
+			res_digits[res_ndigits - 4] = (NumericDigit) (term % NBASE);
+			carry = term / NBASE;
+
+			term = PRODSUM5(var1digits, 1, var2digits, var2ndigits - 1) + carry;
+			res_digits[res_ndigits - 5] = (NumericDigit) (term % NBASE);
+			carry = term / NBASE;
+
+			/* remaining digits, except for the first six */
+			for (int i = var2ndigits - 1; i >= 5; i--)
+			{
+				term = PRODSUM6(var1digits, 0, var2digits, i) + carry;
+				res_digits[i + 1] = (NumericDigit) (term % NBASE);
+				carry = term / NBASE;
+			}
+			break;
+	}
+
+	/*
+	 * Finally, for var1ndigits > 1, compute the remaining var1ndigits most
+	 * significant result digits.
+	 */
+	switch (var1ndigits)
+	{
+		case 6:
+			term = PRODSUM5(var1digits, 0, var2digits, 4) + carry;
+			res_digits[5] = (NumericDigit) (term % NBASE);
+			carry = term / NBASE;
+			/* FALLTHROUGH */
+		case 5:
+			term = PRODSUM4(var1digits, 0, var2digits, 3) + carry;
+			res_digits[4] = (NumericDigit) (term % NBASE);
+			carry = term / NBASE;
+			/* FALLTHROUGH */
+		case 4:
+			term = PRODSUM3(var1digits, 0, var2digits, 2) + carry;
+			res_digits[3] = (NumericDigit) (term % NBASE);
+			carry = term / NBASE;
+			/* FALLTHROUGH */
+		case 3:
+			term = PRODSUM2(var1digits, 0, var2digits, 1) + carry;
+			res_digits[2] = (NumericDigit) (term % NBASE);
+			carry = term / NBASE;
+			/* FALLTHROUGH */
+		case 2:
+			term = PRODSUM1(var1digits, 0, var2digits, 0) + carry;
 			res_digits[1] = (NumericDigit) (term % NBASE);
 			res_digits[0] = (NumericDigit) (term / NBASE);
 			break;
-- 
2.35.3

From d37b42f04ab5a9ab0380c1e746afc615ffad50dd Mon Sep 17 00:00:00 2001
From: Dean Rasheed <dean.a.rash...@gmail.com>
Date: Thu, 18 Jul 2024 18:32:56 +0100
Subject: [PATCH v5 2/2] Optimise numeric multiplication using base-NBASE^2
 arithmetic.

Currently mul_var() uses the schoolbook multiplication algorithm,
which is O(n^2) in the number of NBASE digits. To improve performance
for large inputs, convert the inputs to base NBASE^2 before
multiplying, which effectively halves the number of digits in each
input, theoretically speeding up the computation by a factor of 4. In
practice, the actual speedup for large inputs varies between around 3
and 6 times, depending on the system and compiler used. In turn, this
significantly reduces the runtime of the numeric_big regression test.

For this to work, 64-bit integers are required for the products of
base-NBASE^2 digits, so this works best on 64-bit machines, for which
it is faster whenever the shorter input has more than 4 or 5 NBASE
digits. On 32-bit machines, the additional overheads, especially
during carry propagation and the final conversion back to base-NBASE,
are significantly higher, and it is only faster when the shorter input
has more than around 50 NBASE digits. When the shorter input has more
than 6 NBASE digits (so that mul_var_short() cannot be used), but
fewer than around 50 NBASE digits, there may be a noticeable slowdown
on 32-bit machines. That seems to be an acceptable tradeoff, given the
performance gains for other inputs, and the effort that would be
required to maintain code specifically targeting 32-bit machines.
---
 src/backend/utils/adt/numeric.c | 224 +++++++++++++++++++++-----------
 1 file changed, 150 insertions(+), 74 deletions(-)

diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c
index ca28d0e3b3..4c1cb8153d 100644
--- a/src/backend/utils/adt/numeric.c
+++ b/src/backend/utils/adt/numeric.c
@@ -101,6 +101,8 @@ typedef signed char NumericDigit;
 typedef int16 NumericDigit;
 #endif
 
+#define NBASE_SQR	(NBASE * NBASE)
+
 /*
  * The Numeric type as stored on disk.
  *
@@ -8674,21 +8676,30 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
 		int rscale)
 {
 	int			res_ndigits;
+	int			res_ndigitpairs;
 	int			res_sign;
 	int			res_weight;
+	int			pair_offset;
 	int			maxdigits;
-	int		   *dig;
-	int			carry;
-	int			maxdig;
-	int			newdig;
+	int			maxdigitpairs;
+	uint64	   *dig,
+			   *dig_i1_off;
+	uint64		maxdig;
+	uint64		carry;
+	uint64		newdig;
 	int			var1ndigits;
 	int			var2ndigits;
+	int			var1ndigitpairs;
+	int			var2ndigitpairs;
 	NumericDigit *var1digits;
 	NumericDigit *var2digits;
+	uint32		var1digitpair;
+	uint32	   *var2digitpairs;
 	NumericDigit *res_digits;
 	int			i,
 				i1,
-				i2;
+				i2,
+				i2limit;
 
 	/*
 	 * Arrange for var1 to be the shorter of the two numbers.  This improves
@@ -8729,86 +8740,164 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
 		return;
 	}
 
-	/* Determine result sign and (maximum possible) weight */
+	/* Determine result sign */
 	if (var1->sign == var2->sign)
 		res_sign = NUMERIC_POS;
 	else
 		res_sign = NUMERIC_NEG;
-	res_weight = var1->weight + var2->weight + 2;
 
 	/*
-	 * Determine the number of result digits to compute.  If the exact result
-	 * would have more than rscale fractional digits, truncate the computation
-	 * with MUL_GUARD_DIGITS guard digits, i.e., ignore input digits that
-	 * would only contribute to the right of that.  (This will give the exact
+	 * Determine the number of result digits to compute and the (maximum
+	 * possible) result weight.  If the exact result would have more than
+	 * rscale fractional digits, truncate the computation with
+	 * MUL_GUARD_DIGITS guard digits, i.e., ignore input digits that would
+	 * only contribute to the right of that.  (This will give the exact
 	 * rounded-to-rscale answer unless carries out of the ignored positions
 	 * would have propagated through more than MUL_GUARD_DIGITS digits.)
 	 *
 	 * Note: an exact computation could not produce more than var1ndigits +
-	 * var2ndigits digits, but we allocate one extra output digit in case
-	 * rscale-driven rounding produces a carry out of the highest exact digit.
+	 * var2ndigits digits, but we allocate at least one extra output digit in
+	 * case rscale-driven rounding produces a carry out of the highest exact
+	 * digit.
+	 *
+	 * The computation itself is done using base-NBASE^2 arithmetic, so we
+	 * actually process the input digits in pairs, producing a base-NBASE^2
+	 * intermediate result.  This significantly improves performance, since
+	 * schoolbook multiplication is O(N^2) in the number of input digits, and
+	 * working in base NBASE^2 effectively halves "N".
+	 *
+	 * Note: in a truncated computation, we must compute at least one extra
+	 * output digit to ensure that all the guard digits are fully computed.
 	 */
-	res_ndigits = var1ndigits + var2ndigits + 1;
+	/* digit pairs in each input */
+	var1ndigitpairs = (var1ndigits + 1) / 2;
+	var2ndigitpairs = (var2ndigits + 1) / 2;
+
+	/* digits in exact result */
+	res_ndigits = var1ndigits + var2ndigits;
+
+	/* digit pairs in exact result with at least one extra output digit */
+	res_ndigitpairs = res_ndigits / 2 + 1;
+
+	/* pair offset to align result to end of dig[] */
+	pair_offset = res_ndigitpairs - var1ndigitpairs - var2ndigitpairs + 1;
+
+	/* maximum possible result weight (odd-length inputs shifted up below) */
+	res_weight = var1->weight + var2->weight + 1 + 2 * res_ndigitpairs -
+		res_ndigits - (var1ndigits & 1) - (var2ndigits & 1);
+
+	/* rscale-based truncation with at least one extra output digit */
 	maxdigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS +
 		MUL_GUARD_DIGITS;
-	res_ndigits = Min(res_ndigits, maxdigits);
+	maxdigitpairs = maxdigits / 2 + 1;
+
+	res_ndigitpairs = Min(res_ndigitpairs, maxdigitpairs);
+	res_ndigits = 2 * res_ndigitpairs;
 
-	if (res_ndigits < 3)
+	/*
+	 * In the computation below, digit pair i1 of var1 and digit pair i2 of
+	 * var2 are multiplied and added to digit i1+i2+pair_offset of dig[]. Thus
+	 * input digit pairs with index >= res_ndigitpairs - pair_offset don't
+	 * contribute to the result, and can be ignored.
+	 */
+	if (res_ndigitpairs <= pair_offset)
 	{
 		/* All input digits will be ignored; so result is zero */
 		zero_var(result);
 		result->dscale = rscale;
 		return;
 	}
+	var1ndigitpairs = Min(var1ndigitpairs, res_ndigitpairs - pair_offset);
+	var2ndigitpairs = Min(var2ndigitpairs, res_ndigitpairs - pair_offset);
 
 	/*
-	 * We do the arithmetic in an array "dig[]" of signed int's.  Since
-	 * INT_MAX is noticeably larger than NBASE*NBASE, this gives us headroom
-	 * to avoid normalizing carries immediately.
+	 * We do the arithmetic in an array "dig[]" of unsigned 64-bit integers.
+	 * Since PG_UINT64_MAX is much larger than NBASE^4, this gives us a lot of
+	 * headroom to avoid normalizing carries immediately.
 	 *
 	 * maxdig tracks the maximum possible value of any dig[] entry; when this
-	 * threatens to exceed INT_MAX, we take the time to propagate carries.
-	 * Furthermore, we need to ensure that overflow doesn't occur during the
-	 * carry propagation passes either.  The carry values could be as much as
-	 * INT_MAX/NBASE, so really we must normalize when digits threaten to
-	 * exceed INT_MAX - INT_MAX/NBASE.
+	 * threatens to exceed PG_UINT64_MAX, we take the time to propagate
+	 * carries.  Furthermore, we need to ensure that overflow doesn't occur
+	 * during the carry propagation passes either.  The carry values could be
+	 * as much as PG_UINT64_MAX / NBASE^2, so really we must normalize when
+	 * digits threaten to exceed PG_UINT64_MAX - PG_UINT64_MAX / NBASE^2.
 	 *
-	 * To avoid overflow in maxdig itself, it actually represents the max
-	 * possible value divided by NBASE-1, ie, at the top of the loop it is
-	 * known that no dig[] entry exceeds maxdig * (NBASE-1).
+	 * To avoid overflow in maxdig itself, it actually represents the maximum
+	 * possible value divided by NBASE^2-1, i.e., at the top of the loop it is
+	 * known that no dig[] entry exceeds maxdig * (NBASE^2-1).
+	 *
+	 * The conversion of var1 to base NBASE^2 is done on the fly, as each new
+	 * digit is required.  The digits of var2 are converted upfront, and
+	 * stored at the end of dig[].  To avoid loss of precision, the input
+	 * digits are aligned with the start of digit pair array, effectively
+	 * shifting them up (multiplying by NBASE) if the inputs have an odd
+	 * number of NBASE digits.
 	 */
-	dig = (int *) palloc0(res_ndigits * sizeof(int));
-	maxdig = 0;
+	dig = (uint64 *) palloc(res_ndigitpairs * sizeof(uint64) +
+							var2ndigitpairs * sizeof(uint32));
+
+	/* convert var2 to base NBASE^2, shifting up if its length is odd */
+	var2digitpairs = (uint32 *) (dig + res_ndigitpairs);
+
+	for (i2 = 0; i2 < var2ndigitpairs - 1; i2++)
+		var2digitpairs[i2] = var2digits[2 * i2] * NBASE + var2digits[2 * i2 + 1];
+
+	if (2 * i2 + 1 < var2ndigits)
+		var2digitpairs[i2] = var2digits[2 * i2] * NBASE + var2digits[2 * i2 + 1];
+	else
+		var2digitpairs[i2] = var2digits[2 * i2] * NBASE;
 
 	/*
-	 * The least significant digits of var1 should be ignored if they don't
-	 * contribute directly to the first res_ndigits digits of the result that
-	 * we are computing.
+	 * Start by multiplying var2 by the least significant contributing digit
+	 * pair from var1, storing the results at the end of dig[], and filling
+	 * the leading digits with zeros.
 	 *
-	 * Digit i1 of var1 and digit i2 of var2 are multiplied and added to digit
-	 * i1+i2+2 of the accumulator array, so we need only consider digits of
-	 * var1 for which i1 <= res_ndigits - 3.
+	 * The loop here is the same as the inner loop below, except that we set
+	 * the results in dig[], rather than adding to them.  This is the
+	 * performance bottleneck for multiplication, so we want to keep it simple
+	 * enough so that it can be auto-vectorized.  Accordingly, process the
+	 * digits left-to-right even though schoolbook multiplication would
+	 * suggest right-to-left.  Since we aren't propagating carries in this
+	 * loop, the order does not matter.
+	 */
+	i1 = var1ndigitpairs - 1;
+	if (2 * i1 + 1 < var1ndigits)
+		var1digitpair = var1digits[2 * i1] * NBASE + var1digits[2 * i1 + 1];
+	else
+		var1digitpair = var1digits[2 * i1] * NBASE;
+	maxdig = var1digitpair;
+
+	i2limit = Min(var2ndigitpairs, res_ndigitpairs - i1 - pair_offset);
+	dig_i1_off = &dig[i1 + pair_offset];
+
+	memset(dig, 0, (i1 + pair_offset) * sizeof(uint64));
+	for (i2 = 0; i2 < i2limit; i2++)
+		dig_i1_off[i2] = (uint64) var1digitpair * var2digitpairs[i2];
+
+	/*
+	 * Next, multiply var2 by the remaining digit pairs from var1, adding the
+	 * results to dig[] at the appropriate offsets, and normalizing whenever
+	 * there is a risk of any dig[] entry overflowing.
 	 */
-	for (i1 = Min(var1ndigits - 1, res_ndigits - 3); i1 >= 0; i1--)
+	for (i1 = i1 - 1; i1 >= 0; i1--)
 	{
-		NumericDigit var1digit = var1digits[i1];
-
-		if (var1digit == 0)
+		var1digitpair = var1digits[2 * i1] * NBASE + var1digits[2 * i1 + 1];
+		if (var1digitpair == 0)
 			continue;
 
 		/* Time to normalize? */
-		maxdig += var1digit;
-		if (maxdig > (INT_MAX - INT_MAX / NBASE) / (NBASE - 1))
+		maxdig += var1digitpair;
+		if (maxdig > (PG_UINT64_MAX - PG_UINT64_MAX / NBASE_SQR) / (NBASE_SQR - 1))
 		{
-			/* Yes, do it */
+			/* Yes, do it (to base NBASE^2) */
 			carry = 0;
-			for (i = res_ndigits - 1; i >= 0; i--)
+			for (i = res_ndigitpairs - 1; i >= 0; i--)
 			{
 				newdig = dig[i] + carry;
-				if (newdig >= NBASE)
+				if (newdig >= NBASE_SQR)
 				{
-					carry = newdig / NBASE;
-					newdig -= carry * NBASE;
+					carry = newdig / NBASE_SQR;
+					newdig -= carry * NBASE_SQR;
 				}
 				else
 					carry = 0;
@@ -8816,50 +8905,37 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
 			}
 			Assert(carry == 0);
 			/* Reset maxdig to indicate new worst-case */
-			maxdig = 1 + var1digit;
+			maxdig = 1 + var1digitpair;
 		}
 
-		/*
-		 * Add the appropriate multiple of var2 into the accumulator.
-		 *
-		 * As above, digits of var2 can be ignored if they don't contribute,
-		 * so we only include digits for which i1+i2+2 < res_ndigits.
-		 *
-		 * This inner loop is the performance bottleneck for multiplication,
-		 * so we want to keep it simple enough so that it can be
-		 * auto-vectorized.  Accordingly, process the digits left-to-right
-		 * even though schoolbook multiplication would suggest right-to-left.
-		 * Since we aren't propagating carries in this loop, the order does
-		 * not matter.
-		 */
-		{
-			int			i2limit = Min(var2ndigits, res_ndigits - i1 - 2);
-			int		   *dig_i1_2 = &dig[i1 + 2];
+		/* Multiply and add */
+		i2limit = Min(var2ndigitpairs, res_ndigitpairs - i1 - pair_offset);
+		dig_i1_off = &dig[i1 + pair_offset];
 
-			for (i2 = 0; i2 < i2limit; i2++)
-				dig_i1_2[i2] += var1digit * var2digits[i2];
-		}
+		for (i2 = 0; i2 < i2limit; i2++)
+			dig_i1_off[i2] += (uint64) var1digitpair * var2digitpairs[i2];
 	}
 
 	/*
-	 * Now we do a final carry propagation pass to normalize the result, which
-	 * we combine with storing the result digits into the output. Note that
-	 * this is still done at full precision w/guard digits.
+	 * Now we do a final carry propagation pass to normalize back to base
+	 * NBASE^2, and construct the base-NBASE result digits.  Note that this is
+	 * still done at full precision w/guard digits.
 	 */
 	alloc_var(result, res_ndigits);
 	res_digits = result->digits;
 	carry = 0;
-	for (i = res_ndigits - 1; i >= 0; i--)
+	for (i = res_ndigitpairs - 1; i >= 0; i--)
 	{
 		newdig = dig[i] + carry;
-		if (newdig >= NBASE)
+		if (newdig >= NBASE_SQR)
 		{
-			carry = newdig / NBASE;
-			newdig -= carry * NBASE;
+			carry = newdig / NBASE_SQR;
+			newdig -= carry * NBASE_SQR;
 		}
 		else
 			carry = 0;
-		res_digits[i] = newdig;
+		res_digits[2 * i + 1] = (NumericDigit) ((uint32) newdig % NBASE);
+		res_digits[2 * i] = (NumericDigit) ((uint32) newdig / NBASE);
 	}
 	Assert(carry == 0);
 
-- 
2.35.3

Reply via email to