Changeset: 8814bfed9eaa for MonetDB
URL: http://dev.monetdb.org/hg/MonetDB?cmd=changeset;node=8814bfed9eaa
Modified Files:
        sql/backends/monet5/rel_bin.c
        sql/backends/monet5/sql_gencode.c
        sql/backends/monet5/sql_statement.c
        sql/backends/monet5/sql_statement.h
        sql/server/rel_graph.c
        sql/server/rel_graph.h
        sql/server/rel_rel.c
        sql/server/rel_rel.h
        sql/server/rel_select.c
Branch: graph1
Log Message:

Codegen: add the computation of shortest paths for simple joins


diffs (truncated from 504 to 300 lines):

diff --git a/sql/backends/monet5/rel_bin.c b/sql/backends/monet5/rel_bin.c
--- a/sql/backends/monet5/rel_bin.c
+++ b/sql/backends/monet5/rel_bin.c
@@ -616,6 +616,7 @@ exp_bin(mvc *sql, sql_exp *e, stmt *left
                        sql_subaggr *aggr_count = sql_bind_aggr(sql->sa, 
sql->session->schema, "count", NULL);
                        stmt *domain =NULL, *query =NULL;
                        stmt *spfw =NULL;
+                       stmt *weights = NULL;
                        int spfw_flags = 0;
 
                        // generate the depending expressions
@@ -631,6 +632,10 @@ exp_bin(mvc *sql, sql_exp *e, stmt *left
                        if(!efrom) { assert(0); return NULL; }
                        eto = exp_bin(sql, g->dst->h->data, graph, NULL, NULL, 
NULL, NULL, NULL, refs);
                        if(!eto) { assert(0); return NULL; }
+                       if(g->cost){
+                               weights = exp_bin(sql, g->cost, graph, NULL, 
NULL, NULL, NULL, NULL, refs);
+                               if(!weights) { assert(0); return NULL; }
+                       }
 
                        // find the domain and map the operands into it
                        // this is, like, super fun...
@@ -646,7 +651,13 @@ exp_bin(mvc *sql, sql_exp *e, stmt *left
                        eperm = stmt_result(sql->sa, efrom, 1);
                        efrom = stmt_prefixsum(sql->sa, efrom, smpl_sz);
                        eto = stmt_project(sql->sa, eperm, eto);
-                       graph = stmt_list(sql->sa, 
list_append(list_append(sa_list(sql->sa), efrom), eto));
+
+                       // generate the graph
+                       l = sa_list(sql->sa);
+                       list_append(l, efrom);
+                       list_append(l, eto);
+                       if(weights) { list_append(l, weights); }
+                       graph = stmt_list(sql->sa, l);
 
                        // generate the query
                        domain = stmt_project(sql->sa, smpl, nodes);
@@ -659,9 +670,14 @@ exp_bin(mvc *sql, sql_exp *e, stmt *left
                        list_append(l, stmt_result(sql->sa, qto, 1));
                        query = stmt_list(sql->sa, l);
 
-                       // run the operator
+                       // create the operator
                        if(right != NULL){ spfw_flags |= SPFW_CROSS_PRODUCT; }
+                       if(weights != NULL){ spfw_flags |= SPFW_SHORTEST_PATH; }
                        spfw = stmt_spfw(sql->sa, query, graph, spfw_flags);
+                       if(weights != NULL){ // propagate the temp~ name for 
the column
+                               spfw->tname = g->cost->rname;
+                               spfw->cname = g->cost->name;
+                       }
 
                        print_tree(sql->sa, spfw);
 
@@ -1710,6 +1726,7 @@ rel2bin_join( mvc *sql, sql_rel *rel, li
        node *en = NULL, *n;
        stmt *left = NULL, *right = NULL, *join = NULL, *jl, *jr;
        stmt *ld = NULL, *rd = NULL;
+       list *shoooortestpaths = sa_list(sql->sa);
        int need_left = (rel->flag == LEFT_JOIN);
 
        if (rel->l) /* first construct the left sub relation */
@@ -1814,6 +1831,10 @@ rel2bin_join( mvc *sql, sql_rel *rel, li
                                join = s;
                        list_append(lje, s->op1);
                        list_append(rje, s->op2);
+
+                       if(s->type == st_spfw && (s->flag & 
SPFW_SHORTEST_PATH)){
+                               list_append(shoooortestpaths, s);
+                       }
                }
                if (list_length(lje) > 1) {
                        join = releqjoin(sql, lje, rje, used_hash, cmp_equal, 
need_left);
@@ -1865,6 +1886,11 @@ rel2bin_join( mvc *sql, sql_rel *rel, li
                                assert(0);
                                return NULL;
                        }
+
+                       if ( s->type == st_spfw && (s->flag & 
SPFW_SHORTEST_PATH) ){
+                               list_append(shoooortestpaths, s);
+                       }
+
                        sel = s;
                }
                /* recreate join output */
@@ -1920,6 +1946,12 @@ rel2bin_join( mvc *sql, sql_rel *rel, li
                s = stmt_alias(sql->sa, s, rnme, nme);
                list_append(l, s);
        }
+       for (n = shoooortestpaths->h; n; n = n->next){
+               stmt *spfw = n->data;
+               stmt *s = stmt_result(sql->sa, spfw, 2);
+               s = stmt_alias(sql->sa, s, spfw->tname, spfw->cname);
+               list_append(l, s);
+       }
        return stmt_list(sql->sa, l);
 }
 
@@ -3021,8 +3053,7 @@ insert_check_ukey(mvc *sql, list *insert
                                stmt *cs = list_fetch(inserts, c->c->colnr); 
 
                                if (orderby_grp)
-                                       orderby = stmt_reorder(sql->sa,         
// strong suspects these are the output cols of the operator
-cs->op1, 1, orderby_ids, orderby_grp);
+                                       orderby = stmt_reorder(sql->sa, 
cs->op1, 1, orderby_ids, orderby_grp);
                                else
                                        orderby = stmt_order(sql->sa, cs->op1, 
1);
                                orderby_ids = stmt_result(sql->sa, orderby, 1);
@@ -4613,105 +4644,6 @@ rel2bin_ddl(mvc *sql, sql_rel *rel, list
        return s;
 }
 
-//static list*
-//spfw_project(mvc *sql, list* l, stmt *spfw, stmt *input){
-//     for( node* n = input->op4.lval->h; n; n = n->next ) {
-//             stmt *c = n->data;
-//             const char *rnme = table_name(sql->sa, c);
-//             const char *nme = column_name(sql->sa, c);
-//             stmt *s = stmt_project(sql->sa, spfw, column(sql->sa, c));
-//
-//             s = stmt_alias(sql->sa, s, rnme, nme);
-//             list_append(l, s);
-//     }
-//     return l;
-//}
-//
-//static stmt *
-//rel2bin_spfw(mvc *sql, sql_rel *rel, list *refs)
-//{
-//     stmt *edges = NULL, *spfw = NULL, *graph = NULL, *query = NULL;
-//     stmt *left = NULL, *right = NULL;
-//     stmt *c = NULL, *g = NULL, *groups = NULL, *smpl = NULL;
-//     sql_subaggr *aggr_count = NULL;
-//     stmt *D = NULL, *D_sz = NULL, *vrtx = NULL;
-//     list *l = NULL;
-//     stmt *split = NULL;
-//     stmt *e_from = NULL, *e_to = NULL;
-//     stmt *q_from = NULL, *q_to = NULL;
-//     stmt *mk_perm = NULL;
-//     stmt *result = NULL;
-//     node *n = NULL; // generic var to iterate through a list
-//
-//     // materialize the input relations
-//     left = subrel_bin(sql, rel->l, refs);
-//     if(!left) return NULL;
-//     right = subrel_bin(sql, rel->r, refs);
-//     if(!right) return NULL;
-//     edges = subrel_bin(sql, rel_edges(rel), refs);
-//     if(!edges) return NULL;
-//
-//     // refer to the columns
-//     assert(rel->exps->cnt == 4 && "Expected four columns as input (ftb)"); 
// TODO weights are missing
-//     n = rel->exps->h;
-//     q_from = exp_bin(sql, n->data, left, NULL, NULL, NULL, NULL, NULL);
-//     n = n->next;
-//     q_to = exp_bin(sql, n->data, right, NULL, NULL, NULL, NULL, NULL);
-//     n = n->next;
-//     e_from = exp_bin(sql, n->data, edges, NULL, NULL, NULL, NULL, NULL);
-//     n = n->next;
-//     e_to = exp_bin(sql, n->data, edges, NULL, NULL, NULL, NULL, NULL);
-//
-//     // create the graph
-//     // this is, like, super fun....
-//     l = sa_list(sql->sa);
-//     list_append(l, e_from);
-//     list_append(l, e_to);
-//     c = stmt_concat(sql->sa, l);
-//     g = stmt_group(sql->sa, c, NULL, NULL, NULL);
-//     groups = stmt_result(sql->sa, g, 0);
-//     smpl = stmt_result(sql->sa, g, 1);
-//     split = stmt_slices(sql->sa, groups, 2);
-//     e_from = stmt_result(sql->sa, split, 0);
-//     e_to = stmt_result(sql->sa, split, 1);
-//     // mkgraph (naive approach)
-//     e_from = stmt_order(sql->sa, e_from, /* direction (0 = DESC, 1 = ASC) = 
*/ 1);
-//     mk_perm = stmt_result(sql->sa, e_from, 1);
-//     aggr_count = sql_bind_aggr(sql->sa, sql->session->schema, "count", 
NULL);
-//     D_sz = stmt_aggr(sql->sa, smpl, NULL, NULL, aggr_count, 1, 0);
-//     e_from = stmt_prefixsum(sql->sa, e_from, D_sz);
-//     // FIXME e_weights = stmt_project(sql->sa, mk_perm, e_weights) etc..
-//     e_to = stmt_project(sql->sa, mk_perm, e_to);
-//     l = sa_list(sql->sa);
-//     list_append(l, e_from);
-//     list_append(l, e_to);
-//     // FIXME list_append(l, e_weights);
-//     graph = stmt_list(sql->sa, l);
-//
-//     // map the values in qfrom, qto into vertex IDs
-//     D = stmt_project(sql->sa, smpl, c); // domain
-//     // TODO I was not able to figure out how to perform a join with a 
candidate list at this layer
-//     // postpone the translation at the mal codegen ftb
-//     vrtx = stmt_exp2vrtx(sql->sa, q_from, q_to, D);
-//     l = sa_list(sql->sa);
-//     for(int i = 0; i < 3; i++)
-//             list_append(l, stmt_result(sql->sa, vrtx, i));
-//     query = stmt_list(sql->sa, l);
-//
-//     // execute the shortest path operator
-//     spfw = stmt_spfw(sql->sa, query, graph);
-//
-//     // almost done, create the new resultset
-//     l = sa_list(sql->sa);
-//     spfw_project(sql, l, spfw, left);
-//     spfw_project(sql, l, spfw, right);
-//     result = stmt_list(sql->sa, l);
-//
-//     print_tree(sql->sa, result); // FIXME debug only
-//
-//     return result;
-//}
-
 static stmt *
 subrel_bin(mvc *sql, sql_rel *rel, list *refs) 
 {
diff --git a/sql/backends/monet5/sql_gencode.c 
b/sql/backends/monet5/sql_gencode.c
--- a/sql/backends/monet5/sql_gencode.c
+++ b/sql/backends/monet5/sql_gencode.c
@@ -2886,12 +2886,12 @@ static int
                        // abi convention
                        s->nr = getDestVar(q);
                        for(int i = 1; i < num_slices; i++) {
-                               snprintf(mb->var[getDestVar(q)]->id, IDLENGTH, 
"r%d_%d", i, s->nr);
+                               snprintf(mb->var[getArg(q, i)]->id, IDLENGTH, 
"r%d_%d", i, s->nr);
                        }
                } break;
                case st_spfw: {
                        const /*size_t*/ int query_sz = 4; // num operands for 
the query
-                       const /*size_t*/ int graph_sz = 2; // num operands for 
the graph
+                       const /*size_t*/ int graph_sz = 2; // num operands 2 or 
3
                        node* n = NULL;
 
                        // generate the query
@@ -2905,6 +2905,7 @@ static int
                        // command spfw(qf:bat[:oid], qt:bat[:oid], 
V:bat[:oid], E:bat[:oid]) --> :bat[:oid]
                        q = newStmt(mb, graphRef, "spfw");
                        q = pushReturn(mb, q, newTmpVariable(mb, TYPE_any));
+                       q = pushReturn(mb, q, newTmpVariable(mb, TYPE_any));
 
                        // set the query params
                        assert(s->op1->type == st_list && s->op1->op4.lval->cnt 
== query_sz);
@@ -2915,17 +2916,28 @@ static int
                        }
 
                        // set the graph params
-                       assert(s->op2->type == st_list && s->op2->op4.lval->cnt 
== graph_sz);
+                       assert(s->op2->type == st_list && s->op2->op4.lval->cnt 
>= graph_sz);
                        n = s->op2->op4.lval->h;
                        for(int i = 0; i < graph_sz; i++){
                                q = pushArgument(mb, q, ((stmt*) n->data)->nr);
                                n = n->next;
                        }
+
+                       // weights for the shortest path
+                       if( n != NULL ){ // n is the last entry in 
s->op2->op4.lval
+                               q = pushArgument(mb, q, ((stmt*) n->data)->nr);
+                               assert(n->next == NULL && "Additional argument 
not handled");
+                       } else {
+                               q = pushNil(mb, q, TYPE_bat);
+                       }
+
                        q = pushBit(mb, q, s->flag & SPFW_CROSS_PRODUCT);
+                       q = pushBit(mb, q, s->flag & SPFW_SHORTEST_PATH);
 
                        // abi convention
                        s->nr = getDestVar(q); // filter src
                        renameVariable(mb, getArg(q, 1), "r1_%d", s->nr); // 
filter dst
+                       renameVariable(mb, getArg(q, 2), "r2_%d", s->nr); // 
shortest path (if required)
                } break;
                case st_void2oid: {
                        int ref_op = _dumpstmt(sql, mb, s->op1);
diff --git a/sql/backends/monet5/sql_statement.c 
b/sql/backends/monet5/sql_statement.c
--- a/sql/backends/monet5/sql_statement.c
+++ b/sql/backends/monet5/sql_statement.c
@@ -1858,18 +1858,32 @@ print_stmt(sql_allocator *sa, stmt *s)
                case st_aggr:
                        printf("%s, ", s->op4.aggrval->aggr->base.name);
                        break;
-               case st_slices:
-                       printf("%d, ", s->flag);
-                       break;
                default:
                        break;
                }
+
+               // input statements
                if (s->op1)
                        printf("z%d", s->op1->nr);
                if (s->op2)
                        printf(", z%d", s->op2->nr);
                if (s->op3)
                        printf(", z%d", s->op3->nr);
+
+               // after the input statements
+               switch(s->type){
+               case st_alias: {
+                       if(s->tname){ printf(", \"%s\"", s->tname); } else { 
printf(", NULL"); }
+                       if(s->cname){ printf(", \"%s\"", s->cname); } else { 
printf(", NULL"); }
+               } break;
+               case st_result:
+               case st_slices:
+                       printf(", %d", s->flag);
_______________________________________________
checkin-list mailing list
checkin-list@monetdb.org
https://www.monetdb.org/mailman/listinfo/checkin-list

Reply via email to