I found a couple of places where numeric multiplication suffers from overflow errors for inputs that aren't necessarily very large in magnitude.
The first is with the numeric * operator, which attempts to always produce the exact result, even though the numeric type has a maximum of 16383 digits after the decimal point. If the limit is exceeded an overflow error is produced, e.g.: SELECT (1+2e-10000) * (1+3e-10000); ERROR: value overflows numeric format I can't imagine anyone actually wanting that many digits after the decimal point, but it can happen with a sequence of multiplications, where the number of digits after the decimal point grows with each multiply. Throwing an error is not particularly useful, and that error message is quite misleading, since the result is not very large. Therefore I propose to make this round the result to 16383 digits if it exceeds that, as in the first attached patch. It's worth noting that to get correct rounding, it's necessary to compute the full exact product (which we're actually already doing) before rounding, as opposed to passing rscale=16383 to mul_var(), and letting it round. The latter approach would compute a truncated product with MUL_GUARD_DIGITS extra digits of precision, which doesn't necessarily round the final digit the right way. The test case in the patch is an example that would round the wrong way with a truncated product. I considered doing the final rounding in mul_var() (after computing the full product), to prevent such overflows for all callers, but I think that's the wrong approach because it risks covering up other problems, such as the following: While looking through the remaining numeric code, I found one other place that has a similar problem -- when calculating the sum of squares for aggregates like variance() and stddev(), the squares can end up with more than 16383 digits after the decimal point. When the query is running on a single backend, that's no problem, because the final result is rounded to 1000 digits. However, if it uses parallel workers, the result from each worker is sent using numeric_send/recv() which attempts to convert to numeric before sending. Thus it's possible to have an aggregate query that fails if it uses parallel workers and succeeds otherwise. In this case, I don't think that rounding the result from each worker is the right approach, since that can lead to the final result being different depending on whether or not the query uses parallel workers. Also, given that each worker is already doing the hard work of computing these squares, it seems a waste to just discard that information. So the second patch fixes this by adding new numericvar_send/recv() functions capable of sending the full precision NumericVar's from each worker, without rounding. The first test case in this patch is an example that would round the wrong way if the result from each worker were rounded before being sent. An additional benefit to this approach is that it also addresses the issue noted in the old code about its use of numeric_send/recv() being wasteful: /* * This is a little wasteful since make_result converts the NumericVar * into a Numeric and numeric_send converts it back again. Is it worth * splitting the tasks in numeric_send into separate functions to stop * this? Doing so would also remove the fmgr call overhead. */ So the patch converts all aggregate serialization/deserialization code to use the new numericvar_send/recv() functions. I doubt that that gives much in the way of a performance improvement, but it makes the code a little neater as well as preventing overflows. After writing that, I realised that there's another overflow risk -- if the input values are very large in magnitude, the sum of squares could genuinely overflow the numeric type, while the final variance could be quite small (and that can't be fixed by rounding). So this too fails when parallel workers are used, and succeeds otherwise, and the patch fixes this too, so I added a separate test case for it. Regards, Dean
diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c new file mode 100644 index eb78f0b..d74001c --- a/src/backend/utils/adt/numeric.c +++ b/src/backend/utils/adt/numeric.c @@ -233,6 +233,7 @@ struct NumericData */ #define NUMERIC_DSCALE_MASK 0x3FFF +#define NUMERIC_DSCALE_MAX NUMERIC_DSCALE_MASK #define NUMERIC_SIGN(n) \ (NUMERIC_IS_SHORT(n) ? \ @@ -2955,7 +2956,11 @@ numeric_mul_opt_error(Numeric num1, Nume * Unlike add_var() and sub_var(), mul_var() will round its result. In the * case of numeric_mul(), which is invoked for the * operator on numerics, * we request exact representation for the product (rscale = sum(dscale of - * arg1, dscale of arg2)). + * arg1, dscale of arg2)). If the exact result has more digits after the + * decimal point than can be stored in a numeric, we round it. Rounding + * after computing the exact result ensures that the final result is + * correctly rounded (rounding in mul_var() using a truncated product + * would not guarantee this). */ init_var_from_num(num1, &arg1); init_var_from_num(num2, &arg2); @@ -2963,6 +2968,9 @@ numeric_mul_opt_error(Numeric num1, Nume init_var(&result); mul_var(&arg1, &arg2, &result, arg1.dscale + arg2.dscale); + if (result.dscale > NUMERIC_DSCALE_MAX) + round_var(&result, NUMERIC_DSCALE_MAX); + res = make_result_opt_error(&result, have_error); free_var(&result); diff --git a/src/test/regress/expected/numeric.out b/src/test/regress/expected/numeric.out new file mode 100644 index 30a5642..e0bc6d9 --- a/src/test/regress/expected/numeric.out +++ b/src/test/regress/expected/numeric.out @@ -2145,6 +2145,12 @@ select 476999999999999999999999999999999 47699999999999999999999999999999999999999999999999999999999999999999999999999999999999985230000000000000000000000000000000000000000000000000000000000000000000000000000000000001 (1 row) +select trim_scale((0.1 - 2e-16383) * (0.1 - 3e-16383)); + trim_scale +------------ + 0.01 +(1 row) + -- -- Test some corner cases for division -- diff --git a/src/test/regress/sql/numeric.sql b/src/test/regress/sql/numeric.sql new file mode 100644 index db812c8..2e75076 --- a/src/test/regress/sql/numeric.sql +++ b/src/test/regress/sql/numeric.sql @@ -1044,6 +1044,8 @@ select 477099999999999999999999999999999 select 4769999999999999999999999999999999999999999999999999999999999999999999999999999999999999 * 9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999; +select trim_scale((0.1 - 2e-16383) * (0.1 - 3e-16383)); + -- -- Test some corner cases for division --
diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c new file mode 100644 index eb78f0b..1bd5697 --- a/src/backend/utils/adt/numeric.c +++ b/src/backend/utils/adt/numeric.c @@ -515,6 +515,9 @@ static void set_var_from_var(const Numer static char *get_str_from_var(const NumericVar *var); static char *get_str_from_var_sci(const NumericVar *var, int rscale); +static void numericvar_recv(StringInfo buf, NumericVar *var); +static bytea *numericvar_send(const NumericVar *var); + static Numeric duplicate_numeric(Numeric num); static Numeric make_result(const NumericVar *var); static Numeric make_result_opt_error(const NumericVar *var, bool *error); @@ -4943,7 +4946,6 @@ numeric_avg_serialize(PG_FUNCTION_ARGS) { NumericAggState *state; StringInfoData buf; - Datum temp; bytea *sumX; bytea *result; NumericVar tmp_var; @@ -4954,18 +4956,9 @@ numeric_avg_serialize(PG_FUNCTION_ARGS) state = (NumericAggState *) PG_GETARG_POINTER(0); - /* - * This is a little wasteful since make_result converts the NumericVar - * into a Numeric and numeric_send converts it back again. Is it worth - * splitting the tasks in numeric_send into separate functions to stop - * this? Doing so would also remove the fmgr call overhead. - */ init_var(&tmp_var); accum_sum_final(&state->sumX, &tmp_var); - - temp = DirectFunctionCall1(numeric_send, - NumericGetDatum(make_result(&tmp_var))); - sumX = DatumGetByteaPP(temp); + sumX = numericvar_send(&tmp_var); free_var(&tmp_var); pq_begintypsend(&buf); @@ -5006,10 +4999,11 @@ numeric_avg_deserialize(PG_FUNCTION_ARGS { bytea *sstate; NumericAggState *result; - Datum temp; NumericVar tmp_var; StringInfoData buf; + init_var(&tmp_var); + if (!AggCheckCallContext(fcinfo, NULL)) elog(ERROR, "aggregate function called in non-aggregate context"); @@ -5029,11 +5023,7 @@ numeric_avg_deserialize(PG_FUNCTION_ARGS result->N = pq_getmsgint64(&buf); /* sumX */ - temp = DirectFunctionCall3(numeric_recv, - PointerGetDatum(&buf), - ObjectIdGetDatum(InvalidOid), - Int32GetDatum(-1)); - init_var_from_num(DatumGetNumeric(temp), &tmp_var); + numericvar_recv(&buf, &tmp_var); accum_sum_add(&(result->sumX), &tmp_var); /* maxScale */ @@ -5054,6 +5044,8 @@ numeric_avg_deserialize(PG_FUNCTION_ARGS pq_getmsgend(&buf); pfree(buf.data); + free_var(&tmp_var); + PG_RETURN_POINTER(result); } @@ -5067,7 +5059,6 @@ numeric_serialize(PG_FUNCTION_ARGS) { NumericAggState *state; StringInfoData buf; - Datum temp; bytea *sumX; NumericVar tmp_var; bytea *sumX2; @@ -5079,23 +5070,13 @@ numeric_serialize(PG_FUNCTION_ARGS) state = (NumericAggState *) PG_GETARG_POINTER(0); - /* - * This is a little wasteful since make_result converts the NumericVar - * into a Numeric and numeric_send converts it back again. Is it worth - * splitting the tasks in numeric_send into separate functions to stop - * this? Doing so would also remove the fmgr call overhead. - */ init_var(&tmp_var); accum_sum_final(&state->sumX, &tmp_var); - temp = DirectFunctionCall1(numeric_send, - NumericGetDatum(make_result(&tmp_var))); - sumX = DatumGetByteaPP(temp); + sumX = numericvar_send(&tmp_var); accum_sum_final(&state->sumX2, &tmp_var); - temp = DirectFunctionCall1(numeric_send, - NumericGetDatum(make_result(&tmp_var))); - sumX2 = DatumGetByteaPP(temp); + sumX2 = numericvar_send(&tmp_var); free_var(&tmp_var); @@ -5140,11 +5121,11 @@ numeric_deserialize(PG_FUNCTION_ARGS) { bytea *sstate; NumericAggState *result; - Datum temp; - NumericVar sumX_var; - NumericVar sumX2_var; + NumericVar tmp_var; StringInfoData buf; + init_var(&tmp_var); + if (!AggCheckCallContext(fcinfo, NULL)) elog(ERROR, "aggregate function called in non-aggregate context"); @@ -5164,20 +5145,12 @@ numeric_deserialize(PG_FUNCTION_ARGS) result->N = pq_getmsgint64(&buf); /* sumX */ - temp = DirectFunctionCall3(numeric_recv, - PointerGetDatum(&buf), - ObjectIdGetDatum(InvalidOid), - Int32GetDatum(-1)); - init_var_from_num(DatumGetNumeric(temp), &sumX_var); - accum_sum_add(&(result->sumX), &sumX_var); + numericvar_recv(&buf, &tmp_var); + accum_sum_add(&(result->sumX), &tmp_var); /* sumX2 */ - temp = DirectFunctionCall3(numeric_recv, - PointerGetDatum(&buf), - ObjectIdGetDatum(InvalidOid), - Int32GetDatum(-1)); - init_var_from_num(DatumGetNumeric(temp), &sumX2_var); - accum_sum_add(&(result->sumX2), &sumX2_var); + numericvar_recv(&buf, &tmp_var); + accum_sum_add(&(result->sumX2), &tmp_var); /* maxScale */ result->maxScale = pq_getmsgint(&buf, 4); @@ -5197,6 +5170,8 @@ numeric_deserialize(PG_FUNCTION_ARGS) pq_getmsgend(&buf); pfree(buf.data); + free_var(&tmp_var); + PG_RETURN_POINTER(result); } @@ -5478,7 +5453,6 @@ numeric_poly_serialize(PG_FUNCTION_ARGS) * processing and we want a standard format to work with. */ { - Datum temp; NumericVar num; init_var(&num); @@ -5488,18 +5462,14 @@ numeric_poly_serialize(PG_FUNCTION_ARGS) #else accum_sum_final(&state->sumX, &num); #endif - temp = DirectFunctionCall1(numeric_send, - NumericGetDatum(make_result(&num))); - sumX = DatumGetByteaPP(temp); + sumX = numericvar_send(&num); #ifdef HAVE_INT128 int128_to_numericvar(state->sumX2, &num); #else accum_sum_final(&state->sumX2, &num); #endif - temp = DirectFunctionCall1(numeric_send, - NumericGetDatum(make_result(&num))); - sumX2 = DatumGetByteaPP(temp); + sumX2 = numericvar_send(&num); free_var(&num); } @@ -5530,12 +5500,11 @@ numeric_poly_deserialize(PG_FUNCTION_ARG { bytea *sstate; PolyNumAggState *result; - Datum sumX; - NumericVar sumX_var; - Datum sumX2; - NumericVar sumX2_var; + NumericVar tmp_var; StringInfoData buf; + init_var(&tmp_var); + if (!AggCheckCallContext(fcinfo, NULL)) elog(ERROR, "aggregate function called in non-aggregate context"); @@ -5555,34 +5524,28 @@ numeric_poly_deserialize(PG_FUNCTION_ARG result->N = pq_getmsgint64(&buf); /* sumX */ - sumX = DirectFunctionCall3(numeric_recv, - PointerGetDatum(&buf), - ObjectIdGetDatum(InvalidOid), - Int32GetDatum(-1)); - - /* sumX2 */ - sumX2 = DirectFunctionCall3(numeric_recv, - PointerGetDatum(&buf), - ObjectIdGetDatum(InvalidOid), - Int32GetDatum(-1)); + numericvar_recv(&buf, &tmp_var); - init_var_from_num(DatumGetNumeric(sumX), &sumX_var); #ifdef HAVE_INT128 - numericvar_to_int128(&sumX_var, &result->sumX); + numericvar_to_int128(&tmp_var, &result->sumX); #else - accum_sum_add(&result->sumX, &sumX_var); + accum_sum_add(&result->sumX, &tmp_var); #endif - init_var_from_num(DatumGetNumeric(sumX2), &sumX2_var); + /* sumX2 */ + numericvar_recv(&buf, &tmp_var); + #ifdef HAVE_INT128 - numericvar_to_int128(&sumX2_var, &result->sumX2); + numericvar_to_int128(&tmp_var, &result->sumX2); #else - accum_sum_add(&result->sumX2, &sumX2_var); + accum_sum_add(&result->sumX2, &tmp_var); #endif pq_getmsgend(&buf); pfree(buf.data); + free_var(&tmp_var); + PG_RETURN_POINTER(result); } @@ -5699,7 +5662,6 @@ int8_avg_serialize(PG_FUNCTION_ARGS) * want a standard format to work with. */ { - Datum temp; NumericVar num; init_var(&num); @@ -5709,9 +5671,7 @@ int8_avg_serialize(PG_FUNCTION_ARGS) #else accum_sum_final(&state->sumX, &num); #endif - temp = DirectFunctionCall1(numeric_send, - NumericGetDatum(make_result(&num))); - sumX = DatumGetByteaPP(temp); + sumX = numericvar_send(&num); free_var(&num); } @@ -5739,9 +5699,10 @@ int8_avg_deserialize(PG_FUNCTION_ARGS) bytea *sstate; PolyNumAggState *result; StringInfoData buf; - Datum temp; NumericVar num; + init_var(&num); + if (!AggCheckCallContext(fcinfo, NULL)) elog(ERROR, "aggregate function called in non-aggregate context"); @@ -5761,11 +5722,7 @@ int8_avg_deserialize(PG_FUNCTION_ARGS) result->N = pq_getmsgint64(&buf); /* sumX */ - temp = DirectFunctionCall3(numeric_recv, - PointerGetDatum(&buf), - ObjectIdGetDatum(InvalidOid), - Int32GetDatum(-1)); - init_var_from_num(DatumGetNumeric(temp), &num); + numericvar_recv(&buf, &num); #ifdef HAVE_INT128 numericvar_to_int128(&num, &result->sumX); #else @@ -5775,6 +5732,8 @@ int8_avg_deserialize(PG_FUNCTION_ARGS) pq_getmsgend(&buf); pfree(buf.data); + free_var(&num); + PG_RETURN_POINTER(result); } @@ -7286,6 +7245,72 @@ get_str_from_var_sci(const NumericVar *v } +/* + * numericvar_recv - converts external binary format to NumericVar + * + * At variable level, no checks are performed on the weight or dscale, allowing + * us to pass around intermediate values with higher precision than supported + * by the numeric type. Note: this is incompatible with numeric_send/recv(), + * which use 16-bit integers for these fields. + */ +static void +numericvar_recv(StringInfo buf, NumericVar *var) +{ + int len, + i; + + len = pq_getmsgint(buf, sizeof(int32)); + + alloc_var(var, len); + + var->weight = pq_getmsgint(buf, sizeof(int32)); + + var->sign = pq_getmsgint(buf, sizeof(int32)); + if (!(var->sign == NUMERIC_POS || + var->sign == NUMERIC_NEG || + var->sign == NUMERIC_NAN || + var->sign == NUMERIC_PINF || + var->sign == NUMERIC_NINF)) + ereport(ERROR, + (errcode(ERRCODE_INVALID_BINARY_REPRESENTATION), + errmsg("invalid sign in \"NumericVar\" value"))); + + var->dscale = pq_getmsgint(buf, sizeof(int32)); + + for (i = 0; i < len; i++) + { + NumericDigit d = pq_getmsgint(buf, sizeof(NumericDigit)); + + if (d < 0 || d >= NBASE) + ereport(ERROR, + (errcode(ERRCODE_INVALID_BINARY_REPRESENTATION), + errmsg("invalid digit in external \"NumericVar\" value"))); + var->digits[i] = d; + } +} + +/* + * numericvar_send - converts NumericVar to binary format + */ +static bytea * +numericvar_send(const NumericVar *var) +{ + StringInfoData buf; + int i; + + pq_begintypsend(&buf); + + pq_sendint32(&buf, var->ndigits); + pq_sendint32(&buf, var->weight); + pq_sendint32(&buf, var->sign); + pq_sendint32(&buf, var->dscale); + for (i = 0; i < var->ndigits; i++) + pq_sendint16(&buf, var->digits[i]); + + return pq_endtypsend(&buf); +} + + /* * duplicate_numeric() - copy a packed-format Numeric * diff --git a/src/test/regress/expected/numeric.out b/src/test/regress/expected/numeric.out new file mode 100644 index 30a5642..4ad4851 --- a/src/test/regress/expected/numeric.out +++ b/src/test/regress/expected/numeric.out @@ -2967,6 +2967,56 @@ SELECT SUM((-9999)::numeric) FROM genera (1 row) -- +-- Tests for VARIANCE() +-- +CREATE TABLE num_variance (a numeric); +INSERT INTO num_variance VALUES (0); +INSERT INTO num_variance VALUES (3e-500); +INSERT INTO num_variance VALUES (-3e-500); +INSERT INTO num_variance VALUES (4e-500 - 1e-16383); +INSERT INTO num_variance VALUES (-4e-500 + 1e-16383); +-- variance is just under 12.5e-1000 and so should round down to 12e-1000 +SELECT trim_scale(variance(a) * 1e1000) FROM num_variance; + trim_scale +------------ + 12 +(1 row) + +-- check that parallel execution produces the same result +BEGIN; +ALTER TABLE num_variance SET (parallel_workers = 4); +SET LOCAL parallel_setup_cost = 0; +SET LOCAL max_parallel_workers_per_gather = 4; +SELECT trim_scale(variance(a) * 1e1000) FROM num_variance; + trim_scale +------------ + 12 +(1 row) + +ROLLBACK; +-- case where sum of squares would overflow but variance does not +DELETE FROM num_variance; +INSERT INTO num_variance SELECT 9e131071 + x FROM generate_series(1, 5) x; +SELECT variance(a) FROM num_variance; + variance +-------------------- + 2.5000000000000000 +(1 row) + +-- check that parallel execution produces the same result +BEGIN; +ALTER TABLE num_variance SET (parallel_workers = 4); +SET LOCAL parallel_setup_cost = 0; +SET LOCAL max_parallel_workers_per_gather = 4; +SELECT variance(a) FROM num_variance; + variance +-------------------- + 2.5000000000000000 +(1 row) + +ROLLBACK; +DROP TABLE num_variance; +-- -- Tests for GCD() -- SELECT a, b, gcd(a, b), gcd(a, -b), gcd(-b, a), gcd(-b, -a) diff --git a/src/test/regress/sql/numeric.sql b/src/test/regress/sql/numeric.sql new file mode 100644 index db812c8..3784c52 --- a/src/test/regress/sql/numeric.sql +++ b/src/test/regress/sql/numeric.sql @@ -1278,6 +1278,42 @@ SELECT SUM(9999::numeric) FROM generate_ SELECT SUM((-9999)::numeric) FROM generate_series(1, 100000); -- +-- Tests for VARIANCE() +-- +CREATE TABLE num_variance (a numeric); +INSERT INTO num_variance VALUES (0); +INSERT INTO num_variance VALUES (3e-500); +INSERT INTO num_variance VALUES (-3e-500); +INSERT INTO num_variance VALUES (4e-500 - 1e-16383); +INSERT INTO num_variance VALUES (-4e-500 + 1e-16383); + +-- variance is just under 12.5e-1000 and so should round down to 12e-1000 +SELECT trim_scale(variance(a) * 1e1000) FROM num_variance; + +-- check that parallel execution produces the same result +BEGIN; +ALTER TABLE num_variance SET (parallel_workers = 4); +SET LOCAL parallel_setup_cost = 0; +SET LOCAL max_parallel_workers_per_gather = 4; +SELECT trim_scale(variance(a) * 1e1000) FROM num_variance; +ROLLBACK; + +-- case where sum of squares would overflow but variance does not +DELETE FROM num_variance; +INSERT INTO num_variance SELECT 9e131071 + x FROM generate_series(1, 5) x; +SELECT variance(a) FROM num_variance; + +-- check that parallel execution produces the same result +BEGIN; +ALTER TABLE num_variance SET (parallel_workers = 4); +SET LOCAL parallel_setup_cost = 0; +SET LOCAL max_parallel_workers_per_gather = 4; +SELECT variance(a) FROM num_variance; +ROLLBACK; + +DROP TABLE num_variance; + +-- -- Tests for GCD() -- SELECT a, b, gcd(a, b), gcd(a, -b), gcd(-b, a), gcd(-b, -a)