Given recent changes to the dot_prod standard pattern name, this patch fixes the aarch64 back-end by implementing the following changes:
1. Add 2nd mode to all (u|s|us)dot_prod patterns in .md files. 2. Rewrite initialization and function expansion mechanism for simd builtins. 3. Fix all direct calls to back-end `dot_prod' patterns in SVE builtins. Finally, given that it is now possible for the compiler to differentiate between the two- and four-way dot product, we add a test to ensure that autovectorization picks up on dot-product patterns where the result is twice the width of the operands. gcc/ChangeLog: * config/aarch64/aarch64-simd.md (<sur>dot_prod<vsi2qi><vczle><vczbe>): Renamed to... (<sur>dot_prod<mode><vsi2qi><vczle><vczbe>): ...this. (usdot_prod<vsi2qi><vczle><vczbe>): Renamed to... (usdot_prod<mode><vsi2qi><vczle><vczbe>): ...this. (<su>sadv16qi): Adjust call to gen_udot_prod take second mode. (popcount<mode2>): fix use of `udot_prod_optab'. * gcc/config/aarch64/aarch64-sve.md (<sur>dot_prod<vsi2qi>): Renamed to... (<sur>dot_prod<mode><vsi2qi>): ...this. (@<sur>dot_prod<vsi2qi>): Renamed to... (@<sur>dot_prod<mode><vsi2qi>): ...this. (<su>sad<vsi2qi>): Adjust call to gen_udot_prod take second mode. * gcc/config/aarch64/aarch64-sve2.md (@aarch64_sve_<sur>dotvnx4sivnx8hi): Renamed to... (<sur>dot_prodvnx4sivnx8hi): ...this. * config/aarch64/aarch64-simd-builtins.def: Modify macro expansion-based initialization and expansion of (u|s|us)dot_prod builtins. * config/aarch64/aarch64-sve-builtins-base.cc (svdot_impl::expand): s/direct/convert/ in `convert_optab_handler_for_sign' function call. (svusdot_impl::expand): add second mode argument in call to `code_for_dot_prod'. * config/aarch64/aarch64-sve-builtins.cc (function_expander::convert_optab_handler_for_sign): New class method. * config/aarch64/aarch64-sve-builtins.h (class function_expander): Add prototype for new `convert_optab_handler_for_sign' method. gcc/testsuite/ChangeLog: * gcc.target/aarch64/sme/vect-dotprod-twoway.c (udot2): New. --- gcc/config/aarch64/aarch64-builtins.cc | 7 ++++++ gcc/config/aarch64/aarch64-simd-builtins.def | 6 ++--- gcc/config/aarch64/aarch64-simd.md | 9 ++++--- .../aarch64/aarch64-sve-builtins-base.cc | 13 +++++----- gcc/config/aarch64/aarch64-sve-builtins.cc | 17 +++++++++++++ gcc/config/aarch64/aarch64-sve-builtins.h | 3 +++ gcc/config/aarch64/aarch64-sve.md | 6 ++--- gcc/config/aarch64/aarch64-sve2.md | 2 +- .../aarch64/sme/vect-dotprod-twoway.c | 25 +++++++++++++++++++ 9 files changed, 71 insertions(+), 17 deletions(-) create mode 100644 gcc/testsuite/gcc.target/aarch64/sme/vect-dotprod-twoway.c diff --git a/gcc/config/aarch64/aarch64-builtins.cc b/gcc/config/aarch64/aarch64-builtins.cc index 30669f8aa18..8af646ab066 100644 --- a/gcc/config/aarch64/aarch64-builtins.cc +++ b/gcc/config/aarch64/aarch64-builtins.cc @@ -458,6 +458,13 @@ aarch64_types_storestruct_lane_p_qualifiers[SIMD_MAX_BUILTIN_ARGS] qualifier_poly, qualifier_struct_load_store_lane_index }; #define TYPES_STORESTRUCT_LANE_P (aarch64_types_storestruct_lane_p_qualifiers) +constexpr insn_code CODE_FOR_aarch64_sdot_prodv8qi = CODE_FOR_sdot_prodv2siv8qi; +constexpr insn_code CODE_FOR_aarch64_udot_prodv8qi = CODE_FOR_udot_prodv2siv8qi; +constexpr insn_code CODE_FOR_aarch64_usdot_prodv8qi = CODE_FOR_usdot_prodv2siv8qi; +constexpr insn_code CODE_FOR_aarch64_sdot_prodv16qi = CODE_FOR_sdot_prodv4siv16qi; +constexpr insn_code CODE_FOR_aarch64_udot_prodv16qi = CODE_FOR_udot_prodv4siv16qi; +constexpr insn_code CODE_FOR_aarch64_usdot_prodv16qi = CODE_FOR_usdot_prodv4siv16qi; + #define CF0(N, X) CODE_FOR_aarch64_##N##X #define CF1(N, X) CODE_FOR_##N##X##1 #define CF2(N, X) CODE_FOR_##N##X##2 diff --git a/gcc/config/aarch64/aarch64-simd-builtins.def b/gcc/config/aarch64/aarch64-simd-builtins.def index e65f73d7ba2..0814f8ba14f 100644 --- a/gcc/config/aarch64/aarch64-simd-builtins.def +++ b/gcc/config/aarch64/aarch64-simd-builtins.def @@ -418,9 +418,9 @@ BUILTIN_VSDQ_I_DI (BINOP_UUS, urshl, 0, NONE) /* Implemented by <sur><dotprod>_prod<dot_mode>. */ - BUILTIN_VB (TERNOP, sdot_prod, 10, NONE) - BUILTIN_VB (TERNOPU, udot_prod, 10, NONE) - BUILTIN_VB (TERNOP_SUSS, usdot_prod, 10, NONE) + BUILTIN_VB (TERNOP, sdot_prod, 0, NONE) + BUILTIN_VB (TERNOPU, udot_prod, 0, NONE) + BUILTIN_VB (TERNOP_SUSS, usdot_prod, 0, NONE) /* Implemented by aarch64_<sur><dotprod>_lane{q}<dot_mode>. */ BUILTIN_VB (QUADOP_LANE, sdot_lane, 0, NONE) BUILTIN_VB (QUADOPU_LANE, udot_lane, 0, NONE) diff --git a/gcc/config/aarch64/aarch64-simd.md b/gcc/config/aarch64/aarch64-simd.md index cc612ec2ca0..e15e547b000 100644 --- a/gcc/config/aarch64/aarch64-simd.md +++ b/gcc/config/aarch64/aarch64-simd.md @@ -568,7 +568,7 @@ (define_expand "cmul<conj_op><mode>3" ;; ... ;; ;; and so the vectorizer provides r, in which the result has to be accumulated. -(define_insn "<sur>dot_prod<vsi2qi><vczle><vczbe>" +(define_insn "<sur>dot_prod<mode><vsi2qi><vczle><vczbe>" [(set (match_operand:VS 0 "register_operand" "=w") (plus:VS (unspec:VS [(match_operand:<VSI2QI> 1 "register_operand" "w") @@ -582,7 +582,7 @@ (define_insn "<sur>dot_prod<vsi2qi><vczle><vczbe>" ;; These instructions map to the __builtins for the Armv8.6-a I8MM usdot ;; (vector) Dot Product operation and the vectorized optab. -(define_insn "usdot_prod<vsi2qi><vczle><vczbe>" +(define_insn "usdot_prod<mode><vsi2qi><vczle><vczbe>" [(set (match_operand:VS 0 "register_operand" "=w") (plus:VS (unspec:VS [(match_operand:<VSI2QI> 1 "register_operand" "w") @@ -1075,7 +1075,7 @@ (define_expand "<su>sadv16qi" rtx ones = force_reg (V16QImode, CONST1_RTX (V16QImode)); rtx abd = gen_reg_rtx (V16QImode); emit_insn (gen_aarch64_<su>abdv16qi (abd, operands[1], operands[2])); - emit_insn (gen_udot_prodv16qi (operands[0], abd, ones, operands[3])); + emit_insn (gen_udot_prodv4siv16qi (operands[0], abd, ones, operands[3])); DONE; } rtx reduc = gen_reg_rtx (V8HImode); @@ -3528,6 +3528,7 @@ (define_expand "popcount<mode>2" /* Generate a byte popcount. */ machine_mode mode = <bitsize> == 64 ? V8QImode : V16QImode; + machine_mode mode2 = <bitsize> == 64 ? V2SImode : V4SImode; rtx tmp = gen_reg_rtx (mode); auto icode = optab_handler (popcount_optab, mode); emit_insn (GEN_FCN (icode) (tmp, gen_lowpart (mode, operands[1]))); @@ -3538,7 +3539,7 @@ (define_expand "popcount<mode>2" /* For V4SI and V2SI, we can generate a UDOT with a 0 accumulator and a 1 multiplicand. For V2DI, another UAADDLP is needed. */ rtx ones = force_reg (mode, CONST1_RTX (mode)); - auto icode = optab_handler (udot_prod_optab, mode); + auto icode = convert_optab_handler (udot_prod_optab, mode2, mode); mode = <bitsize> == 64 ? V2SImode : V4SImode; rtx dest = mode == <MODE>mode ? operands[0] : gen_reg_rtx (mode); rtx zeros = force_reg (mode, CONST0_RTX (mode)); diff --git a/gcc/config/aarch64/aarch64-sve-builtins-base.cc b/gcc/config/aarch64/aarch64-sve-builtins-base.cc index d55bee0b72f..42e9cec57ad 100644 --- a/gcc/config/aarch64/aarch64-sve-builtins-base.cc +++ b/gcc/config/aarch64/aarch64-sve-builtins-base.cc @@ -804,15 +804,16 @@ public: e.rotate_inputs_left (0, 3); insn_code icode; if (e.type_suffix_ids[1] == NUM_TYPE_SUFFIXES) - icode = e.direct_optab_handler_for_sign (sdot_prod_optab, - udot_prod_optab, - 0, GET_MODE (e.args[0])); + icode = e.convert_optab_handler_for_sign (sdot_prod_optab, + udot_prod_optab, + 0, e.result_mode (), + GET_MODE (e.args[0])); else icode = (e.type_suffix (0).float_p ? CODE_FOR_aarch64_sve_fdotvnx4sfvnx8hf : e.type_suffix (0).unsigned_p - ? CODE_FOR_aarch64_sve_udotvnx4sivnx8hi - : CODE_FOR_aarch64_sve_sdotvnx4sivnx8hi); + ? CODE_FOR_udot_prodvnx4sivnx8hi + : CODE_FOR_sdot_prodvnx4sivnx8hi); return e.use_unpred_insn (icode); } }; @@ -2861,7 +2862,7 @@ public: Hence we do the same rotation on arguments as svdot_impl does. */ e.rotate_inputs_left (0, 3); machine_mode mode = e.vector_mode (0); - insn_code icode = code_for_dot_prod (UNSPEC_USDOT, mode); + insn_code icode = code_for_dot_prod (UNSPEC_USDOT, e.result_mode (), mode); return e.use_exact_insn (icode); } diff --git a/gcc/config/aarch64/aarch64-sve-builtins.cc b/gcc/config/aarch64/aarch64-sve-builtins.cc index 0a560eaedca..975eca0bbd6 100644 --- a/gcc/config/aarch64/aarch64-sve-builtins.cc +++ b/gcc/config/aarch64/aarch64-sve-builtins.cc @@ -3745,6 +3745,23 @@ function_expander::direct_optab_handler_for_sign (optab signed_op, return ::direct_optab_handler (op, mode); } +/* Choose between signed and unsigned convert optabs SIGNED_OP and + UNSIGNED_OP based on the signedness of type suffix SUFFIX_I, then + pick the appropriate optab handler for the mode. Use MODE as the + mode if given, otherwise use the mode of type suffix SUFFIX_I. */ +insn_code +function_expander::convert_optab_handler_for_sign (optab signed_op, + optab unsigned_op, + unsigned int suffix_i, + machine_mode to_mode, + machine_mode from_mode) +{ + if (from_mode == VOIDmode) + from_mode = vector_mode (suffix_i); + optab op = type_suffix (suffix_i).unsigned_p ? unsigned_op : signed_op; + return ::convert_optab_handler (op, to_mode, from_mode); +} + /* Return true if X overlaps any input. */ bool function_expander::overlaps_input_p (rtx x) diff --git a/gcc/config/aarch64/aarch64-sve-builtins.h b/gcc/config/aarch64/aarch64-sve-builtins.h index 9ab6f202c30..7534a58c3d7 100644 --- a/gcc/config/aarch64/aarch64-sve-builtins.h +++ b/gcc/config/aarch64/aarch64-sve-builtins.h @@ -659,6 +659,9 @@ public: insn_code direct_optab_handler (optab, unsigned int = 0); insn_code direct_optab_handler_for_sign (optab, optab, unsigned int = 0, machine_mode = E_VOIDmode); + insn_code convert_optab_handler_for_sign (optab, optab, unsigned int = 0, + machine_mode = E_VOIDmode, + machine_mode = E_VOIDmode); machine_mode result_mode () const; diff --git a/gcc/config/aarch64/aarch64-sve.md b/gcc/config/aarch64/aarch64-sve.md index a5cd42be9d5..2fe18bdacfe 100644 --- a/gcc/config/aarch64/aarch64-sve.md +++ b/gcc/config/aarch64/aarch64-sve.md @@ -7197,7 +7197,7 @@ (define_insn_and_rewrite "*cond_fnma<mode>_any" ;; ------------------------------------------------------------------------- ;; Four-element integer dot-product with accumulation. -(define_insn "<sur>dot_prod<vsi2qi>" +(define_insn "<sur>dot_prod<mode><vsi2qi>" [(set (match_operand:SVE_FULL_SDI 0 "register_operand") (plus:SVE_FULL_SDI (unspec:SVE_FULL_SDI @@ -7235,7 +7235,7 @@ (define_insn "@aarch64_<sur>dot_prod_lane<SVE_FULL_SDI:mode><SVE_FULL_BHI:mode>" } ) -(define_insn "@<sur>dot_prod<vsi2qi>" +(define_insn "@<sur>dot_prod<mode><vsi2qi>" [(set (match_operand:VNx4SI_ONLY 0 "register_operand") (plus:VNx4SI_ONLY (unspec:VNx4SI_ONLY @@ -7293,7 +7293,7 @@ (define_expand "<su>sad<vsi2qi>" rtx ones = force_reg (<VSI2QI>mode, CONST1_RTX (<VSI2QI>mode)); rtx diff = gen_reg_rtx (<VSI2QI>mode); emit_insn (gen_<su>abd<vsi2qi>3 (diff, operands[1], operands[2])); - emit_insn (gen_udot_prod<vsi2qi> (operands[0], diff, ones, operands[3])); + emit_insn (gen_udot_prod<mode><vsi2qi> (operands[0], diff, ones, operands[3])); DONE; } ) diff --git a/gcc/config/aarch64/aarch64-sve2.md b/gcc/config/aarch64/aarch64-sve2.md index 972b03a4fef..725092cc95f 100644 --- a/gcc/config/aarch64/aarch64-sve2.md +++ b/gcc/config/aarch64/aarch64-sve2.md @@ -2021,7 +2021,7 @@ (define_insn "@aarch64_sve_qsub_<sve_int_op>_lane_<mode>" ) ;; Two-way dot-product. -(define_insn "@aarch64_sve_<sur>dotvnx4sivnx8hi" +(define_insn "<sur>dot_prodvnx4sivnx8hi" [(set (match_operand:VNx4SI 0 "register_operand") (plus:VNx4SI (unspec:VNx4SI diff --git a/gcc/testsuite/gcc.target/aarch64/sme/vect-dotprod-twoway.c b/gcc/testsuite/gcc.target/aarch64/sme/vect-dotprod-twoway.c new file mode 100644 index 00000000000..453f3a75e6f --- /dev/null +++ b/gcc/testsuite/gcc.target/aarch64/sme/vect-dotprod-twoway.c @@ -0,0 +1,25 @@ +/* { dg-additional-options "-march=armv9.2-a+sme2 -O2 -ftree-vectorize" } */ + +#include <stdint.h> + +uint32_t udot2(int n, uint16_t* data) __arm_streaming +{ + uint32_t sum = 0; + for (int i=0; i<n; i+=1) { + sum += data[i] * data[i]; + } + return sum; +} + +int32_t sdot2(int n, int16_t* data) __arm_streaming +{ + int32_t sum = 0; + for (int i=0; i<n; i+=1) { + sum += data[i] * data[i]; + } + return sum; +} + +/* { dg-final { scan-assembler-times {\tudot\tz[0-9]+\.s, z[0-9]+\.h, z[0-9]+\.h\n} 5 } } */ +/* { dg-final { scan-assembler-times {\tsdot\tz[0-9]+\.s, z[0-9]+\.h, z[0-9]+\.h\n} 5 } } */ +/* { dg-final { scan-assembler-times {\twhilelo\t} 4 } } */ -- 2.34.1