Folks,

I'd like to add weighted statistics to PostgreSQL.  While the included
weighted_avg() is trivial to calculate using existing machinery, the
included weighted_stddev_*() functions are not.

I've only done the float8 versions, but if we decide to move forward,
I'd be delighted to add the rest of the numeric types and maybe others
as make sense.

What say?

Cheers,
David.
-- 
David Fetter <da...@fetter.org> http://fetter.org/
Phone: +1 415 235 3778  AIM: dfetter666  Yahoo!: dfetter
Skype: davidfetter      XMPP: david.fet...@gmail.com

Remember to vote!
Consider donating to Postgres: http://www.postgresql.org/about/donate
diff --git a/doc/src/sgml/func.sgml b/doc/src/sgml/func.sgml
index 4d482ec..2174594 100644
--- a/doc/src/sgml/func.sgml
+++ b/doc/src/sgml/func.sgml
@@ -12443,6 +12443,29 @@ NULL baz</literallayout>(3 rows)</entry>
      <row>
       <entry>
        <indexterm>
+        <primary>weighted_average</primary>
+       </indexterm>
+       <indexterm>
+        <primary>weighted_avg</primary>
+       </indexterm>
+       <function>weighted_avg(<replaceable class="parameter">value 
expression</replaceable>, <replaceable class="parameter">weight 
expression</replaceable>)</function>
+      </entry>
+      <entry>
+       <type>smallint</type>, <type>int</type>,
+       <type>bigint</type>, <type>real</type>, <type>double
+       precision</type>, <type>numeric</type>, or <type>interval</type>
+      </entry>
+      <entry>
+       <type>numeric</type> for any integer-type argument,
+       <type>double precision</type> for a floating-point argument,
+       otherwise the same as the argument data type
+      </entry>
+      <entry>the average (arithmetic mean) of all input values, weighted by 
the input weights</entry>
+     </row>
+
+     <row>
+      <entry>
+       <indexterm>
         <primary>bit_and</primary>
        </indexterm>
        <function>bit_and(<replaceable 
class="parameter">expression</replaceable>)</function>
@@ -13086,6 +13109,29 @@ SELECT xmlagg(x) FROM (SELECT x FROM test ORDER BY y 
DESC) AS tab;
      <row>
       <entry>
        <indexterm>
+        <primary>weighted standard deviation</primary>
+        <secondary>population</secondary>
+       </indexterm>
+       <indexterm>
+        <primary>weighted_stddev_pop</primary>
+       </indexterm>
+       <function>weighted_stddev_pop(<replaceable class="parameter">value 
expression</replaceable>, <replaceable class="parameter">weight 
expression</replaceable>)</function>
+      </entry>
+      <entry>
+       <type>smallint</type>, <type>int</type>,
+       <type>bigint</type>, <type>real</type>, <type>double
+       precision</type>, or <type>numeric</type>
+      </entry>
+      <entry>
+       <type>double precision</type> for floating-point arguments,
+       otherwise <type>numeric</type>
+      </entry>
+      <entry>weighted population standard deviation of the input values</entry>
+     </row>
+
+     <row>
+      <entry>
+       <indexterm>
         <primary>standard deviation</primary>
         <secondary>sample</secondary>
        </indexterm>
@@ -13109,6 +13155,29 @@ SELECT xmlagg(x) FROM (SELECT x FROM test ORDER BY y 
DESC) AS tab;
      <row>
       <entry>
        <indexterm>
+        <primary>weighted standard deviation</primary>
+        <secondary>sample</secondary>
+       </indexterm>
+       <indexterm>
+        <primary>weighted_stddev_samp</primary>
+       </indexterm>
+       <function>weighted_stddev_samp(<replaceable class="parameter">value 
expression</replaceable>, <replaceable class="parameter">weight 
expression</replaceable>)</function>
+      </entry>
+      <entry>
+       <type>smallint</type>, <type>int</type>,
+       <type>bigint</type>, <type>real</type>, <type>double
+       precision</type>, or <type>numeric</type>
+      </entry>
+      <entry>
+       <type>double precision</type> for floating-point arguments,
+       otherwise <type>numeric</type>
+      </entry>
+      <entry>weighted sample standard deviation of the input values</entry>
+     </row>
+
+     <row>
+      <entry>
+       <indexterm>
         <primary>variance</primary>
        </indexterm>
        <function>variance</function>(<replaceable 
class="parameter">expression</replaceable>)
diff --git a/src/backend/utils/adt/float.c b/src/backend/utils/adt/float.c
index 4e927d8..533ce0a 100644
--- a/src/backend/utils/adt/float.c
+++ b/src/backend/utils/adt/float.c
@@ -1774,6 +1774,7 @@ setseed(PG_FUNCTION_ARGS)
  *             float8_accum            - accumulate for AVG(), variance 
