Changeset: bc14799583b5 for MonetDB
URL: https://dev.monetdb.org/hg/MonetDB/rev/bc14799583b5
Modified Files:
        sql/server/rel_optimize_proj.c
        sql/test/miscellaneous/Tests/simple_plans.test
Branch: default
Log Message:

First iteration of sum(x + 12) into sum(x) + 12*count(*) optimization. Tomorrow 
I will tune it further


diffs (truncated from 327 to 300 lines):

diff --git a/sql/server/rel_optimize_proj.c b/sql/server/rel_optimize_proj.c
--- a/sql/server/rel_optimize_proj.c
+++ b/sql/server/rel_optimize_proj.c
@@ -1333,6 +1333,231 @@ rel_project_cse(visitor *v, sql_rel *rel
        return rel;
 }
 
+static int exp_is_const_op(sql_exp *exp, sql_exp *tope, sql_rel *expr);
+
+static int
+exps_are_const_op(list *exps, sql_exp *tope, sql_rel *expr)
+{
+       int ok = 1;
+
+       if (list_empty(exps))
+               return 1;
+       for (node *n = exps->h; n && ok; n = n->next)
+               ok &= exp_is_const_op(n->data, tope, expr);
+       return ok;
+}
+
+static int
+exp_is_const_op(sql_exp *exp, sql_exp *tope, sql_rel *expr)
+{
+       switch (exp->type) {
+       case e_atom:
+               return exp->f ? 0 : 1;
+       case e_convert:
+               return exp_is_const_op(exp->l, tope, expr);
+       case e_func:
+       case e_aggr: {
+               sql_subfunc *f = exp->f;
+               if (f->func->side_effect || IS_ANALYTIC(f->func))
+                       return 0;
+               return exps_are_const_op(exp->l, tope, expr);
+       }       
+       case e_cmp:
+               if (exp->flag == cmp_or || exp->flag == cmp_filter)
+                       return exps_are_const_op(exp->l, tope, expr) && 
exps_are_const_op(exp->r, tope, expr);
+               if (exp->flag == cmp_in || exp->flag == cmp_notin)
+                       return exp_is_const_op(exp->l, tope, expr) && 
exps_are_const_op(exp->r, tope, expr);
+               return exps_are_const_op(exp->l, tope, expr)&& 
exps_are_const_op(exp->r, tope, expr) && (!exp->f || exps_are_const_op(exp->f, 
tope, expr));
+       case e_column: {
+               if (is_simple_project(expr->op) || is_groupby(expr->op)) {
+                       /* in a simple projection, self-references may occur */
+                       sql_exp *nexp = (exp->l ? exps_bind_column2(expr->exps, 
exp->l, exp->r, NULL) : exps_bind_column(expr->exps, exp->r, NULL, NULL, 0));
+                       if (nexp && list_position(expr->exps, nexp) < 
list_position(expr->exps, tope))
+                               return exp_is_const_op(nexp, exp, expr);
+               }
+               return 0;
+       }
+       default:
+               return 0;
+       }
+}
+
+static sql_exp *
+rel_groupby_add_count_star(mvc *sql, sql_rel *rel, sql_exp *count_star_exp, 
bool *count_added)
+{
+       if (count_star_exp)
+               return count_star_exp;
+       if (!list_empty(rel->exps)) {
+               for (node *n=rel->exps->h; n ; n = n->next) {
+                       sql_exp *e = n->data;
+
+                       if (exp_aggr_is_count(e) && !need_distinct(e) && 
list_length(e->l) == 0)
+                               return e;
+               }
+       }
+       sql_subfunc *cf = sql_bind_func(sql, "sys", "count", 
sql_bind_localtype("void"), NULL, F_AGGR, true);
+       *count_added = true;
+       return rel_groupby_add_aggr(sql, rel, exp_aggr(sql->sa, NULL, cf, 0, 0, 
rel->card, 0));
+}
+
+/* optimize sum(x + 12) into sum(x) + 12*count(*) */
+static inline sql_rel *
+rel_simplify_sum(visitor *v, sql_rel *rel)
+{
+       if (is_groupby(rel->op) && !list_empty(rel->exps)) {
+               sql_rel *upper = NULL, *groupby = rel, *l = groupby->l;
+               sql_exp *count_star_exp = NULL;
+               bool count_added = false;
+
+               for (node *n=groupby->exps->h; n ; n = n->next) {
+                       sql_exp *e = n->data;
+                       list *el = e->l;
+                       sql_subfunc *sf = e->f;
+
+                       if (e->type == e_aggr && !need_distinct(e) && 
sf->func->type == F_AGGR && !sf->func->s && !strcmp(sf->func->base.name, 
"sum")) {
+                               sql_rel *expr = groupby;
+                               sql_exp *exp = (sql_exp*) el->h->data, *oexp = 
exp;
+
+                               while (is_numeric_upcast(exp))
+                                       exp = exp->l;
+                               /* we want to find a +/- call, so expect them 
to appear only on simple projections */
+                               while (exp && exp->type == e_column && 
(is_simple_project(expr->op) || is_groupby(expr->op)) && expr->l) {
+                                       sql_rel *nexpr = NULL;
+                                       sql_exp *nexp = 
rel_find_exp_and_corresponding_rel(l, exp, false, &nexpr, NULL);
+
+                                       /* break when it loops on the same 
relation */
+                                       if (nexpr == expr && 
list_position(expr->exps, nexp) >= list_position(expr->exps, exp))
+                                               break;
+                                       expr = nexpr;
+                                       exp = oexp = nexp;
+                                       while (exp && is_numeric_upcast(exp))
+                                               exp = exp->l;
+                               }
+
+                               list *expl = exp ? exp->l : NULL;
+                               sql_subfunc *expf = exp ? exp->f : NULL;
+                               /* found a candidate function */
+                               if (exp && exp->type == e_func && 
expf->func->type == F_FUNC && !expf->func->s &&
+                                       (!strcmp(expf->func->base.name, 
"sql_sub") || !strcmp(expf->func->base.name, "sql_add"))) {
+                                       sql_exp *e1 = (sql_exp*) expl->h->data, 
*e2 = (sql_exp*) expl->h->next->data;
+                                       int e1ok = exp_is_const_op(e1, oexp, 
expr), e2ok = exp_is_const_op(e2, oexp, expr);
+
+                                       if ((!e1ok && e2ok) || (e1ok && !e2ok)) 
{
+                                               sql_exp *ocol = e1ok ? e2 : e1, 
*constant = e1ok ? e1 : e2, *mul, *colref, *naggr, *newop, *col = ocol, *match;
+                                               bool add_col = true;
+
+                                               /* add count star */
+                                               count_star_exp = 
rel_groupby_add_count_star(v->sql, groupby, count_star_exp, &count_added);
+                                               /* multiply constant by count 
star */
+                                               if (!(mul = rel_binop_(v->sql, 
NULL, constant, exp_ref(v->sql, count_star_exp), "sys", "sql_mul", 
card_value))) {
+                                                       v->sql->session->status 
= 0;
+                                                       v->sql->errstr[0] = 
'\0';
+                                                       continue;
+                                               }
+                                               if (!has_label(mul))
+                                                       exp_label(v->sql->sa, 
mul, ++v->sql->label);
+
+                                               /* if 'col' is a projection 
from the under relation, then use it */
+                                               while (is_numeric_upcast(col))
+                                                       col = col->l;
+                                               if (col->type == e_column) {
+                                                       sql_rel *crel = NULL;
+                                                       sql_exp *colref = 
rel_find_exp_and_corresponding_rel(l, col, false, &crel, NULL);
+
+                                                       if (colref && l == 
crel) {
+                                                               add_col = false;
+                                                       } else if 
(is_simple_project(l->op) && list_empty(l->r) && !rel_is_ref(l) && 
!need_distinct(l)) {
+                                                               
list_prepend(l->exps, exp_ref(v->sql, col));
+                                                               add_col = false;
+                                                       }
+                                               }
+                                               /* 'col' is not in the under 
relation, so add it */
+                                               if (add_col) {
+                                                       ocol = exp_ref(v->sql, 
ocol);
+                                                       exp_label(v->sql->sa, 
ocol, ++v->sql->label);
+                                               }
+
+                                               colref = exp_ref(v->sql, ocol);
+                                               /* 'oexp' contains the type for 
the input for the 'sum' aggregate */
+                                               if (!(colref = 
exp_check_type(v->sql, exp_subtype(oexp), groupby, colref, type_equal))) {
+                                                       v->sql->session->status 
= 0;
+                                                       v->sql->errstr[0] = 
'\0';
+                                                       continue;
+                                               }
+                                               /* update sum to use the column 
side only */
+                                               sql_subfunc *a = 
sql_bind_func(v->sql, "sys", "sum", exp_subtype(colref), NULL, F_AGGR, true);
+                                               if (!a)
+                                                       continue;
+                                               naggr = exp_aggr(v->sql->sa, 
list_append(sa_list(v->sql->sa), colref), a, need_distinct(e), need_no_nil(e), 
groupby->card, has_nil(e));
+                                               if ((match = 
exps_any_match(groupby->exps, naggr)) && list_position(groupby->exps, match) < 
list_position(groupby->exps, e)) { /* found a matching aggregate, use it */
+                                                       naggr = exp_ref(v->sql, 
match);
+                                                       exp_label(v->sql->sa, 
naggr, ++v->sql->label);
+                                               } else if (!has_label(naggr)) { 
/* otherwise use the new one */
+                                                       exp_label(v->sql->sa, 
naggr, ++v->sql->label);
+                                               }
+
+                                               /* generate 
addition/subtraction. subtraction is not commutative, so keep original order! */
+                                               if (!(newop = 
rel_binop_(v->sql, NULL, e1 == constant ? mul : exp_ref(v->sql, naggr), e1 == 
constant ? exp_ref(v->sql, naggr) : mul, "sys", expf->func->base.name, 
card_value))) {
+                                                       v->sql->session->status 
= 0;
+                                                       v->sql->errstr[0] = 
'\0';
+                                                       continue;
+                                               }
+                                               if (!(newop = 
exp_check_type(v->sql, exp_subtype(e), groupby, newop, type_equal))) {
+                                                       v->sql->session->status 
= 0;
+                                                       v->sql->errstr[0] = 
'\0';
+                                                       continue;
+                                               }
+
+                                               /* the new generate function 
calls are valid, update relations */
+                                               /* we need a new relation for 
the multiplication and addition/subtraction */
+                                               if (!upper) {
+                                                       /* be carefull with 
relations with more than 1 reference, so do in-place replacement */
+                                                       list *projs = 
rel_projections(v->sql, rel, NULL, 1, 1);
+                                                       sql_rel *nrel = 
rel_groupby(v->sql, rel->l, NULL);
+                                                       nrel->exps = rel->exps;
+                                                       nrel->r = rel->r;
+                                                       nrel->card = rel->card;
+                                                       nrel->nrcols = 
list_length(rel->exps);
+                                                       set_processed(nrel);
+                                                       rel_dup(rel->l);
+                                                       upper = rel = 
rel_inplace_project(v->sql->sa, rel, nrel, projs);
+                                                       rel->card = 
exps_card(projs);
+                                                       groupby = nrel; /* 
update pointers :) */
+                                                       l = groupby->l;
+                                               }
+                                               for (node *n = upper->exps->h ; 
n ; ) {
+                                                       node *next = n->next;
+                                                       sql_exp *re = n->data;
+
+                                                       /* remove the old 
reference to the aggregate because we will use the addition/subtraction instead,
+                                                          as well as the count 
star if count_added */
+                                                       if (exp_refers(e, re) 
|| (count_added && exp_refers(count_star_exp, re)))
+                                                               
list_remove_node(upper->exps, NULL, n);
+                                                       n = next;
+                                               }
+
+                                               /* update sum aggregate with 
new aggregate or reference to an existing one */
+                                               n->data = naggr;
+                                               list_hash_clear(groupby->exps);
+
+                                               /* add column reference with 
new label, if 'col' was not found */
+                                               if (add_col) {
+                                                       if 
(!is_simple_project(l->op) || !list_empty(l->r) || rel_is_ref(l) || 
need_distinct(l) || is_single(l))
+                                                               groupby->l = l 
= rel_project(v->sql->sa, l, rel_projections(v->sql, l, NULL, 1, 1));
+                                                       list_append(l->exps, 
ocol);
+                                               }
+
+                                               /* propagate alias and add new 
expression */
+                                               exp_prop_alias(v->sql->sa, 
newop, e);
+                                               list_append(upper->exps, newop);
+                                       }
+                               }
+                       }
+               }
+       }
+       return rel;
+}
+
 /* optimize group by x+1,(y-2)*3,2-z into group by x,y,z */
 static inline sql_rel *
 rel_simplify_groupby_columns(visitor *v, sql_rel *rel)
