On Sun, Jun 22, 2014 at 3:09 AM, Prathamesh Kulkarni <bilbotheelffri...@gmail.com> wrote: > On Fri, Jun 20, 2014 at 3:02 AM, Prathamesh Kulkarni > <bilbotheelffri...@gmail.com> wrote: >> >> On Fri, Jun 20, 2014 at 2:53 AM, Prathamesh Kulkarni >> <bilbotheelffri...@gmail.com> wrote: >> > Hi, >> > The attached patch attempts to generate commutative variants for >> > a given expression. >> > >> > Example: >> > For the AST: (PLUS_EXPR (PLUS_EXPR @0 @1) @2), >> > >> > the commutative variants are: >> > (PLUS_EXPR (PLUS_EXPR @0 @1 ) @2 ) >> > (PLUS_EXPR (PLUS_EXPR @1 @0 ) @2 ) >> > (PLUS_EXPR @2 (PLUS_EXPR @0 @1 ) ) >> > (PLUS_EXPR @2 (PLUS_EXPR @1 @0 ) ) >> > >> > >> > * Basic Idea: >> > Consider expression e with two operands o0, and o1, >> > and expr-code denoting expression's code (plus/mult, etc.) >> > >> > Commutative variants are stored in vector (vec<operand *>). >> > >> > vec<operand *> >> > commutative (e) >> > { >> > if (e is not commutative) >> > return [e]; // vector with only one expression >> > >> > v1 = commutative (o0); >> > v2 = commutative (o1); >> > ret = [] >> > >> > for i = 0 ... v1.length () >> > for j = 0 ... v2.length () >> > { >> > ne = new expr with <expr-code> and operands: v1[i], v2[j]; >> > append ne to ret; >> > } >> > >> > for i = 0 ... v2.length () >> > for j = 0 ... v1.length () >> > { >> > ne = new expr with <expr-code> and operand: v2[i], v1[j]; >> > append ne to ret >> > } >> > >> > return ret; >> > } >> > >> > Example: >> > (plus (plus @0 @1) (plus @2 @3)) >> > generates following commutative variants: >> oops. >> the pattern given to genmatch was (bogus): >> (plus (plus @0 @1) (plus @0 @3)) >> > >> > (PLUS_EXPR (PLUS_EXPR @0 @1 ) (PLUS_EXPR @0 @3 ) ) >> > (PLUS_EXPR (PLUS_EXPR @0 @1 ) (PLUS_EXPR @3 @0 ) ) >> > (PLUS_EXPR (PLUS_EXPR @1 @0 ) (PLUS_EXPR @0 @3 ) ) >> > (PLUS_EXPR (PLUS_EXPR @1 @0 ) (PLUS_EXPR @3 @0 ) ) >> > (PLUS_EXPR (PLUS_EXPR @0 @3 ) (PLUS_EXPR @0 @1 ) ) >> > (PLUS_EXPR (PLUS_EXPR @0 @3 ) (PLUS_EXPR @1 @0 ) ) >> > (PLUS_EXPR (PLUS_EXPR @3 @0 ) (PLUS_EXPR @0 @1 ) ) >> > (PLUS_EXPR (PLUS_EXPR @3 @0 ) (PLUS_EXPR @1 @0 ) ) >> > >> > >> > * Decide which operators are commutative. >> > Currently I assume all PLUS_EXPR and MULT_EXPR are true. >> s/true/commutative > There's a bug in the previous patch - if the operator is not > commutative, it does not try > for generating commutative variants of it's operands, and does not > commutate captured > expression (.what). > example: > (negate (plus @0 @1)) has two commutative variants (including the > original pattern), > but the patch does not generate them, since negate is not commutative. > > The attached patch fixes that. As a quick hack i handled each operator > class (unary, binary, ternary) > specially (commutate_unary, commutate_binary, commutate_ternary). > Ideally it should be unified > (I tried that way, but it was segfaulting). I will try and come up > with a better way. > Also the current patch won't work for built-in functions/operators > having more than 3 operands. > (max we have 3 so far in match.pd for cond, I hope this doesn't come > "in the way"). > > With the current patch, > for the expression (negate (plus @0 @1)) > it generates following commutative variants: > (negate (plus @0 @1)) > (negate (plus @1 @0)) > > and for the following pattern (involving captured expression): > (negate (plus@0 @1 @2)) > it generates following variants: > (negate (plus@0 @1 @2)) > (negate (plus@0 @2 @1)) > > * generates multiple matching patterns > Since at AST-level we do not test for captures equality (true/match), > it treats both of the captures > as different, even though they are same. > example: the following also expression has 2 variants generated > (BUILT_IN_SQRT (mult @0 @0)) > commutative variants: > (BUILT_IN_SQRT (mult @0 @0)) > (BUILT_IN_SQRT (mult @0 @0)) > I guess this won't really be a problem with decision tree. If we decide to > emit > warning, we should warn only for user defined patterns, and not generated > ones. > > * syntax for commutative operators > Currently, I assume any PLUS_EXPR / MULT_EXPR to be commutative. > I guess we should have syntax for users marking an operator to be commutative. > > sth like: > a) op:c > b) op "c" > c) op! > d) op "commutative" > > Or any other, that you would like -:) > > * cloning AST nodes > Currently I do not do a deep-copy of the AST for each distinct > commutative variant, so the nodes > are shared for different expressions, which are commutative variants > of the original expression. > Is this OK, or should we clone each AST node, so that each expression > is represented by a distinct AST ? > cloning shall eat up space, while sharing shall require more careful > memory management (freeing one ast, may also > free nodes of other expression). This patch removes the hack of special handling according to operator classes. For now, I added op:c syntax to denote operator op as commutative.
Example: (does not commutate outer plus since it's not marked commutative). (plus (plus:c@0 @1 @2)) generates following variants: (PLUS_EXPR (PLUS_EXPR@0 @1 @2 ) @3 ) (PLUS_EXPR (PLUS_EXPR@0 @2 @1 ) @3 ) How do we resize a vector to hold n elements at start ? I tried: vec<operand *> v = vNULL; v.resize (n); v.resize_exact (n); however accessing v[i] led to internal abort in operator[] (vec.h line 735). As a work-around I did (in cartesian_product): for (unsigned i = 0; i < n_ops; ++i) v.safe_push (0); This works to make vector "big enough" to hold n_ops elements, but is rather ugly. Thanks and Regards, Prathamesh > > Thanks and Regards, > Prathamesh > >> > Maybe we should add syntax to mark a particular operator as commutative ? >> > >> > * Cloning AST nodes >> > While creating another AST that represents one of >> > the commutative variants, should we clone the AST nodes, >> > so that all commutative variants have distinct AST nodes ? >> > That's not done currently, and AST nodes are shared amongst >> > different commutative expressions, and we end up with a DAG, >> > for a set of commutative expressions. >> > >> > Thanks and Regards, >> > Prathamesh
Index: genmatch.c =================================================================== --- genmatch.c (revision 211732) +++ genmatch.c (working copy) @@ -119,7 +119,7 @@ struct id_base : typed_free_remove<id_ba { enum id_kind { CODE, FN } kind; - id_base (id_kind, const char *); + id_base (id_kind, const char *); hashval_t hashval; const char *id; @@ -146,7 +146,7 @@ id_base::equal (const value_type *op1, static hash_table<id_base> operators; -id_base::id_base (id_kind kind_, const char *id_) +id_base::id_base (id_kind kind_, const char *id_) { kind = kind_; id = id_; @@ -218,8 +218,9 @@ struct predicate : public operand }; struct e_operation { - e_operation (const char *id); + e_operation (const char *id, bool is_commutative_ = false); id_base *op; + bool is_commutative; }; @@ -258,9 +259,11 @@ struct capture : public operand }; -e_operation::e_operation (const char *id) +e_operation::e_operation (const char *id, bool is_commutative_) { id_base tem (id_base::CODE, id); + is_commutative = is_commutative_; + op = operators.find_with_hash (&tem, tem.hashval); if (op) return; @@ -293,14 +296,14 @@ e_operation::e_operation (const char *id struct simplify { simplify (const char *name_, - struct operand *match_, source_location match_location_, + vec<operand *> matchers_, source_location match_location_, struct operand *ifexpr_, source_location ifexpr_location_, struct operand *result_, source_location result_location_) - : name (name_), match (match_), match_location (match_location_), + : name (name_), matchers (matchers_), match_location (match_location_), ifexpr (ifexpr_), ifexpr_location (ifexpr_location_), result (result_), result_location (result_location_) {} const char *name; - struct operand *match; + vec<operand *> matchers; // vector to hold commutative expressions source_location match_location; struct operand *ifexpr; source_location ifexpr_location; @@ -308,7 +311,148 @@ struct simplify { source_location result_location; }; +void +print_operand (operand *o, FILE *f = stderr) +{ + if (o->type == operand::OP_CAPTURE) + { + capture *c = static_cast<capture *> (o); + fprintf (f, "@%s", (static_cast<capture *> (o))->where); + if (c->what) + { + putc (':', f); + print_operand (c->what, f); + putc (' ', f); + } + } + + else if (o->type == operand::OP_PREDICATE) + fprintf (f, "%s", (static_cast<predicate *> (o))->ident); + + else if (o->type == operand::OP_C_EXPR) + fprintf (f, "c_expr"); + + else if (o->type == operand::OP_EXPR) + { + expr *e = static_cast<expr *> (o); + fprintf (f, "(%s ", e->operation->op->id); + + for (unsigned i = 0; i < e->ops.length (); ++i) + { + print_operand (e->ops[i], f); + putc (' ', f); + } + + putc (')', f); + } + + else + gcc_unreachable (); +} + +void +print_matches (struct simplify *s, FILE *f = stderr) +{ + if (s->matchers.length () == 1) + return; + + fprintf (f, "for expression: "); + print_operand (s->matchers[0], f); // s->matchers[0] is equivalent to original expression + putc ('\n', f); + + fprintf (f, "commutative expressions:\n"); + for (unsigned i = 0; i < s->matchers.length (); ++i) + { + print_operand (s->matchers[i], f); + putc ('\n', f); + } +} + +void +cartesian_product (const vec< vec<operand *> >& ops_vector, vec< vec<operand *> >& result, vec<operand *>& v, unsigned n) +{ + if (n == ops_vector.length ()) + { + vec<operand *> xv = v.copy (); + result.safe_push (xv); + return; + } + + for (unsigned i = 0; i < ops_vector[n].length (); ++i) + { + v[n] = ops_vector[n][i]; + cartesian_product (ops_vector, result, v, n + 1); + } +} + +void +cartesian_product (const vec< vec<operand *> >& ops_vector, vec< vec<operand *> >& result, unsigned n_ops) +{ + vec<operand *> v = vNULL; +// FIXME: this is done to resize v to length n_ops. + for (unsigned i = 0; i < n_ops; ++i) + v.safe_push (0); + cartesian_product (ops_vector, result, v, 0); +} + +vec<operand *> +commutate (operand *op) +{ + vec<operand *> ret = vNULL; + + if (op->type == operand::OP_CAPTURE) + { + capture *c = static_cast<capture *> (op); + if (!c->what) + { + ret.safe_push (op); + return ret; + } + vec<operand *> v = commutate (c->what); + for (unsigned i = 0; i < v.length (); ++i) + { + capture *nc = new capture (c->where, v[i]); + ret.safe_push (nc); + } + return ret; + } + + if (op->type != operand::OP_EXPR) + { + ret.safe_push (op); + return ret; + } + + expr *e = static_cast<expr *> (op); + + vec< vec<operand *> > ops_vector = vNULL; + for (unsigned i = 0; i < e->ops.length (); ++i) + ops_vector.safe_push (commutate (e->ops[i])); + + vec< vec<operand *> > result = vNULL; + cartesian_product (ops_vector, result, e->ops.length ()); + + for (unsigned i = 0; i < result.length (); ++i) + { + expr *ne = new expr (e->operation); + for (unsigned j = 0; j < result[i].length (); ++j) + ne->append_op (result[i][j]); + ret.safe_push (ne); + } + if (!e->operation->is_commutative) + return ret; + + for (unsigned i = 0; i < result.length (); ++i) + { + expr *ne = new expr (e->operation); + for (unsigned j = result[i].length (); j; --j) // result[i].length () is 2 since e->operation is binary + ne->append_op (result[i][j-1]); + ret.safe_push (ne); + } + + return ret; +} /* Code gen off the AST. */ @@ -574,11 +718,15 @@ write_nary_simplifiers (FILE *f, vec<sim { simplify *s = simplifiers[i]; /* ??? This means we can't capture the outermost expression. */ - if (s->match->type != operand::OP_EXPR) + for (unsigned i = 0; i < s->matchers.length (); ++i) + { + operand *match = s->matchers[i]; + if (match->type != operand::OP_EXPR) continue; - expr *e = static_cast <expr *> (s->match); + expr *e = static_cast <expr *> (match); if (e->ops.length () != n) continue; + char fail_label[16]; snprintf (fail_label, 16, "fail%d", label_cnt++); output_line_directive (f, s->match_location); @@ -627,6 +775,7 @@ write_nary_simplifiers (FILE *f, vec<sim fprintf (f, " }\n"); fprintf (f, "%s:\n", fail_label); } + } fprintf (f, " return false;\n"); fprintf (f, "}\n"); } @@ -827,6 +976,25 @@ parse_expr (cpp_reader *r) expr *e = new expr (parse_operation (r)); const cpp_token *token = peek (r); operand *op; + bool is_commutative = false; + + if (token->type == CPP_COLON) + { + eat_token (r, CPP_COLON); + token = peek (r); + if (token->type == CPP_NAME + && !(token->flags & PREV_WHITE)) + { + const char *s = (const char *)CPP_HASHNODE (token->val.node.node)->ident.str; + eat_token (r, CPP_NAME); + token = peek (r); + if (s[0] == 'c' && !s[1]) + is_commutative = true; + else + fatal_at (token, "not implemented: predicates on expressions"); + } + } + if (token->type == CPP_ATSIGN && !(token->flags & PREV_WHITE)) op = parse_capture (r, e); @@ -847,6 +1015,13 @@ parse_expr (cpp_reader *r) fatal_at (token, "got %d operands instead of the required %d", e->ops.length (), opr->get_required_nargs ()); } + if (is_commutative) + { + if (e->ops.length () == 2) + e->operation->is_commutative = true; + else + fatal_at (token, "only binary operators or function with two arguments can be marked commutative"); + } return op; } e->append_op (parse_op (r)); @@ -971,7 +1146,7 @@ parse_match_and_simplify (cpp_reader *r, ifexpr = parse_c_expr (r, CPP_OPEN_PAREN); } token = peek (r); - return new simplify (id, match, match_location, + return new simplify (id, commutate (match), match_location, ifexpr, ifexpr_location, parse_op (r), token->src_loc); } @@ -1043,6 +1218,9 @@ main(int argc, char **argv) } while (1); + for (unsigned i = 0; i < simplifiers.length (); ++i) + print_matches (simplifiers[i]); + write_gimple (stdout, simplifiers); cpp_finish (r, NULL);