From 35309fea033413977a4e5b927a26db7b4c1442e8 Mon Sep 17 00:00:00 2001
From: "dzhao.ampere" <di.zhao@amperecomputing.com>
Date: Thu, 14 Sep 2023 16:48:20 +0800
Subject: [PATCH] Consider FMA in get_reassociation_width

---
 gcc/testsuite/gcc.dg/pr110279.c |  62 ++++++++++++++
 gcc/tree-ssa-reassoc.cc         | 147 ++++++++++++++++++++++++++++----
 2 files changed, 194 insertions(+), 15 deletions(-)
 create mode 100644 gcc/testsuite/gcc.dg/pr110279.c

diff --git a/gcc/testsuite/gcc.dg/pr110279.c b/gcc/testsuite/gcc.dg/pr110279.c
new file mode 100644
index 00000000000..9dc72658bff
--- /dev/null
+++ b/gcc/testsuite/gcc.dg/pr110279.c
@@ -0,0 +1,62 @@
+/* { dg-do compile } */
+/* { dg-options "-Ofast --param avoid-fma-max-bits=512 --param tree-reassoc-width=4 -fdump-tree-widening_mul-details" } */
+/* { dg-additional-options "-march=armv8.2-a" } */
+
+#define LOOP_COUNT 800000000
+typedef double data_e;
+
+/* Check that FMAs with backedge dependency are avoided. Otherwise there won't
+   be FMA generated with "--param avoid-fma-max-bits=512".   */
+
+foo1 (data_e a, data_e b, data_e c, data_e d)
+{
+  data_e result = 0;
+
+  for (int ic = 0; ic < LOOP_COUNT; ic++)
+    {
+      result += (a * b + c * d);
+
+      a -= 0.1;
+      b += 0.9;
+      c *= 1.02;
+      d *= 0.61;
+    }
+
+  return result;
+}
+
+foo2 (data_e a, data_e b, data_e c, data_e d)
+{
+  data_e result = 0;
+
+  for (int ic = 0; ic < LOOP_COUNT; ic++)
+    {
+      result += a * b + result + c * d;
+
+      a -= 0.1;
+      b += 0.9;
+      c *= 1.02;
+      d *= 0.61;
+    }
+
+  return result;
+}
+
+foo3 (data_e a, data_e b, data_e c, data_e d)
+{
+  data_e result = 0;
+
+  for (int ic = 0; ic < LOOP_COUNT; ic++)
+    {
+      result += result + a * b + c * d;
+
+      a -= 0.1;
+      b += 0.9;
+      c *= 1.02;
+      d *= 0.61;
+    }
+
+  return result;
+}
+
+/* { dg-final { scan-tree-dump-times "Generated FMA" 3 "widening_mul"} } */
diff --git a/gcc/tree-ssa-reassoc.cc b/gcc/tree-ssa-reassoc.cc
index eda03bf98a6..94db11edd4b 100644
--- a/gcc/tree-ssa-reassoc.cc
+++ b/gcc/tree-ssa-reassoc.cc
@@ -5427,17 +5427,96 @@ get_required_cycles (int ops_num, int cpu_width)
   return res;
 }
 
