Hi All, This is a respin of this patch using the new approach.
Thanks, Tamar gcc/ChangeLog: * doc/md.texi: Document optabs. * internal-fn.def (COMPLEX_MUL, COMPLEX_MUL_CONJ): New. * optabs.def (cmul_optab, cmul_conj_optab): New, * tree-vect-slp-patterns.c (vect_build_perm_groups, (vect_can_combine_node_p, vect_slp_make_combine_linear, vect_match_call_complex_mla, vect_slp_matches_complex_mul, class complex_mul_pattern, complex_mul_pattern::matches, complex_mul_pattern::validate_p, complex_operations_pattern::matches): Add complex_mul_pattern. > -----Original Message----- > From: Gcc-patches <gcc-patches-boun...@gcc.gnu.org> On Behalf Of Tamar > Christina > Sent: Friday, September 25, 2020 3:29 PM > To: gcc-patches@gcc.gnu.org > Cc: nd <n...@arm.com>; rguent...@suse.de; o...@ucw.cz > Subject: [PATCH v2 7/16]middle-end: Add Complex Multiplication and > Multiplication with Conjucate detection > > Hi All, > > This patch adds pattern detections for the following operation: > > Complex multiplication and Conjucate Complex multiplication of the second > parameter. > > c = a * b and c = a * conj (b) > > For the conjucate cases it supports under fast-math that the operands that > is > being conjucated be flipped by flipping the arguments to the optab. This > allows it to support c = conj (a) * b and c += conj (a) * b. > > where a, b and c are complex numbers. > > and provides a shared class for anything needing to recognize complex MLA > patterns. > > Bootstrapped Regtested on aarch64-none-linux-gnu and no issues. > > Ok for master? > > Thanks, > Tamar > > gcc/ChangeLog: > > * doc/md.texi: Document optabs. > * internal-fn.def (COMPLEX_MUL, COMPLEX_MUL_CONJ): New. > * optabs.def (cmul_optab, cmul_conj_optab): New, > * tree-vect-slp-patterns.c (class ComplexMLAPattern, > class ComplexMulPattern): New. > (slp_patterns): Add ComplexMulPattern. > > --
diff --git a/gcc/doc/md.texi b/gcc/doc/md.texi index 71e226505b2619d10982b59a4ebbed73a70f29be..ddaf1abaccbd44dae11ea902ec38b474aacfb8e1 100644 --- a/gcc/doc/md.texi +++ b/gcc/doc/md.texi @@ -6143,6 +6143,28 @@ rotations @var{m} of 90 or 270. This pattern is not allowed to @code{FAIL}. +@cindex @code{cmul@var{m}4} instruction pattern +@item @samp{cmul@var{m}4} +Perform a vector floating point multiplication of complex numbers in operand 0 +and operand 1. + +The instruction must perform the operation on data loaded contiguously into the +vectors. +The operation is only supported for vector modes @var{m}. + +This pattern is not allowed to @code{FAIL}. + +@cindex @code{cmul_conj@var{m}4} instruction pattern +@item @samp{cmul_conj@var{m}4} +Perform a vector floating point multiplication of complex numbers in operand 0 +and the conjucate of operand 1. + +The instruction must perform the operation on data loaded contiguously into the +vectors. +The operation is only supported for vector modes @var{m}. + +This pattern is not allowed to @code{FAIL}. + @cindex @code{ffs@var{m}2} instruction pattern @item @samp{ffs@var{m}2} Store into operand 0 one plus the index of the least significant 1-bit diff --git a/gcc/internal-fn.def b/gcc/internal-fn.def index 33c54be1e158ddea25c4cd6b1148df8cf4a509b5..cb41643f5e332518a0271bb8e1af4883c8bd6880 100644 --- a/gcc/internal-fn.def +++ b/gcc/internal-fn.def @@ -279,6 +279,8 @@ DEF_INTERNAL_FLT_FLOATN_FN (FMAX, ECF_CONST, fmax, binary) DEF_INTERNAL_OPTAB_FN (XORSIGN, ECF_CONST, xorsign, binary) DEF_INTERNAL_OPTAB_FN (COMPLEX_ADD_ROT90, ECF_CONST, cadd90, binary) DEF_INTERNAL_OPTAB_FN (COMPLEX_ADD_ROT270, ECF_CONST, cadd270, binary) +DEF_INTERNAL_OPTAB_FN (COMPLEX_MUL, ECF_CONST, cmul, binary) +DEF_INTERNAL_OPTAB_FN (COMPLEX_MUL_CONJ, ECF_CONST, cmul_conj, binary) /* FP scales. */ diff --git a/gcc/optabs.def b/gcc/optabs.def index 2bb0bf857977035bf562a77f5f6848e80edf936d..9c267d422478d0011f288b1f5f62daabe3989ba7 100644 --- a/gcc/optabs.def +++ b/gcc/optabs.def @@ -292,6 +292,8 @@ OPTAB_D (copysign_optab, "copysign$F$a3") OPTAB_D (xorsign_optab, "xorsign$F$a3") OPTAB_D (cadd90_optab, "cadd90$a3") OPTAB_D (cadd270_optab, "cadd270$a3") +OPTAB_D (cmul_optab, "cmul$a3") +OPTAB_D (cmul_conj_optab, "cmul_conj$a3") OPTAB_D (cos_optab, "cos$a2") OPTAB_D (cosh_optab, "cosh$a2") OPTAB_D (exp10_optab, "exp10$a2") diff --git a/gcc/tree-vect-slp-patterns.c b/gcc/tree-vect-slp-patterns.c index 0732cf0a6d93be8590b84c39dff82940b280e46b..2edb0117f9cbbfc40e9ed3a96120a3c88f84a68e 100644 --- a/gcc/tree-vect-slp-patterns.c +++ b/gcc/tree-vect-slp-patterns.c @@ -196,6 +196,65 @@ linear_loads_p (slp_tree root, bool *linear) return loads; } +/* Builds a permutation group from the operands in OPS and stores it in BLOCKS. + The group describes how to combine the operators to get a valid linear node. + + This is used when combining multiple children from a two_operators node into + one using a lane permute to select the appropriate lane. As an example the + permute { [0 0] [1 4] [2 2] [3 3] [1 4] [5 5] } says the nodes which occur + twice in a group, e.g [0 0] only needs itself to possibly be made linear + whereas [1 4] means to combine the nodes 1 and 4. */ + +static void +vect_build_perm_groups (map_t *blocks, vec<slp_tree> ops) +{ + slp_tree op; + unsigned i; + bool is_linear = false; + unsigned min_eq = -1, max_eq = 0; + unsigned min_idx = 0, max_idx = 0; + FOR_EACH_VEC_ELT (ops, i, op) + { + load_permutation_t perms = linear_loads_p (op, &is_linear); + unsigned x, imin = -1, imax = 0; + for (x = 0; x < perms.length () && !is_linear; x++) + { + imin = MIN (imin, perms[x]); + imax = MAX (imax, perms[x]); + } + + if (imin != imax || perms.length () == 0 || is_linear) + blocks[i] = {i, i}; + else + { + if (imin <= min_eq) + { + min_eq = imin; + min_idx = i; + } + + if (imin >= max_eq) + { + max_eq = imin; + max_idx = i; + } + } + } + + /* Now fill in the gap. */ + blocks[min_idx] = {min_idx, max_idx}; + blocks[max_idx] = {min_idx, max_idx}; + + if (dump_enabled_p ()) + { + dump_printf_loc (MSG_NOTE, vect_location, "pattern group: { "); + for (i = 0; i < ops.length (); i++) + dump_printf (MSG_NOTE,"[%d %d] ", blocks[i].a, blocks[i].b); + dump_printf (MSG_NOTE,"}\n"); + } + +} + /* This function attempts to make a node rooted in NODE linear. If the node if already linear than the node itself is returned in RESULT. @@ -265,6 +324,85 @@ vect_slp_make_linear (slp_tree parent, slp_tree node, slp_tree *result) return is_linear; } +/* Helper utility to check to see if the permutation PERM is one that can be + used in a node combination operation. This is defined as the permute not + having all the elements being the same. e.g [0 0]. */ + +static inline bool +vect_can_combine_node_p (load_permutation_t perm, bool is_linear) +{ + if (is_linear) + return false; + + unsigned i, x; + FOR_EACH_VEC_ELT (perm, i, x) + if (perm[0] != x) + return false; + + return true; +} + +/* This function combines the nodes in MAP together to make a new node using a + lane permute. The nodes to combine are stored in ENTRIES and the resulting + node is returned in RESULT. + + If the nodes are already linear then this function fails and returns FALSE. + Otherwise it returns the new node and TRUE. */ + +static bool +vect_slp_make_combine_linear (slp_tree parent, vec<slp_tree> entries, map_t map, + slp_tree *result) +{ + if (map.a == map.b) + return false; + + slp_tree node_a = entries[map.a]; + slp_tree node_b = entries[map.b]; + + bool is_a_linear = false; + bool is_b_linear = false; + + load_permutation_t load_perm_a = linear_loads_p (node_a, &is_a_linear); + if (!vect_can_combine_node_p (load_perm_a, is_a_linear)) + return false; + load_permutation_t load_perm_b = linear_loads_p (node_b, &is_b_linear); + if (!vect_can_combine_node_p (load_perm_b, is_b_linear)) + return false; + + /* Now we need to figure which node is first. */ + auto_vec<slp_tree> nodes; + nodes.create (2); + vec<std::pair<unsigned, unsigned> > perm; + perm.create (2); + if (load_perm_a[0] < load_perm_b[0]) + { + perm.quick_push (std::make_pair (0, 0)); + perm.quick_push (std::make_pair (1, 0)); + } + else + { + perm.quick_push (std::make_pair (1, 0)); + perm.quick_push (std::make_pair (0, 0)); + } + + nodes.quick_push (node_a); + nodes.quick_push (node_b); + /* Already connected to a, just need b. */ + SLP_TREE_REF_COUNT (node_a)++; + SLP_TREE_REF_COUNT (node_b)++; + + slp_tree vnode = vect_create_new_slp_node (vNULL, 1); + SLP_TREE_CODE (vnode) = VEC_PERM_EXPR; + SLP_TREE_LANE_PERMUTATION (vnode) = perm; + SLP_TREE_VECTYPE (vnode) = SLP_TREE_VECTYPE (parent); + SLP_TREE_CHILDREN (vnode).safe_splice (nodes); + SLP_TREE_REF_COUNT (vnode) = 1; + SLP_TREE_LANES (vnode) = SLP_TREE_LANES (parent); + SLP_TREE_REPRESENTATIVE (vnode) = SLP_TREE_REPRESENTATIVE (parent); + *result = vnode; + return true; +} + /******************************************************************************* * Simple vector pattern matcher ******************************************************************************/ @@ -727,6 +865,313 @@ complex_add_pattern::matches () return matches (op, this->m_ops); } +/******************************************************************************* + * complex_mul_pattern + ******************************************************************************/ + +/* Helper function of that looks for a match in the CHILDth child of NODE. The + child used is stored in RES. + + If the match is successful then ARGS will contain the operands matched + and the complex_operation_t type is returned. If match is not successful + then CMPLX_NONE is returned and ARGS is left unmodified. */ + +static complex_operation_t +vect_match_call_complex_mla (slp_tree node, unsigned child, + vec<slp_tree> *args = NULL, slp_tree *res = NULL) +{ + gcc_assert (child < SLP_TREE_CHILDREN (node).length ()); + + slp_tree data = SLP_TREE_CHILDREN (node)[child]; + + if (res) + *res = data; + + return vect_detect_pair_op (data, false, args); +} + +/* This helper attemps to find a complex MUL pattern rooted in ROOT. If the + match succeeds then the pattern type is set in IFN and the operands are + returned in OPS. + + This function matches both a normal complex multiply and complex conjucate + multiply. Additionally it also matches the MUL part in a FMS and FMA + sequence. However due to the additional TWO_OPERATORS node that an FMS + has the location of the negate node that denotes a conjucate changes. + + In order to differentiate when and where we should check for a conjucate + the value MULTIPLY is set when this should only match a normal complex + multiply operation and INVERSE is set when we're matching a sequence for an + FMS where the negate node is on the other side. + + Note that this function also deals with that the canonicalization of the + sequence is off if there is a type cast in between. This is likely a mid-end + bug but for now we deal with it here. */ +static bool +vect_slp_matches_complex_mul (complex_operation_t op, slp_tree root, + internal_fn *ifn, vec<slp_tree> *ops, + bool multiply, bool inverse = false) +{ + *ifn = IFN_LAST; + + if (op != MINUS_PLUS) + return false; + + /* Now operand1+3 must lead to another expression. */ + auto_vec<slp_tree> args0; + complex_operation_t op2 = vect_match_call_complex_mla (root, 0, &args0); + + if (op2 != MULT_MULT) + return false; + + /* Now operand2+4 must lead to another expression. */ + auto_vec<slp_tree> args1; + complex_operation_t op3 = vect_match_call_complex_mla (root, 1, &args1); + + if (op3 != MULT_MULT) + return false; + + vec<slp_tree> args2 = SLP_TREE_CHILDREN (args1[inverse ? 0 : 1]); + slp_tree neg_node = NULL; + bool first_neg = false, second_neg = false; + + /* Now operand2+4 may lead to another expression. */ + if ((first_neg = vect_match_expression_p (args2[0], NEGATE_EXPR))) + neg_node = SLP_TREE_CHILDREN (args2[0])[0]; + else if ((second_neg = vect_match_expression_p (args2[1], NEGATE_EXPR))) + neg_node = SLP_TREE_CHILDREN (args2[1])[0]; + + if (first_neg && multiply) + return false; + + /* Check if the neg node is a dup, otherwise not a pattern we want. */ + bool is_dup = false; + bool same_operand = true; + stmt_vec_info elem; + unsigned i; + load_permutation_t perm = linear_loads_p (neg_node, &is_dup); + for (i = 0; i < perm.length (); i++) + if (perm[i] != perm[0]) + return false; + + + /* Check if the conjucate is on the second first and flip the order so we + get it in the right place. We can't check the DR of the new child since + we may not be a load. We can't recurse all the way down because we + may find a child with multiple children or external. So instead just + check the operands to the multiply which tell us where the conjucate + was and how it's interpeting the permute. */ + vec<stmt_vec_info> stmts = SLP_TREE_SCALAR_STMTS (args0[0]); + tree first_op = gimple_op (STMT_VINFO_STMT (stmts[0]), 1); + FOR_EACH_VEC_ELT (stmts, i, elem) + if (first_op != gimple_op (STMT_VINFO_STMT (elem), 1)) + { + same_operand = false; + break; + } + + /* Reject operations that we don't have an optab for. */ + if (first_neg && !multiply && same_operand && !inverse) + return false; + + bool is_neg = first_neg || second_neg; + + if (!is_neg) + { + /* Indicates a rotation in the complex number, not a pattern we are + looking for.. */ + vec<slp_tree> params = SLP_TREE_CHILDREN (args0[0]); + if (vect_match_expression_p (params[0], NEGATE_EXPR) + || vect_match_expression_p (params[1], NEGATE_EXPR)) + return false; + *ifn = IFN_COMPLEX_MUL; + ops->safe_splice (params); + ops->safe_push (SLP_TREE_CHILDREN (args1[1])[0]); + ops->safe_push (SLP_TREE_CHILDREN (args1[1])[1]); + } + else if (is_neg) + { + *ifn = IFN_COMPLEX_MUL_CONJ; + vec<slp_tree> params = SLP_TREE_CHILDREN (args0[inverse ? 1 : 0]); + slp_tree value = second_neg ? args2[0] : args2[1]; + /* Check if the conjucate is on the first or second parameter. */ + if (same_operand) + { + ops->safe_push (params[0]); + ops->safe_push (params[1]); + ops->safe_push (neg_node); + ops->safe_push (value); + } + else + { + ops->safe_push (params[1]); + ops->safe_push (params[0]); + ops->safe_push (neg_node); + ops->safe_push (value); + } + + /* The two_operators with an FMS reverse the nodes so we have to swap them + back to make a sensible operation. */ + if (inverse) + std::swap ((*ops)[2], (*ops)[3]); + } + + return *ifn != IFN_LAST; +} + +static bool +vect_slp_matches_complex_mul (slp_tree root, internal_fn *ifn, + vec<slp_tree> *ops, bool multiply, + bool inverse = false) +{ + return vect_slp_matches_complex_mul (vect_detect_pair_op (root), root, ifn, + ops, multiply, inverse); +} + +class complex_mul_pattern : public complex_pattern +{ + protected: + /* Allocate enough space for FMA as well. */ + map_t m_blocks[6] = {}; + bool m_inplace = false; + auto_vec<slp_tree> workset; + complex_mul_pattern (slp_tree *node, vec_info *vinfo) + : complex_pattern (node, vinfo) + { + this->m_arity = 2; + this->m_num_args = 2; + } + + public: + static vect_pattern* create (slp_tree *node, vec_info *vinfo) + { + return new complex_mul_pattern (node, vinfo); + } + + const char* get_name () + { + return "Complex Multiplication"; + } + + bool validate_p (); + bool matches (); + bool matches (complex_operation_t op, vec<slp_tree> ops); +}; + + +/* Pattern matcher for trying to match complex multiply pattern in SLP tree + If the operation matches then IFN is set to the operation it matched + and the arguments to the two replacement statements are put in M_OPS. + + If no match is found then IFN is set to IFN_LAST and M_OPS is unchanged. + + This function matches the patterns shaped as: + + double ax = (b[i+1] * a[i]); + double bx = (a[i+1] * b[i]); + + c[i] = c[i] - ax; + c[i+1] = c[i+1] + bx; + + If a match occurred then TRUE is returned, else FALSE. */ + +bool +complex_mul_pattern::matches (complex_operation_t op, vec<slp_tree> /* ops */) +{ + bool res + = vect_slp_matches_complex_mul (op, *this->m_node, &this->m_ifn, + &this->m_ops, true); + if (res) + { + vect_build_perm_groups (&this->m_blocks[0], this->m_ops); + this->workset.safe_splice (SLP_TREE_CHILDREN (*this->m_node)); + save_match (); + } + return res; +} + +bool +complex_mul_pattern::matches () +{ + complex_operation_t op + = vect_detect_pair_op (*this->m_node); + return matches (op, this->m_ops); +} + + +/* Validates to see if the Complex MUL that we have matched is valid. This is + done through a combination of making nodes linear and combining nodes. */ + +bool +complex_mul_pattern::validate_p () +{ + if (!this->m_match) + return false; + + slp_tree node; + unsigned ix; + hash_set<slp_tree> cache; + FOR_EACH_VEC_ELT (this->workset, ix, node) + { + auto_vec<slp_tree> nodes; + nodes.create (this->m_num_args); + slp_tree tmp = NULL; + + unsigned i; + for (i = 0; i < this->m_num_args; i++) + { + unsigned index = (ix * this->m_num_args) + i; + map_t map = this->m_blocks[index]; + slp_tree vnode = NULL; + bool needs_linear = map.a == map.b; + tmp = this->m_ops[index]; + cache.add (tmp); + if (needs_linear && vect_slp_make_linear (node, tmp, &vnode)) + nodes.quick_push (vnode); + else if (!needs_linear + && vect_slp_make_combine_linear (node, this->m_ops, map, + &vnode)) + nodes.quick_push (vnode); + else + { + if (dump_enabled_p ()) + dump_printf_loc (MSG_MISSED_OPTIMIZATION, vect_location, + "stmts could not be made %s %p\n", + needs_linear ? "linear" : "linear/combined", + tmp); + nodes.release(); + return false; + } + + vect_mark_stmts_as_in_pattern (&cache, node); + } + + if (m_inplace) + { + SLP_TREE_CHILDREN (*this->m_node).truncate (0); + SLP_TREE_CHILDREN (*this->m_node).safe_splice (nodes); + } + else + { + slp_tree new_node + = vect_create_new_slp_node (SLP_TREE_SCALAR_STMTS (node), + SLP_TREE_CHILDREN (node).length ()); + SLP_TREE_VECTYPE (new_node) = SLP_TREE_VECTYPE (node); + SLP_TREE_LANE_PERMUTATION (new_node) + = SLP_TREE_LANE_PERMUTATION (node); + SLP_TREE_CODE (new_node) = SLP_TREE_CODE (node); + SLP_TREE_REF_COUNT (new_node) = SLP_TREE_REF_COUNT (node); + SLP_TREE_REPRESENTATIVE (new_node) = SLP_TREE_REPRESENTATIVE (node); + SLP_TREE_CHILDREN (new_node).safe_splice (nodes); + SLP_TREE_LANES (new_node) = SLP_TREE_LANES (node); + + SLP_TREE_CHILDREN (*this->m_node)[ix] = new_node; + } + } + + return true; +} + /******************************************************************************* * complex_operations_pattern class ******************************************************************************/ @@ -776,6 +1221,12 @@ complex_operations_pattern::matches () return false; /* Check which pattern this may be. Match longest pattern first. */ + this->m_patt = complex_mul_pattern::create (this->m_node, this->m_vinfo); + if (this->m_patt->matches (op, this->m_ops)) + return true; + + delete this->m_patt; + this->m_patt = complex_add_pattern::create (this->m_node, this->m_vinfo); if (this->m_patt->matches (op, this->m_ops)) return true;