Hi All, This patch adds pattern detections for the following operation:
Complex FMLA, Conjucate FMLA of the second parameter and FMLS. c += a * b, c += a * conj (b), c -= a * b and c -= a * conj (b) For the conjucate cases it supports under fast-math that the operands that is being conjucated be flipped by flipping the arguments to the optab. This allows it to support c = conj (a) * b and c += conj (a) * b. where a, b and c are complex numbers. Bootstrapped Regtested on aarch64-none-linux-gnu and no issues. Ok for master? Thanks, Tamar gcc/ChangeLog: * doc/md.texi: Document optabs. * internal-fn.def (COMPLEX_FMA, COMPLEX_FMA_CONJ, COMPLEX_FMS, COMPLEX_FMS_CONJ): New. * optabs.def (cmla_optab, cmla_conj_optab, cmls_optab, cmls_conj_optab): New. * tree-vect-slp-patterns.c (class ComplexFMAPattern): New. (slp_patterns): Add ComplexFMAPattern. --
diff --git a/gcc/doc/md.texi b/gcc/doc/md.texi index ddaf1abaccbd44dae11ea902ec38b474aacfb8e1..d8142f745050d963e8d15c7793fae06d9ad02020 100644 --- a/gcc/doc/md.texi +++ b/gcc/doc/md.texi @@ -6143,6 +6143,50 @@ rotations @var{m} of 90 or 270. This pattern is not allowed to @code{FAIL}. +@cindex @code{cmla@var{m}4} instruction pattern +@item @samp{cmla@var{m}4} +Perform a vector floating point multiply and accumulate of complex numbers +in operand 0, operand 1 and operand 2. + +The instruction must perform the operation on data loaded contiguously into the +vectors. +The operation is only supported for vector modes @var{m}. + +This pattern is not allowed to @code{FAIL}. + +@cindex @code{cmla_conj@var{m}4} instruction pattern +@item @samp{cmla_conj@var{m}4} +Perform a vector floating point multiply and accumulate of complex numbers +in operand 0, operand 1 and the conjucate of operand 2. + +The instruction must perform the operation on data loaded contiguously into the +vectors. +The operation is only supported for vector modes @var{m}. + +This pattern is not allowed to @code{FAIL}. + +@cindex @code{cmls@var{m}4} instruction pattern +@item @samp{cmls@var{m}4} +Perform a vector floating point multiply and subtract of complex numbers +in operand 0, operand 1 and operand 2. + +The instruction must perform the operation on data loaded contiguously into the +vectors. +The operation is only supported for vector modes @var{m}. + +This pattern is not allowed to @code{FAIL}. + +@cindex @code{cmls_conj@var{m}4} instruction pattern +@item @samp{cmls_conj@var{m}4} +Perform a vector floating point multiply and subtract of complex numbers +in operand 0, operand 1 and the conjucate of operand 2. + +The instruction must perform the operation on data loaded contiguously into the +vectors. +The operation is only supported for vector modes @var{m}. + +This pattern is not allowed to @code{FAIL}. + @cindex @code{cmul@var{m}4} instruction pattern @item @samp{cmul@var{m}4} Perform a vector floating point multiplication of complex numbers in operand 0 diff --git a/gcc/internal-fn.def b/gcc/internal-fn.def index 51bebf8701af262b22d66d19a29a8dafb74db1f0..cc0135cb2c1c14b593181edeaa5f896fa6c4c659 100644 --- a/gcc/internal-fn.def +++ b/gcc/internal-fn.def @@ -286,6 +286,10 @@ DEF_INTERNAL_FLT_FN (LDEXP, ECF_CONST, ldexp, binary) /* Ternary math functions. */ DEF_INTERNAL_FLT_FLOATN_FN (FMA, ECF_CONST, fma, ternary) +DEF_INTERNAL_OPTAB_FN (COMPLEX_FMA, ECF_CONST, cmla, ternary) +DEF_INTERNAL_OPTAB_FN (COMPLEX_FMA_CONJ, ECF_CONST, cmla_conj, ternary) +DEF_INTERNAL_OPTAB_FN (COMPLEX_FMS, ECF_CONST, cmls, ternary) +DEF_INTERNAL_OPTAB_FN (COMPLEX_FMS_CONJ, ECF_CONST, cmls_conj, ternary) /* Unary integer ops. */ DEF_INTERNAL_INT_FN (CLRSB, ECF_CONST | ECF_NOTHROW, clrsb, unary) diff --git a/gcc/optabs.def b/gcc/optabs.def index 9c267d422478d0011f288b1f5f62daabe3989ba7..19db9c00896cd08adfd20a01669990bbbebd79f1 100644 --- a/gcc/optabs.def +++ b/gcc/optabs.def @@ -294,6 +294,10 @@ OPTAB_D (cadd90_optab, "cadd90$a3") OPTAB_D (cadd270_optab, "cadd270$a3") OPTAB_D (cmul_optab, "cmul$a3") OPTAB_D (cmul_conj_optab, "cmul_conj$a3") +OPTAB_D (cmla_optab, "cmla$a4") +OPTAB_D (cmla_conj_optab, "cmla_conj$a4") +OPTAB_D (cmls_optab, "cmls$a4") +OPTAB_D (cmls_conj_optab, "cmls_conj$a4") OPTAB_D (cos_optab, "cos$a2") OPTAB_D (cosh_optab, "cosh$a2") OPTAB_D (exp10_optab, "exp10$a2") diff --git a/gcc/tree-vect-slp-patterns.c b/gcc/tree-vect-slp-patterns.c index bef7cc73b21c020e4c0128df5d186a034809b103..d9554aaaf2cce14bb5b9c68e6141ea7f555a35de 100644 --- a/gcc/tree-vect-slp-patterns.c +++ b/gcc/tree-vect-slp-patterns.c @@ -916,6 +916,199 @@ class ComplexMulPattern : public ComplexMLAPattern } }; +class ComplexFMAPattern : public ComplexMLAPattern +{ + protected: + ComplexFMAPattern (slp_tree node, vec_info *vinfo) + : ComplexMLAPattern (node, vinfo) + { + this->m_arity = 2; + this->m_num_args = 3; + this->m_vects.create (0); + this->m_defs.create (0); + } + + public: + ~ComplexFMAPattern () + { + this->m_vects.release (); + this->m_defs.release (); + } + + static VectPattern* create (slp_tree node, vec_info *vinfo) + { + return new ComplexFMAPattern (node, vinfo); + } + + const char* get_name () + { + return "Complex FM(A|S)"; + } + + /* Pattern matcher for trying to match complex multiply and accumulate + pattern in SLP tree using N statements STMT_0 and STMT_0 as the root + statements by finding the statements starting in position IDX in NODE. + If the operation matches then IFN is set to the operation it matched and + the arguments to the two replacement statements are put in VECTS. + + If no match is found then IFN is set to IFN_LAST and VECTS is unchanged. + + This function matches the patterns shaped as: + + double ax = (b[i+1] * a[i]) + (b[i] * a[i]); + double bx = (a[i+1] * b[i]) - (a[i+1] * b[i+1]); + + c[i] = c[i] - ax; + c[i+1] = c[i+1] + bx; + + If a match occurred then TRUE is returned, else FALSE. */ + bool + matches (stmt_vec_info *stmts, int idx) + { + this->m_last_ifn = IFN_LAST; + this->m_vects.truncate (0); + this->m_vects.create (6); + int base = idx - (this->m_arity - 1); + this->m_last_idx = idx; + slp_tree node = this->m_node; + this->m_stmt_info = stmts[0]; + + + /* Find the two components. Rotation in the complex plane will modify + the operations: + + * Rotation 0: + + + * Rotation 90: - + + * Rotation 180: - - + * Rotation 270: + -. */ + auto_vec<stmt_vec_info> args0; + complex_operation_t op1 = vect_detect_pair_op (base, node, &args0); + + if (op1 == CMPLX_NONE) + return false; + + slp_tree sub1, sub2a, sub2b, sub3; + + /* Now operand2+4 must lead to another expression. */ + auto_vec<stmt_vec_info> args1; + complex_operation_t op2 + = vect_match_call_complex_mla_1 (node, &sub1, 1, base, 0, &args1); + + if (op2 != MINUS_PLUS && op2 != PLUS_MINUS) + return false; + + /* Now operand1+3 must lead to another expression. */ + auto_vec<stmt_vec_info> args2; + complex_operation_t op3 + = vect_match_call_complex_mla_1 (sub1, &sub2a, 0, base, 0, &args2); + + if (op3 != MULT_MULT) + return false; + + /* Now operand2+4 must lead to another expression. */ + auto_vec<stmt_vec_info> args3; + complex_operation_t op4 + = vect_match_call_complex_mla_1 (sub1, &sub2b, 1, base, 0, &args3); + + if (op4 != MULT_MULT) + return false; + + /* Now operand2+4 may lead to another expression. */ + auto_vec<stmt_vec_info> args4; + complex_operation_t op5 + = vect_match_call_complex_mla_1 (sub2b, &sub3, 1, base, 0, &args4); + + /* Or operand1+3 may lead to another expression. */ + auto_vec<stmt_vec_info> args5; + complex_operation_t op6 + = vect_match_call_complex_mla_1 (sub2b, &sub3, 0, base, 0, &args5); + + if (op1 == PLUS_MINUS && op2 == MINUS_PLUS) + { + + /* The FMS conjucate has a different layout so check that. */ + if (op5 == CMPLX_NONE && op6 == CMPLX_NONE) + { + op6 = vect_match_call_complex_mla_1 (sub2a, &sub3, 0, base, 0, + &args5); + if (op6 == CMPLX_NONE) + op6 = vect_match_call_complex_mla_1 (sub2a, &sub3, 1, base, 0, + &args5); + } + if (op5 == CMPLX_NONE && op6 != NEG_NEG) + this->m_last_ifn = IFN_COMPLEX_FMS; + else if (op5 == NEG_NEG || op6 == NEG_NEG) + this->m_last_ifn = IFN_COMPLEX_FMS_CONJ; + } + else if (op1 == PLUS_PLUS && op2 == MINUS_PLUS) + { + if (op5 == CMPLX_NONE && op6 != NEG_NEG) + this->m_last_ifn = IFN_COMPLEX_FMA; + else if (op5 == NEG_NEG || op6 == NEG_NEG) + this->m_last_ifn = IFN_COMPLEX_FMA_CONJ; + } + + if (this->m_last_ifn == IFN_LAST) + return false; + + if (this->m_last_ifn == IFN_COMPLEX_FMA_CONJ) + { + /* Check if the conjucate is on the first or second parameter. */ + if (op5 == NEG_NEG) + { + this->m_vects.quick_push (args0[0]); + this->m_vects.quick_push (args2[2]); + this->m_vects.quick_push (args3[2]); + this->m_vects.quick_push (args0[2]); + this->m_vects.quick_push (args4[0]); + this->m_vects.quick_push (args2[3]); + } + else + { + this->m_vects.quick_push (args0[0]); + this->m_vects.quick_push (args2[3]); + this->m_vects.quick_push (args2[0]); + this->m_vects.quick_push (args0[2]); + this->m_vects.quick_push (args5[0]); + this->m_vects.quick_push (args2[2]); + } + } + else if (this->m_last_ifn == IFN_COMPLEX_FMS_CONJ) + { + /* Check if the conjucate is on the first or second parameter. */ + if (op6 == NEG_NEG) + { + this->m_vects.quick_push (args0[0]); + this->m_vects.quick_push (args3[1]); + this->m_vects.quick_push (args2[3]); + this->m_vects.quick_push (args0[2]); + this->m_vects.quick_push (args5[0]); + this->m_vects.quick_push (args2[1]); + } + else + { + this->m_vects.quick_push (args0[0]); + this->m_vects.quick_push (args2[2]); + this->m_vects.quick_push (args3[2]); + this->m_vects.quick_push (args0[2]); + this->m_vects.quick_push (args2[0]); + this->m_vects.quick_push (args5[0]); + } + } + else + { + this->m_vects.quick_push (args0[0]); + this->m_vects.quick_push (args2[3]); + this->m_vects.quick_push (args3[2]); + this->m_vects.quick_push (args0[2]); + this->m_vects.quick_push (args3[3]); + this->m_vects.quick_push (args2[2]); + } + + return store_results (); + } +}; + #define SLP_PATTERN(x) &x::create VectPatternDecl slp_patterns[] { @@ -923,6 +1116,7 @@ VectPatternDecl slp_patterns[] order patterns from the largest to the smallest. Especially if they overlap in what they can detect. */ + SLP_PATTERN (ComplexFMAPattern), SLP_PATTERN (ComplexMulPattern), SLP_PATTERN (ComplexAddPattern), };