On 07/09/2015 12:44 PM, David Rowley wrote:
On 15 June 2015 at 12:05, David Rowley <david.row...@2ndquadrant.com> wrote:
This basically allows an aggregate's state to be shared between other
aggregate functions when both aggregate's transition functions (and a few
other things) match
There's quite a number of aggregates in our standard set which will
benefit from this optimisation.
After compiling the original patch with another compiler, I noticed a
couple of warnings.
The attached fixes these.
I spent some time reviewing this. I refactored the ExecInitAgg code
rather heavily to make it more readable (IMHO); see attached. What do
you think? Did I break anything?
Some comments:
* In aggref_has_compatible_states(), you give up if aggtype or aggcollid
differ. But those properties apply to the final function, so you were
leaving some money on the table by disallowing state-sharing if they differ.
* The filter and input expressions are initialized for every AggRef,
before the deduplication logic kicks in. The AggrefExprState.aggfilter,
aggdirectargs and args fields really belong to the AggStatePerAggState
struct instead. This is not a new issue, but now that we have a
convenient per-aggstate struct to put them in, let's do so.
* There was a reference-after free bug in aggref_has_compatible_states;
you cannot ReleaseSysCache and then continue pointing to the struct.
* The code was a bit fuzzy on which parts of the per-aggstate are filled
in at what time. Some of the fields were overwritten every time a match
was found. With the same values, so no harm done, but I found it
confusing. I refactored ExecInitAgg in the attached patch to clear that up.
* There API of build_aggregate_fnexprs() was a bit strange now that some
callers use it to only fill in the final function, some call it to fill
both the transition functions and the final function. I split it to two
separate functions.
* I wonder if we should do this duplicate elimination at plan time. It's
very fast, so I'm not worried about that right now, but you had grand
plans to expand this so that an aggregate could optionally use one of
many different kinds of state values. At that point, it certainly seems
like a planning decision to decide which aggregates share state. I think
we can leave it as it is for now, but it's something to perhaps consider
later.
BTW, the name of the AggStatePerAggStateData struct is pretty horrible.
The repeated "AggState" feels awkward. Now that I've stared at the patch
for a some time, it doesn't bother me anymore, but it took me quite a
while to over that. I'm sure it will for others too. And it's not just
that struct, the comments talk about "aggregate state", which could be
confused to mean "AggState", but it actually means
AggStatePerAggStateData. I don't have any great suggestions, but can you
come up a better naming scheme?
- Heikki
diff --git a/src/backend/executor/execQual.c b/src/backend/executor/execQual.c
index 0f911f2..fd922bd 100644
--- a/src/backend/executor/execQual.c
+++ b/src/backend/executor/execQual.c
@@ -4485,35 +4485,15 @@ ExecInitExpr(Expr *node, PlanState *parent)
break;
case T_Aggref:
{
- Aggref *aggref = (Aggref *) node;
AggrefExprState *astate = makeNode(AggrefExprState);
astate->xprstate.evalfunc = (ExprStateEvalFunc) ExecEvalAggref;
if (parent && IsA(parent, AggState))
{
AggState *aggstate = (AggState *) parent;
- int naggs;
aggstate->aggs = lcons(astate, aggstate->aggs);
- naggs = ++aggstate->numaggs;
-
- astate->aggdirectargs = (List *) ExecInitExpr((Expr *) aggref->aggdirectargs,
- parent);
- astate->args = (List *) ExecInitExpr((Expr *) aggref->args,
- parent);
- astate->aggfilter = ExecInitExpr(aggref->aggfilter,
- parent);
-
- /*
- * Complain if the aggregate's arguments contain any
- * aggregates; nested agg functions are semantically
- * nonsensical. (This should have been caught earlier,
- * but we defend against it here anyway.)
- */
- if (naggs != aggstate->numaggs)
- ereport(ERROR,
- (errcode(ERRCODE_GROUPING_ERROR),
- errmsg("aggregate function calls cannot be nested")));
+ aggstate->numaggs++;
}
else
{
diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c
index 2bf48c5..fcc3859 100644
--- a/src/backend/executor/nodeAgg.c
+++ b/src/backend/executor/nodeAgg.c
@@ -152,17 +152,28 @@
/*
- * AggStatePerAggData - per-aggregate working state for the Agg scan
+ * AggStatePerAggStateData - per aggregate state data for the Agg scan
+ *
+ * Working state for calculating the aggregate state, using the state
+ * transition function. This struct does not store the information needed
+ * to produce the final aggregate result from the state value; that's stored
+ * in AggStatePerAggData instead. This separation allows multiple aggregate
+ * results to be produced from a single state value.
*/
-typedef struct AggStatePerAggData
+typedef struct AggStatePerAggStateData
{
/*
* These values are set up during ExecInitAgg() and do not change
* thereafter:
*/
- /* Links to Aggref expr and state nodes this working state is for */
- AggrefExprState *aggrefstate;
+ /*
+ * Link to an Aggref expr this working state is for.
+ *
+ * There can actually be multiple AggRef's sharing the same working state,
+ * as long as the inputs and transition state are identical. This points
+ * to the first of them.
+ */
Aggref *aggref;
/*
@@ -186,25 +197,22 @@ typedef struct AggStatePerAggData
*/
int numTransInputs;
- /*
- * Number of arguments to pass to the finalfn. This is always at least 1
- * (the transition state value) plus any ordered-set direct args. If the
- * finalfn wants extra args then we pass nulls corresponding to the
- * aggregated input columns.
- */
- int numFinalArgs;
-
- /* Oids of transfer functions */
+ /* Oid of the state transition function */
Oid transfn_oid;
- Oid finalfn_oid; /* may be InvalidOid */
+
+ /* Oid of state value's datatype */
+ Oid aggtranstype;
+
+ /* ExprStates of the FILTER and argument expressions. */
+ ExprState *aggfilter; /* state of FILTER expression, if any */
+ List *args; /* states of aggregated-argument expressions */
+ List *aggdirectargs; /* states of direct-argument expressions */
/*
- * fmgr lookup data for transfer functions --- only valid when
- * corresponding oid is not InvalidOid. Note in particular that fn_strict
- * flags are kept here.
+ * fmgr lookup data for transfer function. Note in particular that the
+ * fn_strict flag is kept here.
*/
FmgrInfo transfn;
- FmgrInfo finalfn;
/* Input collation derived for aggregate */
Oid aggCollation;
@@ -236,17 +244,15 @@ typedef struct AggStatePerAggData
bool initValueIsNull;
/*
- * We need the len and byval info for the agg's input, result, and
- * transition data types in order to know how to copy/delete values.
+ * We need the len and byval info for the agg's input and transition data
+ * types in order to know how to copy/delete values.
*
* Note that the info for the input type is used only when handling
* DISTINCT aggs with just one argument, so there is only one input type.
*/
int16 inputtypeLen,
- resulttypeLen,
transtypeLen;
bool inputtypeByVal,
- resulttypeByVal,
transtypeByVal;
/*
@@ -288,6 +294,48 @@ typedef struct AggStatePerAggData
* worth the extra space consumption.
*/
FunctionCallInfoData transfn_fcinfo;
+} AggStatePerAggStateData;
+
+/*
+ * AggStatePerAggData - per-aggregate working state
+ *
+ * This contains the information needed to produce a final aggregate result
+ * from the state value.
+ */
+typedef struct AggStatePerAggData
+{
+ /*
+ * These values are set up during ExecInitAgg() and do not change
+ * thereafter:
+ */
+
+ /* index to the corresponding per-aggstate which this agg should use */
+ int stateno;
+
+ /* Optional Oid of final function (may be InvalidOid) */
+ Oid finalfn_oid;
+
+ /*
+ * fmgr lookup data for final function --- only valid when finalfn_oid oid
+ * is not InvalidOid.
+ */
+ FmgrInfo finalfn;
+
+ /*
+ * Number of arguments to pass to the finalfn. This is always at least 1
+ * (the transition state value) plus any ordered-set direct args. If the
+ * finalfn wants extra args then we pass nulls corresponding to the
+ * aggregated input columns.
+ */
+ int numFinalArgs;
+
+ /*
+ * We need the len and byval info for the agg's result data type in order
+ * to know how to copy/delete values.
+ */
+ int16 resulttypeLen;
+ bool resulttypeByVal;
+
} AggStatePerAggData;
/*
@@ -358,25 +406,36 @@ typedef struct AggHashEntryData
AggStatePerGroupData pergroup[FLEXIBLE_ARRAY_MEMBER];
} AggHashEntryData;
+/*
+ * enum states to mark compatibility between aggregate functions.
+ * These are used to enable various optimizations which are applied to similar
+ * aggregate functions. See comments for find_compatible_aggref() for details.
+ */
+typedef enum AggRefCompatibility
+{
+ AGGREF_NO_MATCH, /* state is not compatible between aggregates */
+ AGGREF_STATE_MATCH, /* aggregates may share state only */
+ AGGREF_EXACT_MATCH /* aggregates may share state and finalfn */
+} AggRefCompatibility;
static void initialize_phase(AggState *aggstate, int newphase);
static TupleTableSlot *fetch_input_tuple(AggState *aggstate);
static void initialize_aggregates(AggState *aggstate,
- AggStatePerAgg peragg,
+ AggStatePerAggState peraggstates,
AggStatePerGroup pergroup,
int numReset);
static void advance_transition_function(AggState *aggstate,
- AggStatePerAgg peraggstate,
+ AggStatePerAggState peraggstate,
AggStatePerGroup pergroupstate);
static void advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup);
static void process_ordered_aggregate_single(AggState *aggstate,
- AggStatePerAgg peraggstate,
+ AggStatePerAggState peraggstate,
AggStatePerGroup pergroupstate);
static void process_ordered_aggregate_multi(AggState *aggstate,
- AggStatePerAgg peraggstate,
+ AggStatePerAggState peraggstate,
AggStatePerGroup pergroupstate);
static void finalize_aggregate(AggState *aggstate,
- AggStatePerAgg peraggstate,
+ AggStatePerAgg peragg,
AggStatePerGroup pergroupstate,
Datum *resultVal, bool *resultIsNull);
static void prepare_projection_slot(AggState *aggstate,
@@ -396,6 +455,14 @@ static TupleTableSlot *agg_retrieve_direct(AggState *aggstate);
static void agg_fill_hash_table(AggState *aggstate);
static TupleTableSlot *agg_retrieve_hash_table(AggState *aggstate);
static Datum GetAggInitVal(Datum textInitVal, Oid transtype);
+static void build_peraggstate_for_aggref(AggStatePerAggState peraggstate,
+ AggState *aggsate, EState *estate,
+ Aggref *aggref, HeapTuple aggtuple,
+ Oid *inputTypes, int numArguments);
+static AggRefCompatibility find_compatible_aggref(Aggref *newagg,
+ AggState *aggstate, int lastaggno, int *foundaggno);
+static AggRefCompatibility aggref_has_compatible_states(Aggref *newagg,
+ AggStatePerAgg peragg, AggStatePerAggState peraggstate);
/*
@@ -498,7 +565,7 @@ fetch_input_tuple(AggState *aggstate)
* When called, CurrentMemoryContext should be the per-query context.
*/
static void
-initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate,
+initialize_aggregate(AggState *aggstate, AggStatePerAggState peraggstate,
AggStatePerGroup pergroupstate)
{
/*
@@ -569,7 +636,7 @@ initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate,
}
/*
- * Initialize all aggregates for a new group of input values.
+ * Initialize all aggregate states for a new group of input values.
*
* If there are multiple grouping sets, we initialize only the first numReset
* of them (the grouping sets are ordered so that the most specific one, which
@@ -580,26 +647,26 @@ initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate,
*/
static void
initialize_aggregates(AggState *aggstate,
- AggStatePerAgg peragg,
+ AggStatePerAggState peraggstates,
AggStatePerGroup pergroup,
int numReset)
{
- int aggno;
+ int stateno;
int numGroupingSets = Max(aggstate->phase->numsets, 1);
int setno = 0;
if (numReset < 1)
numReset = numGroupingSets;
- for (aggno = 0; aggno < aggstate->numaggs; aggno++)
+ for (stateno = 0; stateno < aggstate->numstates; stateno++)
{
- AggStatePerAgg peraggstate = &peragg[aggno];
+ AggStatePerAggState peraggstate = &peraggstates[stateno];
for (setno = 0; setno < numReset; setno++)
{
AggStatePerGroup pergroupstate;
- pergroupstate = &pergroup[aggno + (setno * (aggstate->numaggs))];
+ pergroupstate = &pergroup[stateno + (setno * (aggstate->numstates))];
aggstate->current_set = setno;
@@ -610,7 +677,7 @@ initialize_aggregates(AggState *aggstate,
/*
* Given new input value(s), advance the transition function of one aggregate
- * within one grouping set only (already set in aggstate->current_set)
+ * state within one grouping set only (already set in aggstate->current_set)
*
* The new values (and null flags) have been preloaded into argument positions
* 1 and up in peraggstate->transfn_fcinfo, so that we needn't copy them again
@@ -621,7 +688,7 @@ initialize_aggregates(AggState *aggstate,
*/
static void
advance_transition_function(AggState *aggstate,
- AggStatePerAgg peraggstate,
+ AggStatePerAggState peraggstate,
AggStatePerGroup pergroupstate)
{
FunctionCallInfo fcinfo = &peraggstate->transfn_fcinfo;
@@ -678,8 +745,8 @@ advance_transition_function(AggState *aggstate,
/* We run the transition functions in per-input-tuple memory context */
oldContext = MemoryContextSwitchTo(aggstate->tmpcontext->ecxt_per_tuple_memory);
- /* set up aggstate->curperagg for AggGetAggref() */
- aggstate->curperagg = peraggstate;
+ /* set up aggstate->curperaggstate for AggGetAggref() */
+ aggstate->curperaggstate = peraggstate;
/*
* OK to call the transition function
@@ -690,7 +757,7 @@ advance_transition_function(AggState *aggstate,
newVal = FunctionCallInvoke(fcinfo);
- aggstate->curperagg = NULL;
+ aggstate->curperaggstate = NULL;
/*
* If pass-by-ref datatype, must copy the new value into aggcontext and
@@ -718,7 +785,7 @@ advance_transition_function(AggState *aggstate,
}
/*
- * Advance all the aggregates for one input tuple. The input tuple
+ * Advance each aggregate state for one input tuple. The input tuple
* has been stored in tmpcontext->ecxt_outertuple, so that it is accessible
* to ExecEvalExpr. pergroup is the array of per-group structs to use
* (this might be in a hashtable entry).
@@ -728,15 +795,15 @@ advance_transition_function(AggState *aggstate,
static void
advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
{
- int aggno;
+ int stateno;
int setno = 0;
int numGroupingSets = Max(aggstate->phase->numsets, 1);
- int numAggs = aggstate->numaggs;
+ int numStates = aggstate->numstates;
- for (aggno = 0; aggno < numAggs; aggno++)
+ for (stateno = 0; stateno < numStates; stateno++)
{
- AggStatePerAgg peraggstate = &aggstate->peragg[aggno];
- ExprState *filter = peraggstate->aggrefstate->aggfilter;
+ AggStatePerAggState peraggstate = &aggstate->peraggstate[stateno];
+ ExprState *filter = peraggstate->aggfilter;
int numTransInputs = peraggstate->numTransInputs;
int i;
TupleTableSlot *slot;
@@ -806,7 +873,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
for (setno = 0; setno < numGroupingSets; setno++)
{
- AggStatePerGroup pergroupstate = &pergroup[aggno + (setno * numAggs)];
+ AggStatePerGroup pergroupstate = &pergroup[stateno + (setno * numStates)];
aggstate->current_set = setno;
@@ -841,7 +908,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
*/
static void
process_ordered_aggregate_single(AggState *aggstate,
- AggStatePerAgg peraggstate,
+ AggStatePerAggState peraggstate,
AggStatePerGroup pergroupstate)
{
Datum oldVal = (Datum) 0;
@@ -930,7 +997,7 @@ process_ordered_aggregate_single(AggState *aggstate,
*/
static void
process_ordered_aggregate_multi(AggState *aggstate,
- AggStatePerAgg peraggstate,
+ AggStatePerAggState peraggstate,
AggStatePerGroup pergroupstate)
{
MemoryContext workcontext = aggstate->tmpcontext->ecxt_per_tuple_memory;
@@ -1009,10 +1076,14 @@ process_ordered_aggregate_multi(AggState *aggstate,
*
* The finalfunction will be run, and the result delivered, in the
* output-tuple context; caller's CurrentMemoryContext does not matter.
+ *
+ * The finalfn uses the state as set in the stateno. This also might be
+ * being used by another aggregate function, so it's important that we do
+ * nothing destructive here.
*/
static void
finalize_aggregate(AggState *aggstate,
- AggStatePerAgg peraggstate,
+ AggStatePerAgg peragg,
AggStatePerGroup pergroupstate,
Datum *resultVal, bool *resultIsNull)
{
@@ -1021,6 +1092,7 @@ finalize_aggregate(AggState *aggstate,
MemoryContext oldContext;
int i;
ListCell *lc;
+ AggStatePerAggState peraggstate = &aggstate->peraggstate[peragg->stateno];
oldContext = MemoryContextSwitchTo(aggstate->ss.ps.ps_ExprContext->ecxt_per_tuple_memory);
@@ -1031,7 +1103,7 @@ finalize_aggregate(AggState *aggstate,
* for the transition state value.
*/
i = 1;
- foreach(lc, peraggstate->aggrefstate->aggdirectargs)
+ foreach(lc, peraggstate->aggdirectargs)
{
ExprState *expr = (ExprState *) lfirst(lc);
@@ -1046,14 +1118,14 @@ finalize_aggregate(AggState *aggstate,
/*
* Apply the agg's finalfn if one is provided, else return transValue.
*/
- if (OidIsValid(peraggstate->finalfn_oid))
+ if (OidIsValid(peragg->finalfn_oid))
{
- int numFinalArgs = peraggstate->numFinalArgs;
+ int numFinalArgs = peragg->numFinalArgs;
- /* set up aggstate->curperagg for AggGetAggref() */
- aggstate->curperagg = peraggstate;
+ /* set up aggstate->curperaggstate for AggGetAggref() */
+ aggstate->curperaggstate = peraggstate;
- InitFunctionCallInfoData(fcinfo, &peraggstate->finalfn,
+ InitFunctionCallInfoData(fcinfo, &peragg->finalfn,
numFinalArgs,
peraggstate->aggCollation,
(void *) aggstate, NULL);
@@ -1082,7 +1154,7 @@ finalize_aggregate(AggState *aggstate,
*resultVal = FunctionCallInvoke(&fcinfo);
*resultIsNull = fcinfo.isnull;
}
- aggstate->curperagg = NULL;
+ aggstate->curperaggstate = NULL;
}
else
{
@@ -1093,12 +1165,12 @@ finalize_aggregate(AggState *aggstate,
/*
* If result is pass-by-ref, make sure it is in the right context.
*/
- if (!peraggstate->resulttypeByVal && !*resultIsNull &&
+ if (!peragg->resulttypeByVal && !*resultIsNull &&
!MemoryContextContains(CurrentMemoryContext,
DatumGetPointer(*resultVal)))
*resultVal = datumCopy(*resultVal,
- peraggstate->resulttypeByVal,
- peraggstate->resulttypeLen);
+ peragg->resulttypeByVal,
+ peragg->resulttypeLen);
MemoryContextSwitchTo(oldContext);
}
@@ -1173,7 +1245,7 @@ prepare_projection_slot(AggState *aggstate, TupleTableSlot *slot, int currentSet
*/
static void
finalize_aggregates(AggState *aggstate,
- AggStatePerAgg peragg,
+ AggStatePerAgg peraggs,
AggStatePerGroup pergroup,
int currentSet)
{
@@ -1189,10 +1261,12 @@ finalize_aggregates(AggState *aggstate,
for (aggno = 0; aggno < aggstate->numaggs; aggno++)
{
- AggStatePerAgg peraggstate = &peragg[aggno];
+ AggStatePerAgg peragg = &peraggs[aggno];
+ int stateno = peragg->stateno;
+ AggStatePerAggState peraggstate = &aggstate->peraggstate[stateno];
AggStatePerGroup pergroupstate;
- pergroupstate = &pergroup[aggno + (currentSet * (aggstate->numaggs))];
+ pergroupstate = &pergroup[stateno + (currentSet * (aggstate->numstates))];
if (peraggstate->numSortCols > 0)
{
@@ -1208,7 +1282,7 @@ finalize_aggregates(AggState *aggstate,
pergroupstate);
}
- finalize_aggregate(aggstate, peraggstate, pergroupstate,
+ finalize_aggregate(aggstate, peragg, pergroupstate,
&aggvalues[aggno], &aggnulls[aggno]);
}
}
@@ -1428,7 +1502,7 @@ lookup_hash_entry(AggState *aggstate, TupleTableSlot *inputslot)
if (isnew)
{
/* initialize aggregates for new tuple group */
- initialize_aggregates(aggstate, aggstate->peragg, entry->pergroup, 0);
+ initialize_aggregates(aggstate, aggstate->peraggstate, entry->pergroup, 0);
}
return entry;
@@ -1505,6 +1579,7 @@ agg_retrieve_direct(AggState *aggstate)
ExprContext *econtext;
ExprContext *tmpcontext;
AggStatePerAgg peragg;
+ AggStatePerAggState peraggstate;
AggStatePerGroup pergroup;
TupleTableSlot *outerslot;
TupleTableSlot *firstSlot;
@@ -1527,6 +1602,7 @@ agg_retrieve_direct(AggState *aggstate)
tmpcontext = aggstate->tmpcontext;
peragg = aggstate->peragg;
+ peraggstate = aggstate->peraggstate;
pergroup = aggstate->pergroup;
firstSlot = aggstate->ss.ss_ScanTupleSlot;
@@ -1716,7 +1792,7 @@ agg_retrieve_direct(AggState *aggstate)
/*
* Initialize working state for a new input tuple group.
*/
- initialize_aggregates(aggstate, peragg, pergroup, numReset);
+ initialize_aggregates(aggstate, peraggstate, pergroup, numReset);
if (aggstate->grp_firstTuple != NULL)
{
@@ -1945,17 +2021,18 @@ AggState *
ExecInitAgg(Agg *node, EState *estate, int eflags)
{
AggState *aggstate;
- AggStatePerAgg peragg;
+ AggStatePerAgg peraggs;
+ AggStatePerAggState peraggstates;
Plan *outerPlan;
ExprContext *econtext;
int numaggs,
+ stateno,
aggno;
int phase;
ListCell *l;
Bitmapset *all_grouped_cols = NULL;
int numGroupingSets = 1;
int numPhases;
- int currentsortno = 0;
int i = 0;
int j = 0;
@@ -1971,12 +2048,14 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
aggstate->aggs = NIL;
aggstate->numaggs = 0;
+ aggstate->numstates = 0;
aggstate->maxsets = 0;
aggstate->hashfunctions = NULL;
aggstate->projected_set = -1;
aggstate->current_set = 0;
aggstate->peragg = NULL;
- aggstate->curperagg = NULL;
+ aggstate->peraggstate = NULL;
+ aggstate->curperaggstate = NULL;
aggstate->agg_done = false;
aggstate->input_done = false;
aggstate->pergroup = NULL;
@@ -2209,8 +2288,11 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
econtext->ecxt_aggvalues = (Datum *) palloc0(sizeof(Datum) * numaggs);
econtext->ecxt_aggnulls = (bool *) palloc0(sizeof(bool) * numaggs);
- peragg = (AggStatePerAgg) palloc0(sizeof(AggStatePerAggData) * numaggs);
- aggstate->peragg = peragg;
+ peraggs = (AggStatePerAgg) palloc0(sizeof(AggStatePerAggData) * numaggs);
+ peraggstates = (AggStatePerAggState) palloc0(sizeof(AggStatePerAggStateData) * numaggs);
+
+ aggstate->peragg = peraggs;
+ aggstate->peraggstate = peraggstates;
if (node->aggstrategy == AGG_HASHED)
{
@@ -2232,69 +2314,67 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
/*
* Perform lookups of aggregate function info, and initialize the
- * unchanging fields of the per-agg data. We also detect duplicate
- * aggregates (for example, "SELECT sum(x) ... HAVING sum(x) > 0"). When
- * duplicates are detected, we only make an AggStatePerAgg struct for the
- * first one. The clones are simply pointed at the same result entry by
- * giving them duplicate aggno values.
+ * unchanging fields of the per-agg data.
*/
aggno = -1;
+ stateno = -1;
foreach(l, aggstate->aggs)
{
AggrefExprState *aggrefstate = (AggrefExprState *) lfirst(l);
Aggref *aggref = (Aggref *) aggrefstate->xprstate.expr;
- AggStatePerAgg peraggstate;
+ AggStatePerAgg peragg;
+ AggStatePerAggState peraggstate;
+ AggRefCompatibility agg_match;
Oid inputTypes[FUNC_MAX_ARGS];
int numArguments;
int numDirectArgs;
- int numInputs;
- int numSortCols;
- int numDistinctCols;
- List *sortlist;
HeapTuple aggTuple;
Form_pg_aggregate aggform;
- Oid aggtranstype;
AclResult aclresult;
Oid transfn_oid,
finalfn_oid;
- Expr *transfnexpr,
- *finalfnexpr;
- Datum textInitVal;
- int i;
- ListCell *lc;
+ Expr *finalfnexpr;
+ int existing_aggno;
/* Planner should have assigned aggregate to correct level */
Assert(aggref->agglevelsup == 0);
- /* Look for a previous duplicate aggregate */
- for (i = 0; i <= aggno; i++)
- {
- if (equal(aggref, peragg[i].aggref) &&
- !contain_volatile_functions((Node *) aggref))
- break;
- }
- if (i <= aggno)
+ /*
+ * For performance reasons we detect duplicate aggregates (for
+ * example, "SELECT sum(x) ... HAVING sum(x) > 0"). When duplicates
+ * are detected, we only make an AggStatePerAgg struct for the first
+ * one. The clones are simply pointed at the same result entry by
+ * giving them duplicate aggno values. We also do our best to reuse
+ * duplicate aggregate states. The query may use 2 or more aggregate
+ * functions which share the same transition function and initial
+ * value therefore would end up calculating the same state. In this
+ * case we can just calculate the state once, however if the finalfns
+ * do not match then we must create a new peragg to store the varying
+ * finalfn.
+ */
+
+ /* check if we have previous agg or state matches that can be reused */
+ agg_match = find_compatible_aggref(aggref, aggstate, aggno,
+ &existing_aggno);
+ if (agg_match == AGGREF_EXACT_MATCH)
{
- /* Found a match to an existing entry, so just mark it */
- aggrefstate->aggno = i;
+ /*
+ * Exact match -- this must be using same aggregate function or
+ * have the same transfn and finalfn. Just reuse the existing agg.
+ */
+ aggrefstate->aggno = existing_aggno;
continue;
}
- /* Nope, so assign a new PerAgg record */
- peraggstate = &peragg[++aggno];
+ /*
+ * Otherwise set up a new Per-Agg for this, and possibly a new
+ * per-AggState too.
+ */
/* Mark Aggref state node with assigned index in the result array */
+ peragg = &peraggs[++aggno];
aggrefstate->aggno = aggno;
- /* Begin filling in the peraggstate data */
- peraggstate->aggrefstate = aggrefstate;
- peraggstate->aggref = aggref;
- peraggstate->sortstates = (Tuplesortstate **)
- palloc0(sizeof(Tuplesortstate *) * numGroupingSets);
-
- for (currentsortno = 0; currentsortno < numGroupingSets; currentsortno++)
- peraggstate->sortstates[currentsortno] = NULL;
-
/* Fetch the pg_aggregate row */
aggTuple = SearchSysCache1(AGGFNOID,
ObjectIdGetDatum(aggref->aggfnoid));
@@ -2311,8 +2391,8 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
get_func_name(aggref->aggfnoid));
InvokeFunctionExecuteHook(aggref->aggfnoid);
- peraggstate->transfn_oid = transfn_oid = aggform->aggtransfn;
- peraggstate->finalfn_oid = finalfn_oid = aggform->aggfinalfn;
+ transfn_oid = aggform->aggtransfn;
+ peragg->finalfn_oid = finalfn_oid = aggform->aggfinalfn;
/* Check that aggregate owner has permission to call component fns */
{
@@ -2327,12 +2407,20 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
aggOwner = ((Form_pg_proc) GETSTRUCT(procTuple))->proowner;
ReleaseSysCache(procTuple);
- aclresult = pg_proc_aclcheck(transfn_oid, aggOwner,
- ACL_EXECUTE);
- if (aclresult != ACLCHECK_OK)
- aclcheck_error(aclresult, ACL_KIND_PROC,
- get_func_name(transfn_oid));
- InvokeFunctionExecuteHook(transfn_oid);
+ /*
+ * If we're reusing an existing state, no need to check the
+ * transfn permission again.
+ */
+ if (agg_match == AGGREF_NO_MATCH)
+ {
+ aclresult = pg_proc_aclcheck(transfn_oid, aggOwner,
+ ACL_EXECUTE);
+ if (aclresult != ACLCHECK_OK)
+ aclcheck_error(aclresult, ACL_KIND_PROC,
+ get_func_name(transfn_oid));
+ InvokeFunctionExecuteHook(transfn_oid);
+ }
+
if (OidIsValid(finalfn_oid))
{
aclresult = pg_proc_aclcheck(finalfn_oid, aggOwner,
@@ -2350,236 +2438,333 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
* agg accepts ANY or a polymorphic type.
*/
numArguments = get_aggregate_argtypes(aggref, inputTypes);
- peraggstate->numArguments = numArguments;
/* Count the "direct" arguments, if any */
numDirectArgs = list_length(aggref->aggdirectargs);
- /* Count the number of aggregated input columns */
- numInputs = list_length(aggref->args);
- peraggstate->numInputs = numInputs;
-
- /* Detect how many arguments to pass to the transfn */
- if (AGGKIND_IS_ORDERED_SET(aggref->aggkind))
- peraggstate->numTransInputs = numInputs;
+ /*
+ * Build working state for invoking the transition function (or look
+ * up previously initialized working state, if we can share it).
+ */
+ if (agg_match == AGGREF_NO_MATCH)
+ {
+ peraggstate = &peraggstates[++stateno];
+ build_peraggstate_for_aggref(peraggstate, aggstate, estate,
+ aggref,
+ aggTuple, inputTypes, numArguments);
+ peragg->stateno = stateno;
+ }
else
- peraggstate->numTransInputs = numArguments;
+ {
+ int existing_stateno = peraggs[existing_aggno].stateno;
+
+ peraggstate = &peraggstates[existing_stateno];
+ peragg->stateno = existing_stateno;
+
+ /* when reusing the state the transfns should match! */
+ Assert(peraggstate->transfn_oid == aggform->aggtransfn);
+ }
/* Detect how many arguments to pass to the finalfn */
if (aggform->aggfinalextra)
- peraggstate->numFinalArgs = numArguments + 1;
+ peragg->numFinalArgs = numArguments + 1;
else
- peraggstate->numFinalArgs = numDirectArgs + 1;
-
- /* resolve actual type of transition state, if polymorphic */
- aggtranstype = resolve_aggregate_transtype(aggref->aggfnoid,
- aggform->aggtranstype,
- inputTypes,
- numArguments);
-
- /* build expression trees using actual argument & result types */
- build_aggregate_fnexprs(inputTypes,
- numArguments,
- numDirectArgs,
- peraggstate->numFinalArgs,
- aggref->aggvariadic,
- aggtranstype,
- aggref->aggtype,
- aggref->inputcollid,
- transfn_oid,
- InvalidOid, /* invtrans is not needed here */
- finalfn_oid,
- &transfnexpr,
- NULL,
- &finalfnexpr);
-
- /* set up infrastructure for calling the transfn and finalfn */
- fmgr_info(transfn_oid, &peraggstate->transfn);
- fmgr_info_set_expr((Node *) transfnexpr, &peraggstate->transfn);
+ peragg->numFinalArgs = numDirectArgs + 1;
+ /*
+ * build expression trees using actual argument & result types for the
+ * finalfn, if it exists
+ */
if (OidIsValid(finalfn_oid))
{
- fmgr_info(finalfn_oid, &peraggstate->finalfn);
- fmgr_info_set_expr((Node *) finalfnexpr, &peraggstate->finalfn);
+ build_aggregate_finalfn_expr(inputTypes,
+ peragg->numFinalArgs,
+ peraggstate->aggtranstype,
+ aggref->aggtype,
+ aggref->inputcollid,
+ finalfn_oid,
+ &finalfnexpr);
+ fmgr_info(finalfn_oid, &peragg->finalfn);
+ fmgr_info_set_expr((Node *) finalfnexpr, &peragg->finalfn);
}
- peraggstate->aggCollation = aggref->inputcollid;
+ /* get info about the result type's datatype */
+ get_typlenbyval(aggref->aggtype,
+ &peragg->resulttypeLen,
+ &peragg->resulttypeByVal);
- InitFunctionCallInfoData(peraggstate->transfn_fcinfo,
- &peraggstate->transfn,
- peraggstate->numTransInputs + 1,
- peraggstate->aggCollation,
- (void *) aggstate, NULL);
+ ReleaseSysCache(aggTuple);
+ }
- /* get info about relevant datatypes */
- get_typlenbyval(aggref->aggtype,
- &peraggstate->resulttypeLen,
- &peraggstate->resulttypeByVal);
- get_typlenbyval(aggtranstype,
- &peraggstate->transtypeLen,
- &peraggstate->transtypeByVal);
+ /*
+ * Update numaggs to match the number of unique aggregates found. Also set
+ * numstates to the number of unique aggregate states found.
+ */
+ aggstate->numaggs = aggno + 1;
+ aggstate->numstates = stateno + 1;
- /*
- * initval is potentially null, so don't try to access it as a struct
- * field. Must do it the hard way with SysCacheGetAttr.
- */
- textInitVal = SysCacheGetAttr(AGGFNOID, aggTuple,
- Anum_pg_aggregate_agginitval,
- &peraggstate->initValueIsNull);
+ return aggstate;
+}
- if (peraggstate->initValueIsNull)
- peraggstate->initValue = (Datum) 0;
- else
- peraggstate->initValue = GetAggInitVal(textInitVal,
- aggtranstype);
+/*
+ * Build the state needed to calculate a state value for an aggregate.
+ *
+ * This initializes all the fields in 'peraggstate'. 'aggTuple',
+ * 'inputTypes' and 'numArguments' could be derived from 'aggref', but the
+ * caller has calculated them already, so might as well pass them.
+ */
+static void
+build_peraggstate_for_aggref(AggStatePerAggState peraggstate,
+ AggState *aggstate, EState *estate,
+ Aggref *aggref, HeapTuple aggTuple,
+ Oid *inputTypes, int numArguments)
+{
+ Form_pg_aggregate aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple);
+ int numGroupingSets = Max(aggstate->maxsets, 1);
+ Expr *transfnexpr;
+ ListCell *lc;
+ int numInputs;
+ int numDirectArgs;
+ List *sortlist;
+ int numSortCols;
+ int numDistinctCols;
+ int currentsortno;
+ int naggs;
+ int i;
+ Datum textInitVal;
+ Oid transfn_oid;
- /*
- * If the transfn is strict and the initval is NULL, make sure input
- * type and transtype are the same (or at least binary-compatible), so
- * that it's OK to use the first aggregated input value as the initial
- * transValue. This should have been checked at agg definition time,
- * but we must check again in case the transfn's strictness property
- * has been changed.
- */
- if (peraggstate->transfn.fn_strict && peraggstate->initValueIsNull)
- {
- if (numArguments <= numDirectArgs ||
- !IsBinaryCoercible(inputTypes[numDirectArgs], aggtranstype))
- ereport(ERROR,
- (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
- errmsg("aggregate %u needs to have compatible input type and transition type",
- aggref->aggfnoid)));
- }
+ /* Begin filling in the peraggstate data */
+ peraggstate->aggref = aggref;
+ peraggstate->aggCollation = aggref->inputcollid;
+ peraggstate->transfn_oid = transfn_oid = aggform->aggtransfn;
- /*
- * Get a tupledesc corresponding to the aggregated inputs (including
- * sort expressions) of the agg.
- */
- peraggstate->evaldesc = ExecTypeFromTL(aggref->args, false);
+ /* Count the "direct" arguments, if any */
+ numDirectArgs = list_length(aggref->aggdirectargs);
+
+ /* Count the number of aggregated input columns */
+ peraggstate->numInputs = numInputs = list_length(aggref->args);
+
+ /* resolve actual type of transition state, if polymorphic */
+ peraggstate->aggtranstype =
+ resolve_aggregate_transtype(aggref->aggfnoid,
+ aggform->aggtranstype,
+ inputTypes,
+ numArguments);
+
+ /* Detect how many arguments to pass to the transfn */
+ if (AGGKIND_IS_ORDERED_SET(aggref->aggkind))
+ peraggstate->numTransInputs = numInputs;
+ else
+ peraggstate->numTransInputs = numArguments;
+
+ /*
+ * Set up infrastructure for calling the transfn
+ */
+ build_aggregate_transfn_expr(inputTypes,
+ numArguments,
+ numDirectArgs,
+ aggref->aggvariadic,
+ peraggstate->aggtranstype,
+ aggref->inputcollid,
+ transfn_oid,
+ InvalidOid, /* invtrans is not needed here */
+ &transfnexpr,
+ NULL);
+ fmgr_info(peraggstate->transfn_oid, &peraggstate->transfn);
+ fmgr_info_set_expr((Node *) transfnexpr, &peraggstate->transfn);
+
+ InitFunctionCallInfoData(peraggstate->transfn_fcinfo,
+ &peraggstate->transfn,
+ peraggstate->numTransInputs + 1,
+ peraggstate->aggCollation,
+ (void *) aggstate, NULL);
+
+
+ /*
+ * Look up the initial value.
+ *
+ * initval is potentially null, so don't try to access it as a struct
+ * field. Must do it the hard way with SysCacheGetAttr.
+ */
+ textInitVal = SysCacheGetAttr(AGGFNOID, aggTuple,
+ Anum_pg_aggregate_agginitval,
+ &peraggstate->initValueIsNull);
+
+ if (peraggstate->initValueIsNull)
+ peraggstate->initValue = (Datum) 0;
+ else
+ peraggstate->initValue = GetAggInitVal(textInitVal,
+ peraggstate->aggtranstype);
- /* Create slot we're going to do argument evaluation in */
- peraggstate->evalslot = ExecInitExtraTupleSlot(estate);
- ExecSetSlotDescriptor(peraggstate->evalslot, peraggstate->evaldesc);
+ /*
+ * If the transfn is strict and the initval is NULL, make sure input type
+ * and transtype are the same (or at least binary-compatible), so that
+ * it's OK to use the first aggregated input value as the initial
+ * transValue. This should have been checked at agg definition time, but
+ * we must check again in case the transfn's strictness property has been
+ * changed.
+ */
+ if (peraggstate->transfn.fn_strict && peraggstate->initValueIsNull)
+ {
+ if (numArguments <= numDirectArgs ||
+ !IsBinaryCoercible(inputTypes[numDirectArgs],
+ peraggstate->aggtranstype))
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+ errmsg("aggregate needs to have compatible input type and transition type")));
+ }
+
+ /* get info about the state value's datatype */
+ get_typlenbyval(peraggstate->aggtranstype,
+ &peraggstate->transtypeLen,
+ &peraggstate->transtypeByVal);
- /* Set up projection info for evaluation */
- peraggstate->evalproj = ExecBuildProjectionInfo(aggrefstate->args,
- aggstate->tmpcontext,
- peraggstate->evalslot,
- NULL);
+ /*
+ * Get a tupledesc corresponding to the aggregated inputs (including sort
+ * expressions) of the agg.
+ */
+ peraggstate->evaldesc = ExecTypeFromTL(aggref->args, false);
+ /* Create slot we're going to do argument evaluation in */
+ peraggstate->evalslot = ExecInitExtraTupleSlot(estate);
+ ExecSetSlotDescriptor(peraggstate->evalslot, peraggstate->evaldesc);
+
+ /* Initialize the input and FILTER expressions */
+ naggs = aggstate->numaggs;
+ peraggstate->aggfilter = ExecInitExpr(aggref->aggfilter,
+ (PlanState *) aggstate);
+ peraggstate->aggdirectargs = (List *) ExecInitExpr((Expr *) aggref->aggdirectargs,
+ (PlanState *) aggstate);
+ peraggstate->args = (List *) ExecInitExpr((Expr *) aggref->args,
+ (PlanState *) aggstate);
+
+ /*
+ * Complain if the aggregate's arguments contain any aggregates; nested
+ * agg functions are semantically nonsensical. (This should have been
+ * caught earlier, but we defend against it here anyway.)
+ */
+ if (naggs != aggstate->numaggs)
+ ereport(ERROR,
+ (errcode(ERRCODE_GROUPING_ERROR),
+ errmsg("aggregate function calls cannot be nested")));
+
+ /* Set up projection info for evaluation */
+ peraggstate->evalproj = ExecBuildProjectionInfo(peraggstate->args,
+ aggstate->tmpcontext,
+ peraggstate->evalslot,
+ NULL);
+
+ /*
+ * If we're doing either DISTINCT or ORDER BY for a plain agg, then we
+ * have a list of SortGroupClause nodes; fish out the data in them and
+ * stick them into arrays. We ignore ORDER BY for an ordered-set agg,
+ * however; the agg's transfn and finalfn are responsible for that.
+ *
+ * Note that by construction, if there is a DISTINCT clause then the ORDER
+ * BY clause is a prefix of it (see transformDistinctClause).
+ */
+ if (AGGKIND_IS_ORDERED_SET(aggref->aggkind))
+ {
+ sortlist = NIL;
+ numSortCols = numDistinctCols = 0;
+ }
+ else if (aggref->aggdistinct)
+ {
+ sortlist = aggref->aggdistinct;
+ numSortCols = numDistinctCols = list_length(sortlist);
+ Assert(numSortCols >= list_length(aggref->aggorder));
+ }
+ else
+ {
+ sortlist = aggref->aggorder;
+ numSortCols = list_length(sortlist);
+ numDistinctCols = 0;
+ }
+
+ peraggstate->numSortCols = numSortCols;
+ peraggstate->numDistinctCols = numDistinctCols;
+
+ if (numSortCols > 0)
+ {
/*
- * If we're doing either DISTINCT or ORDER BY for a plain agg, then we
- * have a list of SortGroupClause nodes; fish out the data in them and
- * stick them into arrays. We ignore ORDER BY for an ordered-set agg,
- * however; the agg's transfn and finalfn are responsible for that.
- *
- * Note that by construction, if there is a DISTINCT clause then the
- * ORDER BY clause is a prefix of it (see transformDistinctClause).
+ * We don't implement DISTINCT or ORDER BY aggs in the HASHED case
+ * (yet)
*/
- if (AGGKIND_IS_ORDERED_SET(aggref->aggkind))
- {
- sortlist = NIL;
- numSortCols = numDistinctCols = 0;
- }
- else if (aggref->aggdistinct)
+ Assert(((Agg *) aggstate->ss.ps.plan)->aggstrategy != AGG_HASHED);
+
+ /* If we have only one input, we need its len/byval info. */
+ if (numInputs == 1)
{
- sortlist = aggref->aggdistinct;
- numSortCols = numDistinctCols = list_length(sortlist);
- Assert(numSortCols >= list_length(aggref->aggorder));
+ get_typlenbyval(inputTypes[numDirectArgs],
+ &peraggstate->inputtypeLen,
+ &peraggstate->inputtypeByVal);
}
- else
+ else if (numDistinctCols > 0)
{
- sortlist = aggref->aggorder;
- numSortCols = list_length(sortlist);
- numDistinctCols = 0;
+ /* we will need an extra slot to store prior values */
+ peraggstate->uniqslot = ExecInitExtraTupleSlot(estate);
+ ExecSetSlotDescriptor(peraggstate->uniqslot,
+ peraggstate->evaldesc);
}
- peraggstate->numSortCols = numSortCols;
- peraggstate->numDistinctCols = numDistinctCols;
-
- if (numSortCols > 0)
+ /* Extract the sort information for use later */
+ peraggstate->sortColIdx =
+ (AttrNumber *) palloc(numSortCols * sizeof(AttrNumber));
+ peraggstate->sortOperators =
+ (Oid *) palloc(numSortCols * sizeof(Oid));
+ peraggstate->sortCollations =
+ (Oid *) palloc(numSortCols * sizeof(Oid));
+ peraggstate->sortNullsFirst =
+ (bool *) palloc(numSortCols * sizeof(bool));
+
+ i = 0;
+ foreach(lc, sortlist)
{
- /*
- * We don't implement DISTINCT or ORDER BY aggs in the HASHED case
- * (yet)
- */
- Assert(node->aggstrategy != AGG_HASHED);
-
- /* If we have only one input, we need its len/byval info. */
- if (numInputs == 1)
- {
- get_typlenbyval(inputTypes[numDirectArgs],
- &peraggstate->inputtypeLen,
- &peraggstate->inputtypeByVal);
- }
- else if (numDistinctCols > 0)
- {
- /* we will need an extra slot to store prior values */
- peraggstate->uniqslot = ExecInitExtraTupleSlot(estate);
- ExecSetSlotDescriptor(peraggstate->uniqslot,
- peraggstate->evaldesc);
- }
+ SortGroupClause *sortcl = (SortGroupClause *) lfirst(lc);
+ TargetEntry *tle = get_sortgroupclause_tle(sortcl, aggref->args);
- /* Extract the sort information for use later */
- peraggstate->sortColIdx =
- (AttrNumber *) palloc(numSortCols * sizeof(AttrNumber));
- peraggstate->sortOperators =
- (Oid *) palloc(numSortCols * sizeof(Oid));
- peraggstate->sortCollations =
- (Oid *) palloc(numSortCols * sizeof(Oid));
- peraggstate->sortNullsFirst =
- (bool *) palloc(numSortCols * sizeof(bool));
+ /* the parser should have made sure of this */
+ Assert(OidIsValid(sortcl->sortop));
- i = 0;
- foreach(lc, sortlist)
- {
- SortGroupClause *sortcl = (SortGroupClause *) lfirst(lc);
- TargetEntry *tle = get_sortgroupclause_tle(sortcl,
- aggref->args);
-
- /* the parser should have made sure of this */
- Assert(OidIsValid(sortcl->sortop));
-
- peraggstate->sortColIdx[i] = tle->resno;
- peraggstate->sortOperators[i] = sortcl->sortop;
- peraggstate->sortCollations[i] = exprCollation((Node *) tle->expr);
- peraggstate->sortNullsFirst[i] = sortcl->nulls_first;
- i++;
- }
- Assert(i == numSortCols);
+ peraggstate->sortColIdx[i] = tle->resno;
+ peraggstate->sortOperators[i] = sortcl->sortop;
+ peraggstate->sortCollations[i] = exprCollation((Node *) tle->expr);
+ peraggstate->sortNullsFirst[i] = sortcl->nulls_first;
+ i++;
}
+ Assert(i == numSortCols);
+ }
- if (aggref->aggdistinct)
- {
- Assert(numArguments > 0);
+ if (aggref->aggdistinct)
+ {
+ Assert(numArguments > 0);
- /*
- * We need the equal function for each DISTINCT comparison we will
- * make.
- */
- peraggstate->equalfns =
- (FmgrInfo *) palloc(numDistinctCols * sizeof(FmgrInfo));
+ /*
+ * We need the equal function for each DISTINCT comparison we will
+ * make.
+ */
+ peraggstate->equalfns =
+ (FmgrInfo *) palloc(numDistinctCols * sizeof(FmgrInfo));
- i = 0;
- foreach(lc, aggref->aggdistinct)
- {
- SortGroupClause *sortcl = (SortGroupClause *) lfirst(lc);
+ i = 0;
+ foreach(lc, aggref->aggdistinct)
+ {
+ SortGroupClause *sortcl = (SortGroupClause *) lfirst(lc);
- fmgr_info(get_opcode(sortcl->eqop), &peraggstate->equalfns[i]);
- i++;
- }
- Assert(i == numDistinctCols);
+ fmgr_info(get_opcode(sortcl->eqop), &peraggstate->equalfns[i]);
+ i++;
}
-
- ReleaseSysCache(aggTuple);
+ Assert(i == numDistinctCols);
}
- /* Update numaggs to match number of unique aggregates found */
- aggstate->numaggs = aggno + 1;
-
- return aggstate;
+ peraggstate->sortstates = (Tuplesortstate **)
+ palloc0(sizeof(Tuplesortstate *) * numGroupingSets);
+ for (currentsortno = 0; currentsortno < numGroupingSets; currentsortno++)
+ peraggstate->sortstates[currentsortno] = NULL;
}
+
static Datum
GetAggInitVal(Datum textInitVal, Oid transtype)
{
@@ -2596,11 +2781,199 @@ GetAggInitVal(Datum textInitVal, Oid transtype)
return initVal;
}
+/*
+ * find_compatible_aggref
+ * Searches the previously looked at aggregates in order to find a
+ * compatible aggregate or aggregate state. If a positive match is found
+ * then foundaggno is set to the aggregate which matches.
+ * When AGGREF_STATE_MATCH is returned the caller must only use the state
+ * of the foundaggno, not the actual aggno itself.
+ * When AGGREF_EXACT_MATCH is returned the caller may use both the aggno
+ * and the state which that aggno uses.
+ *
+ * Scenario 1 -- An aggregate function appears more than once in query:
+ *
+ * SELECT SUM(x) FROM ... HAVING SUM(x) > 0
+ *
+ * Since in this case the aggregates are both the same we can optimize by
+ * only calculating aggregate state and calling the finalfn just once. This
+ * would be an AGGREF_EXACT_MATCH, meaning both the state and the final
+ * function call are shared.
+ *
+ * Scenario 2 -- Two different aggregate functions appear in the query but
+ * the two functions happen to share the same transfn, but have
+ * different finalfn.
+ *
+ * SELECT SUM(x), AVG(x) FROM ...
+ *
+ * Since in our case these two aggregates both share the same transfn, but
+ * naturally they have different finalfns. This situation is classed as an
+ * AGGREF_STATE_MATCH. This means that the same state can be shared by both
+ * aggregates. Since the finalfn call is not the same this cannot be reused.
+ * For this case to be valid the INITCOND of the aggregate, if one exists, must
+ * also match.
+ *
+ * Scenario 3 -- The same aggregate function is called with different
+ * parameters.
+ *
+ * SELECT SUM(x),SUM(DISTINCT x) FROM ...
+ * SELECT SUM(x),SUM(y) FROM ...
+ * SELECT SUM(x),SUM(x) FILTER(WHERE x > 0) FROM ...
+ *
+ * All three of the above queries cannot share the same state and have to be
+ * calculated independently.
+ *
+ * Scenario 4 -- Different aggregates with the same parameters and the same
+ * transfn and finalfn.
+ *
+ * SELECT SUM(x),SUM2(x) FROM ...
+ *
+ * A perhaps unlikely scenario where two aggregate functions exist which have,
+ * both the same transfn and the same finalfn. In this case we can report an
+ * AGGREF_EXACT_MATCH, providing the INITCOND of both aggregates are the same.
+ */
+static AggRefCompatibility
+find_compatible_aggref(Aggref *newagg, AggState *aggstate,
+ int lastaggno, int *foundaggno)
+{
+ int aggno;
+ int statematchaggno;
+ AggStatePerAggState peraggstates;
+ AggStatePerAgg peraggs;
+
+ /* we mustn't reuse the aggref if it contains volatile function calls */
+ if (contain_volatile_functions((Node *) newagg))
+ return AGGREF_NO_MATCH;
+
+ statematchaggno = -1;
+ peraggstates = aggstate->peraggstate;
+ peraggs = aggstate->peragg;
+
+ /*
+ * Search through the list of already seen aggregates. We'll stop when we
+ * find an exact match, but until then we'll note any state matches that
+ * we find. We may have to fall back on these should we fail to find an
+ * exact match.
+ */
+ for (aggno = 0; aggno <= lastaggno; aggno++)
+ {
+ AggRefCompatibility matchtype;
+ AggStatePerAgg peragg;
+ AggStatePerAggState peraggstate;
+
+ peragg = &peraggs[aggno];
+ peraggstate = &peraggstates[peragg->stateno];
+
+ /* lookup the match type of this agg */
+ matchtype = aggref_has_compatible_states(newagg, peragg, peraggstate);
+
+ /* if it's an exact match then we're done. */
+ if (matchtype == AGGREF_EXACT_MATCH)
+ {
+ *foundaggno = aggno;
+ return AGGREF_EXACT_MATCH;
+ }
+
+ /* remember any state matches, but keep on looking... */
+ else if (matchtype == AGGREF_STATE_MATCH)
+ statematchaggno = aggno;
+ }
+
+ /* no exact match found, but did we find a state match? */
+ if (statematchaggno >= 0)
+ {
+ *foundaggno = statematchaggno;
+ return AGGREF_STATE_MATCH;
+ }
+
+ return AGGREF_NO_MATCH;
+}
+
+/*
+ * aggref_has_compatible_states
+ * Determines match type of this aggregate. See comments in
+ * find_compatible_aggref() for details.
+ */
+static AggRefCompatibility
+aggref_has_compatible_states(Aggref *newagg,
+ AggStatePerAgg peragg,
+ AggStatePerAggState peraggstate)
+{
+ Aggref *existingRef = peraggstate->aggref;
+
+ /* all of the following must be the same or it's no match */
+ if (newagg->inputcollid != existingRef->inputcollid ||
+ newagg->aggstar != existingRef->aggstar ||
+ newagg->aggvariadic != existingRef->aggvariadic ||
+ newagg->aggkind != existingRef->aggkind ||
+ !equal(newagg->aggdirectargs, existingRef->aggdirectargs) ||
+ !equal(newagg->args, existingRef->args) ||
+ !equal(newagg->aggorder, existingRef->aggorder) ||
+ !equal(newagg->aggdistinct, existingRef->aggdistinct) ||
+ !equal(newagg->aggfilter, existingRef->aggfilter))
+ return AGGREF_NO_MATCH;
+
+ /* if it's the same aggregate function then report exact match */
+ if (newagg->aggfnoid == existingRef->aggfnoid &&
+ newagg->aggtype == existingRef->aggtype &&
+ newagg->aggcollid == existingRef->aggcollid)
+ return AGGREF_EXACT_MATCH;
+ else
+ {
+ /*
+ * Aggregate functions differ. We'll need to do some more analysis
+ * before we can know what the match type will be. If the transfn
+ * match and the initvalue is the same then we can at least let the
+ * newagg share the state, but if the finalfn also happens to match
+ * then we can actually still report an exact match.
+ */
+ HeapTuple aggTuple;
+ Form_pg_aggregate aggform;
+ bool initValueIsNull;
+
+ /* Fetch the pg_aggregate row */
+ aggTuple = SearchSysCache1(AGGFNOID, ObjectIdGetDatum(newagg->aggfnoid));
+ if (!HeapTupleIsValid(aggTuple))
+ elog(ERROR, "cache lookup failed for aggregate %u", newagg->aggfnoid);
+ aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple);
+
+ /* if the transfns are not the same then the state can't be shared */
+ if (aggform->aggtransfn != peraggstate->transfn_oid)
+ {
+ ReleaseSysCache(aggTuple);
+ return AGGREF_NO_MATCH;
+ }
+
+ SysCacheGetAttr(AGGFNOID, aggTuple,
+ Anum_pg_aggregate_agginitval, &initValueIsNull);
+
+ ReleaseSysCache(aggTuple);
+
+ /*
+ * If both INITCONDs are null then the outcome depends on if the
+ * finalfns match.
+ */
+ if (initValueIsNull && peraggstate->initValueIsNull)
+ {
+ if (aggform->aggfinalfn != peragg->finalfn_oid)
+ return AGGREF_STATE_MATCH;
+ else
+ return AGGREF_EXACT_MATCH;
+ }
+
+ /*
+ * XXX perhaps we should check the value of the initValue to see if
+ * they match?
+ */
+ return AGGREF_NO_MATCH;
+ }
+}
+
void
ExecEndAgg(AggState *node)
{
PlanState *outerPlan;
- int aggno;
+ int stateno;
int numGroupingSets = Max(node->maxsets, 1);
int setno;
@@ -2611,9 +2984,9 @@ ExecEndAgg(AggState *node)
if (node->sort_out)
tuplesort_end(node->sort_out);
- for (aggno = 0; aggno < node->numaggs; aggno++)
+ for (stateno = 0; stateno < node->numstates; stateno++)
{
- AggStatePerAgg peraggstate = &node->peragg[aggno];
+ AggStatePerAggState peraggstate = &node->peraggstate[stateno];
for (setno = 0; setno < numGroupingSets; setno++)
{
@@ -2646,7 +3019,7 @@ ExecReScanAgg(AggState *node)
ExprContext *econtext = node->ss.ps.ps_ExprContext;
PlanState *outerPlan = outerPlanState(node);
Agg *aggnode = (Agg *) node->ss.ps.plan;
- int aggno;
+ int stateno;
int numGroupingSets = Max(node->maxsets, 1);
int setno;
@@ -2678,11 +3051,11 @@ ExecReScanAgg(AggState *node)
}
/* Make sure we have closed any open tuplesorts */
- for (aggno = 0; aggno < node->numaggs; aggno++)
+ for (stateno = 0; stateno < node->numstates; stateno++)
{
for (setno = 0; setno < numGroupingSets; setno++)
{
- AggStatePerAgg peraggstate = &node->peragg[aggno];
+ AggStatePerAggState peraggstate = &node->peraggstate[stateno];
if (peraggstate->sortstates[setno])
{
@@ -2811,10 +3184,12 @@ AggGetAggref(FunctionCallInfo fcinfo)
{
if (fcinfo->context && IsA(fcinfo->context, AggState))
{
- AggStatePerAgg curperagg = ((AggState *) fcinfo->context)->curperagg;
+ AggStatePerAggState curperaggstate;
+
+ curperaggstate = ((AggState *) fcinfo->context)->curperaggstate;
- if (curperagg)
- return curperagg->aggref;
+ if (curperaggstate)
+ return curperaggstate->aggref;
}
return NULL;
}
diff --git a/src/backend/executor/nodeWindowAgg.c b/src/backend/executor/nodeWindowAgg.c
index ecf96f8..c371d4d 100644
--- a/src/backend/executor/nodeWindowAgg.c
+++ b/src/backend/executor/nodeWindowAgg.c
@@ -2218,20 +2218,16 @@ initialize_peragg(WindowAggState *winstate, WindowFunc *wfunc,
numArguments);
/* build expression trees using actual argument & result types */
- build_aggregate_fnexprs(inputTypes,
- numArguments,
- 0, /* no ordered-set window functions yet */
- peraggstate->numFinalArgs,
- false, /* no variadic window functions yet */
- aggtranstype,
- wfunc->wintype,
- wfunc->inputcollid,
- transfn_oid,
- invtransfn_oid,
- finalfn_oid,
- &transfnexpr,
- &invtransfnexpr,
- &finalfnexpr);
+ build_aggregate_transfn_expr(inputTypes,
+ numArguments,
+ 0, /* no ordered-set window functions yet */
+ false, /* no variadic window functions yet */
+ wfunc->wintype,
+ wfunc->inputcollid,
+ transfn_oid,
+ invtransfn_oid,
+ &transfnexpr,
+ &invtransfnexpr);
/* set up infrastructure for calling the transfn(s) and finalfn */
fmgr_info(transfn_oid, &peraggstate->transfn);
@@ -2245,6 +2241,13 @@ initialize_peragg(WindowAggState *winstate, WindowFunc *wfunc,
if (OidIsValid(finalfn_oid))
{
+ build_aggregate_finalfn_expr(inputTypes,
+ peraggstate->numFinalArgs,
+ aggtranstype,
+ wfunc->wintype,
+ wfunc->inputcollid,
+ finalfn_oid,
+ &finalfnexpr);
fmgr_info(finalfn_oid, &peraggstate->finalfn);
fmgr_info_set_expr((Node *) finalfnexpr, &peraggstate->finalfn);
}
diff --git a/src/backend/parser/parse_agg.c b/src/backend/parser/parse_agg.c
index 478d8ca..65e6a85 100644
--- a/src/backend/parser/parse_agg.c
+++ b/src/backend/parser/parse_agg.c
@@ -1819,44 +1819,40 @@ resolve_aggregate_transtype(Oid aggfuncid,
}
/*
- * Create expression trees for the transition and final functions
+ * Create an expression tree for the transition functions
* of an aggregate. These are needed so that polymorphic functions
- * can be used within an aggregate --- without the expression trees,
+ * can be used within an aggregate --- without the expression tree,
* such functions would not know the datatypes they are supposed to use.
* (The trees will never actually be executed, however, so we can skimp
* a bit on correctness.)
*
- * agg_input_types, agg_state_type, agg_result_type identify the input,
- * transition, and result types of the aggregate. These should all be
- * resolved to actual types (ie, none should ever be ANYELEMENT etc).
+ * agg_input_types identifies the input types of the aggregate. These should
+ * be resolved to actual types (ie, none should ever be ANYELEMENT etc).
* agg_input_collation is the aggregate function's input collation.
*
* For an ordered-set aggregate, remember that agg_input_types describes
* the direct arguments followed by the aggregated arguments.
*
- * transfn_oid, invtransfn_oid and finalfn_oid identify the funcs to be
- * called; the latter two may be InvalidOid.
+ * transfn_oid and invtransfn_oid identify the funcs to be called; the
+ * latter may be InvalidOid, however if invtransfn_oid is set then
+ * transfn_oid must also be set.
*
* Pointers to the constructed trees are returned into *transfnexpr,
- * *invtransfnexpr and *finalfnexpr. If there is no invtransfn or finalfn,
- * the respective pointers are set to NULL. Since use of the invtransfn is
- * optional, NULL may be passed for invtransfnexpr.
+ * *invtransfnexpr. If there is no invtransfn, the respective pointer is set
+ * to NULL. Since use of the invtransfn is optional, NULL may be passed for
+ * invtransfnexpr.
*/
void
-build_aggregate_fnexprs(Oid *agg_input_types,
+build_aggregate_transfn_expr(Oid *agg_input_types,
int agg_num_inputs,
int agg_num_direct_inputs,
- int num_finalfn_inputs,
bool agg_variadic,
Oid agg_state_type,
- Oid agg_result_type,
Oid agg_input_collation,
Oid transfn_oid,
Oid invtransfn_oid,
- Oid finalfn_oid,
Expr **transfnexpr,
- Expr **invtransfnexpr,
- Expr **finalfnexpr)
+ Expr **invtransfnexpr)
{
Param *argp;
List *args;
@@ -1919,13 +1915,24 @@ build_aggregate_fnexprs(Oid *agg_input_types,
else
*invtransfnexpr = NULL;
}
+}
- /* see if we have a final function */
- if (!OidIsValid(finalfn_oid))
- {
- *finalfnexpr = NULL;
- return;
- }
+/*
+ * Like build_aggregate_transfn_expr, but creates an expression tree for
+ * the final function of an aggregate, rather than the transition function.
+ */
+void
+build_aggregate_finalfn_expr(Oid *agg_input_types,
+ int num_finalfn_inputs,
+ Oid agg_state_type,
+ Oid agg_result_type,
+ Oid agg_input_collation,
+ Oid finalfn_oid,
+ Expr **finalfnexpr)
+{
+ Param *argp;
+ List *args;
+ int i;
/*
* Build expr tree for final function
diff --git a/src/include/nodes/execnodes.h b/src/include/nodes/execnodes.h
index 303fc3c..65c0f74 100644
--- a/src/include/nodes/execnodes.h
+++ b/src/include/nodes/execnodes.h
@@ -609,9 +609,6 @@ typedef struct WholeRowVarExprState
typedef struct AggrefExprState
{
ExprState xprstate;
- List *aggdirectargs; /* states of direct-argument expressions */
- List *args; /* states of aggregated-argument expressions */
- ExprState *aggfilter; /* state of FILTER expression, if any */
int aggno; /* ID number for agg within its plan node */
} AggrefExprState;
@@ -1825,6 +1822,7 @@ typedef struct GroupState
*/
/* these structs are private in nodeAgg.c: */
typedef struct AggStatePerAggData *AggStatePerAgg;
+typedef struct AggStatePerAggStateData *AggStatePerAggState;
typedef struct AggStatePerGroupData *AggStatePerGroup;
typedef struct AggStatePerPhaseData *AggStatePerPhase;
@@ -1833,14 +1831,16 @@ typedef struct AggState
ScanState ss; /* its first field is NodeTag */
List *aggs; /* all Aggref nodes in targetlist & quals */
int numaggs; /* length of list (could be zero!) */
+ int numstates; /* number of peraggstate items */
AggStatePerPhase phase; /* pointer to current phase data */
int numphases; /* number of phases */
int current_phase; /* current phase number */
FmgrInfo *hashfunctions; /* per-grouping-field hash fns */
AggStatePerAgg peragg; /* per-Aggref information */
+ AggStatePerAggState peraggstate; /* per-Agg State information */
ExprContext **aggcontexts; /* econtexts for long-lived data (per GS) */
ExprContext *tmpcontext; /* econtext for input expressions */
- AggStatePerAgg curperagg; /* identifies currently active aggregate */
+ AggStatePerAggState curperaggstate; /* identifies currently active aggregate */
bool input_done; /* indicates end of input */
bool agg_done; /* indicates completion of Agg scan */
int projected_set; /* The last projected grouping set */
diff --git a/src/include/parser/parse_agg.h b/src/include/parser/parse_agg.h
index 6a5f9bb..e2b3894 100644
--- a/src/include/parser/parse_agg.h
+++ b/src/include/parser/parse_agg.h
@@ -35,19 +35,23 @@ extern Oid resolve_aggregate_transtype(Oid aggfuncid,
Oid *inputTypes,
int numArguments);
-extern void build_aggregate_fnexprs(Oid *agg_input_types,
+extern void build_aggregate_transfn_expr(Oid *agg_input_types,
int agg_num_inputs,
int agg_num_direct_inputs,
- int num_finalfn_inputs,
bool agg_variadic,
Oid agg_state_type,
- Oid agg_result_type,
Oid agg_input_collation,
Oid transfn_oid,
Oid invtransfn_oid,
- Oid finalfn_oid,
Expr **transfnexpr,
- Expr **invtransfnexpr,
+ Expr **invtransfnexpr);
+
+extern void build_aggregate_finalfn_expr(Oid *agg_input_types,
+ int num_finalfn_inputs,
+ Oid agg_state_type,
+ Oid agg_result_type,
+ Oid agg_input_collation,
+ Oid finalfn_oid,
Expr **finalfnexpr);
#endif /* PARSE_AGG_H */
diff --git a/src/test/regress/expected/aggregates.out b/src/test/regress/expected/aggregates.out
index 8852051..4dad4fe 100644
--- a/src/test/regress/expected/aggregates.out
+++ b/src/test/regress/expected/aggregates.out
@@ -1580,3 +1580,171 @@ select least_agg(variadic array[q1,q2]) from int8_tbl;
-4567890123456789
(1 row)
+-- test aggregates with common transition functions share the same states
+begin work;
+create type avg_state as (total bigint, count bigint);
+create or replace function avg_transfn(state avg_state, n int) returns avg_state as
+$$
+declare new_state avg_state;
+begin
+ raise notice 'avg_transfn called with %', n;
+ if state is null then
+ if n is not null then
+ new_state.total := n;
+ new_state.count := 1;
+ return new_state;
+ end if;
+ return null;
+ elsif n is not null then
+ state.total := state.total + n;
+ state.count := state.count + 1;
+ return state;
+ end if;
+
+ return null;
+end
+$$ language plpgsql;
+create function avg_finalfn(state avg_state) returns int4 as
+$$
+begin
+ if state is null then
+ return NULL;
+ else
+ return state.total / state.count;
+ end if;
+end
+$$ language plpgsql;
+create function sum_finalfn(state avg_state) returns int4 as
+$$
+begin
+ if state is null then
+ return NULL;
+ else
+ return state.total;
+ end if;
+end
+$$ language plpgsql;
+create aggregate my_avg(int4)
+(
+ stype = avg_state,
+ sfunc = avg_transfn,
+ finalfunc = avg_finalfn
+);
+create aggregate my_sum(int4)
+(
+ stype = avg_state,
+ sfunc = avg_transfn,
+ finalfunc = sum_finalfn
+);
+-- aggregate state should be shared as transfn is the same for both aggs.
+select my_avg(one),my_sum(one) from (values(1,2),(3,4)) t(one,two);
+NOTICE: avg_transfn called with 1
+NOTICE: avg_transfn called with 3
+ my_avg | my_sum
+--------+--------
+ 2 | 4
+(1 row)
+
+-- shouldn't share states due to the distinctness not matching.
+select my_avg(distinct one),my_sum(one) from (values(1,2),(3,4)) t(one,two);
+NOTICE: avg_transfn called with 1
+NOTICE: avg_transfn called with 3
+NOTICE: avg_transfn called with 1
+NOTICE: avg_transfn called with 3
+ my_avg | my_sum
+--------+--------
+ 2 | 4
+(1 row)
+
+-- this should not share the state due to different input columns.
+select my_avg(one),my_sum(two) from (values(1,2),(3,4)) t(one,two);
+NOTICE: avg_transfn called with 2
+NOTICE: avg_transfn called with 1
+NOTICE: avg_transfn called with 4
+NOTICE: avg_transfn called with 3
+ my_avg | my_sum
+--------+--------
+ 2 | 6
+(1 row)
+
+create aggregate my_sum_init(int4)
+(
+ stype = avg_state,
+ sfunc = avg_transfn,
+ finalfunc = sum_finalfn,
+ initcond = '(10,0)'
+);
+create aggregate my_avg_init(int4)
+(
+ stype = avg_state,
+ sfunc = avg_transfn,
+ finalfunc = avg_finalfn,
+ initcond = '(5,0)'
+);
+-- Varying INITCONDs should cause the states not to be shared.
+select my_avg_init(one),my_sum_init(one) from (values(1,2),(3,4)) t(one,two);
+NOTICE: avg_transfn called with 1
+NOTICE: avg_transfn called with 1
+NOTICE: avg_transfn called with 3
+NOTICE: avg_transfn called with 3
+ my_avg_init | my_sum_init
+-------------+-------------
+ 4 | 14
+(1 row)
+
+rollback;
+-- test aggregate state sharing to ensure it works if one aggregate has a
+-- finalfn and the other one has none.
+begin work;
+create or replace function sum_transfn(state int4, n int4) returns int4 as
+$$
+declare new_state int4;
+begin
+ raise notice 'sum_transfn called with %', n;
+ if state is null then
+ if n is not null then
+ new_state := n;
+ return new_state;
+ end if;
+ return null;
+ elsif n is not null then
+ state := state + n;
+ return state;
+ end if;
+
+ return null;
+end
+$$ language plpgsql;
+create function halfsum_finalfn(state int4) returns int4 as
+$$
+begin
+ if state is null then
+ return NULL;
+ else
+ return state / 2;
+ end if;
+end
+$$ language plpgsql;
+create aggregate my_sum(int4)
+(
+ stype = int4,
+ sfunc = sum_transfn
+);
+create aggregate my_half_sum(int4)
+(
+ stype = int4,
+ sfunc = sum_transfn,
+ finalfunc = halfsum_finalfn
+);
+-- Agg state should be shared even though my_sum has no finalfn
+select my_sum(one),my_half_sum(one) from (values(1),(2),(3),(4)) t(one);
+NOTICE: sum_transfn called with 1
+NOTICE: sum_transfn called with 2
+NOTICE: sum_transfn called with 3
+NOTICE: sum_transfn called with 4
+ my_sum | my_half_sum
+--------+-------------
+ 10 | 5
+(1 row)
+
+rollback;
diff --git a/src/test/regress/sql/aggregates.sql b/src/test/regress/sql/aggregates.sql
index a84327d..42c3b3c 100644
--- a/src/test/regress/sql/aggregates.sql
+++ b/src/test/regress/sql/aggregates.sql
@@ -590,3 +590,151 @@ drop view aggordview1;
-- variadic aggregates
select least_agg(q1,q2) from int8_tbl;
select least_agg(variadic array[q1,q2]) from int8_tbl;
+
+
+-- test aggregates with common transition functions share the same states
+begin work;
+
+create type avg_state as (total bigint, count bigint);
+
+create or replace function avg_transfn(state avg_state, n int) returns avg_state as
+$$
+declare new_state avg_state;
+begin
+ raise notice 'avg_transfn called with %', n;
+ if state is null then
+ if n is not null then
+ new_state.total := n;
+ new_state.count := 1;
+ return new_state;
+ end if;
+ return null;
+ elsif n is not null then
+ state.total := state.total + n;
+ state.count := state.count + 1;
+ return state;
+ end if;
+
+ return null;
+end
+$$ language plpgsql;
+
+create function avg_finalfn(state avg_state) returns int4 as
+$$
+begin
+ if state is null then
+ return NULL;
+ else
+ return state.total / state.count;
+ end if;
+end
+$$ language plpgsql;
+
+create function sum_finalfn(state avg_state) returns int4 as
+$$
+begin
+ if state is null then
+ return NULL;
+ else
+ return state.total;
+ end if;
+end
+$$ language plpgsql;
+
+create aggregate my_avg(int4)
+(
+ stype = avg_state,
+ sfunc = avg_transfn,
+ finalfunc = avg_finalfn
+);
+
+create aggregate my_sum(int4)
+(
+ stype = avg_state,
+ sfunc = avg_transfn,
+ finalfunc = sum_finalfn
+);
+
+-- aggregate state should be shared as transfn is the same for both aggs.
+select my_avg(one),my_sum(one) from (values(1,2),(3,4)) t(one,two);
+
+-- shouldn't share states due to the distinctness not matching.
+select my_avg(distinct one),my_sum(one) from (values(1,2),(3,4)) t(one,two);
+
+-- this should not share the state due to different input columns.
+select my_avg(one),my_sum(two) from (values(1,2),(3,4)) t(one,two);
+
+
+create aggregate my_sum_init(int4)
+(
+ stype = avg_state,
+ sfunc = avg_transfn,
+ finalfunc = sum_finalfn,
+ initcond = '(10,0)'
+);
+
+create aggregate my_avg_init(int4)
+(
+ stype = avg_state,
+ sfunc = avg_transfn,
+ finalfunc = avg_finalfn,
+ initcond = '(5,0)'
+);
+
+-- Varying INITCONDs should cause the states not to be shared.
+select my_avg_init(one),my_sum_init(one) from (values(1,2),(3,4)) t(one,two);
+
+rollback;
+
+-- test aggregate state sharing to ensure it works if one aggregate has a
+-- finalfn and the other one has none.
+begin work;
+
+create or replace function sum_transfn(state int4, n int4) returns int4 as
+$$
+declare new_state int4;
+begin
+ raise notice 'sum_transfn called with %', n;
+ if state is null then
+ if n is not null then
+ new_state := n;
+ return new_state;
+ end if;
+ return null;
+ elsif n is not null then
+ state := state + n;
+ return state;
+ end if;
+
+ return null;
+end
+$$ language plpgsql;
+
+create function halfsum_finalfn(state int4) returns int4 as
+$$
+begin
+ if state is null then
+ return NULL;
+ else
+ return state / 2;
+ end if;
+end
+$$ language plpgsql;
+
+create aggregate my_sum(int4)
+(
+ stype = int4,
+ sfunc = sum_transfn
+);
+
+create aggregate my_half_sum(int4)
+(
+ stype = int4,
+ sfunc = sum_transfn,
+ finalfunc = halfsum_finalfn
+);
+
+-- Agg state should be shared even though my_sum has no finalfn
+select my_sum(one),my_half_sum(one) from (values(1),(2),(3),(4)) t(one);
+
+rollback;
--
Sent via pgsql-hackers mailing list (pgsql-hackers@postgresql.org)
To make changes to your subscription:
http://www.postgresql.org/mailpref/pgsql-hackers