On Tue, 2 Jul 2024 at 21:10, Joel Jacobson <j...@compiler.org> wrote: > > I found the bug in the case 3 code, > and it turns out the same type of bug also exists in the case 2 code: > > case 2: > newdig = (int) var1digits[1] * > var2digits[res_ndigits - 4]; > > The problem here is that res_ndigits could become less than 4,
Yes. It can't be less than 3 though (per an earlier test), so the case 2 code was correct. I've been hacking on this a bit and trying to tidy it up. Firstly, I moved it to a separate function, because it was starting to look messy having so much extra code in mul_var(). Then I added a bunch more comments to explain what's going on, and the limits of the various variables. Note that most of the boundary checks are actually unnecessary -- in particular all the ones in or after the main loop, provided you pull out the first 2 result digits from the main loop in the 3-digit case. That does seem to work very well, but... I wasn't entirely happy with how messy that code is getting, so I tried a different approach. Similar to div_var_int(), I tried writing a mul_var_int() function instead. This can be used for 1 and 2 digit factors, and we could add a similar mul_var_int64() function on platforms with 128-bit integers. The code looks quite a lot neater, so it's probably less likely to contain bugs (though I have just written it in a hurry,so it might still have bugs). In testing, it seemed to give a decent speedup, but perhaps a little less than before. But that's to be balanced against having more maintainable code, and also a function that might be useful elsewhere in numeric.c. Anyway, here are both patches for comparison. I'll stop hacking for a while and let you see what you make of these. Regards, Dean
diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c new file mode 100644 index 5510a20..81600b3 --- a/src/backend/utils/adt/numeric.c +++ b/src/backend/utils/adt/numeric.c @@ -551,6 +551,8 @@ static void sub_var(const NumericVar *va static void mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result, int rscale); +static void mul_var_small(const NumericVar *var1, const NumericVar *var2, + NumericVar *result, int rscale); static void div_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result, int rscale, bool round); @@ -8707,7 +8709,7 @@ mul_var(const NumericVar *var1, const Nu var1digits = var1->digits; var2digits = var2->digits; - if (var1ndigits == 0 || var2ndigits == 0) + if (var1ndigits == 0) { /* one or both inputs is zero; so is result */ zero_var(result); @@ -8715,6 +8717,16 @@ mul_var(const NumericVar *var1, const Nu return; } + /* + * If var1 has 3 digits or fewer, delegate to mul_var_small() which uses a + * faster short multiplication algorithm. + */ + if (var1ndigits <= 3) + { + mul_var_small(var1, var2, result, rscale); + return; + } + /* Determine result sign and (maximum possible) weight */ if (var1->sign == var2->sign) res_sign = NUMERIC_POS; @@ -8858,6 +8870,188 @@ mul_var(const NumericVar *var1, const Nu result->sign = res_sign; /* Round to target rscale (and set result->dscale) */ + round_var(result, rscale); + + /* Strip leading and trailing zeroes */ + strip_var(result); +} + + +/* + * mul_var_small() - + * + * This has the same API as mul_var, but it assumes that var1 has no more + * than 3 digits and var2 has at least as many digits as var1. For variables + * satisfying these conditions, the product can be computed more quickly than + * the general algorithm used in mul_var. + */ +static void +mul_var_small(const NumericVar *var1, const NumericVar *var2, + NumericVar *result, int rscale) +{ + int var1ndigits = var1->ndigits; + int var2ndigits = var2->ndigits; + NumericDigit *var1digits = var1->digits; + NumericDigit *var2digits = var2->digits; + int res_sign; + int res_weight; + int res_ndigits; + int maxdigits; + NumericDigit *res_buf; + NumericDigit *res_digits; + int carry; + int term; + + /* Check preconditions */ + Assert(var1ndigits <= 3); + Assert(var2ndigits >= var1ndigits); + + /* Determine result sign and (maximum possible) weight */ + if (var1->sign == var2->sign) + res_sign = NUMERIC_POS; + else + res_sign = NUMERIC_NEG; + res_weight = var1->weight + var2->weight + 2; + + /* Determine the number of result digits to compute - see mul_var() */ + res_ndigits = var1ndigits + var2ndigits + 1; + maxdigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS + + MUL_GUARD_DIGITS; + res_ndigits = Min(res_ndigits, maxdigits); + + if (res_ndigits < 3) + { + /* All input digits will be ignored; so result is zero */ + zero_var(result); + result->dscale = rscale; + return; + } + + /* Allocate result digit array */ + res_buf = digitbuf_alloc(res_ndigits); + res_buf[0] = 0; /* spare digit for later rounding */ + res_digits = res_buf + 1; + + /* + * Compute the result digits in reverse, in one pass, propagating the + * carry up as we go. + * + * This computes res_digits[res_ndigits - 2], ... res_digits[0] by summing + * the products var1digits[i1] * var2digits[i2] for which i1 + i2 + 1 is + * the result index. + */ + switch (var1ndigits) + { + case 1: + /* --------- + * 1-digit case: + * var1ndigits = 1 + * var2ndigits >= 1 + * 3 <= res_ndigits <= var2ndigits + 2 + * ---------- + */ + carry = 0; + for (int i = res_ndigits - 3; i >= 0; i--) + { + term = (int) var1digits[0] * var2digits[i] + carry; + res_digits[i + 1] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + } + res_digits[0] = (NumericDigit) carry; + break; + + case 2: + /* --------- + * 2-digit case: + * var1ndigits = 2 + * var2ndigits >= 2 + * 3 <= res_ndigits <= var2ndigits + 3 + * ---------- + */ + /* last result digit and carry */ + term = 0; + if (res_ndigits - 3 < var2ndigits) + term += (int) var1digits[0] * var2digits[res_ndigits - 3]; + if (res_ndigits > 3) + term += (int) var1digits[1] * var2digits[res_ndigits - 4]; + res_digits[res_ndigits - 2] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + + /* remaining digits, except for the first two */ + for (int i = res_ndigits - 4; i >= 1; i--) + { + term = (int) var1digits[0] * var2digits[i] + + (int) var1digits[1] * var2digits[i - 1] + carry; + res_digits[i + 1] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + } + + /* first two digits */ + term = (int) var1digits[0] * var2digits[0] + carry; + res_digits[1] = (NumericDigit) (term % NBASE); + res_digits[0] = (NumericDigit) (term / NBASE); + break; + + case 3: + /* --------- + * 3-digit case: + * var1ndigits = 3 + * var2ndigits >= 3 + * 3 <= res_ndigits <= var2ndigits + 4 + * ---------- + */ + /* last result digit and carry */ + term = 0; + if (res_ndigits - 3 < var2ndigits) + term += (int) var1digits[0] * var2digits[res_ndigits - 3]; + if (res_ndigits > 3 && res_ndigits - 4 < var2ndigits) + term += (int) var1digits[1] * var2digits[res_ndigits - 4]; + if (res_ndigits > 4) + term += (int) var1digits[2] * var2digits[res_ndigits - 5]; + res_digits[res_ndigits - 2] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + + /* penultimate result digit */ + term = carry; + if (res_ndigits > 3 && res_ndigits - 4 < var2ndigits) + term += (int) var1digits[0] * var2digits[res_ndigits - 4]; + if (res_ndigits > 4) + term += (int) var1digits[1] * var2digits[res_ndigits - 5]; + if (res_ndigits > 5) + term += (int) var1digits[2] * var2digits[res_ndigits - 6]; + res_digits[res_ndigits - 3] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + + /* remaining digits, except for the first three */ + for (int i = res_ndigits - 5; i >= 2; i--) + { + term = (int) var1digits[0] * var2digits[i] + + (int) var1digits[1] * var2digits[i - 1] + + (int) var1digits[2] * var2digits[i - 2] + carry; + res_digits[i + 1] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + } + + /* first three digits */ + term = (int) var1digits[0] * var2digits[1] + + (int) var1digits[1] * var2digits[0] + carry; + res_digits[2] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + term = (int) var1digits[0] * var2digits[0] + carry; + res_digits[1] = (NumericDigit) (term % NBASE); + res_digits[0] = (NumericDigit) (term / NBASE); + break; + } + + /* Store the product in result (minus extra rounding digit) */ + digitbuf_free(result->buf); + result->ndigits = res_ndigits - 1; + result->buf = res_buf; + result->digits = res_digits; + result->weight = res_weight - 1; + result->sign = res_sign; + + /* Round to target rscale (and set result->dscale) */ round_var(result, rscale); /* Strip leading and trailing zeroes */
diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c new file mode 100644 index 5510a20..9e50ea7 --- a/src/backend/utils/adt/numeric.c +++ b/src/backend/utils/adt/numeric.c @@ -551,6 +551,8 @@ static void sub_var(const NumericVar *va static void mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result, int rscale); +static void mul_var_int(const NumericVar *var, int ival, int ival_weight, + NumericVar *result, int rscale); static void div_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result, int rscale, bool round); @@ -8707,7 +8709,7 @@ mul_var(const NumericVar *var1, const Nu var1digits = var1->digits; var2digits = var2->digits; - if (var1ndigits == 0 || var2ndigits == 0) + if (var1ndigits == 0) { /* one or both inputs is zero; so is result */ zero_var(result); @@ -8715,6 +8717,31 @@ mul_var(const NumericVar *var1, const Nu return; } + /* + * If var1 has just one or two digits, delegate to mul_var_int(), which + * uses a faster direct multiplication algorithm. + * + * TODO: Similarly, on platforms with 128-bit integers ... + */ + if (var1ndigits <= 2) + { + int ifactor; + int ifactor_weight; + + ifactor = var1->digits[0]; + ifactor_weight = var1->weight; + if (var1ndigits == 2) + { + ifactor = ifactor * NBASE + var1->digits[1]; + ifactor_weight--; + } + if (var1->sign == NUMERIC_NEG) + ifactor = -ifactor; + + mul_var_int(var2, ifactor, ifactor_weight, result, rscale); + return; + } + /* Determine result sign and (maximum possible) weight */ if (var1->sign == var2->sign) res_sign = NUMERIC_POS; @@ -8857,6 +8884,123 @@ mul_var(const NumericVar *var1, const Nu result->weight = res_weight; result->sign = res_sign; + /* Round to target rscale (and set result->dscale) */ + round_var(result, rscale); + + /* Strip leading and trailing zeroes */ + strip_var(result); +} + + +/* + * mul_var_int() - + * + * Multiply a numeric variable by a 32-bit integer with the specified weight. + * The product var * ival * NBASE^ival_weight is stored in result. + */ +static void +mul_var_int(const NumericVar *var, int ival, int ival_weight, + NumericVar *result, int rscale) +{ + NumericDigit *var_digits = var->digits; + int var_ndigits = var->ndigits; + int res_sign; + int res_weight; + int res_ndigits; + int maxdigits; + NumericDigit *res_buf; + NumericDigit *res_digits; + uint32 factor; + uint32 carry; + + if (ival == 0 || var_ndigits == 0) + { + zero_var(result); + result->dscale = rscale; + return; + } + + /* + * Determine the result sign, (maximum possible) weight and number of + * digits to calculate. The weight figured here is correct if the emitted + * product has no leading zero digits; otherwise strip_var() will fix + * things up. + */ + if (var->sign == NUMERIC_POS) + res_sign = ival > 0 ? NUMERIC_POS : NUMERIC_NEG; + else + res_sign = ival > 0 ? NUMERIC_NEG : NUMERIC_POS; + res_weight = var->weight + ival_weight + 3; + /* The number of accurate result digits we need to produce: */ + res_ndigits = var_ndigits + 3; + maxdigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS + + MUL_GUARD_DIGITS; + res_ndigits = Min(res_ndigits, maxdigits); + + if (res_ndigits < 3) + { + /* All input digits will be ignored; so result is zero */ + zero_var(result); + result->dscale = rscale; + return; + } + + res_buf = digitbuf_alloc(res_ndigits + 1); + res_buf[0] = 0; /* spare digit for later rounding */ + res_digits = res_buf + 1; + + /* + * Now compute the product digits by procssing the input digits in reverse + * and propagating the carry up as we go. + * + * In this algorithm, the carry from one digit to the next is at most + * factor - 1, and product is at most factor * NBASE - 1, and so it needs + * to be a 64-bit integer if this exceeds UINT_MAX. + */ + factor = abs(ival); + carry = 0; + + if (factor <= UINT_MAX / NBASE) + { + /* product cannot overflow 32 bits */ + uint32 product; + + for (int i = res_ndigits - 4; i >= 0; i--) + { + product = factor * var_digits[i] + carry; + res_digits[i + 3] = (NumericDigit) (product % NBASE); + carry = product / NBASE; + } + res_digits[2] = (NumericDigit) (carry % NBASE); + carry = carry / NBASE; + res_digits[1] = (NumericDigit) (carry % NBASE); + res_digits[0] = (NumericDigit) (carry / NBASE); + } + else + { + /* product may exceed 32 bits */ + uint64 product; + + for (int i = res_ndigits - 4; i >= 0; i--) + { + product = (uint64) factor * var_digits[i] + carry; + res_digits[i + 3] = (NumericDigit) (product % NBASE); + carry = (uint32) (product / NBASE); + } + res_digits[2] = (NumericDigit) (carry % NBASE); + carry = carry / NBASE; + res_digits[1] = (NumericDigit) (carry % NBASE); + res_digits[0] = (NumericDigit) (carry / NBASE); + } + + /* Store the product in result */ + digitbuf_free(result->buf); + result->ndigits = res_ndigits; + result->buf = res_buf; + result->digits = res_digits; + result->weight = res_weight; + result->sign = res_sign; + /* Round to target rscale (and set result->dscale) */ round_var(result, rscale);