+/* Given that LHS is the result SSA_NAME of OPS, returns whether ranking the ops
+   results in better parallelism.  */
+static bool
+rank_ops_for_better_parallelism_p (vec<operand_entry *> *ops, tree lhs)
+{
+  /* If there's code like "acc = a * b + c * d + acc" in a tight loop, some
+     uarchs can execute results like:
+
+	_1 = a * b;
+	_2 = .FMA (c, d, _1);
+	acc_1 = acc_0 + _2;
+
+     in parallel, while turning it into
+
+	_1 = .FMA(a, b, acc_0);
+	acc_1 = .FMA(c, d, _1);
+
+     hinders that, because then the first FMA depends on the result of preceding
+     iteration.  */
+  if (maybe_le (tree_to_poly_int64 (TYPE_SIZE (TREE_TYPE (lhs))),
+		param_avoid_fma_max_bits))
+    {
+      /* Look for cross backedge dependency:
+	1. LHS is a phi argument in the same basic block it is defined.
+	2. And the result of the phi node is used in OPS.  */
+      basic_block bb = gimple_bb (SSA_NAME_DEF_STMT (lhs));
+      gimple_stmt_iterator gsi;
+      for (gsi = gsi_start_phis (bb); !gsi_end_p (gsi); gsi_next (&gsi))
+	{
+	  gphi *phi = dyn_cast<gphi *> (gsi_stmt (gsi));
+	  for (unsigned i = 0; i < gimple_phi_num_args (phi); ++i)
+	    {
+	      tree op = PHI_ARG_DEF (phi, i);
+	      if (!(op == lhs && gimple_phi_arg_edge (phi, i)->src == bb))
+		continue;
+	      tree phi_result = gimple_phi_result (phi);
+	      operand_entry *oe;
+	      unsigned int j;
+	      FOR_EACH_VEC_ELT (*ops, j, oe)
+		{
+		  if (TREE_CODE (oe->op) != SSA_NAME)
+		    continue;
+
+		  /* Result of phi is operand of PLUS_EXPR.  */
+		  if (oe->op == phi_result)
+		    return true;
+
+		  /* Check is result of phi is operand of MULT_EXPR.  */
+		  gimple *def_stmt = SSA_NAME_DEF_STMT (oe->op);
+		  if (is_gimple_assign (def_stmt)
+		      && gimple_assign_rhs_code (def_stmt) == NEGATE_EXPR)
+		    {
+		      tree rhs = gimple_assign_rhs1 (def_stmt);
+		      if (TREE_CODE (rhs) == SSA_NAME)
+			{
+			  if (rhs == phi_result)
+			    return true;
+			  def_stmt = SSA_NAME_DEF_STMT (rhs);
+			}
+		    }
+		  if (is_gimple_assign (def_stmt)
+		      && gimple_assign_rhs_code (def_stmt) == MULT_EXPR)
+		    {
+		      if (gimple_assign_rhs1 (def_stmt) == phi_result
+			  || gimple_assign_rhs2 (def_stmt) == phi_result)
+			return true;
+		    }
+		}
+	    }
+	}
+    }
+
+  return false;
+}
+
 /* Returns an optimal number of registers to use for computation of
-   given statements.  */
+   given statements.
+
+   MULT_NUM is the number of MULT_EXPRs in OPS.  LHS is the result SSA_NAME of
+   the operators.  */
 
 static int
-get_reassociation_width (int ops_num, enum tree_code opc,
-			 machine_mode mode)
+get_reassociation_width (vec<operand_entry *> *ops, int mult_num, tree lhs,
+			 enum tree_code opc, machine_mode mode)
 {
   int param_width = param_tree_reassoc_width;
   int width;
   int width_min;
   int cycles_best;
+  int ops_num = ops->length ();
 
   if (param_width > 0)
     width = param_width;
@@ -5468,6 +5547,37 @@ get_reassociation_width (int ops_num, enum tree_code opc,
 	break;
     }
 
+  /* For a complete FMA chain, rewriting to parallel reduces the number of FMA,
+     so the code size increases.  Check if fewer partitions results in better
+     (or same) cycle number.  */
+  if (mult_num >= ops_num - 1 && width > 1)
+    {
+      width_min = 1;
+      while (width > width_min)
+	{
+	  int width_mid = (width + width_min) / 2;
+	  int elog = exact_log2 (width_mid);
+	  elog = elog >= 0 ? elog : floor_log2 (width_mid) + 1;
+	  int attempt_cycles = CEIL (mult_num, width_mid) + elog;
+	  /* Since CYCLES_BEST doesn't count the circle of multiplications,
+	     compare with CYCLES_BEST + 1.  */
+	  if (cycles_best + 1 >= attempt_cycles)
+	    {
+	      width = width_mid;
+	      cycles_best = attempt_cycles - 1;
+	    }
+	  else if (width_min < width_mid)
+	    width_min = width_mid;
+	  else
+	    break;
+	}
+    }
+
+  /* If there's loop dependent FMA result, rewrite to avoid that.  This is
+     better than skipping the FMA candidates in widening_mul.  */
+  if (width == 1 && mult_num && rank_ops_for_better_parallelism_p (ops, lhs))
+    return 2;
+
   return width;
 }
 
