Hi All,

This is a respin of this patch using the new approach.

Thanks,
Tamar

gcc/ChangeLog:

        * doc/md.texi: Document optabs.
        * internal-fn.def (COMPLEX_MUL, COMPLEX_MUL_CONJ): New.
        * optabs.def (cmul_optab, cmul_conj_optab): New,
        * tree-vect-slp-patterns.c (vect_build_perm_groups,
        (vect_can_combine_node_p, vect_slp_make_combine_linear,
        vect_match_call_complex_mla, vect_slp_matches_complex_mul,
        class complex_mul_pattern, complex_mul_pattern::matches,
        complex_mul_pattern::validate_p,
        complex_operations_pattern::matches): Add complex_mul_pattern.


> -----Original Message-----
> From: Gcc-patches <gcc-patches-boun...@gcc.gnu.org> On Behalf Of Tamar
> Christina
> Sent: Friday, September 25, 2020 3:29 PM
> To: gcc-patches@gcc.gnu.org
> Cc: nd <n...@arm.com>; rguent...@suse.de; o...@ucw.cz
> Subject: [PATCH v2 7/16]middle-end: Add Complex Multiplication and
> Multiplication with Conjucate detection
> 
> Hi All,
> 
> This patch adds pattern detections for the following operation:
> 
>   Complex multiplication and Conjucate Complex multiplication of the second
>      parameter.
> 
>     c = a * b and c = a * conj (b)
> 
>   For the conjucate cases it supports under fast-math that the operands that
> is
>   being conjucated be flipped by flipping the arguments to the optab.  This
>   allows it to support c = conj (a) * b and c += conj (a) * b.
> 
>   where a, b and c are complex numbers.
> 
> and provides a shared class for anything needing to recognize complex MLA
> patterns.
> 
> Bootstrapped Regtested on aarch64-none-linux-gnu and no issues.
> 
> Ok for master?
> 
> Thanks,
> Tamar
> 
> gcc/ChangeLog:
> 
>       * doc/md.texi: Document optabs.
>       * internal-fn.def (COMPLEX_MUL, COMPLEX_MUL_CONJ): New.
>       * optabs.def (cmul_optab, cmul_conj_optab): New,
>       * tree-vect-slp-patterns.c (class ComplexMLAPattern,
>       class ComplexMulPattern): New.
>       (slp_patterns): Add ComplexMulPattern.
> 
> --
diff --git a/gcc/doc/md.texi b/gcc/doc/md.texi
index 71e226505b2619d10982b59a4ebbed73a70f29be..ddaf1abaccbd44dae11ea902ec38b474aacfb8e1 100644
--- a/gcc/doc/md.texi
+++ b/gcc/doc/md.texi
@@ -6143,6 +6143,28 @@ rotations @var{m} of 90 or 270.
 
 This pattern is not allowed to @code{FAIL}.
 
+@cindex @code{cmul@var{m}4} instruction pattern
+@item @samp{cmul@var{m}4}
+Perform a vector floating point multiplication of complex numbers in operand 0
+and operand 1.
+
+The instruction must perform the operation on data loaded contiguously into the
+vectors.
+The operation is only supported for vector modes @var{m}.
+
+This pattern is not allowed to @code{FAIL}.
+
+@cindex @code{cmul_conj@var{m}4} instruction pattern
+@item @samp{cmul_conj@var{m}4}
+Perform a vector floating point multiplication of complex numbers in operand 0
+and the conjucate of operand 1.
+
+The instruction must perform the operation on data loaded contiguously into the
+vectors.
+The operation is only supported for vector modes @var{m}.
+
+This pattern is not allowed to @code{FAIL}.
+
 @cindex @code{ffs@var{m}2} instruction pattern
 @item @samp{ffs@var{m}2}
 Store into operand 0 one plus the index of the least significant 1-bit
diff --git a/gcc/internal-fn.def b/gcc/internal-fn.def
index 33c54be1e158ddea25c4cd6b1148df8cf4a509b5..cb41643f5e332518a0271bb8e1af4883c8bd6880 100644
--- a/gcc/internal-fn.def
+++ b/gcc/internal-fn.def
@@ -279,6 +279,8 @@ DEF_INTERNAL_FLT_FLOATN_FN (FMAX, ECF_CONST, fmax, binary)
 DEF_INTERNAL_OPTAB_FN (XORSIGN, ECF_CONST, xorsign, binary)
 DEF_INTERNAL_OPTAB_FN (COMPLEX_ADD_ROT90, ECF_CONST, cadd90, binary)
 DEF_INTERNAL_OPTAB_FN (COMPLEX_ADD_ROT270, ECF_CONST, cadd270, binary)
+DEF_INTERNAL_OPTAB_FN (COMPLEX_MUL, ECF_CONST, cmul, binary)
+DEF_INTERNAL_OPTAB_FN (COMPLEX_MUL_CONJ, ECF_CONST, cmul_conj, binary)
 
 
 /* FP scales.  */
diff --git a/gcc/optabs.def b/gcc/optabs.def
index 2bb0bf857977035bf562a77f5f6848e80edf936d..9c267d422478d0011f288b1f5f62daabe3989ba7 100644
--- a/gcc/optabs.def
+++ b/gcc/optabs.def
@@ -292,6 +292,8 @@ OPTAB_D (copysign_optab, "copysign$F$a3")
 OPTAB_D (xorsign_optab, "xorsign$F$a3")
 OPTAB_D (cadd90_optab, "cadd90$a3")
 OPTAB_D (cadd270_optab, "cadd270$a3")
+OPTAB_D (cmul_optab, "cmul$a3")
+OPTAB_D (cmul_conj_optab, "cmul_conj$a3")
 OPTAB_D (cos_optab, "cos$a2")
 OPTAB_D (cosh_optab, "cosh$a2")
 OPTAB_D (exp10_optab, "exp10$a2")
diff --git a/gcc/tree-vect-slp-patterns.c b/gcc/tree-vect-slp-patterns.c
index 0732cf0a6d93be8590b84c39dff82940b280e46b..2edb0117f9cbbfc40e9ed3a96120a3c88f84a68e 100644
--- a/gcc/tree-vect-slp-patterns.c
+++ b/gcc/tree-vect-slp-patterns.c
@@ -196,6 +196,65 @@ linear_loads_p (slp_tree root, bool *linear)
   return loads;
 }
 
