Hi All,

This is a respin of the patch using the new approach.

Bootstrapped Regtested on aarch64-none-linux-gnu and no issues.

Ok for master?

Thanks,
Tamar

gcc/ChangeLog:

        * doc/md.texi: Document optabs.
        * internal-fn.def (COMPLEX_FMA, COMPLEX_FMA_CONJ, COMPLEX_FMS,
        COMPLEX_FMS_CONJ): New.
        * optabs.def (cmla_optab, cmla_conj_optab, cmls_optab, cmls_conj_optab):
        New.
        * tree-vect-slp-patterns.c (class complex_fma_pattern,
        complex_fma_pattern::matches): New.
        (slp_patterns): Add complex_fma_pattern.

> -----Original Message-----
> From: Gcc-patches <gcc-patches-boun...@gcc.gnu.org> On Behalf Of Tamar
> Christina
> Sent: Friday, September 25, 2020 3:30 PM
> To: gcc-patches@gcc.gnu.org
> Cc: nd <n...@arm.com>; rguent...@suse.de; o...@ucw.cz
> Subject: [PATCH v2 8/16]middle-end: add Complex Multiply and
> Accumulate/Subtract and Multiply and Accumulate/Subtract with Conjucate
> detection
> 
> Hi All,
> 
> This patch adds pattern detections for the following operation:
> 
>   Complex FMLA, Conjucate FMLA of the second parameter and FMLS.
> 
>     c += a * b, c += a * conj (b), c -= a * b and c -= a * conj (b)
> 
>   For the conjucate cases it supports under fast-math that the operands that
> is
>   being conjucated be flipped by flipping the arguments to the optab.  This
>   allows it to support c = conj (a) * b and c += conj (a) * b.
> 
>   where a, b and c are complex numbers.
> 
> Bootstrapped Regtested on aarch64-none-linux-gnu and no issues.
> 
> Ok for master?
> 
> Thanks,
> Tamar
> 
> gcc/ChangeLog:
> 
>       * doc/md.texi: Document optabs.
>       * internal-fn.def (COMPLEX_FMA, COMPLEX_FMA_CONJ,
> COMPLEX_FMS,
>       COMPLEX_FMS_CONJ): New.
>       * optabs.def (cmla_optab, cmla_conj_optab, cmls_optab,
> cmls_conj_optab):
>       New.
>       * tree-vect-slp-patterns.c (class ComplexFMAPattern): New.
>       (slp_patterns): Add ComplexFMAPattern.
> 
> --
diff --git a/gcc/doc/md.texi b/gcc/doc/md.texi
index ddaf1abaccbd44dae11ea902ec38b474aacfb8e1..d8142f745050d963e8d15c7793fae06d9ad02020 100644
--- a/gcc/doc/md.texi
+++ b/gcc/doc/md.texi
@@ -6143,6 +6143,50 @@ rotations @var{m} of 90 or 270.
 
 This pattern is not allowed to @code{FAIL}.
 
+@cindex @code{cmla@var{m}4} instruction pattern
+@item @samp{cmla@var{m}4}
+Perform a vector floating point multiply and accumulate of complex numbers
+in operand 0, operand 1 and operand 2.
+
+The instruction must perform the operation on data loaded contiguously into the
+vectors.
+The operation is only supported for vector modes @var{m}.
+
+This pattern is not allowed to @code{FAIL}.
+
+@cindex @code{cmla_conj@var{m}4} instruction pattern
+@item @samp{cmla_conj@var{m}4}
+Perform a vector floating point multiply and accumulate of complex numbers
+in operand 0, operand 1 and the conjucate of operand 2.
+
+The instruction must perform the operation on data loaded contiguously into the
+vectors.
+The operation is only supported for vector modes @var{m}.
+
+This pattern is not allowed to @code{FAIL}.
+
+@cindex @code{cmls@var{m}4} instruction pattern
+@item @samp{cmls@var{m}4}
+Perform a vector floating point multiply and subtract of complex numbers
+in operand 0, operand 1 and operand 2.
+
+The instruction must perform the operation on data loaded contiguously into the
+vectors.
+The operation is only supported for vector modes @var{m}.
+
+This pattern is not allowed to @code{FAIL}.
+
+@cindex @code{cmls_conj@var{m}4} instruction pattern
+@item @samp{cmls_conj@var{m}4}
+Perform a vector floating point multiply and subtract of complex numbers
+in operand 0, operand 1 and the conjucate of operand 2.
+
+The instruction must perform the operation on data loaded contiguously into the
+vectors.
+The operation is only supported for vector modes @var{m}.
+
+This pattern is not allowed to @code{FAIL}.
+
 @cindex @code{cmul@var{m}4} instruction pattern
 @item @samp{cmul@var{m}4}
 Perform a vector floating point multiplication of complex numbers in operand 0
