From: Abhishek Kaushik <[email protected]>

The FMA folds in match.pd currently only matches (negate @0) directly.
When the negated operand is wrapped in a type conversion
(e.g. (convert (negate @0))), the simplification to IFN_FNMA does not
trigger.

This prevents folding of patterns such as:

*c = *c - (v8u)(*a * *b);

when the multiply operands undergo vector type conversions before being
passed to FMA. In such cases the expression lowers to neg + mla/mad
instead of the more optimal msb/mls on AArch64 SVE, because the current
fold cannot see through the casts.

Extend the match pattern to allow optional conversions on the negated
operand and the second multiplicand:

(fmas:c (nop_convert? (negate @0)) @1 @2)

When the inner operand has a signed type, the fold rewrites the
expression directly as:

IFN_FNMA (convert @0) @1 @2

For unsigned inner types, to preserve the original unsigned semantics,
the fold performs the multiply-subtract in the unsigned domain and
converts the result back to the original type:

convert (IFN_FNMA @0 (convert:t @1) (convert:t @2))

The match is restricted to nop_convert on the negated operand to avoid
folding through value-changing conversions. This enables recognition of
the subtraction-of-product form even when vector element type casts are
present.

With this change, AArch64 SVE code generation is able to select msb/mls
instead of emitting a separate neg followed by mla/mad.

This patch was bootstrapped and regression tested on aarch64-linux-gnu.

gcc/
        PR target/123897
        * match.pd: Allow conversions in FMA-to-FNMA fold.

gcc/testsuite/
        PR target/123897
        * gcc.target/aarch64/sve/fnma_match.c: New test.
        * gcc.target/aarch64/sve/pr123897.c:
        Update the test to scan for FNMA in the tree dump.
---
 gcc/match.pd                                  |  8 ++-
 .../gcc.target/aarch64/sve/fnma_match.c       | 59 +++++++++++++++++++
 .../gcc.target/aarch64/sve/pr123897.c         |  3 +-
 3 files changed, 68 insertions(+), 2 deletions(-)
 create mode 100644 gcc/testsuite/gcc.target/aarch64/sve/fnma_match.c

diff --git a/gcc/match.pd b/gcc/match.pd
index 7f16fd4e081..e99714d1e93 100644
--- a/gcc/match.pd
+++ b/gcc/match.pd
@@ -10266,7 +10266,13 @@ DEFINE_INT_AND_FLOAT_ROUND_FN (RINT)
   (simplify
    (negate (fmas@3 @0 @1 @2))
    (if (!HONOR_SIGN_DEPENDENT_ROUNDING (type) && single_use (@3))
-    (IFN_FNMS @0 @1 @2))))
+    (IFN_FNMS @0 @1 @2)))
+  (simplify
+   (fmas:c (nop_convert? (negate @0)) @1 @2)
+   (with { tree t = TREE_TYPE (@0); }
+    (if (!TYPE_UNSIGNED(t))
+      (IFN_FNMA (convert @0) @1 @2)
+      (convert (IFN_FNMA @0 (convert:t @1) (convert:t @2)))))))
 
  (simplify
   (IFN_FMS:c (negate @0) @1 @2)
diff --git a/gcc/testsuite/gcc.target/aarch64/sve/fnma_match.c 
b/gcc/testsuite/gcc.target/aarch64/sve/fnma_match.c
new file mode 100644
index 00000000000..9b6d6fe6e3e
--- /dev/null
+++ b/gcc/testsuite/gcc.target/aarch64/sve/fnma_match.c
@@ -0,0 +1,59 @@
+/* { dg-do compile } */
+/* { dg-options "-O2 -march=armv9-a -msve-vector-bits=256" } */
+
+typedef __attribute__((__vector_size__(sizeof(int)*8))) signed int v8i;
+typedef __attribute__((__vector_size__(sizeof(int)*8))) unsigned int v8u;
+
+void g(v8i *a, v8i *b, v8u *c)
+{
+  *c = *c - (v8u)(*a * *b);
+}
+
+v8u g_(v8i a, v8i b, v8u c)
+{
+  return c - (v8u)(a * b);
+}
+
+void h(v8u *a, v8u *b, v8i *c)
+{
+  *c = *c - (v8i)(*a * *b);
+}
+
+v8i h_(v8u a, v8u b, v8i c)
+{
+  return c - (v8i)(a * b);
+}
+
+void x(v8u *a, v8u *b, v8i *c)
+{
+  *c = *c + ((v8i)(-*a) * (v8i)*b);
+}
+
+v8i x_(v8u a, v8u b, v8i c)
+{
+  return c + ((v8i)(-a) * (v8i)b);
+}
+
+void y(v8u *a, v8i *b,v8i *c)
+{
+  *c = *c + ((v8i)(-*a) * *b);
+}
+
+v8i y_(v8u a, v8i b, v8i c)
+{
+  return c + ((v8i)(-a) * b);
+}
+
+void z(v8i *a, v8u *b, v8u *c)
+{
+  *c = *c + ((v8u)(-*a) * *b);
+}
+
+v8u z_(v8i a, v8u b, v8u c)
+{
+  return c + ((v8u)(-a) * b);
+}
+
+/* { dg-final { scan-assembler-times "\\tmsb\\t" 5 } } */
+/* { dg-final { scan-assembler-times "\\tmls\\t" 5 } } */
+/* { dg-final { scan-assembler-not "\\tneg\\t" } } */
diff --git a/gcc/testsuite/gcc.target/aarch64/sve/pr123897.c 
b/gcc/testsuite/gcc.target/aarch64/sve/pr123897.c
index d74efabb7f8..45bc52522a9 100644
--- a/gcc/testsuite/gcc.target/aarch64/sve/pr123897.c
+++ b/gcc/testsuite/gcc.target/aarch64/sve/pr123897.c
@@ -13,4 +13,5 @@ void g(v8i *a,v8i *b,v8u *c)
   *c = *c - (v8u)(*a * *b);
 }
 
-/* { dg-final { scan-tree-dump-times "\.FMA" 2 "widening_mul" } } */
+/* { dg-final { scan-tree-dump-times "\.FMA" 1 "widening_mul" } } */
+/* { dg-final { scan-tree-dump-times "\.FNMA" 1 "widening_mul" } } */
-- 
2.43.0

Reply via email to