aggregates, etc.
  *             float4_accum            - same, but input data is float4
  *             float8_avg                      - produce final result for 
float AVG()
+ *             float8_weighted_avg     - produce final result for float 
WEIGHTED_AVG()
  *             float8_var_samp         - produce final result for float 
VAR_SAMP()
  *             float8_var_pop          - produce final result for float 
VAR_POP()
  *             float8_stddev_samp      - produce final result for float 
STDDEV_SAMP()
@@ -1929,6 +1930,28 @@ float8_avg(PG_FUNCTION_ARGS)
 }
 
 Datum
+float8_weighted_avg(PG_FUNCTION_ARGS)
+{
+       ArrayType  *transarray = PG_GETARG_ARRAYTYPE_P(0);
+       float8     *transvalues;
+       float8          N,
+                               sumWX,
+                               sumW;
+
+       transvalues = check_float8_array(transarray, "float8_weighted_avg", 6);
+       N = transvalues[0];
+       sumW = transvalues[1];
+       sumWX = transvalues[5];
+
+       if (N < 1.0)
+               PG_RETURN_NULL();
+
+       CHECKFLOATVAL(N, isinf(1.0/sumW) || isinf(sumWX), true);
+
+       PG_RETURN_FLOAT8(sumWX/sumW);
+}
+
+Datum
 float8_var_pop(PG_FUNCTION_ARGS)
 {
        ArrayType  *transarray = PG_GETARG_ARRAYTYPE_P(0);
@@ -2467,6 +2490,119 @@ float8_regr_intercept(PG_FUNCTION_ARGS)
        PG_RETURN_FLOAT8(numeratorXXY / numeratorX);
 }
 
