Changeset: 400aff01f0d6 for MonetDB
URL: https://dev.monetdb.org/hg/MonetDB?cmd=changeset;node=400aff01f0d6
Modified Files:
        sql/backends/monet5/rel_weld.c
Branch: rel-weld
Log Message:

weld impl for join - can't test it yet because of missing weld features


diffs (193 lines):

diff --git a/sql/backends/monet5/rel_weld.c b/sql/backends/monet5/rel_weld.c
--- a/sql/backends/monet5/rel_weld.c
+++ b/sql/backends/monet5/rel_weld.c
@@ -46,6 +46,7 @@
 
 #define REL 0
 #define ALIAS 1
+#define ANY 2
 
 /* From sql_statement.c */
 #define meta(Id, Tpe)                \
@@ -157,6 +158,12 @@ get_col_name(sql_allocator *sa, sql_exp 
                sprintf(col_name, "%s_%s", exp->l ? (str)exp->l : (str)exp->r, 
(str)exp->r);
        } else if (name_type == ALIAS) {
                sprintf(col_name, "%s_%s", exp->rname ? exp->rname : exp->name, 
exp->name);
+       } else if (name_type == ANY) {
+               if (exp->name) {
+                       return get_col_name(sa, exp, ALIAS);
+               } else {
+                       return get_col_name(sa, exp, REL);
+               }
        }
        for (i = 0; i < strlen(col_name); i++) {
                if (!isalnum(col_name[i])) {
@@ -688,6 +695,159 @@ groupby_produce(backend *be, sql_rel *re
 }
 
 static void
+join_produce(backend *be, sql_rel *rel, weld_state *wstate)
+{
+       char new_builder[STR_BUF_SIZE];
+       str col_name;
+       int len = 0, i, count;
+       node *en;
+       sql_exp *exp;
+       sql_rel *right = rel->r;
+       list *right_cols = sa_list(wstate->sa);
+       list *right_cmp_cols = sa_list(wstate->sa);
+       list *left_cmp_cols = sa_list(wstate->sa);
+       produce_func left_produce, right_produce;
+
+       /* === Produce === */
+       int old_num_parens = wstate->num_parens;
+       int old_num_loops = wstate->num_loops;
+       str old_builder = wstate->builder;
+
+       /* Create a new builder */
+       wstate->num_parens = wstate->num_loops = 0;
+       int result_var = wstate->next_var++;
+       wprintf(wstate, "let v%d = (", result_var);
+       wstate->num_parens++;
+
+       /* Find the operator that produces the columns */
+       while (right != NULL && right->op != op_project && right->op != 
op_basetable) {
+               right = right->l;
+       }
+       if (right == NULL) {
+               wstate->error = 1;
+               goto cleanup;
+       }
+       for (en = right->exps->h; en; en = en->next) {
+               list_append(right_cols, get_col_name(wstate->sa, en->data, 
ANY));
+       }
+
+       len = 0;
+       len += sprintf(new_builder + len, "groupmerger[{");
+       for (en = rel->exps->h; en; en = en->next) {
+               /* left cmp */
+               exp = ((sql_exp*)en->data)->l;
+               col_name = get_col_name(wstate->sa, exp, ANY);
+               if (list_find(right_cols, col_name, (fcmp)strcmp)) {
+                       list_append(right_cmp_cols, col_name);
+               } else {
+                       list_append(left_cmp_cols, col_name);
+               }
+               /* right cmp */
+               exp = ((sql_exp*)en->data)->r;
+               col_name = get_col_name(wstate->sa, exp, ANY);
+               if (list_find(right_cols, col_name, (fcmp)strcmp)) {
+                       list_append(right_cmp_cols, col_name);
+               } else {
+                       list_append(left_cmp_cols, col_name);
+               }
+
+               /* both have the same type */
+               int type = exp_subtype(exp)->type->localtype;
+               len += sprintf(new_builder + len, "%s", getWeldType(type));
+               if (en->next != NULL) {
+                       len += sprintf(new_builder + len, ", ");
+               }
+       }
+       len += sprintf(new_builder + len, "}, {");
+       for (en = right->exps->h; en; en = en->next) {
+               exp = en->data;
+               int type = exp_subtype(exp)->type->localtype;
+               if (type == TYPE_str) {
+                       len += sprintf(new_builder + len, "?");
+               } else {
+                       len += sprintf(new_builder + len, "%s", 
getWeldType(type));
+               }
+               if (en->next != NULL) {
+                       len += sprintf(new_builder + len, ", ");
+               }
+       }
+       len += sprintf(new_builder + len, "}]");
+
+       wstate->builder = new_builder;
+       right_produce = getproduce_func(rel->r);
+       left_produce = getproduce_func(rel->l);
+       if (right_produce == NULL || left_produce == NULL) {
+               wstate->error = 1;
+               goto cleanup;
+       }
+       right_produce(be, rel->r, wstate);
+
+       /* === Consume === */
+       wprintf(wstate, "merge(b%d, {{", wstate->num_loops);
+       /* Build the key */
+       for (en = right_cmp_cols->h; en; en = en->next) {
+               wprintf(wstate, "%s", (str)en->data);
+               if (en->next != NULL) {
+                       wprintf(wstate, ", ");
+               }
+       }
+       wprintf(wstate, "}, {");
+       /* Build the value */
+       for (en = right->exps->h, count = 0; en; en = en->next, count++) {
+               exp = en->data;
+               wprintf(wstate, "%s", (str)list_fetch(right_cols, count));
+               if (exp_subtype(exp)->type->localtype == TYPE_str) {
+                       wprintf(wstate, "_stridx");
+               }
+               if (en->next != NULL) {
+                       wprintf(wstate, ", ");
+               }
+       }
+       wprintf(wstate, "}})");
+       for (i = 0; i < wstate->num_parens; i++) {
+               wprintf(wstate, ")");
+       }
+       wprintf(wstate, ";");
+       /* Materialize the hashtable */
+       wprintf(wstate, "let v%d = result(v%d);", result_var, result_var);
+
+       /* Resume the pipeline */
+       wstate->num_parens = old_num_parens;
+       wstate->num_loops = old_num_loops;
+       wstate->builder = old_builder;
+       left_produce(be, rel->l, wstate);
+
+       /* === 2nd Consume === */
+       wstate->num_loops++;
+       wstate->num_parens++;
+       wprintf(wstate, "for(lookup(v%d, {", result_var);
+       for (en = left_cmp_cols->h; en; en = en->next) {
+               /* Hashtable key */
+               wprintf(wstate, "%s", (str)en->data);
+               if (en->next != NULL) {
+                       wprintf(wstate, ", ");
+               }
+       }
+       wprintf(wstate, "}), b%d, |b%d, i%d, n%d|", wstate->num_loops - 1, 
wstate->num_loops,
+                       wstate->num_loops, wstate->num_loops);
+       for (en = right->exps->h, count = 0; en; en = en->next, count++) {
+               exp = en->data;
+               col_name = list_fetch(right_cols, count);
+               if (exp_subtype(exp)->type->localtype == TYPE_str) {
+                       wprintf(wstate, "let %s = strslice(%s_strcol, 
i64(n%d.$%d) + %s_stroffset);", 
+                                                  col_name, col_name, 
wstate->num_loops, count, col_name);
+                       wprintf(wstate, "let %s_stridx = n%d.$%d;", col_name, 
wstate->num_loops, count);
+               } else {
+                       wprintf(wstate, "let %s = n%d.$%d;", col_name, 
wstate->num_loops, count);
+               }
+       }
+cleanup:
+       list_destroy(right_cols);
+       list_destroy(right_cmp_cols);
+       list_destroy(left_cmp_cols);
+}
+
+static void
 push_args(MalBlkPtr mb, InstrPtr *weld_instr, list *stmt_list, int *arg_names, 
int *idx) {
        node *en;
        for (en = stmt_list->h; en; en = en->next) {
@@ -925,6 +1085,8 @@ produce_func getproduce_func(sql_rel *re
                        return &project_produce;
                case op_groupby:
                        return &groupby_produce;
+               case op_join:
+                       return &join_produce;
                default:
                        return NULL;
        }
_______________________________________________
checkin-list mailing list
checkin-list@monetdb.org
https://www.monetdb.org/mailman/listinfo/checkin-list

Reply via email to