+/* Builds a permutation group from the operands in OPS and stores it in BLOCKS.
+   The group describes how to combine the operators to get a valid linear node.
+
+   This is used when combining multiple children from a two_operators node into
+   one using a lane permute to select the appropriate lane. As an example the
+   permute { [0 0] [1 4] [2 2] [3 3] [1 4] [5 5] } says the nodes which occur
+   twice in a group, e.g [0 0] only needs itself to possibly be made linear
+   whereas [1 4] means to combine the nodes 1 and 4. */
+
+static void
+vect_build_perm_groups (map_t *blocks, vec<slp_tree> ops)
+{
+  slp_tree op;
+  unsigned i;
+  bool is_linear = false;
+  unsigned min_eq = -1, max_eq = 0;
+  unsigned min_idx = 0, max_idx = 0;
+  FOR_EACH_VEC_ELT (ops, i, op)
+    {
+      load_permutation_t perms = linear_loads_p (op, &is_linear);
+      unsigned x, imin = -1, imax = 0;
+      for (x = 0; x < perms.length () && !is_linear; x++)
+	{
+	  imin = MIN (imin, perms[x]);
+	  imax = MAX (imax, perms[x]);
+	}
+
+	if (imin != imax || perms.length () == 0 || is_linear)
+	  blocks[i] = {i, i};
+	else
+	  {
+	    if (imin <= min_eq)
+	      {
+		min_eq = imin;
+		min_idx = i;
+	      }
+
+	    if (imin >= max_eq)
+	      {
+		max_eq = imin;
+		max_idx = i;
+	      }
+	  }
+    }
+
+  /* Now fill in the gap.  */
+  blocks[min_idx] = {min_idx, max_idx};
+  blocks[max_idx] = {min_idx, max_idx};
+
+  if (dump_enabled_p ())
+    {
+      dump_printf_loc (MSG_NOTE, vect_location, "pattern group: { ");
+      for (i = 0; i < ops.length (); i++)
+	dump_printf (MSG_NOTE,"[%d %d] ", blocks[i].a, blocks[i].b);
+      dump_printf (MSG_NOTE,"}\n");
+    }
+
+}
+
 /* This function attempts to make a node rooted in NODE linear.  If the node
    if already linear than the node itself is returned in RESULT.
 
@@ -265,6 +324,85 @@ vect_slp_make_linear (slp_tree parent, slp_tree node, slp_tree *result)
   return is_linear;
 }
 
+/* Helper utility to check to see if the permutation PERM is one that can be
+   used in a node combination operation.  This is defined as the permute not
+   having all the elements being the same. e.g [0 0].  */
+
+static inline bool
+vect_can_combine_node_p (load_permutation_t perm, bool is_linear)
+{
+  if (is_linear)
+    return false;
+
+  unsigned i, x;
+  FOR_EACH_VEC_ELT (perm, i, x)
+    if (perm[0] != x)
+      return false;
+
+  return true;
+}
+
+/* This function combines the nodes in MAP together to make a new node using a
+   lane permute.  The nodes to combine are stored in ENTRIES and the resulting
+   node is returned in RESULT.
+
+   If the nodes are already linear then this function fails and returns FALSE.
+   Otherwise it returns the new node and TRUE.  */
+
+static bool
+vect_slp_make_combine_linear (slp_tree parent, vec<slp_tree> entries, map_t map,
+			      slp_tree *result)
+{
+  if (map.a == map.b)
+    return false;
+
+  slp_tree node_a = entries[map.a];
+  slp_tree node_b = entries[map.b];
+
+  bool is_a_linear = false;
+  bool is_b_linear = false;
+
+  load_permutation_t load_perm_a = linear_loads_p (node_a, &is_a_linear);
+  if (!vect_can_combine_node_p (load_perm_a, is_a_linear))
+    return false;
+  load_permutation_t load_perm_b = linear_loads_p (node_b, &is_b_linear);
+  if (!vect_can_combine_node_p (load_perm_b, is_b_linear))
+    return false;
+
+  /* Now we need to figure which node is first.  */
+  auto_vec<slp_tree> nodes;
+  nodes.create (2);
+  vec<std::pair<unsigned, unsigned> > perm;
+  perm.create (2);
+  if (load_perm_a[0] < load_perm_b[0])
+    {
+      perm.quick_push (std::make_pair (0, 0));
+      perm.quick_push (std::make_pair (1, 0));
+    }
+  else
+    {
+      perm.quick_push (std::make_pair (1, 0));
+      perm.quick_push (std::make_pair (0, 0));
+    }
+
+  nodes.quick_push (node_a);
+  nodes.quick_push (node_b);
+  /* Already connected to a, just need b.  */
+  SLP_TREE_REF_COUNT (node_a)++;
+  SLP_TREE_REF_COUNT (node_b)++;
+
+  slp_tree vnode = vect_create_new_slp_node (vNULL, 1);
+  SLP_TREE_CODE (vnode) = VEC_PERM_EXPR;
+  SLP_TREE_LANE_PERMUTATION (vnode) = perm;
+  SLP_TREE_VECTYPE (vnode) = SLP_TREE_VECTYPE (parent);
+  SLP_TREE_CHILDREN (vnode).safe_splice (nodes);
+  SLP_TREE_REF_COUNT (vnode) = 1;
+  SLP_TREE_LANES (vnode) = SLP_TREE_LANES (parent);
+  SLP_TREE_REPRESENTATIVE (vnode) = SLP_TREE_REPRESENTATIVE (parent);
+  *result = vnode;
+  return true;
+}
+
 /*******************************************************************************
  * Simple vector pattern matcher
  ******************************************************************************/
