From: Ju-Zhe Zhong <juzhe.zh...@rivai.ai> Hi, Richard and Richi.
This patch is adding cond_len_* operations pattern for target support loop control with length. These patterns will be used in these following case: 1. Integer division: void f (int32_t *restrict a, int32_t *restrict b, int32_t *restrict c, int n) { for (int i = 0; i < n; ++i) { a[i] = b[i] / c[i]; } } ARM SVE IR: ... max_mask_36 = .WHILE_ULT (0, bnd.5_32, { 0, ... }); Loop: ... # loop_mask_29 = PHI <next_mask_37(4), max_mask_36(3)> ... vect__4.8_28 = .MASK_LOAD (_33, 32B, loop_mask_29); ... vect__6.11_25 = .MASK_LOAD (_20, 32B, loop_mask_29); vect__8.12_24 = .COND_DIV (loop_mask_29, vect__4.8_28, vect__6.11_25, vect__4.8_28); ... .MASK_STORE (_1, 32B, loop_mask_29, vect__8.12_24); ... next_mask_37 = .WHILE_ULT (_2, bnd.5_32, { 0, ... }); ... For target like RVV who support loop control with length, we want to see IR as follows: Loop: ... # loop_len_29 = SELECT_VL ... vect__4.8_28 = .LEN_MASK_LOAD (_33, 32B, loop_len_29); ... vect__6.11_25 = .LEN_MASK_LOAD (_20, 32B, loop_len_29); vect__8.12_24 = .COND_LEN_DIV (dummp_mask, vect__4.8_28, vect__6.11_25, vect__4.8_28, loop_len_29); ... .LEN_MASK_STORE (_1, 32B, loop_len_29, vect__8.12_24); ... next_mask_37 = .WHILE_ULT (_2, bnd.5_32, { 0, ... }); ... Notice here, we use dummp_mask = { -1, -1, .... , -1 } 2. Integer conditional division: Similar case with (1) but with condtion: void f (int32_t *restrict a, int32_t *restrict b, int32_t *restrict c, int32_t * cond, int n) { for (int i = 0; i < n; ++i) { if (cond[i]) a[i] = b[i] / c[i]; } } ARM SVE: ... max_mask_76 = .WHILE_ULT (0, bnd.6_52, { 0, ... }); Loop: ... # loop_mask_55 = PHI <next_mask_77(5), max_mask_76(4)> ... vect__4.9_56 = .MASK_LOAD (_51, 32B, loop_mask_55); mask__29.10_58 = vect__4.9_56 != { 0, ... }; vec_mask_and_61 = loop_mask_55 & mask__29.10_58; ... vect__6.13_62 = .MASK_LOAD (_24, 32B, vec_mask_and_61); ... vect__8.16_66 = .MASK_LOAD (_1, 32B, vec_mask_and_61); vect__10.17_68 = .COND_DIV (vec_mask_and_61, vect__6.13_62, vect__8.16_66, vect__6.13_62); ... .MASK_STORE (_2, 32B, vec_mask_and_61, vect__10.17_68); ... next_mask_77 = .WHILE_ULT (_3, bnd.6_52, { 0, ... }); Here, ARM SVE use vec_mask_and_61 = loop_mask_55 & mask__29.10_58; to gurantee the correct result. However, target with length control can not perform this elegant flow, for RVV, we would expect: Loop: ... loop_len_55 = SELECT_VL ... mask__29.10_58 = vect__4.9_56 != { 0, ... }; ... vect__10.17_68 = .COND_LEN_DIV (mask__29.10_58, vect__6.13_62, vect__8.16_66, vect__6.13_62, loop_len_55); ... Here we expect COND_LEN_DIV predicated by a real mask which is the outcome of comparison: mask__29.10_58 = vect__4.9_56 != { 0, ... }; and a real length which is produced by loop control : loop_len_55 = SELECT_VL 3. conditional Floating-point operations (no -ffast-math): void f (float *restrict a, float *restrict b, int32_t *restrict cond, int n) { for (int i = 0; i < n; ++i) { if (cond[i]) a[i] = b[i] + a[i]; } } ARM SVE IR: max_mask_70 = .WHILE_ULT (0, bnd.6_46, { 0, ... }); ... # loop_mask_49 = PHI <next_mask_71(4), max_mask_70(3)> ... mask__27.10_52 = vect__4.9_50 != { 0, ... }; vec_mask_and_55 = loop_mask_49 & mask__27.10_52; ... vect__9.17_62 = .COND_ADD (vec_mask_and_55, vect__6.13_56, vect__8.16_60, vect__6.13_56); ... next_mask_71 = .WHILE_ULT (_22, bnd.6_46, { 0, ... }); ... For RVV, we would expect IR: ... loop_len_49 = SELECT_VL ... mask__27.10_52 = vect__4.9_50 != { 0, ... }; ... vect__9.17_62 = .COND_LEN_ADD (mask__27.10_52, vect__6.13_56, vect__8.16_60, vect__6.13_56, loop_len_49); ... 4. Conditional un-ordered reduction: int32_t f (int32_t *restrict a, int32_t *restrict cond, int n) { int32_t result = 0; for (int i = 0; i < n; ++i) { if (cond[i]) result += a[i]; } return result; } ARM SVE IR: Loop: # vect_result_18.7_37 = PHI <vect__33.16_51(4), { 0, ... }(3)> ... # loop_mask_40 = PHI <next_mask_58(4), max_mask_57(3)> ... mask__17.11_43 = vect__4.10_41 != { 0, ... }; vec_mask_and_46 = loop_mask_40 & mask__17.11_43; ... vect__33.16_51 = .COND_ADD (vec_mask_and_46, vect_result_18.7_37, vect__7.14_47, vect_result_18.7_37); ... next_mask_58 = .WHILE_ULT (_15, bnd.6_36, { 0, ... }); ... Epilogue: _53 = .REDUC_PLUS (vect__33.16_51); [tail call] For RVV, we expect: Loop: # vect_result_18.7_37 = PHI <vect__33.16_51(4), { 0, ... }(3)> ... loop_len_40 = SELECT_VL ... mask__17.11_43 = vect__4.10_41 != { 0, ... }; ... vect__33.16_51 = .COND_LEN_ADD (mask__17.11_43, vect_result_18.7_37, vect__7.14_47, vect_result_18.7_37, loop_len_40); ... next_mask_58 = .WHILE_ULT (_15, bnd.6_36, { 0, ... }); ... Epilogue: _53 = .REDUC_PLUS (vect__33.16_51); [tail call] I name these patterns as "cond_len_*" since I want the length operand comes after mask operand and all other operands except length operand same order as "cond_*" patterns. Such order will make life easier in the following loop vectorizer support. gcc/ChangeLog: * doc/md.texi: Add COND_LEN_* operations for loop control with length. * internal-fn.cc (cond_len_unary_direct): Ditto. (cond_len_binary_direct): Ditto. (cond_len_ternary_direct): Ditto. (expand_cond_len_unary_optab_fn): Ditto. (expand_cond_len_binary_optab_fn): Ditto. (expand_cond_len_ternary_optab_fn): Ditto. (direct_cond_len_unary_optab_supported_p): Ditto. (direct_cond_len_binary_optab_supported_p): Ditto. (direct_cond_len_ternary_optab_supported_p): Ditto. * internal-fn.def (COND_LEN_ADD): Ditto. (COND_LEN_SUB): Ditto. (COND_LEN_MUL): Ditto. (COND_LEN_DIV): Ditto. (COND_LEN_MOD): Ditto. (COND_LEN_RDIV): Ditto. (COND_LEN_MIN): Ditto. (COND_LEN_MAX): Ditto. (COND_LEN_FMIN): Ditto. (COND_LEN_FMAX): Ditto. (COND_LEN_AND): Ditto. (COND_LEN_IOR): Ditto. (COND_LEN_XOR): Ditto. (COND_LEN_SHL): Ditto. (COND_LEN_SHR): Ditto. (COND_LEN_FMA): Ditto. (COND_LEN_FMS): Ditto. (COND_LEN_FNMA): Ditto. (COND_LEN_FNMS): Ditto. (COND_LEN_NEG): Ditto. * optabs.def (OPTAB_D): Ditto. --- gcc/doc/md.texi | 80 +++++++++++++++++++++++++++++++++++++++++++++ gcc/internal-fn.cc | 15 +++++++++ gcc/internal-fn.def | 38 +++++++++++++++++++++ gcc/optabs.def | 24 ++++++++++++++ 4 files changed, 157 insertions(+) diff --git a/gcc/doc/md.texi b/gcc/doc/md.texi index b30a824488b..287726d642b 100644 --- a/gcc/doc/md.texi +++ b/gcc/doc/md.texi @@ -7234,6 +7234,86 @@ for (i = 0; i < GET_MODE_NUNITS (@var{m}); i++) op0[i] = op1[i] ? fma (op2[i], op3[i], op4[i]) : op5[i]; @end smallexample +@cindex @code{cond_len_add@var{mode}} instruction pattern +@cindex @code{cond_len_sub@var{mode}} instruction pattern +@cindex @code{cond_len_mul@var{mode}} instruction pattern +@cindex @code{cond_len_div@var{mode}} instruction pattern +@cindex @code{cond_len_udiv@var{mode}} instruction pattern +@cindex @code{cond_len_mod@var{mode}} instruction pattern +@cindex @code{cond_len_umod@var{mode}} instruction pattern +@cindex @code{cond_len_and@var{mode}} instruction pattern +@cindex @code{cond_len_ior@var{mode}} instruction pattern +@cindex @code{cond_len_xor@var{mode}} instruction pattern +@cindex @code{cond_len_smin@var{mode}} instruction pattern +@cindex @code{cond_len_smax@var{mode}} instruction pattern +@cindex @code{cond_len_umin@var{mode}} instruction pattern +@cindex @code{cond_len_umax@var{mode}} instruction pattern +@cindex @code{cond_len_fmin@var{mode}} instruction pattern +@cindex @code{cond_len_fmax@var{mode}} instruction pattern +@cindex @code{cond_len_ashl@var{mode}} instruction pattern +@cindex @code{cond_len_ashr@var{mode}} instruction pattern +@cindex @code{cond_len_lshr@var{mode}} instruction pattern +@item @samp{cond_len_add@var{mode}} +@itemx @samp{cond_len_sub@var{mode}} +@itemx @samp{cond_len_mul@var{mode}} +@itemx @samp{cond_len_div@var{mode}} +@itemx @samp{cond_len_udiv@var{mode}} +@itemx @samp{cond_len_mod@var{mode}} +@itemx @samp{cond_len_umod@var{mode}} +@itemx @samp{cond_len_and@var{mode}} +@itemx @samp{cond_len_ior@var{mode}} +@itemx @samp{cond_len_xor@var{mode}} +@itemx @samp{cond_len_smin@var{mode}} +@itemx @samp{cond_len_smax@var{mode}} +@itemx @samp{cond_len_umin@var{mode}} +@itemx @samp{cond_len_umax@var{mode}} +@itemx @samp{cond_len_fmin@var{mode}} +@itemx @samp{cond_len_fmax@var{mode}} +@itemx @samp{cond_len_ashl@var{mode}} +@itemx @samp{cond_len_ashr@var{mode}} +@itemx @samp{cond_len_lshr@var{mode}} +When operand 1 is true and element index < operand 5, perform an operation on operands 2 and 3 and +store the result in operand 0, otherwise store operand 4 in operand 0. +The operation only works for the operands are vectors. + +@smallexample +for (i = 0; i < ops[5]; i++) + op0[i] = op1[i] ? op2[i] @var{op} op3[i] : op4[i]; +@end smallexample + +where, for example, @var{op} is @code{+} for @samp{cond_len_add@var{mode}}. + +When defined for floating-point modes, the contents of @samp{op3[i]} +are not interpreted if @samp{op1[i]} is false, just like they would not +be in a normal C @samp{?:} condition. + +Operands 0, 2, 3 and 4 all have mode @var{m}. Operand 1 is a scalar +integer if @var{m} is scalar, otherwise it has the mode returned by +@code{TARGET_VECTORIZE_GET_MASK_MODE}. Operand 5 has whichever +integer mode the target prefers. + +@samp{cond_@var{op}@var{mode}} generally corresponds to a conditional +form of @samp{@var{op}@var{mode}3}. As an exception, the vector forms +of shifts correspond to patterns like @code{vashl@var{mode}3} rather +than patterns like @code{ashl@var{mode}3}. + +@cindex @code{cond_len_fma@var{mode}} instruction pattern +@cindex @code{cond_len_fms@var{mode}} instruction pattern +@cindex @code{cond_len_fnma@var{mode}} instruction pattern +@cindex @code{cond_len_fnms@var{mode}} instruction pattern +@item @samp{cond_len_fma@var{mode}} +@itemx @samp{cond_len_fms@var{mode}} +@itemx @samp{cond_len_fnma@var{mode}} +@itemx @samp{cond_len_fnms@var{mode}} +Like @samp{cond_len_add@var{m}}, except that the conditional operation +takes 3 operands rather than two. For example, the vector form of +@samp{cond_len_fma@var{mode}} is equivalent to: + +@smallexample +for (i = 0; i < ops[6]; i++) + op0[i] = op1[i] ? fma (op2[i], op3[i], op4[i]) : op5[i]; +@end smallexample + @cindex @code{neg@var{mode}cc} instruction pattern @item @samp{neg@var{mode}cc} Similar to @samp{mov@var{mode}cc} but for conditional negation. Conditionally diff --git a/gcc/internal-fn.cc b/gcc/internal-fn.cc index 278db7b1805..b0700aa1998 100644 --- a/gcc/internal-fn.cc +++ b/gcc/internal-fn.cc @@ -183,6 +183,9 @@ init_internal_fns () #define cond_unary_direct { 1, 1, true } #define cond_binary_direct { 1, 1, true } #define cond_ternary_direct { 1, 1, true } +#define cond_len_unary_direct { 1, 1, true } +#define cond_len_binary_direct { 1, 1, true } +#define cond_len_ternary_direct { 1, 1, true } #define while_direct { 0, 2, false } #define fold_extract_direct { 2, 2, false } #define fold_left_direct { 1, 1, false } @@ -3869,6 +3872,15 @@ expand_convert_optab_fn (internal_fn fn, gcall *stmt, convert_optab optab, #define expand_cond_ternary_optab_fn(FN, STMT, OPTAB) \ expand_direct_optab_fn (FN, STMT, OPTAB, 5) +#define expand_cond_len_unary_optab_fn(FN, STMT, OPTAB) \ + expand_direct_optab_fn (FN, STMT, OPTAB, 4) + +#define expand_cond_len_binary_optab_fn(FN, STMT, OPTAB) \ + expand_direct_optab_fn (FN, STMT, OPTAB, 5) + +#define expand_cond_len_ternary_optab_fn(FN, STMT, OPTAB) \ + expand_direct_optab_fn (FN, STMT, OPTAB, 6) + #define expand_fold_extract_optab_fn(FN, STMT, OPTAB) \ expand_direct_optab_fn (FN, STMT, OPTAB, 3) @@ -3964,6 +3976,9 @@ multi_vector_optab_supported_p (convert_optab optab, tree_pair types, #define direct_cond_unary_optab_supported_p direct_optab_supported_p #define direct_cond_binary_optab_supported_p direct_optab_supported_p #define direct_cond_ternary_optab_supported_p direct_optab_supported_p +#define direct_cond_len_unary_optab_supported_p direct_optab_supported_p +#define direct_cond_len_binary_optab_supported_p direct_optab_supported_p +#define direct_cond_len_ternary_optab_supported_p direct_optab_supported_p #define direct_mask_load_optab_supported_p convert_optab_supported_p #define direct_load_lanes_optab_supported_p multi_vector_optab_supported_p #define direct_mask_load_lanes_optab_supported_p multi_vector_optab_supported_p diff --git a/gcc/internal-fn.def b/gcc/internal-fn.def index 238b7ee0bc9..ea750a921ed 100644 --- a/gcc/internal-fn.def +++ b/gcc/internal-fn.def @@ -72,6 +72,10 @@ along with GCC; see the file COPYING3. If not see - fold_left: for scalar = FN (scalar, vector), keyed off the vector mode - check_ptrs: used for check_{raw,war}_ptrs + - cond_len_unary: a conditional unary optab, such as cond_len_neg<mode> + - cond_len_binary: a conditional binary optab, such as cond_len_add<mode> + - cond_len_ternary: a conditional ternary optab, such as cond_len_fma_rev<mode> + DEF_INTERNAL_SIGNED_OPTAB_FN defines an internal function that maps to one of two optabs, depending on the signedness of an input. SIGNED_OPTAB and UNSIGNED_OPTAB are the optabs for signed and @@ -248,6 +252,40 @@ DEF_INTERNAL_OPTAB_FN (COND_FNMS, ECF_CONST, cond_fnms, cond_ternary) DEF_INTERNAL_OPTAB_FN (COND_NEG, ECF_CONST, cond_neg, cond_unary) +DEF_INTERNAL_OPTAB_FN (COND_LEN_ADD, ECF_CONST, cond_len_add, cond_len_binary) +DEF_INTERNAL_OPTAB_FN (COND_LEN_SUB, ECF_CONST, cond_len_sub, cond_len_binary) +DEF_INTERNAL_OPTAB_FN (COND_LEN_MUL, ECF_CONST, cond_len_smul, cond_len_binary) +DEF_INTERNAL_SIGNED_OPTAB_FN (COND_LEN_DIV, ECF_CONST, first, cond_len_sdiv, + cond_len_udiv, cond_len_binary) +DEF_INTERNAL_SIGNED_OPTAB_FN (COND_LEN_MOD, ECF_CONST, first, cond_len_smod, + cond_len_umod, cond_len_binary) +DEF_INTERNAL_OPTAB_FN (COND_LEN_RDIV, ECF_CONST, cond_len_sdiv, cond_len_binary) +DEF_INTERNAL_SIGNED_OPTAB_FN (COND_LEN_MIN, ECF_CONST, first, cond_len_smin, + cond_len_umin, cond_len_binary) +DEF_INTERNAL_SIGNED_OPTAB_FN (COND_LEN_MAX, ECF_CONST, first, cond_len_smax, + cond_len_umax, cond_len_binary) +DEF_INTERNAL_OPTAB_FN (COND_LEN_FMIN, ECF_CONST, cond_len_fmin, cond_len_binary) +DEF_INTERNAL_OPTAB_FN (COND_LEN_FMAX, ECF_CONST, cond_len_fmax, cond_len_binary) +DEF_INTERNAL_OPTAB_FN (COND_LEN_AND, ECF_CONST | ECF_NOTHROW, cond_len_and, + cond_len_binary) +DEF_INTERNAL_OPTAB_FN (COND_LEN_IOR, ECF_CONST | ECF_NOTHROW, cond_len_ior, + cond_len_binary) +DEF_INTERNAL_OPTAB_FN (COND_LEN_XOR, ECF_CONST | ECF_NOTHROW, cond_len_xor, + cond_len_binary) +DEF_INTERNAL_OPTAB_FN (COND_LEN_SHL, ECF_CONST | ECF_NOTHROW, cond_len_ashl, + cond_len_binary) +DEF_INTERNAL_SIGNED_OPTAB_FN (COND_LEN_SHR, ECF_CONST | ECF_NOTHROW, first, + cond_len_ashr, cond_len_lshr, cond_len_binary) + +DEF_INTERNAL_OPTAB_FN (COND_LEN_FMA, ECF_CONST, cond_len_fma, cond_len_ternary) +DEF_INTERNAL_OPTAB_FN (COND_LEN_FMS, ECF_CONST, cond_len_fms, cond_len_ternary) +DEF_INTERNAL_OPTAB_FN (COND_LEN_FNMA, ECF_CONST, cond_len_fnma, + cond_len_ternary) +DEF_INTERNAL_OPTAB_FN (COND_LEN_FNMS, ECF_CONST, cond_len_fnms, + cond_len_ternary) + +DEF_INTERNAL_OPTAB_FN (COND_LEN_NEG, ECF_CONST, cond_len_neg, cond_len_unary) + DEF_INTERNAL_OPTAB_FN (RSQRT, ECF_CONST, rsqrt, unary) DEF_INTERNAL_OPTAB_FN (REDUC_PLUS, ECF_CONST | ECF_NOTHROW, diff --git a/gcc/optabs.def b/gcc/optabs.def index 73c9a0c760f..3dae228fba6 100644 --- a/gcc/optabs.def +++ b/gcc/optabs.def @@ -254,6 +254,30 @@ OPTAB_D (cond_fms_optab, "cond_fms$a") OPTAB_D (cond_fnma_optab, "cond_fnma$a") OPTAB_D (cond_fnms_optab, "cond_fnms$a") OPTAB_D (cond_neg_optab, "cond_neg$a") +OPTAB_D (cond_len_add_optab, "cond_len_add$a") +OPTAB_D (cond_len_sub_optab, "cond_len_sub$a") +OPTAB_D (cond_len_smul_optab, "cond_len_mul$a") +OPTAB_D (cond_len_sdiv_optab, "cond_len_div$a") +OPTAB_D (cond_len_smod_optab, "cond_len_mod$a") +OPTAB_D (cond_len_udiv_optab, "cond_len_udiv$a") +OPTAB_D (cond_len_umod_optab, "cond_len_umod$a") +OPTAB_D (cond_len_and_optab, "cond_len_and$a") +OPTAB_D (cond_len_ior_optab, "cond_len_ior$a") +OPTAB_D (cond_len_xor_optab, "cond_len_xor$a") +OPTAB_D (cond_len_ashl_optab, "cond_len_ashl$a") +OPTAB_D (cond_len_ashr_optab, "cond_len_ashr$a") +OPTAB_D (cond_len_lshr_optab, "cond_len_lshr$a") +OPTAB_D (cond_len_smin_optab, "cond_len_smin$a") +OPTAB_D (cond_len_smax_optab, "cond_len_smax$a") +OPTAB_D (cond_len_umin_optab, "cond_len_umin$a") +OPTAB_D (cond_len_umax_optab, "cond_len_umax$a") +OPTAB_D (cond_len_fmin_optab, "cond_len_fmin$a") +OPTAB_D (cond_len_fmax_optab, "cond_len_fmax$a") +OPTAB_D (cond_len_fma_optab, "cond_len_fma$a") +OPTAB_D (cond_len_fms_optab, "cond_len_fms$a") +OPTAB_D (cond_len_fnma_optab, "cond_len_fnma$a") +OPTAB_D (cond_len_fnms_optab, "cond_len_fnms$a") +OPTAB_D (cond_len_neg_optab, "cond_len_neg$a") OPTAB_D (cmov_optab, "cmov$a6") OPTAB_D (cstore_optab, "cstore$a4") OPTAB_D (ctrap_optab, "ctrap$a4") -- 2.36.1