On Fri, 5 Jul 2024 at 18:37, Joel Jacobson <j...@compiler.org> wrote: > > On Fri, Jul 5, 2024, at 18:42, Joel Jacobson wrote: > > Very nice, v7-optimize-numeric-mul_var-small-var1-arbitrary-var2.patch > > is now the winner on all my CPUs: > > I thought it would be interesting to also measure the isolated effect > on just numeric_mul() without the query overhead. > > Impressive speed-up, between 25% - 81%. >
Cool. I think we should go with the mul_var_small() patch then, since it's more generally applicable. I also did some testing with much larger var2 values, and saw similar speed-ups. One high-level function that benefits from that is factorial(), which accepts inputs up to 32177, and so uses both the 1-digit and 2-digit code with very large var2 values. I doubt anyone actually uses it with such large inputs, but it's interesting nonetheless: SELECT factorial(32177); Time: 923.117 ms -- HEAD Time: 534.375 ms -- mul_var_small() patch I did one more round of (mostly cosmetic) copy-editing. Aside from improving some of the comments, it occurred to me that there's no need to pass rscale to mul_var_small(), or for it to call round_var(), since it's always computing the exact result. That shaves off a few more cycles. Additionally, I didn't like how res_weight and res_ndigits were being set 1 higher than they needed to be. That makes sense in mul_var() because it may round the result, causing a non-zero carry to propagate into the next digit up, but it's just confusing in mul_var_small(). So I've reduced those by 1, which makes the look much more logical. To be clear, this doesn't change how many digits we're calculating. But now res_ndigits is actually the number of digits being calculated, whereas before, res_ndigits was 1 larger and we were calculating res_ndigits - 1 digits, which was confusing. I think this is good to go, so unless there are any further comments, I plan to commit it soon. Possible future work would be to try extending it to larger var1 values. I have a feeling that might work quite well for 5 or 6 digits, but at some point, we'll start seeing diminishing returns, and the code bloat won't be worth it. Regards, Dean
diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c new file mode 100644 index 5510a20..b556861 --- a/src/backend/utils/adt/numeric.c +++ b/src/backend/utils/adt/numeric.c @@ -551,6 +551,8 @@ static void sub_var(const NumericVar *va static void mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result, int rscale); +static void mul_var_small(const NumericVar *var1, const NumericVar *var2, + NumericVar *result); static void div_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result, int rscale, bool round); @@ -8707,7 +8709,7 @@ mul_var(const NumericVar *var1, const Nu var1digits = var1->digits; var2digits = var2->digits; - if (var1ndigits == 0 || var2ndigits == 0) + if (var1ndigits == 0) { /* one or both inputs is zero; so is result */ zero_var(result); @@ -8715,6 +8717,16 @@ mul_var(const NumericVar *var1, const Nu return; } + /* + * If var1 has 1-4 digits and the exact result was requested, delegate to + * mul_var_small() which uses a faster direct multiplication algorithm. + */ + if (var1ndigits <= 4 && rscale == var1->dscale + var2->dscale) + { + mul_var_small(var1, var2, result); + return; + } + /* Determine result sign and (maximum possible) weight */ if (var1->sign == var2->sign) res_sign = NUMERIC_POS; @@ -8862,6 +8874,212 @@ mul_var(const NumericVar *var1, const Nu /* Strip leading and trailing zeroes */ strip_var(result); +} + + +/* + * mul_var_small() - + * + * Special-case multiplication function used when var1 has 1-4 digits, var2 + * has at least as many digits as var1, and the exact product var1 * var2 is + * requested. + */ +static void +mul_var_small(const NumericVar *var1, const NumericVar *var2, + NumericVar *result) +{ + int var1ndigits = var1->ndigits; + int var2ndigits = var2->ndigits; + NumericDigit *var1digits = var1->digits; + NumericDigit *var2digits = var2->digits; + int res_sign; + int res_weight; + int res_ndigits; + NumericDigit *res_buf; + NumericDigit *res_digits; + uint32 carry; + uint32 term; + + /* Check preconditions */ + Assert(var1ndigits >= 1); + Assert(var1ndigits <= 4); + Assert(var2ndigits >= var1ndigits); + + /* + * Determine the result sign, weight, and number of digits to calculate. + * The weight figured here is correct if the product has no leading zero + * digits; otherwise strip_var() will fix things up. Note that, unlike + * mul_var(), we do not need to allocate an extra output digit, because we + * are not rounding here. + */ + if (var1->sign == var2->sign) + res_sign = NUMERIC_POS; + else + res_sign = NUMERIC_NEG; + res_weight = var1->weight + var2->weight + 1; + res_ndigits = var1ndigits + var2ndigits; + + /* Allocate result digit array */ + res_buf = digitbuf_alloc(res_ndigits + 1); + res_buf[0] = 0; /* spare digit for later rounding */ + res_digits = res_buf + 1; + + /* + * Compute the result digits in reverse, in one pass, propagating the + * carry up as we go. The i'th result digit consists of the sum of the + * products var1digits[i1] * var2digits[i2] for which i = i1 + i2 + 1. + */ + switch (var1ndigits) + { + case 1: + /* --------- + * 1-digit case: + * var1ndigits = 1 + * var2ndigits >= 1 + * res_ndigits = var2ndigits + 1 + * ---------- + */ + carry = 0; + for (int i = res_ndigits - 2; i >= 0; i--) + { + term = (uint32) var1digits[0] * var2digits[i] + carry; + res_digits[i + 1] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + } + res_digits[0] = (NumericDigit) carry; + break; + + case 2: + /* --------- + * 2-digit case: + * var1ndigits = 2 + * var2ndigits >= 2 + * res_ndigits = var2ndigits + 2 + * ---------- + */ + /* last result digit and carry */ + term = (uint32) var1digits[1] * var2digits[res_ndigits - 3]; + res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + + /* remaining digits, except for the first two */ + for (int i = res_ndigits - 3; i >= 1; i--) + { + term = (uint32) var1digits[0] * var2digits[i] + + (uint32) var1digits[1] * var2digits[i - 1] + carry; + res_digits[i + 1] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + } + + /* first two digits */ + term = (uint32) var1digits[0] * var2digits[0] + carry; + res_digits[1] = (NumericDigit) (term % NBASE); + res_digits[0] = (NumericDigit) (term / NBASE); + break; + + case 3: + /* --------- + * 3-digit case: + * var1ndigits = 3 + * var2ndigits >= 3 + * res_ndigits = var2ndigits + 3 + * ---------- + */ + /* last two result digits */ + term = (uint32) var1digits[2] * var2digits[res_ndigits - 4]; + res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + + term = (uint32) var1digits[1] * var2digits[res_ndigits - 4] + + (uint32) var1digits[2] * var2digits[res_ndigits - 5] + carry; + res_digits[res_ndigits - 2] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + + /* remaining digits, except for the first three */ + for (int i = res_ndigits - 4; i >= 2; i--) + { + term = (uint32) var1digits[0] * var2digits[i] + + (uint32) var1digits[1] * var2digits[i - 1] + + (uint32) var1digits[2] * var2digits[i - 2] + carry; + res_digits[i + 1] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + } + + /* first three digits */ + term = (uint32) var1digits[0] * var2digits[1] + + (uint32) var1digits[1] * var2digits[0] + carry; + res_digits[2] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + + term = (uint32) var1digits[0] * var2digits[0] + carry; + res_digits[1] = (NumericDigit) (term % NBASE); + res_digits[0] = (NumericDigit) (term / NBASE); + break; + + case 4: + /* --------- + * 4-digit case: + * var1ndigits = 4 + * var2ndigits >= 4 + * res_ndigits = var2ndigits + 4 + * ---------- + */ + /* last three result digits */ + term = (uint32) var1digits[3] * var2digits[res_ndigits - 5]; + res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + + term = (uint32) var1digits[2] * var2digits[res_ndigits - 5] + + (uint32) var1digits[3] * var2digits[res_ndigits - 6] + carry; + res_digits[res_ndigits - 2] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + + term = (uint32) var1digits[1] * var2digits[res_ndigits - 5] + + (uint32) var1digits[2] * var2digits[res_ndigits - 6] + + (uint32) var1digits[3] * var2digits[res_ndigits - 7] + carry; + res_digits[res_ndigits - 3] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + + /* remaining digits, except for the first four */ + for (int i = res_ndigits - 5; i >= 3; i--) + { + term = (uint32) var1digits[0] * var2digits[i] + + (uint32) var1digits[1] * var2digits[i - 1] + + (uint32) var1digits[2] * var2digits[i - 2] + + (uint32) var1digits[3] * var2digits[i - 3] + carry; + res_digits[i + 1] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + } + + /* first four digits */ + term = (uint32) var1digits[0] * var2digits[2] + + (uint32) var1digits[1] * var2digits[1] + + (uint32) var1digits[2] * var2digits[0] + carry; + res_digits[3] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + + term = (uint32) var1digits[0] * var2digits[1] + + (uint32) var1digits[1] * var2digits[0] + carry; + res_digits[2] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + + term = (uint32) var1digits[0] * var2digits[0] + carry; + res_digits[1] = (NumericDigit) (term % NBASE); + res_digits[0] = (NumericDigit) (term / NBASE); + break; + } + + /* Store the product in result */ + digitbuf_free(result->buf); + result->ndigits = res_ndigits; + result->buf = res_buf; + result->digits = res_digits; + result->weight = res_weight; + result->sign = res_sign; + result->dscale = var1->dscale + var2->dscale; + + /* Strip leading and trailing zeroes */ + strip_var(result); }