Hello Lili Cui, Since I'm also trying to improve this lately, I've tested your patch on several aarch64 machines we have, including neoverse-n1 and ampere1 architectures. However, I haven't reproduced the 6.00% improvement of 503.bwaves_r single copy run you mentioned. Could you share more information about the aarch64 CPU and compile options you tested? The option I'm using is "-Ofast", with or without "--param avoid-fma-max-bits=512".
Additionally, we found some spec2017 cases with regressions, including 4% on 527.cam4_r (neoverse-n1). > -----Original Message----- > From: Gcc-patches <gcc-patches- > bounces+dizhao=os.amperecomputing....@gcc.gnu.org> On Behalf Of Cui, Lili via > Gcc-patches > Sent: Thursday, May 25, 2023 7:30 AM > To: gcc-patches@gcc.gnu.org > Cc: richard.guent...@gmail.com; li...@linux.ibm.com; Lili Cui > <lili....@intel.com> > Subject: [PATCH] Handle FMA friendly in reassoc pass > > From: Lili Cui <lili....@intel.com> > > Make some changes in reassoc pass to make it more friendly to fma pass later. > Using FMA instead of mult + add reduces register pressure and insruction > retired. > > There are mainly two changes > 1. Put no-mult ops and mult ops alternately at the end of the queue, which is > conducive to generating more fma and reducing the loss of FMA when breaking > the chain. > 2. Rewrite the rewrite_expr_tree_parallel function to try to build parallel > chains according to the given correlation width, keeping the FMA chance as > much as possible. > > With the patch applied > > On ICX: > 507.cactuBSSN_r: Improved by 1.7% for multi-copy . > 503.bwaves_r : Improved by 0.60% for single copy . > 507.cactuBSSN_r: Improved by 1.10% for single copy . > 519.lbm_r : Improved by 2.21% for single copy . > no measurable changes for other benchmarks. > > On aarch64 > 507.cactuBSSN_r: Improved by 1.7% for multi-copy. > 503.bwaves_r : Improved by 6.00% for single-copy. > no measurable changes for other benchmarks. > > TEST1: > > float > foo (float a, float b, float c, float d, float *e) > { > return *e + a * b + c * d ; > } > > For "-Ofast -mfpmath=sse -mfma" GCC generates: > vmulss %xmm3, %xmm2, %xmm2 > vfmadd132ss %xmm1, %xmm2, %xmm0 > vaddss (%rdi), %xmm0, %xmm0 > ret > > With this patch GCC generates: > vfmadd213ss (%rdi), %xmm1, %xmm0 > vfmadd231ss %xmm2, %xmm3, %xmm0 > ret > > TEST2: > > for (int i = 0; i < N; i++) > { > a[i] += b[i]* c[i] + d[i] * e[i] + f[i] * g[i] + h[i] * j[i] + k[i] * l[i] > + m[i]* o[i] + p[i]; > } > > For "-Ofast -mfpmath=sse -mfma" GCC generates: > vmovapd e(%rax), %ymm4 > vmulpd d(%rax), %ymm4, %ymm3 > addq $32, %rax > vmovapd c-32(%rax), %ymm5 > vmovapd j-32(%rax), %ymm6 > vmulpd h-32(%rax), %ymm6, %ymm2 > vmovapd a-32(%rax), %ymm6 > vaddpd p-32(%rax), %ymm6, %ymm0 > vmovapd g-32(%rax), %ymm7 > vfmadd231pd b-32(%rax), %ymm5, %ymm3 > vmovapd o-32(%rax), %ymm4 > vmulpd m-32(%rax), %ymm4, %ymm1 > vmovapd l-32(%rax), %ymm5 > vfmadd231pd f-32(%rax), %ymm7, %ymm2 > vfmadd231pd k-32(%rax), %ymm5, %ymm1 > vaddpd %ymm3, %ymm0, %ymm0 > vaddpd %ymm2, %ymm0, %ymm0 > vaddpd %ymm1, %ymm0, %ymm0 > vmovapd %ymm0, a-32(%rax) > cmpq $8192, %rax > jne .L4 > vzeroupper > ret > > with this patch applied GCC breaks the chain with width = 2 and generates 6 > fma: > > vmovapd a(%rax), %ymm2 > vmovapd c(%rax), %ymm0 > addq $32, %rax > vmovapd e-32(%rax), %ymm1 > vmovapd p-32(%rax), %ymm5 > vmovapd g-32(%rax), %ymm3 > vmovapd j-32(%rax), %ymm6 > vmovapd l-32(%rax), %ymm4 > vmovapd o-32(%rax), %ymm7 > vfmadd132pd b-32(%rax), %ymm2, %ymm0 > vfmadd132pd d-32(%rax), %ymm5, %ymm1 > vfmadd231pd f-32(%rax), %ymm3, %ymm0 > vfmadd231pd h-32(%rax), %ymm6, %ymm1 > vfmadd231pd k-32(%rax), %ymm4, %ymm0 > vfmadd231pd m-32(%rax), %ymm7, %ymm1 > vaddpd %ymm1, %ymm0, %ymm0 > vmovapd %ymm0, a-32(%rax) > cmpq $8192, %rax > jne .L2 > vzeroupper > ret > > gcc/ChangeLog: > > PR gcc/98350 > * tree-ssa-reassoc.cc > (rewrite_expr_tree_parallel): Rewrite this function. > (rank_ops_for_fma): New. > (reassociate_bb): Handle new function. > > gcc/testsuite/ChangeLog: > > PR gcc/98350 > * gcc.dg/pr98350-1.c: New test. > * gcc.dg/pr98350-2.c: Ditto. > --- > gcc/testsuite/gcc.dg/pr98350-1.c | 31 ++++ > gcc/testsuite/gcc.dg/pr98350-2.c | 11 ++ > gcc/tree-ssa-reassoc.cc | 256 +++++++++++++++++++++---------- > 3 files changed, 215 insertions(+), 83 deletions(-) > create mode 100644 gcc/testsuite/gcc.dg/pr98350-1.c > create mode 100644 gcc/testsuite/gcc.dg/pr98350-2.c > > diff --git a/gcc/testsuite/gcc.dg/pr98350-1.c b/gcc/testsuite/gcc.dg/pr98350- > 1.c > new file mode 100644 > index 00000000000..6bcf78a19ab > --- /dev/null > +++ b/gcc/testsuite/gcc.dg/pr98350-1.c > @@ -0,0 +1,31 @@ > +/* { dg-do compile } */ > +/* { dg-options "-Ofast -fdump-tree-widening_mul" } */ > + > +/* Test that the compiler properly optimizes multiply and add > + to generate more FMA instructions. */ > +#define N 1024 > +double a[N]; > +double b[N]; > +double c[N]; > +double d[N]; > +double e[N]; > +double f[N]; > +double g[N]; > +double h[N]; > +double j[N]; > +double k[N]; > +double l[N]; > +double m[N]; > +double o[N]; > +double p[N]; > + > + > +void > +foo (void) > +{ > + for (int i = 0; i < N; i++) > + { > + a[i] += b[i] * c[i] + d[i] * e[i] + f[i] * g[i] + h[i] * j[i] + k[i] * > l[i] + m[i]* o[i] + p[i]; > + } > +} > +/* { dg-final { scan-tree-dump-times { = \.FMA \(} 6 "widening_mul" } } */ > diff --git a/gcc/testsuite/gcc.dg/pr98350-2.c b/gcc/testsuite/gcc.dg/pr98350- > 2.c > new file mode 100644 > index 00000000000..333d34f026a > --- /dev/null > +++ b/gcc/testsuite/gcc.dg/pr98350-2.c > @@ -0,0 +1,11 @@ > +/* { dg-do compile } */ > +/* { dg-options "-Ofast -fdump-tree-widening_mul" } */ > + > +/* Test that the compiler rearrange the ops to generate more FMA. */ > + > +float > +foo1 (float a, float b, float c, float d, float *e) > +{ > + return *e + a * b + c * d ; > +} > +/* { dg-final { scan-tree-dump-times { = \.FMA \(} 2 "widening_mul" } } */ > diff --git a/gcc/tree-ssa-reassoc.cc b/gcc/tree-ssa-reassoc.cc > index 067a3f07f7e..611fb9b1c99 100644 > --- a/gcc/tree-ssa-reassoc.cc > +++ b/gcc/tree-ssa-reassoc.cc > @@ -54,6 +54,7 @@ along with GCC; see the file COPYING3. If not see > #include "tree-ssa-reassoc.h" > #include "tree-ssa-math-opts.h" > #include "gimple-range.h" > +#include "internal-fn.h" > > /* This is a simple global reassociation pass. It is, in part, based > on the LLVM pass of the same name (They do some things more/less > @@ -5468,14 +5469,24 @@ get_reassociation_width (int ops_num, enum tree_code > opc, > return width; > } > > -/* Recursively rewrite our linearized statements so that the operators > - match those in OPS[OPINDEX], putting the computation in rank > - order and trying to allow operations to be executed in > - parallel. */ > +/* Rewrite statements with dependency chain with regard the chance to > generate > + FMA. > + For the chain with FMA: Try to keep fma opportunity as much as possible. > + For the chain without FMA: Putting the computation in rank order and > trying > + to allow operations to be executed in parallel. > + E.g. > + e + f + g + a * b + c * d; > + > + ssa1 = e + f; > + ssa2 = g + a * b; > + ssa3 = ssa1 + c * d; > + ssa4 = ssa2 + ssa3; > > + This reassociation approach preserves the chance of fma generation as > much > + as possible. */ > static void > -rewrite_expr_tree_parallel (gassign *stmt, int width, > - const vec<operand_entry *> &ops) > +rewrite_expr_tree_parallel (gassign *stmt, int width, bool has_fma, > + const vec<operand_entry *> &ops) > { > enum tree_code opcode = gimple_assign_rhs_code (stmt); > int op_num = ops.length (); > @@ -5483,10 +5494,11 @@ rewrite_expr_tree_parallel (gassign *stmt, int width, > int stmt_num = op_num - 1; > gimple **stmts = XALLOCAVEC (gimple *, stmt_num); > int op_index = op_num - 1; > - int stmt_index = 0; > - int ready_stmts_end = 0; > - int i = 0; > - gimple *stmt1 = NULL, *stmt2 = NULL; > + int width_count = width; > + int i = 0, j = 0; > + tree tmp_op[2], op1; > + operand_entry *oe; > + gimple *stmt1 = NULL; > tree last_rhs1 = gimple_assign_rhs1 (stmt); > > /* We start expression rewriting from the top statements. > @@ -5496,91 +5508,87 @@ rewrite_expr_tree_parallel (gassign *stmt, int width, > for (i = stmt_num - 2; i >= 0; i--) > stmts[i] = SSA_NAME_DEF_STMT (gimple_assign_rhs1 (stmts[i+1])); > > - for (i = 0; i < stmt_num; i++) > + /* Build parallel dependency chain according to width. */ > + for (i = 0; i < width; i++) > { > - tree op1, op2; > - > - /* Determine whether we should use results of > - already handled statements or not. */ > - if (ready_stmts_end == 0 > - && (i - stmt_index >= width || op_index < 1)) > - ready_stmts_end = i; > - > - /* Now we choose operands for the next statement. Non zero > - value in ready_stmts_end means here that we should use > - the result of already generated statements as new operand. */ > - if (ready_stmts_end > 0) > - { > - op1 = gimple_assign_lhs (stmts[stmt_index++]); > - if (ready_stmts_end > stmt_index) > - op2 = gimple_assign_lhs (stmts[stmt_index++]); > - else if (op_index >= 0) > - { > - operand_entry *oe = ops[op_index--]; > - stmt2 = oe->stmt_to_insert; > - op2 = oe->op; > - } > - else > - { > - gcc_assert (stmt_index < i); > - op2 = gimple_assign_lhs (stmts[stmt_index++]); > - } > + /* If the chain has FAM, we do not swap two operands. */ > + if (op_index > 1 && !has_fma) > + swap_ops_for_binary_stmt (ops, op_index - 2); > > - if (stmt_index >= ready_stmts_end) > - ready_stmts_end = 0; > - } > - else > + for (j = 0; j < 2; j++) > { > - if (op_index > 1) > - swap_ops_for_binary_stmt (ops, op_index - 2); > - operand_entry *oe2 = ops[op_index--]; > - operand_entry *oe1 = ops[op_index--]; > - op2 = oe2->op; > - stmt2 = oe2->stmt_to_insert; > - op1 = oe1->op; > - stmt1 = oe1->stmt_to_insert; > + gcc_assert (op_index >= 0); > + oe = ops[op_index--]; > + tmp_op[j] = oe->op; > + /* If the stmt that defines operand has to be inserted, insert it > + before the use. */ > + stmt1 = oe->stmt_to_insert; > + if (stmt1) > + insert_stmt_before_use (stmts[i], stmt1); > + stmt1 = NULL; > } > - > - /* If we emit the last statement then we should put > - operands into the last statement. It will also > - break the loop. */ > - if (op_index < 0 && stmt_index == i) > - i = stmt_num - 1; > + stmts[i] = build_and_add_sum (TREE_TYPE (last_rhs1), > + tmp_op[1], > + tmp_op[0], > + opcode); > + gimple_set_visited (stmts[i], true); > > if (dump_file && (dump_flags & TDF_DETAILS)) > { > - fprintf (dump_file, "Transforming "); > + fprintf (dump_file, " into "); > print_gimple_stmt (dump_file, stmts[i], 0); > } > + } > > - /* If the stmt that defines operand has to be inserted, insert it > - before the use. */ > - if (stmt1) > - insert_stmt_before_use (stmts[i], stmt1); > - if (stmt2) > - insert_stmt_before_use (stmts[i], stmt2); > - stmt1 = stmt2 = NULL; > - > - /* We keep original statement only for the last one. All > - others are recreated. */ > - if (i == stmt_num - 1) > + for (i = width; i < stmt_num; i++) > + { > + /* We keep original statement only for the last one. All others are > + recreated. */ > + if ( op_index < 0) > { > - gimple_assign_set_rhs1 (stmts[i], op1); > - gimple_assign_set_rhs2 (stmts[i], op2); > - update_stmt (stmts[i]); > + if (width_count == 2) > + { > + > + /* We keep original statement only for the last one. All > + others are recreated. */ > + gimple_assign_set_rhs1 (stmts[i], gimple_assign_lhs (stmts[i-1])); > + gimple_assign_set_rhs2 (stmts[i], gimple_assign_lhs (stmts[i-2])); > + update_stmt (stmts[i]); > + } > + else > + { > + > + stmts[i] = > + build_and_add_sum (TREE_TYPE (last_rhs1), > + gimple_assign_lhs (stmts[i-width_count]), > + gimple_assign_lhs (stmts[i-width_count+1]), > + opcode); > + gimple_set_visited (stmts[i], true); > + width_count--; > + } > } > else > { > - stmts[i] = build_and_add_sum (TREE_TYPE (last_rhs1), op1, op2, > opcode); > + /* Attach the rest of the ops to the parallel dependency chain. */ > + oe = ops[op_index--]; > + op1 = oe->op; > + stmt1 = oe->stmt_to_insert; > + if (stmt1) > + insert_stmt_before_use (stmts[i], stmt1); > + stmt1 = NULL; > + stmts[i] = build_and_add_sum (TREE_TYPE (last_rhs1), > + gimple_assign_lhs (stmts[i-width]), > + op1, > + opcode); > gimple_set_visited (stmts[i], true); > } > + > if (dump_file && (dump_flags & TDF_DETAILS)) > { > fprintf (dump_file, " into "); > print_gimple_stmt (dump_file, stmts[i], 0); > } > } > - > remove_visited_stmt_chain (last_rhs1); > } > > @@ -6649,6 +6657,73 @@ transform_stmt_to_multiply (gimple_stmt_iterator *gsi, > gimple *stmt, > } > } > > +/* Rearrange ops may have more FMA when the chain may has more than 2 FMAs. > + Put no-mult ops and mult ops alternately at the end of the queue, which > is > + conducive to generating more FMA and reducing the loss of FMA when > breaking > + the chain. > + E.g. > + a * b + c * d + e generates: > + > + _4 = c_9(D) * d_10(D); > + _12 = .FMA (a_7(D), b_8(D), _4); > + _11 = e_6(D) + _12; > + > + Rearrange ops to -> e + a * b + c * d generates: > + > + _4 = .FMA (c_7(D), d_8(D), _3); > + _11 = .FMA (a_5(D), b_6(D), _4); */ > +static bool > +rank_ops_for_fma (vec<operand_entry *> *ops) > +{ > + operand_entry *oe; > + unsigned int i; > + unsigned int ops_length = ops->length (); > + auto_vec<operand_entry *> ops_mult; > + auto_vec<operand_entry *> ops_others; > + > + FOR_EACH_VEC_ELT (*ops, i, oe) > + { > + if (TREE_CODE (oe->op) == SSA_NAME) > + { > + gimple *def_stmt = SSA_NAME_DEF_STMT (oe->op); > + if (is_gimple_assign (def_stmt) > + && gimple_assign_rhs_code (def_stmt) == MULT_EXPR) > + ops_mult.safe_push (oe); > + else > + ops_others.safe_push (oe); > + } > + else > + ops_others.safe_push (oe); > + } > + /* 1. When ops_mult.length == 2, like the following case, > + > + a * b + c * d + e. > + > + we need to rearrange the ops. > + > + Putting ops that not def from mult in front can generate more FMAs. > + > + 2. If all ops are defined with mult, we don't need to rearrange them. > */ > + if (ops_mult.length () >= 2 && ops_mult.length () != ops_length) > + { > + /* Put no-mult ops and mult ops alternately at the end of the > + queue, which is conducive to generating more FMA and reducing the > + loss of FMA when breaking the chain. */ > + ops->truncate (0); > + ops->splice (ops_mult); > + int j, opindex = ops->length (); > + int others_length = ops_others.length (); > + for (j = 0; j < others_length; j++) > + { > + oe = ops_others.pop (); > + ops->quick_insert (opindex, oe); > + if (opindex > 0) > + opindex--; > + } > + return true; > + } > + return false; > +} > /* Reassociate expressions in basic block BB and its post-dominator as > children. > > @@ -6813,6 +6888,7 @@ reassociate_bb (basic_block bb) > machine_mode mode = TYPE_MODE (TREE_TYPE (lhs)); > int ops_num = ops.length (); > int width; > + bool has_fma = false; > > /* For binary bit operations, if there are at least 3 > operands and the last operand in OPS is a constant, > @@ -6821,11 +6897,23 @@ reassociate_bb (basic_block bb) > often match a canonical bit test when we get to RTL. */ > if (ops.length () > 2 > && (rhs_code == BIT_AND_EXPR > - || rhs_code == BIT_IOR_EXPR > - || rhs_code == BIT_XOR_EXPR) > + || rhs_code == BIT_IOR_EXPR > + || rhs_code == BIT_XOR_EXPR) > && TREE_CODE (ops.last ()->op) == INTEGER_CST) > std::swap (*ops[0], *ops[ops_num - 1]); > > + optimization_type opt_type = bb_optimization_type (bb); > + > + /* If the target support FMA, rank_ops_for_fma will detect if > + the chain has fmas and rearrange the ops if so. */ > + if (direct_internal_fn_supported_p (IFN_FMA, > + TREE_TYPE (lhs), > + opt_type) > + && (rhs_code == PLUS_EXPR || rhs_code == MINUS_EXPR)) > + { > + has_fma = rank_ops_for_fma (&ops); > + } > + > /* Only rewrite the expression tree to parallel in the > last reassoc pass to avoid useless work back-and-forth > with initial linearization. */ > @@ -6839,22 +6927,24 @@ reassociate_bb (basic_block bb) > "Width = %d was chosen for reassociation\n", > width); > rewrite_expr_tree_parallel (as_a <gassign *> (stmt), > - width, ops); > + width, > + has_fma, > + ops); > } > else > - { > - /* When there are three operands left, we want > - to make sure the ones that get the double > - binary op are chosen wisely. */ > - int len = ops.length (); > - if (len >= 3) > + { > + /* When there are three operands left, we want > + to make sure the ones that get the double > + binary op are chosen wisely. */ > + int len = ops.length (); > + if (len >= 3 && !has_fma) > swap_ops_for_binary_stmt (ops, len - 3); > > new_lhs = rewrite_expr_tree (stmt, rhs_code, 0, ops, > powi_result != NULL > || negate_result, > len != orig_len); > - } > + } > > /* If we combined some repeated factors into a > __builtin_powi call, multiply that result by the > -- > 2.25.1