diff --git a/gcc/internal-fn.def b/gcc/internal-fn.def
index cb41643f5e332518a0271bb8e1af4883c8bd6880..acb7d9f3bdc757437d5492a652144ba31c2ef702 100644
--- a/gcc/internal-fn.def
+++ b/gcc/internal-fn.def
@@ -288,6 +288,10 @@ DEF_INTERNAL_FLT_FN (LDEXP, ECF_CONST, ldexp, binary)
 
 /* Ternary math functions.  */
 DEF_INTERNAL_FLT_FLOATN_FN (FMA, ECF_CONST, fma, ternary)
+DEF_INTERNAL_OPTAB_FN (COMPLEX_FMA, ECF_CONST, cmla, ternary)
+DEF_INTERNAL_OPTAB_FN (COMPLEX_FMA_CONJ, ECF_CONST, cmla_conj, ternary)
+DEF_INTERNAL_OPTAB_FN (COMPLEX_FMS, ECF_CONST, cmls, ternary)
+DEF_INTERNAL_OPTAB_FN (COMPLEX_FMS_CONJ, ECF_CONST, cmls_conj, ternary)
 
 /* Unary integer ops.  */
 DEF_INTERNAL_INT_FN (CLRSB, ECF_CONST | ECF_NOTHROW, clrsb, unary)
diff --git a/gcc/optabs.def b/gcc/optabs.def
index 9c267d422478d0011f288b1f5f62daabe3989ba7..19db9c00896cd08adfd20a01669990bbbebd79f1 100644
--- a/gcc/optabs.def
+++ b/gcc/optabs.def
@@ -294,6 +294,10 @@ OPTAB_D (cadd90_optab, "cadd90$a3")
 OPTAB_D (cadd270_optab, "cadd270$a3")
 OPTAB_D (cmul_optab, "cmul$a3")
 OPTAB_D (cmul_conj_optab, "cmul_conj$a3")
+OPTAB_D (cmla_optab, "cmla$a4")
+OPTAB_D (cmla_conj_optab, "cmla_conj$a4")
+OPTAB_D (cmls_optab, "cmls$a4")
+OPTAB_D (cmls_conj_optab, "cmls_conj$a4")
 OPTAB_D (cos_optab, "cos$a2")
 OPTAB_D (cosh_optab, "cosh$a2")
 OPTAB_D (exp10_optab, "exp10$a2")
diff --git a/gcc/tree-vect-slp-patterns.c b/gcc/tree-vect-slp-patterns.c
index 2edb0117f9cbbfc40e9ed3a96120a3c88f84a68e..c2987c2afac2fbd55e2acd6b56fc13c7d3ad13c1 100644
--- a/gcc/tree-vect-slp-patterns.c
+++ b/gcc/tree-vect-slp-patterns.c
@@ -1172,6 +1172,176 @@ complex_mul_pattern::validate_p ()
   return true;
 }
 