@@ -727,6 +865,313 @@ complex_add_pattern::matches ()
   return matches (op, this->m_ops);
 }
 
+/*******************************************************************************
+ * complex_mul_pattern
+ ******************************************************************************/
+
+/* Helper function of that looks for a match in the CHILDth child of NODE.  The
+   child used is stored in RES.
+
+   If the match is successful then ARGS will contain the operands matched
+   and the complex_operation_t type is returned.  If match is not successful
+   then CMPLX_NONE is returned and ARGS is left unmodified.  */
+
+static complex_operation_t
+vect_match_call_complex_mla (slp_tree node, unsigned child,
+			     vec<slp_tree> *args = NULL, slp_tree *res = NULL)
+{
+  gcc_assert (child < SLP_TREE_CHILDREN (node).length ());
+
+  slp_tree data = SLP_TREE_CHILDREN (node)[child];
+
+  if (res)
+    *res = data;
+
+  return vect_detect_pair_op (data, false, args);
+}
+
+/* This helper attemps to find a complex MUL pattern rooted in ROOT.  If the
+   match succeeds then the pattern type is set in IFN and the operands are
+   returned in OPS.
+
+   This function matches both a normal complex multiply and complex conjucate
+   multiply.  Additionally it also matches the MUL part in a FMS and FMA
+   sequence.  However due to the additional TWO_OPERATORS node that an FMS
+   has the location of the negate node that denotes a conjucate changes.
+
+   In order to differentiate when and where we should check for a conjucate
+   the value MULTIPLY is set when this should only match a normal complex
+   multiply operation and INVERSE is set when we're matching a sequence for an
+   FMS where the negate node is on the other side.
+
+   Note that this function also deals with that the canonicalization of the
+   sequence is off if there is a type cast in between.  This is likely a mid-end
+   bug but for now we deal with it here.  */
+static bool
+vect_slp_matches_complex_mul (complex_operation_t op, slp_tree root,
+			      internal_fn *ifn, vec<slp_tree> *ops,
+			      bool multiply, bool inverse = false)
+{
+  *ifn = IFN_LAST;
+
+  if (op != MINUS_PLUS)
+    return false;
+
+  /* Now operand1+3 must lead to another expression.  */
+  auto_vec<slp_tree> args0;
+  complex_operation_t op2 = vect_match_call_complex_mla (root, 0, &args0);
+
+  if (op2 != MULT_MULT)
+    return false;
+
+  /* Now operand2+4 must lead to another expression.  */
+  auto_vec<slp_tree> args1;
+  complex_operation_t op3 = vect_match_call_complex_mla (root, 1, &args1);
+
+  if (op3 != MULT_MULT)
+    return false;
+
+  vec<slp_tree> args2 = SLP_TREE_CHILDREN (args1[inverse ? 0 : 1]);
+  slp_tree neg_node = NULL;
+  bool first_neg = false, second_neg = false;
+
+  /* Now operand2+4 may lead to another expression.  */
+  if ((first_neg = vect_match_expression_p (args2[0], NEGATE_EXPR)))
+     neg_node = SLP_TREE_CHILDREN (args2[0])[0];
+  else if ((second_neg = vect_match_expression_p (args2[1], NEGATE_EXPR)))
+     neg_node = SLP_TREE_CHILDREN (args2[1])[0];
+
+  if (first_neg && multiply)
+    return false;
+
+  /* Check if the neg node is a dup, otherwise not a pattern we want.  */
+  bool is_dup = false;
+  bool same_operand = true;
+  stmt_vec_info elem;
+  unsigned i;
+  load_permutation_t perm = linear_loads_p (neg_node, &is_dup);
+  for (i = 0; i < perm.length (); i++)
+    if (perm[i] != perm[0])
+      return false;
+
+
+  /* Check if the conjucate is on the second first and flip the order so we
+     get it in the right place.  We can't check the DR of the new child since
+     we may not be a load.  We can't recurse all the way down because we
+     may find a child with multiple children or external.  So instead just
+     check the operands to the multiply which tell us where the conjucate
+     was and how it's interpeting the permute.  */
+  vec<stmt_vec_info> stmts = SLP_TREE_SCALAR_STMTS (args0[0]);
+  tree first_op = gimple_op (STMT_VINFO_STMT (stmts[0]), 1);
+  FOR_EACH_VEC_ELT (stmts, i, elem)
+  if (first_op != gimple_op (STMT_VINFO_STMT (elem), 1))
+    {
+      same_operand = false;
+      break;
+    }
+
+  /* Reject operations that we don't have an optab for.  */
+  if (first_neg && !multiply && same_operand && !inverse)
+    return false;
+
+  bool is_neg = first_neg || second_neg;
+
+  if (!is_neg)
+    {
+      /* Indicates a rotation in the complex number, not a pattern we are
+	 looking for..  */
+      vec<slp_tree> params = SLP_TREE_CHILDREN (args0[0]);
+      if (vect_match_expression_p (params[0], NEGATE_EXPR)
+	  || vect_match_expression_p (params[1], NEGATE_EXPR))
+	return false;
+      *ifn = IFN_COMPLEX_MUL;
+      ops->safe_splice (params);
+      ops->safe_push (SLP_TREE_CHILDREN (args1[1])[0]);
+      ops->safe_push (SLP_TREE_CHILDREN (args1[1])[1]);
+    }
+  else if (is_neg)
+    {
+      *ifn = IFN_COMPLEX_MUL_CONJ;
+      vec<slp_tree> params = SLP_TREE_CHILDREN (args0[inverse ? 1 : 0]);
+      slp_tree value = second_neg ? args2[0] : args2[1];
+      /* Check if the conjucate is on the first or second parameter.  */
+      if (same_operand)
+	{
+	  ops->safe_push (params[0]);
+	  ops->safe_push (params[1]);
+	  ops->safe_push (neg_node);
+	  ops->safe_push (value);
+	}
+      else
+	{
+	  ops->safe_push (params[1]);
+	  ops->safe_push (params[0]);
+	  ops->safe_push (neg_node);
+	  ops->safe_push (value);
+	}
+
+      /* The two_operators with an FMS reverse the nodes so we have to swap them
+	 back to make a sensible operation.  */
+      if (inverse)
+	std::swap ((*ops)[2], (*ops)[3]);
+    }
+
+  return *ifn != IFN_LAST;
+}
+
+static bool
+vect_slp_matches_complex_mul (slp_tree root, internal_fn *ifn,
+			      vec<slp_tree> *ops, bool multiply,
+			      bool inverse = false)
+{
+  return vect_slp_matches_complex_mul (vect_detect_pair_op (root), root, ifn,
+				       ops, multiply, inverse);
+}
+
+class complex_mul_pattern : public complex_pattern
+{
+  protected:
+    /* Allocate enough space for FMA as well.  */
+    map_t m_blocks[6] = {};
+    bool m_inplace = false;
+    auto_vec<slp_tree> workset;
+    complex_mul_pattern (slp_tree *node, vec_info *vinfo)
+      : complex_pattern (node, vinfo)
+    {
+      this->m_arity = 2;
+      this->m_num_args = 2;
+    }
+
+  public:
+    static vect_pattern* create (slp_tree *node, vec_info *vinfo)
+    {
+       return new complex_mul_pattern (node, vinfo);
+    }
+
+    const char* get_name ()
+    {
+      return "Complex Multiplication";
+    }
+
+    bool validate_p ();
+    bool matches ();
+    bool matches (complex_operation_t op, vec<slp_tree> ops);
+};
+
+
+/* Pattern matcher for trying to match complex multiply pattern in SLP tree
+   If the operation matches then IFN is set to the operation it matched
+   and the arguments to the two replacement statements are put in M_OPS.
+
+   If no match is found then IFN is set to IFN_LAST and M_OPS is unchanged.
+
+   This function matches the patterns shaped as:
+
+   double ax = (b[i+1] * a[i]);
+   double bx = (a[i+1] * b[i]);
+
+   c[i] = c[i] - ax;
+   c[i+1] = c[i+1] + bx;
+
+   If a match occurred then TRUE is returned, else FALSE.  */
+
+bool
+complex_mul_pattern::matches (complex_operation_t op, vec<slp_tree> /* ops */)
+{
+  bool res
+    = vect_slp_matches_complex_mul (op, *this->m_node, &this->m_ifn,
+				    &this->m_ops, true);
+  if (res)
+    {
+      vect_build_perm_groups (&this->m_blocks[0], this->m_ops);
+      this->workset.safe_splice (SLP_TREE_CHILDREN (*this->m_node));
+      save_match ();
+    }
+  return res;
+}
+
+bool
+complex_mul_pattern::matches ()
+{
+  complex_operation_t op
+    = vect_detect_pair_op (*this->m_node);
+  return matches (op, this->m_ops);
+}
+
+
+/* Validates to see if the Complex MUL that we have matched is valid.  This is
+   done through a combination of making nodes linear and combining nodes.  */
+
+bool
+complex_mul_pattern::validate_p ()
+{
+  if (!this->m_match)
+    return false;
+
+  slp_tree node;
+  unsigned ix;
+  hash_set<slp_tree> cache;
+  FOR_EACH_VEC_ELT (this->workset, ix, node)
+    {
+      auto_vec<slp_tree> nodes;
+      nodes.create (this->m_num_args);
+      slp_tree tmp = NULL;
+
+      unsigned i;
+      for (i = 0; i < this->m_num_args; i++)
+	{
+	  unsigned index = (ix * this->m_num_args) + i;
+	  map_t map = this->m_blocks[index];
+	  slp_tree vnode = NULL;
+	  bool needs_linear = map.a == map.b;
+	  tmp = this->m_ops[index];
+	  cache.add (tmp);
+	  if (needs_linear && vect_slp_make_linear (node, tmp, &vnode))
+	    nodes.quick_push (vnode);
+	  else if (!needs_linear
+		   && vect_slp_make_combine_linear (node, this->m_ops, map,
+						    &vnode))
+	    nodes.quick_push (vnode);
+	  else
+	    {
+	      if (dump_enabled_p ())
+		dump_printf_loc (MSG_MISSED_OPTIMIZATION, vect_location,
+				"stmts could not be made %s %p\n",
+				needs_linear ? "linear" : "linear/combined",
+				tmp);
+	      nodes.release();
+	      return false;
+	    }
+
+	  vect_mark_stmts_as_in_pattern (&cache, node);
+	}
+
+      if (m_inplace)
+	{
+	  SLP_TREE_CHILDREN (*this->m_node).truncate (0);
+	  SLP_TREE_CHILDREN (*this->m_node).safe_splice (nodes);
+	}
+      else
+	{
+	  slp_tree new_node
+	    = vect_create_new_slp_node (SLP_TREE_SCALAR_STMTS (node),
+					SLP_TREE_CHILDREN (node).length ());
+	  SLP_TREE_VECTYPE (new_node) = SLP_TREE_VECTYPE (node);
+	  SLP_TREE_LANE_PERMUTATION (new_node)
+	    = SLP_TREE_LANE_PERMUTATION (node);
+	  SLP_TREE_CODE (new_node) = SLP_TREE_CODE (node);
+	  SLP_TREE_REF_COUNT (new_node) = SLP_TREE_REF_COUNT (node);
+	  SLP_TREE_REPRESENTATIVE (new_node) = SLP_TREE_REPRESENTATIVE (node);
+	  SLP_TREE_CHILDREN (new_node).safe_splice (nodes);
+	  SLP_TREE_LANES (new_node) = SLP_TREE_LANES (node);
+
+	  SLP_TREE_CHILDREN (*this->m_node)[ix] = new_node;
+	}
+    }
+
+  return true;
+}
+
 /*******************************************************************************
  * complex_operations_pattern class
  ******************************************************************************/
@@ -776,6 +1221,12 @@ complex_operations_pattern::matches ()
     return false;
 
   /* Check which pattern this may be.  Match longest pattern first.  */
+  this->m_patt = complex_mul_pattern::create (this->m_node, this->m_vinfo);
+  if (this->m_patt->matches (op, this->m_ops))
+    return true;
+
+  delete this->m_patt;
+
   this->m_patt = complex_add_pattern::create (this->m_node, this->m_vinfo);
   if (this->m_patt->matches (op, this->m_ops))
     return true;

Reply via email to