On Mon, 28 Dec 2020, Tamar Christina wrote: > Hi All, > > This adds support for FMA and FMA conjugated to the slp pattern matcher. > > Bootstrapped Regtested on aarch64-none-linux-gnu, x86_64-pc-linux-gnu > and no issues. > > Ok for master? > > Thanks, > Tamar > > gcc/ChangeLog: > > * internal-fn.def (COMPLEX_FMA, COMPLEX_FMA_CONJ): New. > * optabs.def (cmla_optab, cmla_conj_optab): New. > * doc/md.texi: Document them. > * tree-vect-slp-patterns.c (vect_match_call_p, > class complex_fma_pattern, vect_slp_reset_pattern, > complex_fma_pattern::matches, complex_fma_pattern::recognize, > complex_fma_pattern::build): New. > > --- inline copy of patch -- > diff --git a/gcc/doc/md.texi b/gcc/doc/md.texi > index > b8cc90e1a75e402abbf8a8cf2efefc1a333f8b3a..6d5a98c4946d3ff4c2b8abea5c29caa6863fd3f7 > 100644 > --- a/gcc/doc/md.texi > +++ b/gcc/doc/md.texi > @@ -6202,6 +6202,51 @@ The operation is only supported for vector modes > @var{m}. > > This pattern is not allowed to @code{FAIL}. > > +@cindex @code{cmla@var{m}4} instruction pattern > +@item @samp{cmla@var{m}4} > +Perform a vector multiply and accumulate that is semantically the same as > +a multiply and accumulate of complex numbers. > + > +@smallexample > + complex TYPE c[N]; > + complex TYPE a[N]; > + complex TYPE b[N]; > + for (int i = 0; i < N; i += 1) > + @{ > + c[i] += a[i] * b[i]; > + @} > +@end smallexample > + > +In GCC lane ordering the real part of the number must be in the even lanes > with > +the imaginary part in the odd lanes. > + > +The operation is only supported for vector modes @var{m}. > + > +This pattern is not allowed to @code{FAIL}. > + > +@cindex @code{cmla_conj@var{m}4} instruction pattern > +@item @samp{cmla_conj@var{m}4} > +Perform a vector multiply by conjugate and accumulate that is semantically > +the same as a multiply and accumulate of complex numbers where the second > +multiply arguments is conjugated. > + > +@smallexample > + complex TYPE c[N]; > + complex TYPE a[N]; > + complex TYPE b[N]; > + for (int i = 0; i < N; i += 1) > + @{ > + c[i] += a[i] * conj (b[i]); > + @} > +@end smallexample > + > +In GCC lane ordering the real part of the number must be in the even lanes > with > +the imaginary part in the odd lanes. > + > +The operation is only supported for vector modes @var{m}. > + > +This pattern is not allowed to @code{FAIL}. > + > @cindex @code{cmul@var{m}4} instruction pattern > @item @samp{cmul@var{m}4} > Perform a vector multiply that is semantically the same as multiply of > diff --git a/gcc/internal-fn.def b/gcc/internal-fn.def > index > 5a0bbe3fe5dee591d54130e60f6996b28164ae38..305450e026d4b94ab62ceb9ca719ec5570ff43eb > 100644 > --- a/gcc/internal-fn.def > +++ b/gcc/internal-fn.def > @@ -288,6 +288,8 @@ DEF_INTERNAL_FLT_FN (LDEXP, ECF_CONST, ldexp, binary) > > /* Ternary math functions. */ > DEF_INTERNAL_FLT_FLOATN_FN (FMA, ECF_CONST, fma, ternary) > +DEF_INTERNAL_OPTAB_FN (COMPLEX_FMA, ECF_CONST, cmla, ternary) > +DEF_INTERNAL_OPTAB_FN (COMPLEX_FMA_CONJ, ECF_CONST, cmla_conj, ternary) > > /* Unary integer ops. */ > DEF_INTERNAL_INT_FN (CLRSB, ECF_CONST | ECF_NOTHROW, clrsb, unary) > diff --git a/gcc/optabs.def b/gcc/optabs.def > index > e82396bae1117c6de91304761a560b7fbcb69ce1..8e2758d685ed85e02df10dac571eb40d45a294ed > 100644 > --- a/gcc/optabs.def > +++ b/gcc/optabs.def > @@ -294,6 +294,8 @@ OPTAB_D (cadd90_optab, "cadd90$a3") > OPTAB_D (cadd270_optab, "cadd270$a3") > OPTAB_D (cmul_optab, "cmul$a3") > OPTAB_D (cmul_conj_optab, "cmul_conj$a3") > +OPTAB_D (cmla_optab, "cmla$a4") > +OPTAB_D (cmla_conj_optab, "cmla_conj$a4") > OPTAB_D (cos_optab, "cos$a2") > OPTAB_D (cosh_optab, "cosh$a2") > OPTAB_D (exp10_optab, "exp10$a2") > diff --git a/gcc/tree-vect-slp-patterns.c b/gcc/tree-vect-slp-patterns.c > index > 82721acbab8cf81c4d6f9954c98fb913a7bb6282..3625a80c08e3d70fd362fc52e17e65b3b2c7da83 > 100644 > --- a/gcc/tree-vect-slp-patterns.c > +++ b/gcc/tree-vect-slp-patterns.c > @@ -325,6 +325,24 @@ vect_match_expression_p (slp_tree node, tree_code code) > return true; > } > > +/* Checks to see if the expression represented by NODE is a call to the > internal > + function FN. */ > + > +static inline bool > +vect_match_call_p (slp_tree node, internal_fn fn) > +{ > + if (!node > + || !SLP_TREE_REPRESENTATIVE (node)) > + return false; > + > + gimple* expr = STMT_VINFO_STMT (SLP_TREE_REPRESENTATIVE (node)); > + if (!expr > + || !gimple_call_internal_p (expr, fn)) > + return false; > + > + return true; > +} > + > /* Check if the given lane permute in PERMUTES matches an alternating > sequence > of {even odd even odd ...}. This to account for unrolled loops. Further > mode there resulting permute must be linear. */ > @@ -1081,6 +1099,161 @@ complex_mul_pattern::build (vec_info *vinfo) > complex_pattern::build (vinfo); > } > > +/******************************************************************************* > + * complex_fma_pattern class > + > ******************************************************************************/ > + > +class complex_fma_pattern : public complex_pattern > +{ > + protected: > + complex_fma_pattern (slp_tree *node, vec<slp_tree> *m_ops, internal_fn > ifn) > + : complex_pattern (node, m_ops, ifn) > + { > + this->m_num_args = 3; > + } > + > + public: > + void build (vec_info *); > + static internal_fn > + matches (complex_operation_t op, slp_tree_to_load_perm_map_t *, slp_tree > *, > + vec<slp_tree> *); > + > + static vect_pattern* > + recognize (slp_tree_to_load_perm_map_t *, slp_tree *); > + > + static vect_pattern* > + mkInstance (slp_tree *node, vec<slp_tree> *m_ops, internal_fn ifn) > + { > + return new complex_fma_pattern (node, m_ops, ifn); > + } > +}; > + > +/* Helper function to "reset" a previously matched node and undo the changes > + made enough so that the node is treated as an irrelevant node. */ > + > +static inline void > +vect_slp_reset_pattern (slp_tree node) > +{ > + stmt_vec_info stmt_info = vect_orig_stmt (SLP_TREE_REPRESENTATIVE (node)); > + STMT_VINFO_IN_PATTERN_P (stmt_info) = false; > + STMT_SLP_TYPE (stmt_info) = pure_slp; > + SLP_TREE_REPRESENTATIVE (node) = stmt_info; > +} > + > +/* Pattern matcher for trying to match complex multiply and accumulate > + and multiply and subtract patterns in SLP tree. > + If the operation matches then IFN is set to the operation it matched and > + the arguments to the two replacement statements are put in m_ops. > + > + If no match is found then IFN is set to IFN_LAST and m_ops is unchanged. > + > + This function matches the patterns shaped as: > + > + double ax = (b[i+1] * a[i]) + (b[i] * a[i]); > + double bx = (a[i+1] * b[i]) - (a[i+1] * b[i+1]); > + > + c[i] = c[i] - ax; > + c[i+1] = c[i+1] + bx; > + > + If a match occurred then TRUE is returned, else FALSE. The match is > + performed after COMPLEX_MUL which would have done the majority of the > work. > + This function merely matches an ADD with a COMPLEX_MUL IFN. The initial > + match is expected to be in OP1 and the initial match operands in args0. > */ > + > +internal_fn > +complex_fma_pattern::matches (complex_operation_t op, > + slp_tree_to_load_perm_map_t * /* perm_cache */, > + slp_tree *ref_node, vec<slp_tree> *ops) > +{ > + internal_fn ifn = IFN_LAST; > + > + /* Find the two components. We match Complex MUL first which reduces the > + amount of work this pattern has to do. After that we just match the > + head node and we're done.: > + > + * FMA: + +. > + > + We need to ignore the two_operands nodes that may also match. > + For that we can check if they have any scalar statements and also > + check that it's not a permute node as we're looking for a normal > + PLUS_EXPR operation. */ > + if (op != CMPLX_NONE) > + return IFN_LAST; > + > + /* Find the two components. We match Complex MUL first which reduces the > + amount of work this pattern has to do. After that we just match the > + head node and we're done.: > + > + * FMA: + + on a non-two_operands node. */ > + slp_tree vnode = *ref_node; > + if (SLP_TREE_LANE_PERMUTATION (vnode).exists () > + /* Need to exclude the plus two-operands node. These are not marked > + so we have to infer it based on conditions. */ > + || !SLP_TREE_SCALAR_STMTS (vnode).exists ()
as said earlier we shouldn't test this. The existing lane permute should already cover this - where the test would better be SLP_TREE_CODE (vnode) == VEC_PERM_EXPR > + || !vect_match_expression_p (vnode, PLUS_EXPR)) But then it shouldn't match this (the vect_match_expression_p should only ever match SLP_TREE_CODE (vnode) != VEC_PERM_EXPR) anyway. > + return IFN_LAST; > + > + slp_tree node = SLP_TREE_CHILDREN (vnode)[1]; > + > + if (vect_match_call_p (node, IFN_COMPLEX_MUL)) > + ifn = IFN_COMPLEX_FMA; > + else if (vect_match_call_p (node, IFN_COMPLEX_MUL_CONJ)) > + ifn = IFN_COMPLEX_FMA_CONJ; > + else > + return IFN_LAST; > + > + if (!vect_pattern_validate_optab (ifn, vnode)) > + return IFN_LAST; > + > + vect_slp_reset_pattern (node); I don't understand this ... it deserves a comment at least. Having no testcases with this patch makes it impossible for me to dig in myself :/ Otherwise looks OK. Thanks, Richard. > + ops->truncate (0); > + ops->create (3); > + > + if (ifn == IFN_COMPLEX_FMA) > + { > + ops->quick_push (SLP_TREE_CHILDREN (vnode)[0]); > + ops->quick_push (SLP_TREE_CHILDREN (node)[1]); > + ops->quick_push (SLP_TREE_CHILDREN (node)[0]); > + } > + else > + { > + ops->quick_push (SLP_TREE_CHILDREN (vnode)[0]); > + ops->quick_push (SLP_TREE_CHILDREN (node)[0]); > + ops->quick_push (SLP_TREE_CHILDREN (node)[1]); > + } > + > + return ifn; > +} > + > +/* Attempt to recognize a complex mul pattern. */ > + > +vect_pattern* > +complex_fma_pattern::recognize (slp_tree_to_load_perm_map_t *perm_cache, > + slp_tree *node) > +{ > + auto_vec<slp_tree> ops; > + complex_operation_t op > + = vect_detect_pair_op (*node, true, &ops); > + internal_fn ifn > + = complex_fma_pattern::matches (op, perm_cache, node, &ops); > + if (ifn == IFN_LAST) > + return NULL; > + > + return new complex_fma_pattern (node, &ops, ifn); > +} > + > +/* Perform a replacement of the detected complex mul pattern with the new > + instruction sequences. */ > + > +void > +complex_fma_pattern::build (vec_info *vinfo) > +{ > + SLP_TREE_CHILDREN (*this->m_node).truncate (0); > + SLP_TREE_CHILDREN (*this->m_node).safe_splice (this->m_ops); > + > + complex_pattern::build (vinfo); > +} > + > > /******************************************************************************* > * Pattern matching definitions > > ******************************************************************************/ > > > -- Richard Biener <rguent...@suse.de> SUSE Software Solutions Germany GmbH, Maxfeldstrasse 5, 90409 Nuernberg, Germany; GF: Felix Imendörffer; HRB 36809 (AG Nuernberg)