@@ -2631,8 +2856,10 @@ rel_optimize_projections_(visitor *v, sq
        if (!rel || !is_groupby(rel->op))
                return rel;
 
-       if (v->value_based_opt)
+       if (v->value_based_opt) {
+               rel = rel_simplify_sum(v, rel);
                rel = rel_simplify_groupby_columns(v, rel);
+       }
        rel = rel_groupby_cse(v, rel);
        rel = rel_push_aggr_down(v, rel);
        rel = rel_push_groupby_down(v, rel);
diff --git a/sql/test/miscellaneous/Tests/simple_plans.test 
b/sql/test/miscellaneous/Tests/simple_plans.test
--- a/sql/test/miscellaneous/Tests/simple_plans.test
+++ b/sql/test/miscellaneous/Tests/simple_plans.test
@@ -624,3 +624,76 @@ project (
 | | ) [  ] [ "sys"."min" no nil ("x"."x") as "%11"."%11" ]
 | ) [ ("%11"."%11") > (tinyint(8) "0") ]
 ) [ "%11"."%11" as "x"."x" ]
+
+# optimize sum(x + 2) into sum(x) + 2*count(*)
+query IIII nosort
+SELECT sum(x), sum(x + 1), sum(x + 2), sum(x + 3) FROM (VALUES 
(1),(2),(3),(4),(5)) x(x)
+----
+15
+20
+25
+30
+
+query T nosort
+PLAN SELECT sum(x), sum(x + 1), sum(x + 2), sum(x + 3) FROM (VALUES 
(1),(2),(3),(4),(5)) x(x)
+----
+project (
+| project (
+| | group by (
+| | | project (
+| | | |  [  [ tinyint(3) "1", tinyint(3) "2", tinyint(3) "3", tinyint(3) "4", 
tinyint(3) "5" ] as "x"."x" ]
+| | | ) [ "x"."x" ]
+| | ) [  ] [ "sys"."sum" no nil ("x"."x") as "%7"."%7", "%7"."%7" as 
"%20"."%20", "%7"."%7" as "%22"."%22", "%7"."%7" as "%24"."%24", 
"sys"."count"() NOT NULL as "%16"."%16" ]
+| ) [ "%7"."%7", "sys"."sql_add"("%20"."%20", "sys"."sql_mul"(tinyint(2) "1", 
"%16"."%16" NOT NULL) NOT NULL) as "%10"."%10", "sys"."sql_add"("%22"."%22", 
"sys"."sql_mul"(tinyint(3) "2", "%16"."%16" NOT NULL) NOT NULL) as "%11"."%11", 
"sys"."sql_add"("%24"."%24", "sys"."sql_mul"(tinyint(3) "3", "%16"."%16" NOT 
NULL) NOT NULL) as "%12"."%12" ]
+) [ "%7"."%7", "%10"."%10", "%11"."%11", "%12"."%12" ]
+
+query III nosort
+SELECT 10*sum(5 - x) as aa, sum(x + 12) + 15 as bb, count(*) as cc FROM 
(VALUES (1),(2),(3),(4),(5)) x(x)
+----
+100
+90
+5
+
+query T nosort
+PLAN SELECT 10*sum(5 - x) as aa, sum(x + 12) * 2 as bb, count(*) as cc FROM 
(VALUES (1),(2),(3),(4),(5)) x(x)
+----
+project (
+| project (
+| | group by (
+| | | project (
+| | | |  [  [ tinyint(3) "1", tinyint(3) "2", tinyint(3) "3", tinyint(3) "4", 
tinyint(3) "5" ] as "x"."x" ]
+| | | ) [ "x"."x" ]
+| | ) [  ] [ "sys"."sum" no nil ("x"."x") as "%15"."%15", "%15"."%15" as 
"%17"."%17", "sys"."count"() NOT NULL as "%11"."%11" ]
+| ) [ "%11"."%11" NOT NULL, "sys"."sql_sub"("sys"."sql_mul"(tinyint(4) "5", 
"%11"."%11" NOT NULL) NOT NULL, "%15"."%15") as "%7"."%7", 
"sys"."sql_add"("%17"."%17", "sys"."sql_mul"(tinyint(5) "12", "%11"."%11" NOT 
NULL) NOT NULL) as "%10"."%10" ]
+) [ "sys"."sql_mul"(tinyint(4) "10", "%7"."%7") as "aa", 
"sys"."sql_mul"("%10"."%10", tinyint(2) "2") as "bb", "%11"."%11" NOT NULL as 
"cc" ]
+
+query II nosort
+SELECT sum(5 - x), count(*) FROM (VALUES (1),(2),(3),(4),(5)) x(x)
+----
_______________________________________________
checkin-list mailing list -- checkin-list@monetdb.org
To unsubscribe send an email to checkin-list-le...@monetdb.org

Reply via email to