Given recent changes to the dot_prod standard pattern name, this patch fixes the aarch64 back-end by implementing the following changes:
1. Add 2nd mode to all (u|s|us)dot_prod patterns in .md files. 2. Rewrite initialization and function expansion mechanism for simd builtins. 3. Fix all direct calls to back-end `dot_prod' patterns in SVE builtins. Finally, given that it is now possible for the compiler to differentiate between the two- and four-way dot product, we add a test to ensure that autovectorization picks up on dot-product patterns where the result is twice the width of the operands. gcc/ChangeLog: * config/aarch64/aarch64-builtins.cc (enum aarch64_builtins): New AARCH64_BUILTIN_* enum values: SDOTV8QI, SDOTV16QI, UDOTV8QI, UDOTV16QI, USDOTV8QI, USDOTV16QI. (aarch64_init_builtin_dotprod_functions): New. (aarch64_init_simd_builtins): Add call to `aarch64_init_builtin_dotprod_functions'. (aarch64_general_gimple_fold_builtin): Add DOT_PROD_EXPR handling. * config/aarch64/aarch64-simd-builtins.def: Remove macro expansion-based initialization and expansion of (u|s|us)dot_prod builtins. * config/aarch64/aarch64-simd.md (<sur>dot_prod<vsi2qi><vczle><vczbe>): Deleted. (<sur>dot_prod<mode><vsi2qi><vczle><vczbe>): New. (usdot_prod<vsi2qi><vczle><vczbe>): Deleted. (usdot_prod<mode><vsi2qi><vczle><vczbe>): New. (<su>sadv16qi): Adjust call to gen_udot_prod take second mode. (popcount<mode2>): fix use of `udot_prod_optab'. * config/aarch64/aarch64-sve-builtins-base.cc (svdot_impl::expand): s/direct/convert/ in `convert_optab_handler_for_sign' function call. (svusdot_impl::expand): add second mode argument in call to `code_for_dot_prod'. * config/aarch64/aarch64-sve-builtins.cc (function_expander::convert_optab_handler_for_sign): New class method. * config/aarch64/aarch64-sve-builtins.h (class function_expander): Add prototype for new `convert_optab_handler_for_sign' method. * gcc/config/aarch64/aarch64-sve.md (<sur>dot_prod<vsi2qi>): Deleted. (<sur>dot_prod<mode><vsi2qi>): New. (@<sur>dot_prod<vsi2qi>): Deleted. (@<sur>dot_prod<mode><vsi2qi>): New. (<su>sad<vsi2qi>): Adjust call to gen_udot_prod take second mode. * gcc/config/aarch64/aarch64-sve2.md (@aarch64_sve_<sur>dotvnx4sivnx8hi): Deleted. (<sur>dot_prodvnx4sivnx8hi): New. gcc/testsuite/ChangeLog: * gcc.target/aarch64/sme/vect-dotprod-twoway.c (udot2): New. --- gcc/config/aarch64/aarch64-builtins.cc | 71 +++++++++++++++++++ gcc/config/aarch64/aarch64-simd-builtins.def | 4 -- gcc/config/aarch64/aarch64-simd.md | 9 +-- .../aarch64/aarch64-sve-builtins-base.cc | 13 ++-- gcc/config/aarch64/aarch64-sve-builtins.cc | 17 +++++ gcc/config/aarch64/aarch64-sve-builtins.h | 3 + gcc/config/aarch64/aarch64-sve.md | 6 +- gcc/config/aarch64/aarch64-sve2.md | 2 +- gcc/config/aarch64/iterators.md | 1 + .../aarch64/sme/vect-dotprod-twoway.c | 25 +++++++ 10 files changed, 133 insertions(+), 18 deletions(-) create mode 100644 gcc/testsuite/gcc.target/aarch64/sme/vect-dotprod-twoway.c diff --git a/gcc/config/aarch64/aarch64-builtins.cc b/gcc/config/aarch64/aarch64-builtins.cc index 30669f8aa18..6c7c86d0e6e 100644 --- a/gcc/config/aarch64/aarch64-builtins.cc +++ b/gcc/config/aarch64/aarch64-builtins.cc @@ -783,6 +783,12 @@ enum aarch64_builtins AARCH64_SIMD_PATTERN_START = AARCH64_SIMD_BUILTIN_LANE_CHECK + 1, AARCH64_SIMD_BUILTIN_MAX = AARCH64_SIMD_PATTERN_START + ARRAY_SIZE (aarch64_simd_builtin_data) - 1, + AARCH64_BUILTIN_SDOTV8QI, + AARCH64_BUILTIN_SDOTV16QI, + AARCH64_BUILTIN_UDOTV8QI, + AARCH64_BUILTIN_UDOTV16QI, + AARCH64_BUILTIN_USDOTV8QI, + AARCH64_BUILTIN_USDOTV16QI, AARCH64_CRC32_BUILTIN_BASE, AARCH64_CRC32_BUILTINS AARCH64_CRC32_BUILTIN_MAX, @@ -1642,6 +1648,60 @@ handle_arm_neon_h (void) aarch64_init_simd_intrinsics (); } +void +aarch64_init_builtin_dotprod_functions (void) +{ + tree fndecl = NULL; + tree ftype = NULL; + + tree uv8qi = aarch64_simd_builtin_type (V8QImode, qualifier_unsigned); + tree sv8qi = aarch64_simd_builtin_type (V8QImode, qualifier_none); + tree uv16qi = aarch64_simd_builtin_type (V16QImode, qualifier_unsigned); + tree sv16qi = aarch64_simd_builtin_type (V16QImode, qualifier_none); + tree uv2si = aarch64_simd_builtin_type (V2SImode, qualifier_unsigned); + tree sv2si = aarch64_simd_builtin_type (V2SImode, qualifier_none); + tree uv4si = aarch64_simd_builtin_type (V4SImode, qualifier_unsigned); + tree sv4si = aarch64_simd_builtin_type (V4SImode, qualifier_none); + + struct builtin_decls_data + { + tree out_type_node; + tree in_type1_node; + tree in_type2_node; + const char *builtin_name; + int function_code; + }; + +#define NAME(A) "__builtin_aarch64_" #A +#define ENUM(B) AARCH64_BUILTIN_##B + + builtin_decls_data bdda[] = + { + { sv2si, sv8qi, sv8qi, NAME (sdot_prodv8qi), ENUM (SDOTV8QI) }, + { uv2si, uv8qi, uv8qi, NAME (udot_prodv8qi_uuuu), ENUM (UDOTV8QI) }, + { sv2si, uv8qi, sv8qi, NAME (usdot_prodv8qi_suss), ENUM (USDOTV8QI) }, + { sv4si, sv16qi, sv16qi, NAME (sdot_prodv16qi), ENUM (SDOTV16QI) }, + { uv4si, uv16qi, uv16qi, NAME (udot_prodv16qi_uuuu), ENUM (UDOTV16QI) }, + { sv4si, uv16qi, sv16qi, NAME (usdot_prodv16qi_suss), ENUM (USDOTV16QI) }, + }; + +#undef NAME +#undef ENUM + + builtin_decls_data *bdd = bdda; + builtin_decls_data *bdd_end = bdd + (ARRAY_SIZE (bdda)); + + for (; bdd < bdd_end; bdd++) + { + ftype = build_function_type_list (bdd->out_type_node, bdd->in_type1_node, + bdd->in_type2_node, bdd->out_type_node, + NULL_TREE); + fndecl = aarch64_general_add_builtin (bdd->builtin_name, + ftype, bdd->function_code); + aarch64_builtin_decls[bdd->function_code] = fndecl; + } +} + static void aarch64_init_simd_builtins (void) { @@ -1654,6 +1714,8 @@ aarch64_init_simd_builtins (void) aarch64_init_simd_builtin_scalar_types (); aarch64_init_simd_builtin_functions (false); + aarch64_init_builtin_dotprod_functions (); + if (in_lto_p) handle_arm_neon_h (); @@ -3676,6 +3738,15 @@ aarch64_general_gimple_fold_builtin (unsigned int fcode, gcall *stmt, new_stmt = gimple_build_nop (); } break; + case AARCH64_BUILTIN_SDOTV8QI: + case AARCH64_BUILTIN_SDOTV16QI: + case AARCH64_BUILTIN_UDOTV8QI: + case AARCH64_BUILTIN_UDOTV16QI: + case AARCH64_BUILTIN_USDOTV8QI: + case AARCH64_BUILTIN_USDOTV16QI: + new_stmt = gimple_build_assign (gimple_call_lhs (stmt), + DOT_PROD_EXPR, args[0], + args[1], args[2]); default: break; } diff --git a/gcc/config/aarch64/aarch64-simd-builtins.def b/gcc/config/aarch64/aarch64-simd-builtins.def index e65f73d7ba2..ea774ba1d49 100644 --- a/gcc/config/aarch64/aarch64-simd-builtins.def +++ b/gcc/config/aarch64/aarch64-simd-builtins.def @@ -417,10 +417,6 @@ BUILTIN_VSDQ_I_DI (BINOP, srshl, 0, NONE) BUILTIN_VSDQ_I_DI (BINOP_UUS, urshl, 0, NONE) - /* Implemented by <sur><dotprod>_prod<dot_mode>. */ - BUILTIN_VB (TERNOP, sdot_prod, 10, NONE) - BUILTIN_VB (TERNOPU, udot_prod, 10, NONE) - BUILTIN_VB (TERNOP_SUSS, usdot_prod, 10, NONE) /* Implemented by aarch64_<sur><dotprod>_lane{q}<dot_mode>. */ BUILTIN_VB (QUADOP_LANE, sdot_lane, 0, NONE) BUILTIN_VB (QUADOPU_LANE, udot_lane, 0, NONE) diff --git a/gcc/config/aarch64/aarch64-simd.md b/gcc/config/aarch64/aarch64-simd.md index bbeee221f37..e8e1539fcf3 100644 --- a/gcc/config/aarch64/aarch64-simd.md +++ b/gcc/config/aarch64/aarch64-simd.md @@ -568,7 +568,7 @@ (define_expand "cmul<conj_op><mode>3" ;; ... ;; ;; and so the vectorizer provides r, in which the result has to be accumulated. -(define_insn "<sur>dot_prod<vsi2qi><vczle><vczbe>" +(define_insn "<sur>dot_prod<mode><vsi2qi><vczle><vczbe>" [(set (match_operand:VS 0 "register_operand" "=w") (plus:VS (unspec:VS [(match_operand:<VSI2QI> 1 "register_operand" "w") @@ -582,7 +582,7 @@ (define_insn "<sur>dot_prod<vsi2qi><vczle><vczbe>" ;; These instructions map to the __builtins for the Armv8.6-a I8MM usdot ;; (vector) Dot Product operation and the vectorized optab. -(define_insn "usdot_prod<vsi2qi><vczle><vczbe>" +(define_insn "usdot_prod<mode><vsi2qi><vczle><vczbe>" [(set (match_operand:VS 0 "register_operand" "=w") (plus:VS (unspec:VS [(match_operand:<VSI2QI> 1 "register_operand" "w") @@ -1075,7 +1075,7 @@ (define_expand "<su>sadv16qi" rtx ones = force_reg (V16QImode, CONST1_RTX (V16QImode)); rtx abd = gen_reg_rtx (V16QImode); emit_insn (gen_aarch64_<su>abdv16qi (abd, operands[1], operands[2])); - emit_insn (gen_udot_prodv16qi (operands[0], abd, ones, operands[3])); + emit_insn (gen_udot_prodv4siv16qi (operands[0], abd, ones, operands[3])); DONE; } rtx reduc = gen_reg_rtx (V8HImode); @@ -3510,6 +3510,7 @@ (define_expand "popcount<mode>2" { /* Generate a byte popcount. */ machine_mode mode = <bitsize> == 64 ? V8QImode : V16QImode; + machine_mode mode2 = <bitsize> == 64 ? V2SImode : V4SImode; rtx tmp = gen_reg_rtx (mode); auto icode = optab_handler (popcount_optab, mode); emit_insn (GEN_FCN (icode) (tmp, gen_lowpart (mode, operands[1]))); @@ -3520,7 +3521,7 @@ (define_expand "popcount<mode>2" /* For V4SI and V2SI, we can generate a UDOT with a 0 accumulator and a 1 multiplicand. For V2DI, another UAADDLP is needed. */ rtx ones = force_reg (mode, CONST1_RTX (mode)); - auto icode = optab_handler (udot_prod_optab, mode); + auto icode = convert_optab_handler (udot_prod_optab, mode2, mode); mode = <bitsize> == 64 ? V2SImode : V4SImode; rtx dest = mode == <MODE>mode ? operands[0] : gen_reg_rtx (mode); rtx zeros = force_reg (mode, CONST0_RTX (mode)); diff --git a/gcc/config/aarch64/aarch64-sve-builtins-base.cc b/gcc/config/aarch64/aarch64-sve-builtins-base.cc index aa26370d397..12ffaf7e6ca 100644 --- a/gcc/config/aarch64/aarch64-sve-builtins-base.cc +++ b/gcc/config/aarch64/aarch64-sve-builtins-base.cc @@ -757,15 +757,16 @@ public: e.rotate_inputs_left (0, 3); insn_code icode; if (e.type_suffix_ids[1] == NUM_TYPE_SUFFIXES) - icode = e.direct_optab_handler_for_sign (sdot_prod_optab, - udot_prod_optab, - 0, GET_MODE (e.args[0])); + icode = e.convert_optab_handler_for_sign (sdot_prod_optab, + udot_prod_optab, + 0, e.result_mode (), + GET_MODE (e.args[0])); else icode = (e.type_suffix (0).float_p ? CODE_FOR_aarch64_sve_fdotvnx4sfvnx8hf : e.type_suffix (0).unsigned_p - ? CODE_FOR_aarch64_sve_udotvnx4sivnx8hi - : CODE_FOR_aarch64_sve_sdotvnx4sivnx8hi); + ? CODE_FOR_udot_prodvnx4sivnx8hi + : CODE_FOR_sdot_prodvnx4sivnx8hi); return e.use_unpred_insn (icode); } }; @@ -2814,7 +2815,7 @@ public: Hence we do the same rotation on arguments as svdot_impl does. */ e.rotate_inputs_left (0, 3); machine_mode mode = e.vector_mode (0); - insn_code icode = code_for_dot_prod (UNSPEC_USDOT, mode); + insn_code icode = code_for_dot_prod (UNSPEC_USDOT, e.result_mode (), mode); return e.use_exact_insn (icode); } diff --git a/gcc/config/aarch64/aarch64-sve-builtins.cc b/gcc/config/aarch64/aarch64-sve-builtins.cc index f3983a123e3..0650e1d0a4d 100644 --- a/gcc/config/aarch64/aarch64-sve-builtins.cc +++ b/gcc/config/aarch64/aarch64-sve-builtins.cc @@ -3745,6 +3745,23 @@ function_expander::direct_optab_handler_for_sign (optab signed_op, return ::direct_optab_handler (op, mode); } +/* Choose between signed and unsigned convert optabs SIGNED_OP and + UNSIGNED_OP based on the signedness of type suffix SUFFIX_I, then + pick the appropriate optab handler for the mode. Use MODE as the + mode if given, otherwise use the mode of type suffix SUFFIX_I. */ +insn_code +function_expander::convert_optab_handler_for_sign (optab signed_op, + optab unsigned_op, + unsigned int suffix_i, + machine_mode to_mode, + machine_mode from_mode) +{ + if (from_mode == VOIDmode) + from_mode = vector_mode (suffix_i); + optab op = type_suffix (suffix_i).unsigned_p ? unsigned_op : signed_op; + return ::convert_optab_handler (op, to_mode, from_mode); +} + /* Return true if X overlaps any input. */ bool function_expander::overlaps_input_p (rtx x) diff --git a/gcc/config/aarch64/aarch64-sve-builtins.h b/gcc/config/aarch64/aarch64-sve-builtins.h index 9cc07d5fa3d..c277632e1dc 100644 --- a/gcc/config/aarch64/aarch64-sve-builtins.h +++ b/gcc/config/aarch64/aarch64-sve-builtins.h @@ -659,6 +659,9 @@ public: insn_code direct_optab_handler (optab, unsigned int = 0); insn_code direct_optab_handler_for_sign (optab, optab, unsigned int = 0, machine_mode = E_VOIDmode); + insn_code convert_optab_handler_for_sign (optab, optab, unsigned int = 0, + machine_mode = E_VOIDmode, + machine_mode = E_VOIDmode); machine_mode result_mode () const; diff --git a/gcc/config/aarch64/aarch64-sve.md b/gcc/config/aarch64/aarch64-sve.md index 5331e7121d5..ce83a109725 100644 --- a/gcc/config/aarch64/aarch64-sve.md +++ b/gcc/config/aarch64/aarch64-sve.md @@ -7196,7 +7196,7 @@ (define_insn_and_rewrite "*cond_fnma<mode>_any" ;; ------------------------------------------------------------------------- ;; Four-element integer dot-product with accumulation. -(define_insn "<sur>dot_prod<vsi2qi>" +(define_insn "<sur>dot_prod<mode><vsi2qi>" [(set (match_operand:SVE_FULL_SDI 0 "register_operand") (plus:SVE_FULL_SDI (unspec:SVE_FULL_SDI @@ -7234,7 +7234,7 @@ (define_insn "@aarch64_<sur>dot_prod_lane<SVE_FULL_SDI:mode><SVE_FULL_BHI:mode>" } ) -(define_insn "@<sur>dot_prod<vsi2qi>" +(define_insn "@<sur>dot_prod<mode><vsi2qi>" [(set (match_operand:VNx4SI_ONLY 0 "register_operand") (plus:VNx4SI_ONLY (unspec:VNx4SI_ONLY @@ -7292,7 +7292,7 @@ (define_expand "<su>sad<vsi2qi>" rtx ones = force_reg (<VSI2QI>mode, CONST1_RTX (<VSI2QI>mode)); rtx diff = gen_reg_rtx (<VSI2QI>mode); emit_insn (gen_<su>abd<vsi2qi>3 (diff, operands[1], operands[2])); - emit_insn (gen_udot_prod<vsi2qi> (operands[0], diff, ones, operands[3])); + emit_insn (gen_udot_prod<mode><vsi2qi> (operands[0], diff, ones, operands[3])); DONE; } ) diff --git a/gcc/config/aarch64/aarch64-sve2.md b/gcc/config/aarch64/aarch64-sve2.md index 972b03a4fef..725092cc95f 100644 --- a/gcc/config/aarch64/aarch64-sve2.md +++ b/gcc/config/aarch64/aarch64-sve2.md @@ -2021,7 +2021,7 @@ (define_insn "@aarch64_sve_qsub_<sve_int_op>_lane_<mode>" ) ;; Two-way dot-product. -(define_insn "@aarch64_sve_<sur>dotvnx4sivnx8hi" +(define_insn "<sur>dot_prodvnx4sivnx8hi" [(set (match_operand:VNx4SI 0 "register_operand") (plus:VNx4SI (unspec:VNx4SI diff --git a/gcc/config/aarch64/iterators.md b/gcc/config/aarch64/iterators.md index f527b2cfeb8..1864462ccfc 100644 --- a/gcc/config/aarch64/iterators.md +++ b/gcc/config/aarch64/iterators.md @@ -2119,6 +2119,7 @@ (define_mode_attr vp [(V8QI "v") (V16QI "v") (define_mode_attr vsi2qi [(V2SI "v8qi") (V4SI "v16qi") (VNx4SI "vnx16qi") (VNx2DI "vnx8hi")]) + (define_mode_attr VSI2QI [(V2SI "V8QI") (V4SI "V16QI") (VNx4SI "VNx16QI") (VNx2DI "VNx8HI")]) diff --git a/gcc/testsuite/gcc.target/aarch64/sme/vect-dotprod-twoway.c b/gcc/testsuite/gcc.target/aarch64/sme/vect-dotprod-twoway.c new file mode 100644 index 00000000000..453f3a75e6f --- /dev/null +++ b/gcc/testsuite/gcc.target/aarch64/sme/vect-dotprod-twoway.c @@ -0,0 +1,25 @@ +/* { dg-additional-options "-march=armv9.2-a+sme2 -O2 -ftree-vectorize" } */ + +#include <stdint.h> + +uint32_t udot2(int n, uint16_t* data) __arm_streaming +{ + uint32_t sum = 0; + for (int i=0; i<n; i+=1) { + sum += data[i] * data[i]; + } + return sum; +} + +int32_t sdot2(int n, int16_t* data) __arm_streaming +{ + int32_t sum = 0; + for (int i=0; i<n; i+=1) { + sum += data[i] * data[i]; + } + return sum; +} + +/* { dg-final { scan-assembler-times {\tudot\tz[0-9]+\.s, z[0-9]+\.h, z[0-9]+\.h\n} 5 } } */ +/* { dg-final { scan-assembler-times {\tsdot\tz[0-9]+\.s, z[0-9]+\.h, z[0-9]+\.h\n} 5 } } */ +/* { dg-final { scan-assembler-times {\twhilelo\t} 4 } } */ -- 2.34.1