Hi All, This patch adds pattern detections for the following operation:
Complex multiplication and Conjucate Complex multiplication of the second parameter. c = a * b and c = a * conj (b) For the conjucate cases it supports under fast-math that the operands that is being conjucated be flipped by flipping the arguments to the optab. This allows it to support c = conj (a) * b and c += conj (a) * b. where a, b and c are complex numbers. and provides a shared class for anything needing to recognize complex MLA patterns. Bootstrapped Regtested on aarch64-none-linux-gnu and no issues. Ok for master? Thanks, Tamar gcc/ChangeLog: * doc/md.texi: Document optabs. * internal-fn.def (COMPLEX_MUL, COMPLEX_MUL_CONJ): New. * optabs.def (cmul_optab, cmul_conj_optab): New, * tree-vect-slp-patterns.c (class ComplexMLAPattern, class ComplexMulPattern): New. (slp_patterns): Add ComplexMulPattern. --
diff --git a/gcc/doc/md.texi b/gcc/doc/md.texi index 71e226505b2619d10982b59a4ebbed73a70f29be..ddaf1abaccbd44dae11ea902ec38b474aacfb8e1 100644 --- a/gcc/doc/md.texi +++ b/gcc/doc/md.texi @@ -6143,6 +6143,28 @@ rotations @var{m} of 90 or 270. This pattern is not allowed to @code{FAIL}. +@cindex @code{cmul@var{m}4} instruction pattern +@item @samp{cmul@var{m}4} +Perform a vector floating point multiplication of complex numbers in operand 0 +and operand 1. + +The instruction must perform the operation on data loaded contiguously into the +vectors. +The operation is only supported for vector modes @var{m}. + +This pattern is not allowed to @code{FAIL}. + +@cindex @code{cmul_conj@var{m}4} instruction pattern +@item @samp{cmul_conj@var{m}4} +Perform a vector floating point multiplication of complex numbers in operand 0 +and the conjucate of operand 1. + +The instruction must perform the operation on data loaded contiguously into the +vectors. +The operation is only supported for vector modes @var{m}. + +This pattern is not allowed to @code{FAIL}. + @cindex @code{ffs@var{m}2} instruction pattern @item @samp{ffs@var{m}2} Store into operand 0 one plus the index of the least significant 1-bit diff --git a/gcc/internal-fn.def b/gcc/internal-fn.def index 956a65a338c157b51de7e78a3fb005b5af78ef31..51bebf8701af262b22d66d19a29a8dafb74db1f0 100644 --- a/gcc/internal-fn.def +++ b/gcc/internal-fn.def @@ -277,6 +277,9 @@ DEF_INTERNAL_FLT_FLOATN_FN (FMAX, ECF_CONST, fmax, binary) DEF_INTERNAL_OPTAB_FN (XORSIGN, ECF_CONST, xorsign, binary) DEF_INTERNAL_OPTAB_FN (COMPLEX_ADD_ROT90, ECF_CONST, cadd90, binary) DEF_INTERNAL_OPTAB_FN (COMPLEX_ADD_ROT270, ECF_CONST, cadd270, binary) +DEF_INTERNAL_OPTAB_FN (COMPLEX_MUL, ECF_CONST, cmul, binary) +DEF_INTERNAL_OPTAB_FN (COMPLEX_MUL_CONJ, ECF_CONST, cmul_conj, binary) + /* FP scales. */ DEF_INTERNAL_FLT_FN (LDEXP, ECF_CONST, ldexp, binary) diff --git a/gcc/optabs.def b/gcc/optabs.def index 2bb0bf857977035bf562a77f5f6848e80edf936d..9c267d422478d0011f288b1f5f62daabe3989ba7 100644 --- a/gcc/optabs.def +++ b/gcc/optabs.def @@ -292,6 +292,8 @@ OPTAB_D (copysign_optab, "copysign$F$a3") OPTAB_D (xorsign_optab, "xorsign$F$a3") OPTAB_D (cadd90_optab, "cadd90$a3") OPTAB_D (cadd270_optab, "cadd270$a3") +OPTAB_D (cmul_optab, "cmul$a3") +OPTAB_D (cmul_conj_optab, "cmul_conj$a3") OPTAB_D (cos_optab, "cos$a2") OPTAB_D (cosh_optab, "cosh$a2") OPTAB_D (exp10_optab, "exp10$a2") diff --git a/gcc/tree-vect-slp-patterns.c b/gcc/tree-vect-slp-patterns.c index b2b0ac62e9a69145470f41d2bac736dd970be735..bef7cc73b21c020e4c0128df5d186a034809b103 100644 --- a/gcc/tree-vect-slp-patterns.c +++ b/gcc/tree-vect-slp-patterns.c @@ -743,6 +743,179 @@ class ComplexAddPattern : public ComplexPattern } }; +class ComplexMLAPattern : public ComplexPattern +{ + protected: + ComplexMLAPattern (slp_tree node, vec_info *vinfo) + : ComplexPattern (node, vinfo) + { } + + protected: + /* Helper function of vect_match_call_complex_mla that looks up the + definition of LHS_0 and LHS_1 by finding the statements starting in + position BASE + IDX in child ROOT of NODE and tries to match the + definition against pair ops. + + If the match is successful then ARGS will contain the operands matched + and the complex_operation_t type is returned. If match is not successful + then CMPLX_NONE is returned and ARGS is left unmodified. */ + + complex_operation_t + vect_match_call_complex_mla_1 (slp_tree node, slp_tree *res, int root, + int base, int idx, vec<stmt_vec_info> *args) + { + gcc_assert (base >= 0 && idx >= 0 && node != NULL); + + if ((unsigned)root >= SLP_TREE_CHILDREN (node).length ()) + return CMPLX_NONE; + + slp_tree data = SLP_TREE_CHILDREN (node)[root]; + + /* If it's a VEC_PERM_EXPR we need to look one deeper. */ + if (node->code == VEC_PERM_EXPR) + data = SLP_TREE_CHILDREN (data)[root]; + + int lhs_0 = base + idx; + int lhs_1 = base + idx + 1; + + vec<stmt_vec_info> stmts = SLP_TREE_SCALAR_STMTS (data); + if (stmts.length () < (unsigned)lhs_1) + return CMPLX_NONE; + + gimple *stmt_0 = STMT_VINFO_STMT (stmts[lhs_0]); + gimple *stmt_1 = STMT_VINFO_STMT (stmts[lhs_1]); + + if (gimple_expr_type (stmt_0) != gimple_expr_type (stmt_1)) + return CMPLX_NONE; + + if (res) + *res = data; + + return vect_detect_pair_op (base, data, args); + } +}; + +class ComplexMulPattern : public ComplexMLAPattern +{ + protected: + ComplexMulPattern (slp_tree node, vec_info *vinfo) + : ComplexMLAPattern (node, vinfo) + { + this->m_arity = 2; + this->m_num_args = 2; + this->m_vects.create (0); + this->m_defs.create (0); + } + + public: + ~ComplexMulPattern () + { + this->m_vects.release (); + this->m_defs.release (); + } + + static VectPattern* create (slp_tree node, vec_info *vinfo) + { + return new ComplexMulPattern (node, vinfo); + } + + const char* get_name () + { + return "Complex Multiplication"; + } + + + /* Pattern matcher for trying to match complex multiply pattern in SLP tree + using N statements STMT_0 and STMT_0 as the root statements by finding + the statements starting in position IDX in NODE. If the operation + matches then IFN is set to the operation it matched and the arguments to + the two replacement statements are put in VECTS. + + If no match is found then IFN is set to IFN_LAST and VECTS is unchanged. + + This function matches the patterns shaped as: + + double ax = (b[i+1] * a[i]); + double bx = (a[i+1] * b[i]); + + c[i] = c[i] - ax; + c[i+1] = c[i+1] + bx; + + If a match occurred then TRUE is returned, else FALSE. */ + + bool + matches (stmt_vec_info *stmts, int idx) + { + this->m_last_ifn = IFN_LAST; + this->m_vects.truncate (0); + this->m_vects.create (6); + int base = idx - (this->m_arity - 1); + this->m_last_idx = idx; + this->m_stmt_info = stmts[0]; + + complex_operation_t op1 = vect_detect_pair_op (base, this->m_node, NULL); + + if (op1 != MINUS_PLUS) + return false; + + slp_tree sub1a, sub1b, sub2; + /* Now operand1+3 must lead to another expression. */ + auto_vec<stmt_vec_info> args0; + complex_operation_t op2 + = vect_match_call_complex_mla_1 (this->m_node, &sub1a, 0, base, 0, + &args0); + + if (op2 != MULT_MULT) + return false; + + /* Now operand2+4 must lead to another expression. */ + auto_vec<stmt_vec_info> args1; + complex_operation_t op3 + = vect_match_call_complex_mla_1 (this->m_node, &sub1b, 1, base, 0, + &args1); + + if (op3 != MULT_MULT) + return false; + + /* Now operand2+4 may lead to another expression. */ + auto_vec<stmt_vec_info> args2; + complex_operation_t op4 + = vect_match_call_complex_mla_1 (sub1b, &sub2, 1, base, 0, &args2); + + if (op4 != CMPLX_NONE && op4 != NEG_NEG) + return false; + + if (op4 == CMPLX_NONE) + { + this->m_last_ifn = IFN_COMPLEX_MUL; + /* Correct the arguments after matching. */ + std::swap (args0[2], args1[0]); + } + else if (op4 == NEG_NEG) + { + this->m_last_ifn = IFN_COMPLEX_MUL_CONJ; + /* Check if the conjucate is on the first or second parameter. */ + if (args1[1] == args1[3] && args0[1] == args0[3]) + { + this->m_vects.quick_push (args0[3]); + this->m_vects.quick_push (args0[0]); + this->m_vects.quick_push (args2[0]); + this->m_vects.quick_push (args0[2]); + } + else + { + /* Correct the arguments after matching. */ + std::swap (args0[2], args2[0]); + } + } + + if (this->m_vects.length () == 0) + this->m_vects.splice (args0); + + return this->m_last_ifn != IFN_LAST && store_results (); + } +}; + #define SLP_PATTERN(x) &x::create VectPatternDecl slp_patterns[] { @@ -750,6 +923,7 @@ VectPatternDecl slp_patterns[] order patterns from the largest to the smallest. Especially if they overlap in what they can detect. */ + SLP_PATTERN (ComplexMulPattern), SLP_PATTERN (ComplexAddPattern), }; #undef SLP_PATTERN