@@ -6780,8 +6890,10 @@ transform_stmt_to_multiply (gimple_stmt_iterator *gsi, gimple *stmt,
    Rearrange ops to -> e + a * b + c * d generates:
 
    _4  = .FMA (c_7(D), d_8(D), _3);
-   _11 = .FMA (a_5(D), b_6(D), _4);  */
-static bool
+   _11 = .FMA (a_5(D), b_6(D), _4);
+
+   Return the return number of MULT_EXPRs in the chain.  */
+static unsigned
 rank_ops_for_fma (vec<operand_entry *> *ops)
 {
   operand_entry *oe;
@@ -6813,7 +6925,8 @@ rank_ops_for_fma (vec<operand_entry *> *ops)
      Putting ops that not def from mult in front can generate more FMAs.
 
      2. If all ops are defined with mult, we don't need to rearrange them.  */
-  if (ops_mult.length () >= 2 && ops_mult.length () != ops_length)
+  unsigned mult_num = ops_mult.length ();
+  if (mult_num >= 2 && mult_num != ops_length)
     {
       /* Put no-mult ops and mult ops alternately at the end of the
 	 queue, which is conducive to generating more FMA and reducing the
@@ -6829,9 +6942,8 @@ rank_ops_for_fma (vec<operand_entry *> *ops)
 	  if (opindex > 0)
 	    opindex--;
 	}
-      return true;
     }
-  return false;
+  return mult_num;
 }
 /* Reassociate expressions in basic block BB and its post-dominator as
    children.
@@ -6995,9 +7107,10 @@ reassociate_bb (basic_block bb)
 	      else
 		{
 		  machine_mode mode = TYPE_MODE (TREE_TYPE (lhs));
-		  int ops_num = ops.length ();
+		  unsigned ops_num = ops.length ();
 		  int width;
-		  bool has_fma = false;
+		  /* Number of MULT_EXPRs in the op list.  */
+		  unsigned mult_num = 0;
 
 		  /* For binary bit operations, if there are at least 3
 		     operands and the last operand in OPS is a constant,
@@ -7020,16 +7133,18 @@ reassociate_bb (basic_block bb)
 						      opt_type)
 		      && (rhs_code == PLUS_EXPR || rhs_code == MINUS_EXPR))
 		    {
-		      has_fma = rank_ops_for_fma (&ops);
+		      mult_num = rank_ops_for_fma (&ops);
 		    }
 
 		  /* Only rewrite the expression tree to parallel in the
 		     last reassoc pass to avoid useless work back-and-forth
 		     with initial linearization.  */
+		  bool has_fma = mult_num >= 2 && mult_num != ops_num;
 		  if (!reassoc_insert_powi_p
-		      && ops.length () > 3
-		      && (width = get_reassociation_width (ops_num, rhs_code,
-							   mode)) > 1)
+		      && ops_num > 3
+		      && (width = get_reassociation_width (&ops, mult_num, lhs,
+							   rhs_code, mode))
+			   > 1)
 		    {
 		      if (dump_file && (dump_flags & TDF_DETAILS))
 			fprintf (dump_file,
@@ -7046,7 +7161,9 @@ reassociate_bb (basic_block bb)
 			 to make sure the ones that get the double
 			 binary op are chosen wisely.  */
 		      int len = ops.length ();
-		      if (len >= 3 && !has_fma)
+		      if (len >= 3
+			  && (!has_fma
+			      || rank_ops_for_better_parallelism_p (&ops, lhs)))
 			swap_ops_for_binary_stmt (ops, len - 3);
 
 		      new_lhs = rewrite_expr_tree (stmt, rhs_code, 0, ops,
-- 
2.25.1

