Hello Lili Cui,

Since I'm also trying to improve this lately, I've tested your patch on
several aarch64 machines we have, including neoverse-n1 and ampere1
architectures. However, I haven't reproduced the 6.00% improvement of 
503.bwaves_r single copy run you mentioned. Could you share more information
about the aarch64 CPU and compile options you tested? The option
I'm using is "-Ofast", with or without "--param avoid-fma-max-bits=512".

Additionally, we found some spec2017 cases with regressions, including
4% on 527.cam4_r (neoverse-n1).

> -----Original Message-----
> From: Gcc-patches <gcc-patches-
> bounces+dizhao=os.amperecomputing....@gcc.gnu.org> On Behalf Of Cui, Lili via
> Gcc-patches
> Sent: Thursday, May 25, 2023 7:30 AM
> To: gcc-patches@gcc.gnu.org
> Cc: richard.guent...@gmail.com; li...@linux.ibm.com; Lili Cui
> <lili....@intel.com>
> Subject: [PATCH] Handle FMA friendly in reassoc pass
> 
> From: Lili Cui <lili....@intel.com>
> 
> Make some changes in reassoc pass to make it more friendly to fma pass later.
> Using FMA instead of mult + add reduces register pressure and insruction
> retired.
> 
> There are mainly two changes
> 1. Put no-mult ops and mult ops alternately at the end of the queue, which is
> conducive to generating more fma and reducing the loss of FMA when breaking
> the chain.
> 2. Rewrite the rewrite_expr_tree_parallel function to try to build parallel
> chains according to the given correlation width, keeping the FMA chance as
> much as possible.
> 
> With the patch applied
> 
> On ICX:
> 507.cactuBSSN_r: Improved by 1.7% for multi-copy .
> 503.bwaves_r   : Improved by  0.60% for single copy .
> 507.cactuBSSN_r: Improved by  1.10% for single copy .
> 519.lbm_r      : Improved by  2.21% for single copy .
> no measurable changes for other benchmarks.
> 
> On aarch64
> 507.cactuBSSN_r: Improved by 1.7% for multi-copy.
> 503.bwaves_r   : Improved by 6.00% for single-copy.
> no measurable changes for other benchmarks.
> 
> TEST1:
> 
> float
> foo (float a, float b, float c, float d, float *e)
> {
>    return  *e  + a * b + c * d ;
> }
> 
> For "-Ofast -mfpmath=sse -mfma" GCC generates:
>         vmulss  %xmm3, %xmm2, %xmm2
>         vfmadd132ss     %xmm1, %xmm2, %xmm0
>         vaddss  (%rdi), %xmm0, %xmm0
>         ret
> 
> With this patch GCC generates:
>         vfmadd213ss   (%rdi), %xmm1, %xmm0
>         vfmadd231ss   %xmm2, %xmm3, %xmm0
>         ret
> 
> TEST2:
> 
> for (int i = 0; i < N; i++)
> {
>   a[i] += b[i]* c[i] + d[i] * e[i] + f[i] * g[i] + h[i] * j[i] + k[i] * l[i]
> + m[i]* o[i] + p[i];
> }
> 
> For "-Ofast -mfpmath=sse -mfma"  GCC generates:
>       vmovapd e(%rax), %ymm4
>       vmulpd  d(%rax), %ymm4, %ymm3
>       addq    $32, %rax
>       vmovapd c-32(%rax), %ymm5
>       vmovapd j-32(%rax), %ymm6
>       vmulpd  h-32(%rax), %ymm6, %ymm2
>       vmovapd a-32(%rax), %ymm6
>       vaddpd  p-32(%rax), %ymm6, %ymm0
>       vmovapd g-32(%rax), %ymm7
>       vfmadd231pd     b-32(%rax), %ymm5, %ymm3
>       vmovapd o-32(%rax), %ymm4
>       vmulpd  m-32(%rax), %ymm4, %ymm1
>       vmovapd l-32(%rax), %ymm5
>       vfmadd231pd     f-32(%rax), %ymm7, %ymm2
>       vfmadd231pd     k-32(%rax), %ymm5, %ymm1
>       vaddpd  %ymm3, %ymm0, %ymm0
>       vaddpd  %ymm2, %ymm0, %ymm0
>       vaddpd  %ymm1, %ymm0, %ymm0
>       vmovapd %ymm0, a-32(%rax)
>       cmpq    $8192, %rax
>       jne     .L4
>       vzeroupper
>       ret
> 
> with this patch applied GCC breaks the chain with width = 2 and generates 6
> fma:
> 
>       vmovapd a(%rax), %ymm2
>       vmovapd c(%rax), %ymm0
>       addq    $32, %rax
>       vmovapd e-32(%rax), %ymm1
>       vmovapd p-32(%rax), %ymm5
>       vmovapd g-32(%rax), %ymm3
>       vmovapd j-32(%rax), %ymm6
>       vmovapd l-32(%rax), %ymm4
>       vmovapd o-32(%rax), %ymm7
>       vfmadd132pd     b-32(%rax), %ymm2, %ymm0
>       vfmadd132pd     d-32(%rax), %ymm5, %ymm1
>       vfmadd231pd     f-32(%rax), %ymm3, %ymm0
>       vfmadd231pd     h-32(%rax), %ymm6, %ymm1
>       vfmadd231pd     k-32(%rax), %ymm4, %ymm0
>       vfmadd231pd     m-32(%rax), %ymm7, %ymm1
>       vaddpd  %ymm1, %ymm0, %ymm0
>       vmovapd %ymm0, a-32(%rax)
>       cmpq    $8192, %rax
>       jne     .L2
>       vzeroupper
>       ret
> 
> gcc/ChangeLog:
> 
>       PR gcc/98350
>       * tree-ssa-reassoc.cc
>       (rewrite_expr_tree_parallel): Rewrite this function.
>       (rank_ops_for_fma): New.
>       (reassociate_bb): Handle new function.
> 
> gcc/testsuite/ChangeLog:
> 
>       PR gcc/98350
>       * gcc.dg/pr98350-1.c: New test.
>       * gcc.dg/pr98350-2.c: Ditto.
> ---
>  gcc/testsuite/gcc.dg/pr98350-1.c |  31 ++++
>  gcc/testsuite/gcc.dg/pr98350-2.c |  11 ++
>  gcc/tree-ssa-reassoc.cc          | 256 +++++++++++++++++++++----------
>  3 files changed, 215 insertions(+), 83 deletions(-)
>  create mode 100644 gcc/testsuite/gcc.dg/pr98350-1.c
>  create mode 100644 gcc/testsuite/gcc.dg/pr98350-2.c
> 
> diff --git a/gcc/testsuite/gcc.dg/pr98350-1.c b/gcc/testsuite/gcc.dg/pr98350-
> 1.c
> new file mode 100644
> index 00000000000..6bcf78a19ab
> --- /dev/null
> +++ b/gcc/testsuite/gcc.dg/pr98350-1.c
> @@ -0,0 +1,31 @@
> +/* { dg-do compile } */
> +/* { dg-options "-Ofast  -fdump-tree-widening_mul" } */
> +
> +/* Test that the compiler properly optimizes multiply and add
> +   to generate more FMA instructions.  */
> +#define N 1024
> +double a[N];
> +double b[N];
> +double c[N];
> +double d[N];
> +double e[N];
> +double f[N];
> +double g[N];
> +double h[N];
> +double j[N];
> +double k[N];
> +double l[N];
> +double m[N];
> +double o[N];
> +double p[N];
> +
> +
> +void
> +foo (void)
> +{
> +  for (int i = 0; i < N; i++)
> +  {
> +    a[i] += b[i] * c[i] + d[i] * e[i] + f[i] * g[i] + h[i] * j[i] + k[i] *
> l[i] + m[i]* o[i] + p[i];
> +  }
> +}
> +/* { dg-final { scan-tree-dump-times { = \.FMA \(} 6 "widening_mul" } } */
> diff --git a/gcc/testsuite/gcc.dg/pr98350-2.c b/gcc/testsuite/gcc.dg/pr98350-
> 2.c
> new file mode 100644
> index 00000000000..333d34f026a
> --- /dev/null
> +++ b/gcc/testsuite/gcc.dg/pr98350-2.c
> @@ -0,0 +1,11 @@
> +/* { dg-do compile } */
> +/* { dg-options "-Ofast -fdump-tree-widening_mul" } */
> +
> +/* Test that the compiler rearrange the ops to generate more FMA.  */
> +
> +float
> +foo1 (float a, float b, float c, float d, float *e)
> +{
> +   return   *e + a * b + c * d ;
> +}
> +/* { dg-final { scan-tree-dump-times { = \.FMA \(} 2 "widening_mul" } } */
> diff --git a/gcc/tree-ssa-reassoc.cc b/gcc/tree-ssa-reassoc.cc
> index 067a3f07f7e..611fb9b1c99 100644
> --- a/gcc/tree-ssa-reassoc.cc
> +++ b/gcc/tree-ssa-reassoc.cc
> @@ -54,6 +54,7 @@ along with GCC; see the file COPYING3.  If not see
>  #include "tree-ssa-reassoc.h"
>  #include "tree-ssa-math-opts.h"
>  #include "gimple-range.h"
> +#include "internal-fn.h"
> 
>  /*  This is a simple global reassociation pass.  It is, in part, based
>      on the LLVM pass of the same name (They do some things more/less
> @@ -5468,14 +5469,24 @@ get_reassociation_width (int ops_num, enum tree_code
> opc,
>    return width;
>  }
> 
> -/* Recursively rewrite our linearized statements so that the operators
> -   match those in OPS[OPINDEX], putting the computation in rank
> -   order and trying to allow operations to be executed in
> -   parallel.  */
> +/* Rewrite statements with dependency chain with regard the chance to
> generate
> +   FMA.
> +   For the chain with FMA: Try to keep fma opportunity as much as possible.
> +   For the chain without FMA: Putting the computation in rank order and
> trying
> +   to allow operations to be executed in parallel.
> +   E.g.
> +   e + f + g + a * b + c * d;
> +
> +   ssa1 = e + f;
> +   ssa2 = g + a * b;
> +   ssa3 = ssa1 + c * d;
> +   ssa4 = ssa2 + ssa3;
> 
> +   This reassociation approach preserves the chance of fma generation as
> much
> +   as possible.  */
>  static void
> -rewrite_expr_tree_parallel (gassign *stmt, int width,
> -                         const vec<operand_entry *> &ops)
> +rewrite_expr_tree_parallel (gassign *stmt, int width, bool has_fma,
> +                                      const vec<operand_entry *> &ops)
>  {
>    enum tree_code opcode = gimple_assign_rhs_code (stmt);
>    int op_num = ops.length ();
> @@ -5483,10 +5494,11 @@ rewrite_expr_tree_parallel (gassign *stmt, int width,
>    int stmt_num = op_num - 1;
>    gimple **stmts = XALLOCAVEC (gimple *, stmt_num);
>    int op_index = op_num - 1;
> -  int stmt_index = 0;
> -  int ready_stmts_end = 0;
> -  int i = 0;
> -  gimple *stmt1 = NULL, *stmt2 = NULL;
> +  int width_count = width;
> +  int i = 0, j = 0;
> +  tree tmp_op[2], op1;
> +  operand_entry *oe;
> +  gimple *stmt1 = NULL;
>    tree last_rhs1 = gimple_assign_rhs1 (stmt);
> 
>    /* We start expression rewriting from the top statements.
> @@ -5496,91 +5508,87 @@ rewrite_expr_tree_parallel (gassign *stmt, int width,
>    for (i = stmt_num - 2; i >= 0; i--)
>      stmts[i] = SSA_NAME_DEF_STMT (gimple_assign_rhs1 (stmts[i+1]));
> 
> -  for (i = 0; i < stmt_num; i++)
> +  /* Build parallel dependency chain according to width.  */
> +  for (i = 0; i < width; i++)
>      {
> -      tree op1, op2;
> -
> -      /* Determine whether we should use results of
> -      already handled statements or not.  */
> -      if (ready_stmts_end == 0
> -       && (i - stmt_index >= width || op_index < 1))
> -     ready_stmts_end = i;
> -
> -      /* Now we choose operands for the next statement.  Non zero
> -      value in ready_stmts_end means here that we should use
> -      the result of already generated statements as new operand.  */
> -      if (ready_stmts_end > 0)
> -     {
> -       op1 = gimple_assign_lhs (stmts[stmt_index++]);
> -       if (ready_stmts_end > stmt_index)
> -         op2 = gimple_assign_lhs (stmts[stmt_index++]);
> -       else if (op_index >= 0)
> -         {
> -           operand_entry *oe = ops[op_index--];
> -           stmt2 = oe->stmt_to_insert;
> -           op2 = oe->op;
> -         }
> -       else
> -         {
> -           gcc_assert (stmt_index < i);
> -           op2 = gimple_assign_lhs (stmts[stmt_index++]);
> -         }
> +      /* If the chain has FAM, we do not swap two operands.  */
> +      if (op_index > 1 && !has_fma)
> +     swap_ops_for_binary_stmt (ops, op_index - 2);
> 
> -       if (stmt_index >= ready_stmts_end)
> -         ready_stmts_end = 0;
> -     }
> -      else
> +      for (j = 0; j < 2; j++)
>       {
> -       if (op_index > 1)
> -         swap_ops_for_binary_stmt (ops, op_index - 2);
> -       operand_entry *oe2 = ops[op_index--];
> -       operand_entry *oe1 = ops[op_index--];
> -       op2 = oe2->op;
> -       stmt2 = oe2->stmt_to_insert;
> -       op1 = oe1->op;
> -       stmt1 = oe1->stmt_to_insert;
> +       gcc_assert (op_index >= 0);
> +       oe = ops[op_index--];
> +       tmp_op[j] = oe->op;
> +       /* If the stmt that defines operand has to be inserted, insert it
> +          before the use.  */
> +       stmt1 = oe->stmt_to_insert;
> +       if (stmt1)
> +         insert_stmt_before_use (stmts[i], stmt1);
> +       stmt1 = NULL;
>       }
> -
> -      /* If we emit the last statement then we should put
> -      operands into the last statement.  It will also
> -      break the loop.  */
> -      if (op_index < 0 && stmt_index == i)
> -     i = stmt_num - 1;
> +      stmts[i] = build_and_add_sum (TREE_TYPE (last_rhs1),
> +                                 tmp_op[1],
> +                                 tmp_op[0],
> +                                 opcode);
> +      gimple_set_visited (stmts[i], true);
> 
>        if (dump_file && (dump_flags & TDF_DETAILS))
>       {
> -       fprintf (dump_file, "Transforming ");
> +       fprintf (dump_file, " into ");
>         print_gimple_stmt (dump_file, stmts[i], 0);
>       }
> +    }
> 
> -      /* If the stmt that defines operand has to be inserted, insert it
> -      before the use.  */
> -      if (stmt1)
> -     insert_stmt_before_use (stmts[i], stmt1);
> -      if (stmt2)
> -     insert_stmt_before_use (stmts[i], stmt2);
> -      stmt1 = stmt2 = NULL;
> -
> -      /* We keep original statement only for the last one.  All
> -      others are recreated.  */
> -      if (i == stmt_num - 1)
> +  for (i = width; i < stmt_num; i++)
> +    {
> +      /* We keep original statement only for the last one.  All others are
> +      recreated.  */
> +      if ( op_index < 0)
>       {
> -       gimple_assign_set_rhs1 (stmts[i], op1);
> -       gimple_assign_set_rhs2 (stmts[i], op2);
> -       update_stmt (stmts[i]);
> +       if (width_count == 2)
> +         {
> +
> +           /* We keep original statement only for the last one.  All
> +              others are recreated.  */
> +           gimple_assign_set_rhs1 (stmts[i], gimple_assign_lhs (stmts[i-1]));
> +           gimple_assign_set_rhs2 (stmts[i], gimple_assign_lhs (stmts[i-2]));
> +           update_stmt (stmts[i]);
> +         }
> +       else
> +         {
> +
> +           stmts[i] =
> +             build_and_add_sum (TREE_TYPE (last_rhs1),
> +                                gimple_assign_lhs (stmts[i-width_count]),
> +                                gimple_assign_lhs (stmts[i-width_count+1]),
> +                                opcode);
> +           gimple_set_visited (stmts[i], true);
> +           width_count--;
> +         }
>       }
>        else
>       {
> -       stmts[i] = build_and_add_sum (TREE_TYPE (last_rhs1), op1, op2,
> opcode);
> +       /* Attach the rest of the ops to the parallel dependency chain.  */
> +       oe = ops[op_index--];
> +       op1 = oe->op;
> +       stmt1 = oe->stmt_to_insert;
> +       if (stmt1)
> +         insert_stmt_before_use (stmts[i], stmt1);
> +       stmt1 = NULL;
> +       stmts[i] = build_and_add_sum (TREE_TYPE (last_rhs1),
> +                                     gimple_assign_lhs (stmts[i-width]),
> +                                     op1,
> +                                     opcode);
>         gimple_set_visited (stmts[i], true);
>       }
> +
>        if (dump_file && (dump_flags & TDF_DETAILS))
>       {
>         fprintf (dump_file, " into ");
>         print_gimple_stmt (dump_file, stmts[i], 0);
>       }
>      }
> -
>    remove_visited_stmt_chain (last_rhs1);
>  }
> 
> @@ -6649,6 +6657,73 @@ transform_stmt_to_multiply (gimple_stmt_iterator *gsi,
> gimple *stmt,
>      }
>  }
> 
> +/* Rearrange ops may have more FMA when the chain may has more than 2 FMAs.
> +   Put no-mult ops and mult ops alternately at the end of the queue, which
> is
> +   conducive to generating more FMA and reducing the loss of FMA when
> breaking
> +   the chain.
> +   E.g.
> +   a * b + c * d + e generates:
> +
> +   _4  = c_9(D) * d_10(D);
> +   _12 = .FMA (a_7(D), b_8(D), _4);
> +   _11 = e_6(D) + _12;
> +
> +   Rearrange ops to -> e + a * b + c * d generates:
> +
> +   _4  = .FMA (c_7(D), d_8(D), _3);
> +   _11 = .FMA (a_5(D), b_6(D), _4);  */
> +static bool
> +rank_ops_for_fma (vec<operand_entry *> *ops)
> +{
> +  operand_entry *oe;
> +  unsigned int i;
> +  unsigned int ops_length = ops->length ();
> +  auto_vec<operand_entry *> ops_mult;
> +  auto_vec<operand_entry *> ops_others;
> +
> +  FOR_EACH_VEC_ELT (*ops, i, oe)
> +    {
> +      if (TREE_CODE (oe->op) == SSA_NAME)
> +     {
> +       gimple *def_stmt = SSA_NAME_DEF_STMT (oe->op);
> +       if (is_gimple_assign (def_stmt)
> +           && gimple_assign_rhs_code (def_stmt) == MULT_EXPR)
> +         ops_mult.safe_push (oe);
> +       else
> +         ops_others.safe_push (oe);
> +     }
> +      else
> +     ops_others.safe_push (oe);
> +    }
> +  /* 1. When ops_mult.length == 2, like the following case,
> +
> +     a * b + c * d + e.
> +
> +     we need to rearrange the ops.
> +
> +     Putting ops that not def from mult in front can generate more FMAs.
> +
> +     2. If all ops are defined with mult, we don't need to rearrange them.
> */
> +  if (ops_mult.length () >= 2 && ops_mult.length () != ops_length)
> +    {
> +      /* Put no-mult ops and mult ops alternately at the end of the
> +      queue, which is conducive to generating more FMA and reducing the
> +      loss of FMA when breaking the chain.  */
> +      ops->truncate (0);
> +      ops->splice (ops_mult);
> +      int j, opindex = ops->length ();
> +      int others_length = ops_others.length ();
> +      for (j = 0; j < others_length; j++)
> +     {
> +       oe = ops_others.pop ();
> +       ops->quick_insert (opindex, oe);
> +       if (opindex > 0)
> +         opindex--;
> +     }
> +      return true;
> +    }
> +  return false;
> +}
>  /* Reassociate expressions in basic block BB and its post-dominator as
>     children.
> 
> @@ -6813,6 +6888,7 @@ reassociate_bb (basic_block bb)
>                 machine_mode mode = TYPE_MODE (TREE_TYPE (lhs));
>                 int ops_num = ops.length ();
>                 int width;
> +               bool has_fma = false;
> 
>                 /* For binary bit operations, if there are at least 3
>                    operands and the last operand in OPS is a constant,
> @@ -6821,11 +6897,23 @@ reassociate_bb (basic_block bb)
>                    often match a canonical bit test when we get to RTL.  */
>                 if (ops.length () > 2
>                     && (rhs_code == BIT_AND_EXPR
> -                       || rhs_code == BIT_IOR_EXPR
> -                       || rhs_code == BIT_XOR_EXPR)
> +                       || rhs_code == BIT_IOR_EXPR
> +                       || rhs_code == BIT_XOR_EXPR)
>                     && TREE_CODE (ops.last ()->op) == INTEGER_CST)
>                   std::swap (*ops[0], *ops[ops_num - 1]);
> 
> +               optimization_type opt_type = bb_optimization_type (bb);
> +
> +               /* If the target support FMA, rank_ops_for_fma will detect if
> +                  the chain has fmas and rearrange the ops if so.  */
> +               if (direct_internal_fn_supported_p (IFN_FMA,
> +                                                   TREE_TYPE (lhs),
> +                                                   opt_type)
> +                   && (rhs_code == PLUS_EXPR || rhs_code == MINUS_EXPR))
> +                 {
> +                   has_fma = rank_ops_for_fma (&ops);
> +                 }
> +
>                 /* Only rewrite the expression tree to parallel in the
>                    last reassoc pass to avoid useless work back-and-forth
>                    with initial linearization.  */
> @@ -6839,22 +6927,24 @@ reassociate_bb (basic_block bb)
>                                "Width = %d was chosen for reassociation\n",
>                                width);
>                     rewrite_expr_tree_parallel (as_a <gassign *> (stmt),
> -                                               width, ops);
> +                                               width,
> +                                               has_fma,
> +                                               ops);
>                   }
>                 else
> -                    {
> -                      /* When there are three operands left, we want
> -                         to make sure the ones that get the double
> -                         binary op are chosen wisely.  */
> -                      int len = ops.length ();
> -                      if (len >= 3)
> +                 {
> +                   /* When there are three operands left, we want
> +                      to make sure the ones that get the double
> +                      binary op are chosen wisely.  */
> +                   int len = ops.length ();
> +                   if (len >= 3 && !has_fma)
>                       swap_ops_for_binary_stmt (ops, len - 3);
> 
>                     new_lhs = rewrite_expr_tree (stmt, rhs_code, 0, ops,
>                                                  powi_result != NULL
>                                                  || negate_result,
>                                                  len != orig_len);
> -                    }
> +                 }
> 
>                 /* If we combined some repeated factors into a
>                    __builtin_powi call, multiply that result by the
> --
> 2.25.1

Reply via email to