Folks, I'd like to add weighted statistics to PostgreSQL. While the included weighted_avg() is trivial to calculate using existing machinery, the included weighted_stddev_*() functions are not.
I've only done the float8 versions, but if we decide to move forward, I'd be delighted to add the rest of the numeric types and maybe others as make sense. What say? Cheers, David. -- David Fetter <da...@fetter.org> http://fetter.org/ Phone: +1 415 235 3778 AIM: dfetter666 Yahoo!: dfetter Skype: davidfetter XMPP: david.fet...@gmail.com Remember to vote! Consider donating to Postgres: http://www.postgresql.org/about/donate
diff --git a/doc/src/sgml/func.sgml b/doc/src/sgml/func.sgml index 4d482ec..2174594 100644 --- a/doc/src/sgml/func.sgml +++ b/doc/src/sgml/func.sgml @@ -12443,6 +12443,29 @@ NULL baz</literallayout>(3 rows)</entry> <row> <entry> <indexterm> + <primary>weighted_average</primary> + </indexterm> + <indexterm> + <primary>weighted_avg</primary> + </indexterm> + <function>weighted_avg(<replaceable class="parameter">value expression</replaceable>, <replaceable class="parameter">weight expression</replaceable>)</function> + </entry> + <entry> + <type>smallint</type>, <type>int</type>, + <type>bigint</type>, <type>real</type>, <type>double + precision</type>, <type>numeric</type>, or <type>interval</type> + </entry> + <entry> + <type>numeric</type> for any integer-type argument, + <type>double precision</type> for a floating-point argument, + otherwise the same as the argument data type + </entry> + <entry>the average (arithmetic mean) of all input values, weighted by the input weights</entry> + </row> + + <row> + <entry> + <indexterm> <primary>bit_and</primary> </indexterm> <function>bit_and(<replaceable class="parameter">expression</replaceable>)</function> @@ -13086,6 +13109,29 @@ SELECT xmlagg(x) FROM (SELECT x FROM test ORDER BY y DESC) AS tab; <row> <entry> <indexterm> + <primary>weighted standard deviation</primary> + <secondary>population</secondary> + </indexterm> + <indexterm> + <primary>weighted_stddev_pop</primary> + </indexterm> + <function>weighted_stddev_pop(<replaceable class="parameter">value expression</replaceable>, <replaceable class="parameter">weight expression</replaceable>)</function> + </entry> + <entry> + <type>smallint</type>, <type>int</type>, + <type>bigint</type>, <type>real</type>, <type>double + precision</type>, or <type>numeric</type> + </entry> + <entry> + <type>double precision</type> for floating-point arguments, + otherwise <type>numeric</type> + </entry> + <entry>weighted population standard deviation of the input values</entry> + </row> + + <row> + <entry> + <indexterm> <primary>standard deviation</primary> <secondary>sample</secondary> </indexterm> @@ -13109,6 +13155,29 @@ SELECT xmlagg(x) FROM (SELECT x FROM test ORDER BY y DESC) AS tab; <row> <entry> <indexterm> + <primary>weighted standard deviation</primary> + <secondary>sample</secondary> + </indexterm> + <indexterm> + <primary>weighted_stddev_samp</primary> + </indexterm> + <function>weighted_stddev_samp(<replaceable class="parameter">value expression</replaceable>, <replaceable class="parameter">weight expression</replaceable>)</function> + </entry> + <entry> + <type>smallint</type>, <type>int</type>, + <type>bigint</type>, <type>real</type>, <type>double + precision</type>, or <type>numeric</type> + </entry> + <entry> + <type>double precision</type> for floating-point arguments, + otherwise <type>numeric</type> + </entry> + <entry>weighted sample standard deviation of the input values</entry> + </row> + + <row> + <entry> + <indexterm> <primary>variance</primary> </indexterm> <function>variance</function>(<replaceable class="parameter">expression</replaceable>) diff --git a/src/backend/utils/adt/float.c b/src/backend/utils/adt/float.c index 4e927d8..533ce0a 100644 --- a/src/backend/utils/adt/float.c +++ b/src/backend/utils/adt/float.c @@ -1774,6 +1774,7 @@ setseed(PG_FUNCTION_ARGS) * float8_accum - accumulate for AVG(), variance aggregates, etc. * float4_accum - same, but input data is float4 * float8_avg - produce final result for float AVG() + * float8_weighted_avg - produce final result for float WEIGHTED_AVG() * float8_var_samp - produce final result for float VAR_SAMP() * float8_var_pop - produce final result for float VAR_POP() * float8_stddev_samp - produce final result for float STDDEV_SAMP() @@ -1929,6 +1930,28 @@ float8_avg(PG_FUNCTION_ARGS) } Datum +float8_weighted_avg(PG_FUNCTION_ARGS) +{ + ArrayType *transarray = PG_GETARG_ARRAYTYPE_P(0); + float8 *transvalues; + float8 N, + sumWX, + sumW; + + transvalues = check_float8_array(transarray, "float8_weighted_avg", 6); + N = transvalues[0]; + sumW = transvalues[1]; + sumWX = transvalues[5]; + + if (N < 1.0) + PG_RETURN_NULL(); + + CHECKFLOATVAL(N, isinf(1.0/sumW) || isinf(sumWX), true); + + PG_RETURN_FLOAT8(sumWX/sumW); +} + +Datum float8_var_pop(PG_FUNCTION_ARGS) { ArrayType *transarray = PG_GETARG_ARRAYTYPE_P(0); @@ -2467,6 +2490,119 @@ float8_regr_intercept(PG_FUNCTION_ARGS) PG_RETURN_FLOAT8(numeratorXXY / numeratorX); } +/* + * =================== + * WEIGHTED AGGREGATES + * =================== + * + * The transition datatype for these aggregates is a 4-element array + * of float8, holding the values N, sum(W), sum(W*X), and sum(W*X*X) + * in that order. + * + * First, an accumulator function for those we can't pirate from the + * other accumulators. This accumulator function takes out some of + * the rounding error inherent in the general one. + * https://en.wikipedia.org/wiki/Standard_deviation#Rapid_calculation_methods + * + * It consists of a four-element array which includes: + * + * N, the number of non-zero-weighted values seen thus far, + * W, the running sum of weights, + * A, an intermediate value used in the calculation, and + * Q, another intermediate value. + * + */ +Datum +float8_weighted_accum(PG_FUNCTION_ARGS) +{ + ArrayType *transarray = PG_GETARG_ARRAYTYPE_P(0); + float8 newvalX = PG_GETARG_FLOAT8(1); + float8 newvalW = PG_GETARG_FLOAT8(2); + float8 *transvalues; + float8 N, + W, + A, + Q; + + transvalues = check_float8_array(transarray, "float8_weighted_stddev_accum", 4); + + if (newvalW <= 0.0) /* We only care about positive weights */ + PG_RETURN_NULL(); + + N = transvalues[0]; + W = transvalues[1]; + A = transvalues[2]; + Q = transvalues[3]; + + N += 1.0; + CHECKFLOATVAL(N, isinf(transvalues[0]), true); + W += newvalW; + CHECKFLOATVAL(W, isinf(transvalues[1]) || isinf(newvalW), true); + A += newvalW * ( newvalX - transvalues[2] ) / W; + CHECKFLOATVAL(A, isinf(newvalW) || isinf(transvalues[2]) || isinf(1.0/W), true); + Q += newvalW * (newvalX - transvalues[2]) * (newvalX - A); + CHECKFLOATVAL(A, isinf(newvalX - transvalues[3]) || isinf(newvalX - A) || isinf(1.0/W), true); + + if (AggCheckCallContext(fcinfo, NULL)) /* Update in place is safe in Agg context */ + { + transvalues[0] = N; + transvalues[1] = W; + transvalues[2] = A; + transvalues[3] = Q; + + PG_RETURN_ARRAYTYPE_P(transarray); + } + else /* You do not need to call this directly. */ + ereport(ERROR, + (errmsg("float8_weighted_accum called outside agg context"))); +} + +Datum +float8_weighted_stddev_samp(PG_FUNCTION_ARGS) +{ + ArrayType *transarray = PG_GETARG_ARRAYTYPE_P(0); + float8 *transvalues; + float8 N, + W, + /* Skip A. Not used in the calculation */ + Q; + + transvalues = check_float8_array(transarray, "float8_weighted_stddev_samp", 4); + N = transvalues[0]; + W = transvalues[1]; + Q = transvalues[3]; + + if (N < 2.0) /* Must have at least two samples to get a stddev */ + PG_RETURN_NULL(); + + PG_RETURN_FLOAT8( + sqrt( + N * Q / + ( (N-1) * W ) + ) + ); +} + +Datum +float8_weighted_stddev_pop(PG_FUNCTION_ARGS) +{ + ArrayType *transarray = PG_GETARG_ARRAYTYPE_P(0); + float8 *transvalues; + float8 N, + W, + /* Skip A. Not used in the calculation */ + Q; + + transvalues = check_float8_array(transarray, "float8_weighted_stddev_pop", 4); + N = transvalues[0]; + W = transvalues[1]; + Q = transvalues[3]; + + if (N < 2.0) /* Must have at least two samples to get a stddev */ + PG_RETURN_NULL(); + + PG_RETURN_FLOAT8( sqrt( Q / W ) ); +} /* * ==================================== diff --git a/src/include/catalog/pg_aggregate.h b/src/include/catalog/pg_aggregate.h index dd6079f..6d2f9d4 100644 --- a/src/include/catalog/pg_aggregate.h +++ b/src/include/catalog/pg_aggregate.h @@ -133,6 +133,7 @@ DATA(insert ( 2103 n 0 numeric_avg_accum numeric_avg numeric_avg_accum numeric_a DATA(insert ( 2104 n 0 float4_accum float8_avg - - - f f 0 1022 0 0 0 "{0,0,0}" _null_ )); DATA(insert ( 2105 n 0 float8_accum float8_avg - - - f f 0 1022 0 0 0 "{0,0,0}" _null_ )); DATA(insert ( 2106 n 0 interval_accum interval_avg interval_accum interval_accum_inv interval_avg f f 0 1187 0 1187 0 "{0 second,0 second}" "{0 second,0 second}" )); +DATA(insert ( 3998 n 0 float8_regr_accum float8_weighted_avg - - - f f 0 1022 0 0 0 "{0,0,0,0,0,0}" _null_ )); /* sum */ DATA(insert ( 2107 n 0 int8_avg_accum numeric_poly_sum int8_avg_accum int8_avg_accum_inv numeric_poly_sum f f 0 2281 48 2281 48 _null_ _null_ )); @@ -225,6 +226,7 @@ DATA(insert ( 2726 n 0 int2_accum numeric_poly_stddev_pop int2_accum int2_accum_ DATA(insert ( 2727 n 0 float4_accum float8_stddev_pop - - - f f 0 1022 0 0 0 "{0,0,0}" _null_ )); DATA(insert ( 2728 n 0 float8_accum float8_stddev_pop - - - f f 0 1022 0 0 0 "{0,0,0}" _null_ )); DATA(insert ( 2729 n 0 numeric_accum numeric_stddev_pop numeric_accum numeric_accum_inv numeric_stddev_pop f f 0 2281 128 2281 128 _null_ _null_ )); +DATA(insert ( 4066 n 0 float8_weighted_accum float8_weighted_stddev_pop - - - f f 0 1022 0 0 0 "{0,0,0,0}" _null_ )); /* stddev_samp */ DATA(insert ( 2712 n 0 int8_accum numeric_stddev_samp int8_accum int8_accum_inv numeric_stddev_samp f f 0 2281 128 2281 128 _null_ _null_ )); @@ -232,6 +234,7 @@ DATA(insert ( 2713 n 0 int4_accum numeric_poly_stddev_samp int4_accum int4_accum DATA(insert ( 2714 n 0 int2_accum numeric_poly_stddev_samp int2_accum int2_accum_inv numeric_poly_stddev_samp f f 0 2281 48 2281 48 _null_ _null_ )); DATA(insert ( 2715 n 0 float4_accum float8_stddev_samp - - - f f 0 1022 0 0 0 "{0,0,0}" _null_ )); DATA(insert ( 2716 n 0 float8_accum float8_stddev_samp - - - f f 0 1022 0 0 0 "{0,0,0}" _null_ )); +DATA(insert ( 4083 n 0 float8_weighted_accum float8_weighted_stddev_samp - - - f f 0 1022 0 0 0 "{0,0,0,0}" _null_ )); DATA(insert ( 2717 n 0 numeric_accum numeric_stddev_samp numeric_accum numeric_accum_inv numeric_stddev_samp f f 0 2281 128 2281 128 _null_ _null_ )); /* stddev: historical Postgres syntax for stddev_samp */ diff --git a/src/include/catalog/pg_proc.h b/src/include/catalog/pg_proc.h index f688454..83c4b64 100644 --- a/src/include/catalog/pg_proc.h +++ b/src/include/catalog/pg_proc.h @@ -2502,6 +2502,12 @@ DESCR("join selectivity of case-insensitive regex non-match"); /* Aggregate-related functions */ DATA(insert OID = 1830 ( float8_avg PGNSP PGUID 12 1 0 0 0 f f f f t f i s 1 0 701 "1022" _null_ _null_ _null_ _null_ _null_ float8_avg _null_ _null_ _null_ )); DESCR("aggregate final function"); +DATA(insert OID = 3997 ( float8_weighted_avg PGNSP PGUID 12 1 0 0 0 f f f f t f i s 1 0 701 "1022" _null_ _null_ _null_ _null_ _null_ float8_weighted_avg _null_ _null_ _null_ )); +DESCR("aggregate final function"); +DATA(insert OID = 4099 ( float8_weighted_stddev_pop PGNSP PGUID 12 1 0 0 0 f f f f t f i s 1 0 701 "1022" _null_ _null_ _null_ _null_ _null_ float8_weighted_stddev_pop _null_ _null_ _null_ )); +DESCR("aggregate final function"); +DATA(insert OID = 4100 ( float8_weighted_stddev_samp PGNSP PGUID 12 1 0 0 0 f f f f t f i s 1 0 701 "1022" _null_ _null_ _null_ _null_ _null_ float8_weighted_stddev_samp _null_ _null_ _null_ )); +DESCR("aggregate final function"); DATA(insert OID = 2512 ( float8_var_pop PGNSP PGUID 12 1 0 0 0 f f f f t f i s 1 0 701 "1022" _null_ _null_ _null_ _null_ _null_ float8_var_pop _null_ _null_ _null_ )); DESCR("aggregate final function"); DATA(insert OID = 1831 ( float8_var_samp PGNSP PGUID 12 1 0 0 0 f f f f t f i s 1 0 701 "1022" _null_ _null_ _null_ _null_ _null_ float8_var_samp _null_ _null_ _null_ )); @@ -2585,6 +2591,8 @@ DATA(insert OID = 2805 ( int8inc_float8_float8 PGNSP PGUID 12 1 0 0 0 f f f f DESCR("aggregate transition function"); DATA(insert OID = 2806 ( float8_regr_accum PGNSP PGUID 12 1 0 0 0 f f f f t f i s 3 0 1022 "1022 701 701" _null_ _null_ _null_ _null_ _null_ float8_regr_accum _null_ _null_ _null_ )); DESCR("aggregate transition function"); +DATA(insert OID = 3999 ( float8_weighted_accum PGNSP PGUID 12 1 0 0 0 f f f f t f i s 3 0 1022 "1022 701 701" _null_ _null_ _null_ _null_ _null_ float8_weighted_accum _null_ _null_ _null_ )); +DESCR("aggregate transition function"); DATA(insert OID = 2807 ( float8_regr_sxx PGNSP PGUID 12 1 0 0 0 f f f f t f i s 1 0 701 "1022" _null_ _null_ _null_ _null_ _null_ float8_regr_sxx _null_ _null_ _null_ )); DESCR("aggregate final function"); DATA(insert OID = 2808 ( float8_regr_syy PGNSP PGUID 12 1 0 0 0 f f f f t f i s 1 0 701 "1022" _null_ _null_ _null_ _null_ _null_ float8_regr_syy _null_ _null_ _null_ )); @@ -3229,6 +3237,8 @@ DATA(insert OID = 2104 ( avg PGNSP PGUID 12 1 0 0 0 t f f f f f i s 1 0 701 DESCR("the average (arithmetic mean) as float8 of all float4 values"); DATA(insert OID = 2105 ( avg PGNSP PGUID 12 1 0 0 0 t f f f f f i s 1 0 701 "701" _null_ _null_ _null_ _null_ _null_ aggregate_dummy _null_ _null_ _null_ )); DESCR("the average (arithmetic mean) as float8 of all float8 values"); +DATA(insert OID = 3998 ( weighted_avg PGNSP PGUID 12 1 0 0 0 t f f f f f i s 2 0 701 "701 701" _null_ _null_ _null_ _null_ _null_ aggregate_dummy _null_ _null_ _null_ )); +DESCR("the weighted average (arithmetic mean) as float8 of all float8 values"); DATA(insert OID = 2106 ( avg PGNSP PGUID 12 1 0 0 0 t f f f f f i s 1 0 1186 "1186" _null_ _null_ _null_ _null_ _null_ aggregate_dummy _null_ _null_ _null_ )); DESCR("the average (arithmetic mean) as interval of all interval values"); @@ -3389,6 +3399,8 @@ DATA(insert OID = 2728 ( stddev_pop PGNSP PGUID 12 1 0 0 0 t f f f f f i s 1 0 DESCR("population standard deviation of float8 input values"); DATA(insert OID = 2729 ( stddev_pop PGNSP PGUID 12 1 0 0 0 t f f f f f i s 1 0 1700 "1700" _null_ _null_ _null_ _null_ _null_ aggregate_dummy _null_ _null_ _null_ )); DESCR("population standard deviation of numeric input values"); +DATA(insert OID = 4066 ( weighted_stddev_pop PGNSP PGUID 12 1 0 0 0 t f f f f f i s 2 0 701 "701 701" _null_ _null_ _null_ _null_ _null_ aggregate_dummy _null_ _null_ _null_ )); +DESCR("population weighted standard deviation of float8 input values"); DATA(insert OID = 2712 ( stddev_samp PGNSP PGUID 12 1 0 0 0 t f f f f f i s 1 0 1700 "20" _null_ _null_ _null_ _null_ _null_ aggregate_dummy _null_ _null_ _null_ )); DESCR("sample standard deviation of bigint input values"); @@ -3402,6 +3414,8 @@ DATA(insert OID = 2716 ( stddev_samp PGNSP PGUID 12 1 0 0 0 t f f f f f i s 1 DESCR("sample standard deviation of float8 input values"); DATA(insert OID = 2717 ( stddev_samp PGNSP PGUID 12 1 0 0 0 t f f f f f i s 1 0 1700 "1700" _null_ _null_ _null_ _null_ _null_ aggregate_dummy _null_ _null_ _null_ )); DESCR("sample standard deviation of numeric input values"); +DATA(insert OID = 4083 ( weighted_stddev_samp PGNSP PGUID 12 1 0 0 0 t f f f f f i s 2 0 701 "701 701" _null_ _null_ _null_ _null_ _null_ aggregate_dummy _null_ _null_ _null_ )); +DESCR("sample weighted standard deviation of float8 input values"); DATA(insert OID = 2154 ( stddev PGNSP PGUID 12 1 0 0 0 t f f f f f i s 1 0 1700 "20" _null_ _null_ _null_ _null_ _null_ aggregate_dummy _null_ _null_ _null_ )); DESCR("historical alias for stddev_samp"); diff --git a/src/include/utils/builtins.h b/src/include/utils/builtins.h index fc1679e..333d538 100644 --- a/src/include/utils/builtins.h +++ b/src/include/utils/builtins.h @@ -413,8 +413,12 @@ extern Datum radians(PG_FUNCTION_ARGS); extern Datum drandom(PG_FUNCTION_ARGS); extern Datum setseed(PG_FUNCTION_ARGS); extern Datum float8_accum(PG_FUNCTION_ARGS); +extern Datum float8_weighted_accum(PG_FUNCTION_ARGS); extern Datum float4_accum(PG_FUNCTION_ARGS); extern Datum float8_avg(PG_FUNCTION_ARGS); +extern Datum float8_weighted_avg(PG_FUNCTION_ARGS); +extern Datum float8_weighted_stddev_pop(PG_FUNCTION_ARGS); +extern Datum float8_weighted_stddev_samp(PG_FUNCTION_ARGS); extern Datum float8_var_pop(PG_FUNCTION_ARGS); extern Datum float8_var_samp(PG_FUNCTION_ARGS); extern Datum float8_stddev_pop(PG_FUNCTION_ARGS); diff --git a/src/test/regress/expected/aggregates.out b/src/test/regress/expected/aggregates.out index de826b5..a19fd1d 100644 --- a/src/test/regress/expected/aggregates.out +++ b/src/test/regress/expected/aggregates.out @@ -247,6 +247,18 @@ SELECT covar_pop(b, a), covar_samp(b, a) FROM aggtest; 653.62895538751 | 871.505273850014 (1 row) +SELECT weighted_avg(a, b) FROM aggtest; + weighted_avg +------------------ + 55.5553072763149 +(1 row) + +SELECT weighted_stddev_pop(a, b), weighted_stddev_samp(a, b) FROM aggtest; + weighted_stddev_pop | weighted_stddev_samp +---------------------+---------------------- + 24.3364627240769 | 28.1013266097382 +(1 row) + SELECT corr(b, a) FROM aggtest; corr ------------------- diff --git a/src/test/regress/sql/aggregates.sql b/src/test/regress/sql/aggregates.sql index 8d501dc..77b6102 100644 --- a/src/test/regress/sql/aggregates.sql +++ b/src/test/regress/sql/aggregates.sql @@ -60,6 +60,8 @@ SELECT regr_avgx(b, a), regr_avgy(b, a) FROM aggtest; SELECT regr_r2(b, a) FROM aggtest; SELECT regr_slope(b, a), regr_intercept(b, a) FROM aggtest; SELECT covar_pop(b, a), covar_samp(b, a) FROM aggtest; +SELECT weighted_avg(a, b) FROM aggtest; +SELECT weighted_stddev_pop(a, b), weighted_stddev_samp(a, b) FROM aggtest; SELECT corr(b, a) FROM aggtest; SELECT count(four) AS cnt_1000 FROM onek;
-- Sent via pgsql-hackers mailing list (pgsql-hackers@postgresql.org) To make changes to your subscription: http://www.postgresql.org/mailpref/pgsql-hackers