+/*
+ *             ===================
+ *             WEIGHTED AGGREGATES
+ *             ===================
+ *
+ * The transition datatype for these aggregates is a 4-element array
+ * of float8, holding the values N, sum(W), sum(W*X), and sum(W*X*X)
+ * in that order.
+ *
+ * First, an accumulator function for those we can't pirate from the
+ * other accumulators.  This accumulator function takes out some of
+ * the rounding error inherent in the general one.
+ * https://en.wikipedia.org/wiki/Standard_deviation#Rapid_calculation_methods
+ *
+ * It consists of a four-element array which includes:
+ *
+ * N, the number of non-zero-weighted values seen thus far,
+ * W, the running sum of weights,
+ * A, an intermediate value used in the calculation, and
+ * Q, another intermediate value.
+ *
+ */
+Datum
+float8_weighted_accum(PG_FUNCTION_ARGS)
+{
+       ArrayType  *transarray = PG_GETARG_ARRAYTYPE_P(0);
+       float8          newvalX = PG_GETARG_FLOAT8(1);
+       float8          newvalW = PG_GETARG_FLOAT8(2);
+       float8     *transvalues;
+       float8          N,
+                               W,
+                               A,
+                               Q;
+
+       transvalues = check_float8_array(transarray, 
"float8_weighted_stddev_accum", 4);
+
+       if (newvalW <= 0.0) /* We only care about positive weights */
+               PG_RETURN_NULL();
+
+       N = transvalues[0];
+       W = transvalues[1];
+       A = transvalues[2];
+       Q = transvalues[3];
+
+       N += 1.0;
+       CHECKFLOATVAL(N, isinf(transvalues[0]), true);
+       W += newvalW;
+       CHECKFLOATVAL(W, isinf(transvalues[1]) || isinf(newvalW), true);
+       A += newvalW * ( newvalX - transvalues[2] ) / W;
+       CHECKFLOATVAL(A, isinf(newvalW) || isinf(transvalues[2]) || 
isinf(1.0/W), true);
+       Q += newvalW * (newvalX - transvalues[2]) * (newvalX - A);
+       CHECKFLOATVAL(A, isinf(newvalX -  transvalues[3]) || isinf(newvalX - A) 
|| isinf(1.0/W), true);
+
+       if (AggCheckCallContext(fcinfo, NULL)) /* Update in place is safe in 
Agg context */
+       {
+               transvalues[0] = N;
+               transvalues[1] = W;
+               transvalues[2] = A;
+               transvalues[3] = Q;
+
+               PG_RETURN_ARRAYTYPE_P(transarray);
+       }
+       else /* You do not need to call this directly. */
+               ereport(ERROR,
+                               (errmsg("float8_weighted_accum called outside 
agg context")));
+}
+
+Datum
+float8_weighted_stddev_samp(PG_FUNCTION_ARGS)
+{
+       ArrayType  *transarray = PG_GETARG_ARRAYTYPE_P(0);
+       float8     *transvalues;
+       float8          N,
+                               W,
+                               /* Skip A.  Not used in the calculation */
+                               Q;
+
+       transvalues = check_float8_array(transarray, 
"float8_weighted_stddev_samp", 4);
+       N = transvalues[0];
+       W = transvalues[1];
+       Q = transvalues[3];
+
+       if (N < 2.0) /* Must have at least two samples to get a stddev */
+               PG_RETURN_NULL();
+
+       PG_RETURN_FLOAT8(
+               sqrt(
+                       N * Q /
+                       ( (N-1) * W )
+               )
+       );
+}
+
+Datum
+float8_weighted_stddev_pop(PG_FUNCTION_ARGS)
+{
+       ArrayType  *transarray = PG_GETARG_ARRAYTYPE_P(0);
+       float8     *transvalues;
+       float8          N,
+                               W,
+                               /* Skip A.  Not used in the calculation */
+                               Q;
+
+       transvalues = check_float8_array(transarray, 
"float8_weighted_stddev_pop", 4);
+       N = transvalues[0];
+       W = transvalues[1];
+       Q = transvalues[3];
+
+       if (N < 2.0) /* Must have at least two samples to get a stddev */
+               PG_RETURN_NULL();
+
+       PG_RETURN_FLOAT8( sqrt( Q / W ) );
+}
 
 /*
  *             ====================================
diff --git a/src/include/catalog/pg_aggregate.h 
b/src/include/catalog/pg_aggregate.h
index dd6079f..6d2f9d4 100644
--- a/src/include/catalog/pg_aggregate.h
+++ b/src/include/catalog/pg_aggregate.h
@@ -133,6 +133,7 @@ DATA(insert ( 2103  n 0 numeric_avg_accum numeric_avg       
numeric_avg_accum numeric_a
 DATA(insert ( 2104     n 0 float4_accum        float8_avg              -       
                        -                               -                       
                                        f f 0   1022    0       0               
0       "{0,0,0}" _null_ ));
 DATA(insert ( 2105     n 0 float8_accum        float8_avg              -       
                        -                               -                       
                                        f f 0   1022    0       0               
0       "{0,0,0}" _null_ ));
 DATA(insert ( 2106     n 0 interval_accum      interval_avg    interval_accum  
interval_accum_inv interval_avg                                 f f 0   1187    
0       1187    0       "{0 second,0 second}" "{0 second,0 second}" ));
+DATA(insert ( 3998     n 0 float8_regr_accum   float8_weighted_avg             
-                               -                               -               
                f f 0   1022    0       0               0       "{0,0,0,0,0,0}" 
_null_ ));
 
 /* sum */
 DATA(insert ( 2107     n 0 int8_avg_accum      numeric_poly_sum                
int8_avg_accum  int8_avg_accum_inv numeric_poly_sum f f 0       2281    48      
2281    48      _null_ _null_ ));
@@ -225,6 +226,7 @@ DATA(insert ( 2726  n 0 int2_accum  numeric_poly_stddev_pop 
int2_accum      int2_accum_
 DATA(insert ( 2727     n 0 float4_accum        float8_stddev_pop       -       
                        -                               -                       
                                f f 0   1022    0       0               0       
"{0,0,0}" _null_ ));
 DATA(insert ( 2728     n 0 float8_accum        float8_stddev_pop       -       
                        -                               -                       
                                f f 0   1022    0       0               0       
"{0,0,0}" _null_ ));
 DATA(insert ( 2729     n 0 numeric_accum       numeric_stddev_pop 
numeric_accum numeric_accum_inv numeric_stddev_pop                   f f 0   
2281    128 2281        128 _null_ _null_ ));
+DATA(insert ( 4066     n 0 float8_weighted_accum       
float8_weighted_stddev_pop      -                               -               
                -                                                       f f 0   
1022    0       0               0       "{0,0,0,0}" _null_ ));
 
 /* stddev_samp */
 DATA(insert ( 2712     n 0 int8_accum  numeric_stddev_samp             
int8_accum      int8_accum_inv  numeric_stddev_samp                             
f f 0   2281    128 2281        128 _null_ _null_ ));
@@ -232,6 +234,7 @@ DATA(insert ( 2713  n 0 int4_accum  
numeric_poly_stddev_samp        int4_accum      int4_accum
 DATA(insert ( 2714     n 0 int2_accum  numeric_poly_stddev_samp        
int2_accum      int2_accum_inv  numeric_poly_stddev_samp        f f 0   2281    
48      2281    48      _null_ _null_ ));
 DATA(insert ( 2715     n 0 float4_accum        float8_stddev_samp      -       
                        -                               -                       
                                f f 0   1022    0       0               0       
"{0,0,0}" _null_ ));
 DATA(insert ( 2716     n 0 float8_accum        float8_stddev_samp      -       
                        -                               -                       
                                f f 0   1022    0       0               0       
"{0,0,0}" _null_ ));
+DATA(insert ( 4083     n 0 float8_weighted_accum       
float8_weighted_stddev_samp     -                               -               
                -                                                       f f 0   
1022    0       0               0       "{0,0,0,0}" _null_ ));
 DATA(insert ( 2717     n 0 numeric_accum       numeric_stddev_samp 
numeric_accum numeric_accum_inv numeric_stddev_samp                 f f 0   
2281    128 2281        128 _null_ _null_ ));
 
 /* stddev: historical Postgres syntax for stddev_samp */
diff --git a/src/include/catalog/pg_proc.h b/src/include/catalog/pg_proc.h
index f688454..83c4b64 100644
--- a/src/include/catalog/pg_proc.h
+++ b/src/include/catalog/pg_proc.h
@@ -2502,6 +2502,12 @@ DESCR("join selectivity of case-insensitive regex 
non-match");
 /* Aggregate-related functions */
 DATA(insert OID = 1830 (  float8_avg      PGNSP PGUID 12 1 0 0 0 f f f f t f i 
s 1 0 701 "1022" _null_ _null_ _null_ _null_ _null_ float8_avg _null_ _null_ 
_null_ ));
 DESCR("aggregate final function");
+DATA(insert OID = 3997 (  float8_weighted_avg                  PGNSP PGUID 12 
1 0 0 0 f f f f t f i s 1 0 701 "1022" _null_ _null_ _null_ _null_ _null_ 
float8_weighted_avg _null_ _null_ _null_ ));
+DESCR("aggregate final function");
+DATA(insert OID = 4099 (  float8_weighted_stddev_pop                   PGNSP 
PGUID 12 1 0 0 0 f f f f t f i s 1 0 701 "1022" _null_ _null_ _null_ _null_ 
_null_ float8_weighted_stddev_pop _null_ _null_ _null_ ));
+DESCR("aggregate final function");
+DATA(insert OID = 4100 (  float8_weighted_stddev_samp                  PGNSP 
PGUID 12 1 0 0 0 f f f f t f i s 1 0 701 "1022" _null_ _null_ _null_ _null_ 
_null_ float8_weighted_stddev_samp _null_ _null_ _null_ ));
+DESCR("aggregate final function");
 DATA(insert OID = 2512 (  float8_var_pop   PGNSP PGUID 12 1 0 0 0 f f f f t f 
i s 1 0 701 "1022" _null_ _null_ _null_ _null_ _null_ float8_var_pop _null_ 
_null_ _null_ ));
 DESCR("aggregate final function");
 DATA(insert OID = 1831 (  float8_var_samp  PGNSP PGUID 12 1 0 0 0 f f f f t f 
i s 1 0 701 "1022" _null_ _null_ _null_ _null_ _null_ float8_var_samp _null_ 
_null_ _null_ ));
@@ -2585,6 +2591,8 @@ DATA(insert OID = 2805 (  int8inc_float8_float8           
PGNSP PGUID 12 1 0 0 0 f f f f
 DESCR("aggregate transition function");
 DATA(insert OID = 2806 (  float8_regr_accum                    PGNSP PGUID 12 
1 0 0 0 f f f f t f i s 3 0 1022 "1022 701 701" _null_ _null_ _null_ _null_ 
_null_ float8_regr_accum _null_ _null_ _null_ ));
 DESCR("aggregate transition function");
+DATA(insert OID = 3999 (  float8_weighted_accum                        PGNSP 
PGUID 12 1 0 0 0 f f f f t f i s 3 0 1022 "1022 701 701" _null_ _null_ _null_ 
_null_ _null_ float8_weighted_accum _null_ _null_ _null_ ));
+DESCR("aggregate transition function");
 DATA(insert OID = 2807 (  float8_regr_sxx                      PGNSP PGUID 12 
1 0 0 0 f f f f t f i s 1 0 701 "1022" _null_ _null_ _null_ _null_ _null_ 
float8_regr_sxx _null_ _null_ _null_ ));
 DESCR("aggregate final function");
 DATA(insert OID = 2808 (  float8_regr_syy                      PGNSP PGUID 12 
1 0 0 0 f f f f t f i s 1 0 701 "1022" _null_ _null_ _null_ _null_ _null_ 
float8_regr_syy _null_ _null_ _null_ ));
@@ -3229,6 +3237,8 @@ DATA(insert OID = 2104 (  avg                             
PGNSP PGUID 12 1 0 0 0 t f f f f f i s 1 0 701
 DESCR("the average (arithmetic mean) as float8 of all float4 values");
 DATA(insert OID = 2105 (  avg                          PGNSP PGUID 12 1 0 0 0 
t f f f f f i s 1 0 701 "701" _null_ _null_ _null_ _null_ _null_ 
aggregate_dummy _null_ _null_ _null_ ));
 DESCR("the average (arithmetic mean) as float8 of all float8 values");
+DATA(insert OID = 3998 (  weighted_avg         PGNSP PGUID 12 1 0 0 0 t f f f 
f f i s 2 0 701 "701 701" _null_ _null_ _null_ _null_ _null_     
aggregate_dummy _null_ _null_ _null_ ));
+DESCR("the weighted average (arithmetic mean) as float8 of all float8 values");
 DATA(insert OID = 2106 (  avg                          PGNSP PGUID 12 1 0 0 0 
t f f f f f i s 1 0 1186 "1186" _null_ _null_ _null_ _null_ _null_ 
aggregate_dummy _null_ _null_ _null_ ));
 DESCR("the average (arithmetic mean) as interval of all interval values");
 
@@ -3389,6 +3399,8 @@ DATA(insert OID = 2728 (  stddev_pop              PGNSP 
PGUID 12 1 0 0 0 t f f f f f i s 1 0
 DESCR("population standard deviation of float8 input values");
 DATA(insert OID = 2729 (  stddev_pop           PGNSP PGUID 12 1 0 0 0 t f f f 
f f i s 1 0 1700 "1700" _null_ _null_ _null_ _null_ _null_ aggregate_dummy 
_null_ _null_ _null_ ));
 DESCR("population standard deviation of numeric input values");
+DATA(insert OID = 4066 (  weighted_stddev_pop          PGNSP PGUID 12 1 0 0 0 
t f f f f f i s 2 0 701 "701 701" _null_ _null_ _null_ _null_ _null_     
aggregate_dummy _null_ _null_ _null_ ));
+DESCR("population weighted standard deviation of float8 input values");
 
 DATA(insert OID = 2712 (  stddev_samp          PGNSP PGUID 12 1 0 0 0 t f f f 
f f i s 1 0 1700 "20" _null_ _null_ _null_ _null_ _null_ aggregate_dummy _null_ 
_null_ _null_ ));
 DESCR("sample standard deviation of bigint input values");
@@ -3402,6 +3414,8 @@ DATA(insert OID = 2716 (  stddev_samp             PGNSP 
PGUID 12 1 0 0 0 t f f f f f i s 1
 DESCR("sample standard deviation of float8 input values");
 DATA(insert OID = 2717 (  stddev_samp          PGNSP PGUID 12 1 0 0 0 t f f f 
f f i s 1 0 1700 "1700" _null_ _null_ _null_ _null_ _null_ aggregate_dummy 
_null_ _null_ _null_ ));
 DESCR("sample standard deviation of numeric input values");
+DATA(insert OID = 4083 (  weighted_stddev_samp         PGNSP PGUID 12 1 0 0 0 
t f f f f f i s 2 0 701 "701 701" _null_ _null_ _null_ _null_ _null_     
aggregate_dummy _null_ _null_ _null_ ));
+DESCR("sample weighted standard deviation of float8 input values");
 
 DATA(insert OID = 2154 (  stddev                       PGNSP PGUID 12 1 0 0 0 
t f f f f f i s 1 0 1700 "20" _null_ _null_ _null_ _null_ _null_ 
aggregate_dummy _null_ _null_ _null_ ));
 DESCR("historical alias for stddev_samp");
diff --git a/src/include/utils/builtins.h b/src/include/utils/builtins.h
index fc1679e..333d538 100644
--- a/src/include/utils/builtins.h
+++ b/src/include/utils/builtins.h
@@ -413,8 +413,12 @@ extern Datum radians(PG_FUNCTION_ARGS);
 extern Datum drandom(PG_FUNCTION_ARGS);
 extern Datum setseed(PG_FUNCTION_ARGS);
 extern Datum float8_accum(PG_FUNCTION_ARGS);
+extern Datum float8_weighted_accum(PG_FUNCTION_ARGS);
 extern Datum float4_accum(PG_FUNCTION_ARGS);
 extern Datum float8_avg(PG_FUNCTION_ARGS);
+extern Datum float8_weighted_avg(PG_FUNCTION_ARGS);
+extern Datum float8_weighted_stddev_pop(PG_FUNCTION_ARGS);
+extern Datum float8_weighted_stddev_samp(PG_FUNCTION_ARGS);
 extern Datum float8_var_pop(PG_FUNCTION_ARGS);
 extern Datum float8_var_samp(PG_FUNCTION_ARGS);
 extern Datum float8_stddev_pop(PG_FUNCTION_ARGS);
diff --git a/src/test/regress/expected/aggregates.out 
b/src/test/regress/expected/aggregates.out
index de826b5..a19fd1d 100644
--- a/src/test/regress/expected/aggregates.out
+++ b/src/test/regress/expected/aggregates.out
@@ -247,6 +247,18 @@ SELECT covar_pop(b, a), covar_samp(b, a) FROM aggtest;
  653.62895538751 | 871.505273850014
 (1 row)
 
+SELECT weighted_avg(a, b) FROM aggtest;
+   weighted_avg   
+------------------
+ 55.5553072763149
+(1 row)
+
+SELECT weighted_stddev_pop(a, b), weighted_stddev_samp(a, b) FROM aggtest;
+ weighted_stddev_pop | weighted_stddev_samp 
+---------------------+----------------------
+    24.3364627240769 |     28.1013266097382
+(1 row)
+
 SELECT corr(b, a) FROM aggtest;
        corr        
 -------------------
diff --git a/src/test/regress/sql/aggregates.sql 
b/src/test/regress/sql/aggregates.sql
index 8d501dc..77b6102 100644
--- a/src/test/regress/sql/aggregates.sql
+++ b/src/test/regress/sql/aggregates.sql
@@ -60,6 +60,8 @@ SELECT regr_avgx(b, a), regr_avgy(b, a) FROM aggtest;
 SELECT regr_r2(b, a) FROM aggtest;
 SELECT regr_slope(b, a), regr_intercept(b, a) FROM aggtest;
 SELECT covar_pop(b, a), covar_samp(b, a) FROM aggtest;
+SELECT weighted_avg(a, b) FROM aggtest;
+SELECT weighted_stddev_pop(a, b), weighted_stddev_samp(a, b) FROM aggtest;
 SELECT corr(b, a) FROM aggtest;
 
 SELECT count(four) AS cnt_1000 FROM onek;
-- 
Sent via pgsql-hackers mailing list (pgsql-hackers@postgresql.org)
To make changes to your subscription:
http://www.postgresql.org/mailpref/pgsql-hackers

Reply via email to