+
+/*******************************************************************************
+ * complex_fma_pattern class
+ ******************************************************************************/
+
+class complex_fma_pattern : public complex_mul_pattern
+{
+  protected:
+    complex_fma_pattern (slp_tree *node, vec_info *vinfo)
+      : complex_mul_pattern (node, vinfo)
+    {
+      this->m_arity = 2;
+      this->m_num_args = 3;
+    }
+
+  public:
+    static vect_pattern* create (slp_tree *node, vec_info *vinfo)
+    {
+       return new complex_fma_pattern (node, vinfo);
+    }
+
+    const char* get_name ()
+    {
+      return "Complex FM(A|S)";
+    }
+
+    bool matches ();
+    bool matches (complex_operation_t op, vec<slp_tree> ops);
+};
+
+/* Pattern matcher for trying to match complex multiply and accumulate
+   and multiply and subtract patterns in SLP tree.
+   If the operation matches then IFN is set to the operation it matched and
+   the arguments to the two replacement statements are put in M_OPS.
+
+   If no match is found then IFN is set to IFN_LAST and M_OPTS is unchanged.
+
+   This function matches the patterns shaped as:
+
+   double ax = (b[i+1] * a[i]) + (b[i] * a[i]);
+   double bx = (a[i+1] * b[i]) - (a[i+1] * b[i+1]);
+
+   c[i] = c[i] - ax;
+   c[i+1] = c[i+1] + bx;
+
+   If a match occurred then TRUE is returned, else FALSE.  */
+bool
+complex_fma_pattern::matches (complex_operation_t op1, vec<slp_tree> args0)
+{
+  this->m_ifn = IFN_LAST;
+
+  /* Find the two components.  We match Complex MUL first which reduces the
+     amount of work this pattern has to do.  After that we just match the
+     head node and we're done.:
+
+     * FMA: + +
+     * FMS: - +.  */
+  slp_tree child = NULL;
+
+  /* We need to ignore the two_operands nodes that may also match,
+     for that we can check if they have any scalar statements and also
+     check that it's not a permute node as we're looking for a normal
+     PLUS_EXPR operation.  */
+  if (op1 == PLUS_MINUS)
+    {
+      child = SLP_TREE_CHILDREN (args0[1])[1];
+    }
+  else if (SLP_TREE_SCALAR_STMTS (*this->m_node).length () > 0
+	   && SLP_TREE_CODE (*this->m_node) != VEC_PERM_EXPR
+	   && vect_match_expression_p (*this->m_node, PLUS_EXPR))
+    {
+      if (SLP_TREE_CHILDREN (*this->m_node).length () != 2)
+	  return false;
+
+      op1 = PLUS_PLUS;
+      args0.safe_splice (SLP_TREE_CHILDREN (*this->m_node));
+      child = args0[1];
+    }
+  else
+    return false;
+
+  auto_vec<slp_tree> ops;
+  internal_fn mulfn = IFN_LAST;
+  /* The accumulation step produces an inverse tree from normal
+     multiply so match the nodes in reverse.  */
+  if (!vect_slp_matches_complex_mul (child, &mulfn, &ops, false,
+				     op1 == PLUS_MINUS))
+    return false;
+
+  this->m_ops.create (6);
+  if (op1 == PLUS_MINUS)
+    {
+      if (mulfn == IFN_COMPLEX_MUL)
+	this->m_ifn = IFN_COMPLEX_FMS;
+      else if (mulfn == IFN_COMPLEX_MUL_CONJ)
+	this->m_ifn = IFN_COMPLEX_FMS_CONJ;
+
+      child = SLP_TREE_CHILDREN (args0[0])[0];
+      this->workset.safe_splice (SLP_TREE_CHILDREN (*this->m_node));
+      save_match ();
+    }
+  else if (op1 == PLUS_PLUS)
+    {
+      if (mulfn == IFN_COMPLEX_MUL)
+	this->m_ifn = IFN_COMPLEX_FMA;
+      else if (mulfn == IFN_COMPLEX_MUL_CONJ)
+	this->m_ifn = IFN_COMPLEX_FMA_CONJ;
+
+      /* Add doesn't generate a two_operators node, so for it we replace it
+	 inline by turning the add node itself into a pattern.  */
+      this->m_inplace = true;
+      this->workset.safe_push (*this->m_node);
+      child = args0[0];
+      this->m_match
+	= new vect_simple_pattern_match (this->m_arity, this->m_ifn,
+					 this->m_vinfo, &this->workset,
+					 this->m_num_args);
+    }
+
+  if (this->m_ifn == IFN_LAST)
+    return false;
+
+  /* The conjucate nodes have a different orderings, oddly enough the SUB node
+     has the same order regardless of the conjucate.  This needs to be made more
+     consistent in the mid-end.  */
+  if (op1 == PLUS_MINUS || mulfn == IFN_COMPLEX_MUL)
+    {
+      this->m_ops.quick_push (child);
+      this->m_ops.quick_push (ops[1]);
+      this->m_ops.quick_push (ops[0]);
+      this->m_ops.quick_push (child);
+      this->m_ops.quick_push (ops[3]);
+      this->m_ops.quick_push (ops[2]);
+    }
+  else
+    {
+      this->m_ops.quick_push (child);
+      this->m_ops.quick_push (ops[0]);
+      this->m_ops.quick_push (ops[1]);
+      this->m_ops.quick_push (child);
+      this->m_ops.quick_push (ops[2]);
+      this->m_ops.quick_push (ops[3]);
+    }
+
+  vect_build_perm_groups (&this->m_blocks[0], this->m_ops);
+
+  /* Unfortunately the sequence for a conjucate and rotation by 180 and 270 are
+     remarkably similar.  So we need to do some extra checks to make sure we
+     don't match those.  */
+  if (mulfn == IFN_COMPLEX_MUL_CONJ)
+    for (unsigned i = 0; i < this->m_ops.length (); i++)
+      {
+	map_t m = this->m_blocks[i];
+	if (m.a > m.b)
+	  return false;
+      }
+
+  return true;
+}
+
+bool
+complex_fma_pattern::matches ()
+{
+  auto_vec<slp_tree> args0;
+  complex_operation_t op
+    = vect_detect_pair_op (*this->m_node, true, &args0);
+  return matches (op, args0);
+}
+
+
 /*******************************************************************************
  * complex_operations_pattern class
  ******************************************************************************/
@@ -1303,6 +1473,10 @@ vect_pattern_decl_t slp_patterns[]
      order patterns from the largest to the smallest.  Especially if they
      overlap in what they can detect.  */
 
+  /* FMA overlaps with MUL but is the longer sequence.  Because we're in post
+     order traversal we can't match FMA if included in
+     complex_operations_pattern so must be checked on it's own.  */
+  SLP_PATTERN (complex_fma_pattern),
   SLP_PATTERN (complex_operations_pattern),
 };
 #undef SLP_PATTERN

Reply via email to