On Tue, Jul 2, 2024, at 18:20, Joel Jacobson wrote:
> * v3-optimize-numeric-mul_var-small-var1-arbitrary-var2.patch
Hmm, v3 contains a bug which I haven't been able to solve yet.
Reporting now to avoid time waste reviewing it since it's buggy.
The attached patch is how I tested and found the bug.
It contains a file test-mul-var.sql, which tests mul_var()
with varying rscale, using the SQL-callable numeric_mul_patched(),
which third argument is the rscale_adjustment.
Out of 2481600 random tests, all passed except 4:
SELECT
result = numeric_mul_patched(var1,var2,rscale_adjustment),
COUNT(*)
FROM test_numeric_mul_patched
GROUP BY 1
ORDER BY 1;
?column? | count
----------+---------
f | 4
t | 2481596
(2 rows)
SELECT
var1,
var2,
var1*var2 AS full_resolution,
rscale_adjustment,
result AS expected,
numeric_mul_patched(var1,var2,rscale_adjustment) AS actual
FROM test_numeric_mul_patched
WHERE result <> numeric_mul_patched(var1,var2,rscale_adjustment);
var1 | var2 | full_resolution |
rscale_adjustment | expected | actual
-------------------+----------------+---------------------------+-------------------+----------+--------
676.797214075 | 0.308068877759 | 208.500158210502929257925 |
-21 | 209 | 208
555.07029 | 0.381033094735 | 211.50015039415392315 |
-17 | 212 | 211
0.476637718921 | 66.088276 | 31.500165120061470196 |
-18 | 32 | 31
0.060569165063082 | 998.85933 | 60.50007563356949425506 |
-20 | 61 | 60
(4 rows)
Trying to wrap my head around what could cause this.
It's rounding down instead of up, and these cases all end with decimal .500XXXX.
Regards,
Joel
From 074aeb223ab496f23c2075eabd35e6e76241d1d8 Mon Sep 17 00:00:00 2001
From: Joel Jakobsson <git...@compiler.org>
Date: Mon, 1 Jul 2024 07:17:50 +0200
Subject: [PATCH] Add SQL-callable numeric_mul_patched() to bench Simplified
fast-path computation
---
src/backend/utils/adt/numeric.c | 438 ++++++++++++++++++++++++++++++++
src/include/catalog/pg_proc.dat | 3 +
src/include/utils/numeric.h | 2 +
test-mul-var.sql | 48 ++++
4 files changed, 491 insertions(+)
create mode 100644 test-mul-var.sql
diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c
index 5510a203b0..8f5d553f15 100644
--- a/src/backend/utils/adt/numeric.c
+++ b/src/backend/utils/adt/numeric.c
@@ -551,6 +551,9 @@ static void sub_var(const NumericVar *var1, const
NumericVar *var2,
static void mul_var(const NumericVar *var1, const NumericVar *var2,
NumericVar *result,
int rscale);
+static void mul_var_patched(const NumericVar *var1, const NumericVar *var2,
+ NumericVar *result,
+ int rscale);
static void div_var(const NumericVar *var1, const NumericVar *var2,
NumericVar *result,
int rscale, bool round);
@@ -3115,6 +3118,131 @@ numeric_mul_opt_error(Numeric num1, Numeric num2, bool
*have_error)
}
+/*
+ * numeric_mul_patched() -
+ *
+ * This function multiplies two numeric values using the patched algorithm,
+ * designed for efficient handling of large numbers. It's introduced to allow
+ * direct benchmark comparisons with the standard numeric_mul() function.
+ */
+Datum
+numeric_mul_patched(PG_FUNCTION_ARGS)
+{
+ Numeric num1 = PG_GETARG_NUMERIC(0);
+ Numeric num2 = PG_GETARG_NUMERIC(1);
+ int32 rscale_adjustment = PG_GETARG_INT32(2);
+ Numeric res;
+
+ res = numeric_mul_patched_opt_error(num1, num2, rscale_adjustment,
NULL);
+
+ PG_RETURN_NUMERIC(res);
+}
+
+
+/*
+ * numeric_mul_patched_opt_error() -
+ *
+ * Internal version of numeric_mul_patched().
+ * If "*have_error" flag is provided, on error it's set to true, NULL
returned.
+ * This is helpful when caller need to handle errors by itself.
+ */
+Numeric
+numeric_mul_patched_opt_error(Numeric num1, Numeric num2, int32
rscale_adjustment, bool *have_error)
+{
+ NumericVar arg1;
+ NumericVar arg2;
+ NumericVar result;
+ Numeric res;
+
+ /*
+ * Handle NaN and infinities
+ */
+ if (NUMERIC_IS_SPECIAL(num1) || NUMERIC_IS_SPECIAL(num2))
+ {
+ if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
+ return make_result(&const_nan);
+ if (NUMERIC_IS_PINF(num1))
+ {
+ switch (numeric_sign_internal(num2))
+ {
+ case 0:
+ return make_result(&const_nan); /* Inf
* 0 */
+ case 1:
+ return make_result(&const_pinf);
+ case -1:
+ return make_result(&const_ninf);
+ }
+ Assert(false);
+ }
+ if (NUMERIC_IS_NINF(num1))
+ {
+ switch (numeric_sign_internal(num2))
+ {
+ case 0:
+ return make_result(&const_nan); /* -Inf
* 0 */
+ case 1:
+ return make_result(&const_ninf);
+ case -1:
+ return make_result(&const_pinf);
+ }
+ Assert(false);
+ }
+ /* by here, num1 must be finite, so num2 is not */
+ if (NUMERIC_IS_PINF(num2))
+ {
+ switch (numeric_sign_internal(num1))
+ {
+ case 0:
+ return make_result(&const_nan); /* 0 *
Inf */
+ case 1:
+ return make_result(&const_pinf);
+ case -1:
+ return make_result(&const_ninf);
+ }
+ Assert(false);
+ }
+ Assert(NUMERIC_IS_NINF(num2));
+ switch (numeric_sign_internal(num1))
+ {
+ case 0:
+ return make_result(&const_nan); /* 0 * -Inf */
+ case 1:
+ return make_result(&const_ninf);
+ case -1:
+ return make_result(&const_pinf);
+ }
+ Assert(false);
+ }
+
+ /*
+ * Unpack the values, let mul_var() compute the result and return it.
+ * Unlike add_var() and sub_var(), mul_var() will round its result. In
the
+ * case of numeric_mul(), which is invoked for the * operator on
numerics,
+ * we request exact representation for the product (rscale = sum(dscale
of
+ * arg1, dscale of arg2)). If the exact result has more digits after
the
+ * decimal point than can be stored in a numeric, we round it. Rounding
+ * after computing the exact result ensures that the final result is
+ * correctly rounded (rounding in mul_var() using a truncated product
+ * would not guarantee this).
+ */
+ init_var_from_num(num1, &arg1);
+ init_var_from_num(num2, &arg2);
+
+ init_var(&result);
+
+ mul_var_patched(&arg1, &arg2, &result, arg1.dscale + arg2.dscale +
rscale_adjustment);
+
+ if (result.dscale > NUMERIC_DSCALE_MAX)
+ round_var(&result, NUMERIC_DSCALE_MAX);
+
+ res = make_result_opt_error(&result, have_error);
+
+ free_var(&result);
+
+ return res;
+}
+
+
/*
* numeric_div() -
*
@@ -8864,6 +8992,316 @@ mul_var(const NumericVar *var1, const NumericVar *var2,
NumericVar *result,
strip_var(result);
}
+/*
+ * mul_var_patched() -
+ *
+ * Implements patched multiplication for large numbers, introduced
+ * alongside the unchanged original mul_var(). This function is part of
+ * an optimization effort, allowing direct benchmark comparisons with
+ * mul_var(). It selects full or half patched based on input size.
+ * This is a temporary measure before considering its replacement of
+ * mul_var() based on benchmark outcomes.
+ */
+static void
+mul_var_patched(const NumericVar *var1, const NumericVar *var2,
+ NumericVar *result, int rscale)
+{
+ int res_ndigits;
+ int res_sign;
+ int res_weight;
+ int maxdigits;
+ int *dig;
+ int carry;
+ int maxdig;
+ int newdig;
+ int var1ndigits;
+ int var2ndigits;
+ NumericDigit *var1digits;
+ NumericDigit *var2digits;
+ NumericDigit *res_digits;
+ int i,
+ i1,
+ i2;
+
+ /*
+ * Arrange for var1 to be the shorter of the two numbers. This improves
+ * performance because the inner multiplication loop is much simpler
than
+ * the outer loop, so it's better to have a smaller number of iterations
+ * of the outer loop. This also reduces the number of times that the
+ * accumulator array needs to be normalized.
+ */
+ if (var1->ndigits > var2->ndigits)
+ {
+ const NumericVar *tmp = var1;
+
+ var1 = var2;
+ var2 = tmp;
+ }
+
+ /* copy these values into local vars for speed in inner loop */
+ var1ndigits = var1->ndigits;
+ var2ndigits = var2->ndigits;
+ var1digits = var1->digits;
+ var2digits = var2->digits;
+
+ if (var1ndigits == 0 || var2ndigits == 0)
+ {
+ /* one or both inputs is zero; so is result */
+ zero_var(result);
+ result->dscale = rscale;
+ return;
+ }
+
+ /* Determine result sign and (maximum possible) weight */
+ if (var1->sign == var2->sign)
+ res_sign = NUMERIC_POS;
+ else
+ res_sign = NUMERIC_NEG;
+ res_weight = var1->weight + var2->weight + 2;
+
+ /*
+ * Determine the number of result digits to compute. If the exact
result
+ * would have more than rscale fractional digits, truncate the
computation
+ * with MUL_GUARD_DIGITS guard digits, i.e., ignore input digits that
+ * would only contribute to the right of that. (This will give the
exact
+ * rounded-to-rscale answer unless carries out of the ignored positions
+ * would have propagated through more than MUL_GUARD_DIGITS digits.)
+ *
+ * Note: an exact computation could not produce more than var1ndigits +
+ * var2ndigits digits, but we allocate one extra output digit in case
+ * rscale-driven rounding produces a carry out of the highest exact
digit.
+ */
+ res_ndigits = var1ndigits + var2ndigits + 1;
+ maxdigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS +
+ MUL_GUARD_DIGITS;
+ res_ndigits = Min(res_ndigits, maxdigits);
+
+ if (res_ndigits < 3)
+ {
+ /* All input digits will be ignored; so result is zero */
+ zero_var(result);
+ result->dscale = rscale;
+ return;
+ }
+
+ /*
+ * Simplified fast-path computation, if var1 has just one or two digits.
+ * This is significantly faster, since it avoids allocating a separate
+ * digit array, making multiple passes over var2, and having separate
+ * carry-propagation passes.
+ */
+ if (var1ndigits <= 3)
+ {
+ NumericDigit *res_buf;
+
+ /* Allocate result digit array */
+ res_buf = digitbuf_alloc(res_ndigits);
+ res_buf[0] = 0; /* spare digit for
later rounding */
+ res_digits = res_buf + 1;
+
+ /*
+ * Compute the result digits directly, in one pass, propagating
the
+ * carry up as we go.
+ */
+ switch (var1ndigits)
+ {
+ case 1:
+ carry = 0;
+ for (i = res_ndigits - 3; i >= 0; i--)
+ {
+ newdig = (int) var1digits[0] *
var2digits[i] + carry;
+ res_digits[i + 1] = (NumericDigit)
(newdig % NBASE);
+ carry = newdig / NBASE;
+ }
+ res_digits[0] = (NumericDigit) carry;
+ break;
+
+ case 2:
+ newdig = (int) var1digits[1] *
var2digits[res_ndigits - 4];
+ if (res_ndigits - 3 < var2ndigits)
+ newdig += (int) var1digits[0] *
var2digits[res_ndigits - 3];
+ res_digits[res_ndigits - 2] = (NumericDigit)
(newdig % NBASE);
+ carry = newdig / NBASE;
+ for (i = res_ndigits - 4; i >= 1; i--)
+ {
+ newdig = (int) var1digits[0] *
var2digits[i] +
+ (int) var1digits[1] *
var2digits[i - 1] + carry;
+ res_digits[i + 1] = (NumericDigit)
(newdig % NBASE);
+ carry = newdig / NBASE;
+ }
+ newdig = (int) var1digits[0] * var2digits[0] +
carry;
+ res_digits[1] = (NumericDigit) (newdig % NBASE);
+ res_digits[0] = (NumericDigit) (newdig / NBASE);
+ break;
+
+ case 3:
+ newdig = (int) var1digits[2] *
var2digits[res_ndigits - 5];
+ if (res_ndigits - 4 < var2ndigits)
+ newdig += (int) var1digits[1] *
var2digits[res_ndigits - 4];
+ if (res_ndigits - 3 < var2ndigits)
+ newdig += (int) var1digits[0] *
var2digits[res_ndigits - 3];
+ res_digits[res_ndigits - 2] = (NumericDigit)
(newdig % NBASE);
+ carry = newdig / NBASE;
+ for (i = res_ndigits - 4; i >= 2; i--)
+ {
+ newdig = carry;
+ if (i < var2ndigits)
+ newdig += (int) var1digits[0] *
var2digits[i];
+ if (i - 1 >= 0 && i - 1 < var2ndigits)
+ newdig += (int) var1digits[1] *
var2digits[i - 1];
+ if (i - 2 >= 0 && i - 2 < var2ndigits)
+ newdig += (int) var1digits[2] *
var2digits[i - 2];
+ res_digits[i + 1] = (NumericDigit)
(newdig % NBASE);
+ carry = newdig / NBASE;
+ }
+ newdig = carry;
+ if (var2ndigits > 1)
+ newdig += (int) var1digits[0] *
var2digits[1];
+ if (var2ndigits > 0)
+ newdig += (int) var1digits[1] *
var2digits[0];
+ res_digits[2] = (NumericDigit) (newdig % NBASE);
+ carry = newdig / NBASE;
+ newdig = (int) var1digits[0] * var2digits[0] +
carry;
+ res_digits[1] = (NumericDigit) (newdig % NBASE);
+ res_digits[0] = (NumericDigit) (newdig / NBASE);
+ break;
+ }
+
+ /* Store the product in result (minus extra rounding digit) */
+ digitbuf_free(result->buf);
+ result->ndigits = res_ndigits - 1;
+ result->buf = res_buf;
+ result->digits = res_digits;
+ result->weight = res_weight - 1;
+ result->sign = res_sign;
+
+ /* Round to target rscale (and set result->dscale) */
+ round_var(result, rscale);
+
+ /* Strip leading and trailing zeroes */
+ strip_var(result);
+
+ return;
+ }
+
+ /*
+ * We do the arithmetic in an array "dig[]" of signed int's. Since
+ * INT_MAX is noticeably larger than NBASE*NBASE, this gives us headroom
+ * to avoid normalizing carries immediately.
+ *
+ * maxdig tracks the maximum possible value of any dig[] entry; when
this
+ * threatens to exceed INT_MAX, we take the time to propagate carries.
+ * Furthermore, we need to ensure that overflow doesn't occur during the
+ * carry propagation passes either. The carry values could be as much
as
+ * INT_MAX/NBASE, so really we must normalize when digits threaten to
+ * exceed INT_MAX - INT_MAX/NBASE.
+ *
+ * To avoid overflow in maxdig itself, it actually represents the max
+ * possible value divided by NBASE-1, ie, at the top of the loop it is
+ * known that no dig[] entry exceeds maxdig * (NBASE-1).
+ */
+ dig = (int *) palloc0(res_ndigits * sizeof(int));
+ maxdig = 0;
+
+ /*
+ * The least significant digits of var1 should be ignored if they don't
+ * contribute directly to the first res_ndigits digits of the result
that
+ * we are computing.
+ *
+ * Digit i1 of var1 and digit i2 of var2 are multiplied and added to
digit
+ * i1+i2+2 of the accumulator array, so we need only consider digits of
+ * var1 for which i1 <= res_ndigits - 3.
+ */
+ for (i1 = Min(var1ndigits - 1, res_ndigits - 3); i1 >= 0; i1--)
+ {
+ NumericDigit var1digit = var1digits[i1];
+
+ if (var1digit == 0)
+ continue;
+
+ /* Time to normalize? */
+ maxdig += var1digit;
+ if (maxdig > (INT_MAX - INT_MAX / NBASE) / (NBASE - 1))
+ {
+ /* Yes, do it */
+ carry = 0;
+ for (i = res_ndigits - 1; i >= 0; i--)
+ {
+ newdig = dig[i] + carry;
+ if (newdig >= NBASE)
+ {
+ carry = newdig / NBASE;
+ newdig -= carry * NBASE;
+ }
+ else
+ carry = 0;
+ dig[i] = newdig;
+ }
+ Assert(carry == 0);
+ /* Reset maxdig to indicate new worst-case */
+ maxdig = 1 + var1digit;
+ }
+
+ /*
+ * Add the appropriate multiple of var2 into the accumulator.
+ *
+ * As above, digits of var2 can be ignored if they don't
contribute,
+ * so we only include digits for which i1+i2+2 < res_ndigits.
+ *
+ * This inner loop is the performance bottleneck for
multiplication,
+ * so we want to keep it simple enough so that it can be
+ * auto-vectorized. Accordingly, process the digits
left-to-right
+ * even though schoolbook multiplication would suggest
right-to-left.
+ * Since we aren't propagating carries in this loop, the order
does
+ * not matter.
+ */
+ {
+ int i2limit = Min(var2ndigits,
res_ndigits - i1 - 2);
+ int *dig_i1_2 = &dig[i1 + 2];
+
+ for (i2 = 0; i2 < i2limit; i2++)
+ dig_i1_2[i2] += var1digit * var2digits[i2];
+ }
+ }
+
+ /*
+ * Now we do a final carry propagation pass to normalize the result,
which
+ * we combine with storing the result digits into the output. Note that
+ * this is still done at full precision w/guard digits.
+ */
+ alloc_var(result, res_ndigits);
+ res_digits = result->digits;
+ carry = 0;
+ for (i = res_ndigits - 1; i >= 0; i--)
+ {
+ newdig = dig[i] + carry;
+ if (newdig >= NBASE)
+ {
+ carry = newdig / NBASE;
+ newdig -= carry * NBASE;
+ }
+ else
+ carry = 0;
+ res_digits[i] = newdig;
+ }
+ Assert(carry == 0);
+
+ pfree(dig);
+
+ /*
+ * Finally, round the result to the requested precision.
+ */
+ result->weight = res_weight;
+ result->sign = res_sign;
+
+ /* Round to target rscale (and set result->dscale) */
+ round_var(result, rscale);
+
+ /* Strip leading and trailing zeroes */
+ strip_var(result);
+
+}
+
/*
* div_var() -
diff --git a/src/include/catalog/pg_proc.dat b/src/include/catalog/pg_proc.dat
index d4ac578ae6..5b3024cb6d 100644
--- a/src/include/catalog/pg_proc.dat
+++ b/src/include/catalog/pg_proc.dat
@@ -4465,6 +4465,9 @@
{ oid => '1726',
proname => 'numeric_mul', prorettype => 'numeric',
proargtypes => 'numeric numeric', prosrc => 'numeric_mul' },
+{ oid => '6347',
+ proname => 'numeric_mul_patched', prorettype => 'numeric',
+ proargtypes => 'numeric numeric int4', prosrc => 'numeric_mul_patched' },
{ oid => '1727',
proname => 'numeric_div', prorettype => 'numeric',
proargtypes => 'numeric numeric', prosrc => 'numeric_div' },
diff --git a/src/include/utils/numeric.h b/src/include/utils/numeric.h
index 43c75c436f..454a56da9a 100644
--- a/src/include/utils/numeric.h
+++ b/src/include/utils/numeric.h
@@ -97,6 +97,8 @@ extern Numeric numeric_sub_opt_error(Numeric num1, Numeric
num2,
bool
*have_error);
extern Numeric numeric_mul_opt_error(Numeric num1, Numeric num2,
bool
*have_error);
+extern Numeric numeric_mul_patched_opt_error(Numeric num1, Numeric num2,
+ int32
rscale_adjustment, bool *have_error);
extern Numeric numeric_div_opt_error(Numeric num1, Numeric num2,
bool
*have_error);
extern Numeric numeric_mod_opt_error(Numeric num1, Numeric num2,
diff --git a/test-mul-var.sql b/test-mul-var.sql
new file mode 100644
index 0000000000..ee7e3855bc
--- /dev/null
+++ b/test-mul-var.sql
@@ -0,0 +1,48 @@
+CREATE TABLE test_numeric_mul_patched (
+ var1 numeric,
+ var2 numeric,
+ rscale_adjustment int,
+ result numeric
+);
+
+DO $$
+DECLARE
+var1 numeric;
+var2 numeric;
+BEGIN
+ FOR i IN 1..100 LOOP
+ RAISE NOTICE '%', i;
+ FOR var1ndigits IN 1..4 LOOP
+ FOR var2ndigits IN 1..4 LOOP
+ FOR var1dscale IN 0..(var1ndigits*4) LOOP
+ FOR var2dscale IN 0..(var2ndigits*4) LOOP
+ FOR rscale_adjustment IN 0..(var1dscale+var2dscale) LOOP
+ var1 := round(random(
+ format('1%s',repeat('0',(var1ndigits-1)*4-1))::numeric,
+ format('%s',repeat('9',var1ndigits*4))::numeric
+ ) / 10::numeric^var1dscale, var1dscale);
+ var2 := round(random(
+ format('1%s',repeat('0',(var2ndigits-1)*4-1))::numeric,
+ format('%s',repeat('9',var2ndigits*4))::numeric
+ ) / 10::numeric^var2dscale, var2dscale);
+ INSERT INTO test_numeric_mul_patched
+ (var1, var2, rscale_adjustment)
+ VALUES
+ (var1, var2, -rscale_adjustment);
+ END LOOP;
+ END LOOP;
+ END LOOP;
+ END LOOP;
+ END LOOP;
+ END LOOP;
+END $$;
+
+-- First, set result with a numeric_mul_patched() version where
+-- the Simplified fast-path computation code has been commented out.
+UPDATE test_numeric_mul_patched SET result = numeric_mul_patched(var1, var2,
rscale_adjustment);
+
+-- Then, recompile with the Simplified fast-path computation code,
+-- and check if any differences can be found:
+SELECT *, numeric_mul_patched(var1,var2,rscale_adjustment)
+FROM test_numeric_mul_patched
+WHERE result IS DISTINCT FROM numeric_mul_patched(var1,var2,rscale_adjustment);
--
2.45.1