On Mon, 10 Jul 2023, juzhe.zh...@rivai.ai wrote: > From: Ju-Zhe Zhong <juzhe.zh...@rivai.ai> > > Hi, Richard and Richi. > > This patch is adding cond_len_* operations pattern for target support loop > control with length. > > These patterns will be used in these following case: > > 1. Integer division: > void > f (int32_t *restrict a, int32_t *restrict b, int32_t *restrict c, int n) > { > for (int i = 0; i < n; ++i) > { > a[i] = b[i] / c[i]; > } > } > > ARM SVE IR: > > ... > max_mask_36 = .WHILE_ULT (0, bnd.5_32, { 0, ... }); > > Loop: > ... > # loop_mask_29 = PHI <next_mask_37(4), max_mask_36(3)> > ... > vect__4.8_28 = .MASK_LOAD (_33, 32B, loop_mask_29); > ... > vect__6.11_25 = .MASK_LOAD (_20, 32B, loop_mask_29); > vect__8.12_24 = .COND_DIV (loop_mask_29, vect__4.8_28, vect__6.11_25, > vect__4.8_28); > ... > .MASK_STORE (_1, 32B, loop_mask_29, vect__8.12_24); > ... > next_mask_37 = .WHILE_ULT (_2, bnd.5_32, { 0, ... }); > ... > > For target like RVV who support loop control with length, we want to see IR > as follows: > > Loop: > ... > # loop_len_29 = SELECT_VL > ... > vect__4.8_28 = .LEN_MASK_LOAD (_33, 32B, loop_len_29); > ... > vect__6.11_25 = .LEN_MASK_LOAD (_20, 32B, loop_len_29); > vect__8.12_24 = .COND_LEN_DIV (dummp_mask, vect__4.8_28, vect__6.11_25, > vect__4.8_28, loop_len_29, bias); > ... > .LEN_MASK_STORE (_1, 32B, loop_len_29, vect__8.12_24); > ... > next_mask_37 = .WHILE_ULT (_2, bnd.5_32, { 0, ... }); > ... > > Notice here, we use dummp_mask = { -1, -1, .... , -1 } > > 2. Integer conditional division: > Similar case with (1) but with condtion: > void > f (int32_t *restrict a, int32_t *restrict b, int32_t *restrict c, int32_t > * cond, int n) > { > for (int i = 0; i < n; ++i) > { > if (cond[i]) > a[i] = b[i] / c[i]; > } > } > > ARM SVE: > ... > max_mask_76 = .WHILE_ULT (0, bnd.6_52, { 0, ... }); > > Loop: > ... > # loop_mask_55 = PHI <next_mask_77(5), max_mask_76(4)> > ... > vect__4.9_56 = .MASK_LOAD (_51, 32B, loop_mask_55); > mask__29.10_58 = vect__4.9_56 != { 0, ... }; > vec_mask_and_61 = loop_mask_55 & mask__29.10_58; > ... > vect__6.13_62 = .MASK_LOAD (_24, 32B, vec_mask_and_61); > ... > vect__8.16_66 = .MASK_LOAD (_1, 32B, vec_mask_and_61); > vect__10.17_68 = .COND_DIV (vec_mask_and_61, vect__6.13_62, vect__8.16_66, > vect__6.13_62); > ... > .MASK_STORE (_2, 32B, vec_mask_and_61, vect__10.17_68); > ... > next_mask_77 = .WHILE_ULT (_3, bnd.6_52, { 0, ... }); > > Here, ARM SVE use vec_mask_and_61 = loop_mask_55 & mask__29.10_58; to > gurantee the correct result. > > However, target with length control can not perform this elegant flow, for > RVV, we would expect: > > Loop: > ... > loop_len_55 = SELECT_VL > ... > mask__29.10_58 = vect__4.9_56 != { 0, ... }; > ... > vect__10.17_68 = .COND_LEN_DIV (mask__29.10_58, vect__6.13_62, > vect__8.16_66, vect__6.13_62, loop_len_55, bias); > ... > > Here we expect COND_LEN_DIV predicated by a real mask which is the outcome > of comparison: mask__29.10_58 = vect__4.9_56 != { 0, ... }; > and a real length which is produced by loop control : loop_len_55 = > SELECT_VL > > 3. conditional Floating-point operations (no -ffast-math): > > void > f (float *restrict a, float *restrict b, int32_t *restrict cond, int n) > { > for (int i = 0; i < n; ++i) > { > if (cond[i]) > a[i] = b[i] + a[i]; > } > } > > ARM SVE IR: > max_mask_70 = .WHILE_ULT (0, bnd.6_46, { 0, ... }); > > ... > # loop_mask_49 = PHI <next_mask_71(4), max_mask_70(3)> > ... > mask__27.10_52 = vect__4.9_50 != { 0, ... }; > vec_mask_and_55 = loop_mask_49 & mask__27.10_52; > ... > vect__9.17_62 = .COND_ADD (vec_mask_and_55, vect__6.13_56, vect__8.16_60, > vect__6.13_56); > ... > next_mask_71 = .WHILE_ULT (_22, bnd.6_46, { 0, ... }); > ... > > For RVV, we would expect IR: > > ... > loop_len_49 = SELECT_VL > ... > mask__27.10_52 = vect__4.9_50 != { 0, ... }; > ... > vect__9.17_62 = .COND_LEN_ADD (mask__27.10_52, vect__6.13_56, > vect__8.16_60, vect__6.13_56, loop_len_49, bias); > ... > > 4. Conditional un-ordered reduction: > > int32_t > f (int32_t *restrict a, > int32_t *restrict cond, int n) > { > int32_t result = 0; > for (int i = 0; i < n; ++i) > { > if (cond[i]) > result += a[i]; > } > return result; > } > > ARM SVE IR: > > Loop: > # vect_result_18.7_37 = PHI <vect__33.16_51(4), { 0, ... }(3)> > ... > # loop_mask_40 = PHI <next_mask_58(4), max_mask_57(3)> > ... > mask__17.11_43 = vect__4.10_41 != { 0, ... }; > vec_mask_and_46 = loop_mask_40 & mask__17.11_43; > ... > vect__33.16_51 = .COND_ADD (vec_mask_and_46, vect_result_18.7_37, > vect__7.14_47, vect_result_18.7_37); > ... > next_mask_58 = .WHILE_ULT (_15, bnd.6_36, { 0, ... }); > ... > > Epilogue: > _53 = .REDUC_PLUS (vect__33.16_51); [tail call] > > For RVV, we expect: > > Loop: > # vect_result_18.7_37 = PHI <vect__33.16_51(4), { 0, ... }(3)> > ... > loop_len_40 = SELECT_VL > ... > mask__17.11_43 = vect__4.10_41 != { 0, ... }; > ... > vect__33.16_51 = .COND_LEN_ADD (mask__17.11_43, vect_result_18.7_37, > vect__7.14_47, vect_result_18.7_37, loop_len_40, bias); > ... > next_mask_58 = .WHILE_ULT (_15, bnd.6_36, { 0, ... }); > ... > > Epilogue: > _53 = .REDUC_PLUS (vect__33.16_51); [tail call] > > I name these patterns as "cond_len_*" since I want the length operand > comes after mask operand and all other operands except length operand > same order as "cond_*" patterns. Such order will make life easier in the > following loop vectorizer support.
OK. Thanks, Richard. > gcc/ChangeLog: > > * doc/md.texi: Add COND_LEN_* operations for loop control with length. > * internal-fn.cc (cond_len_unary_direct): Ditto. > (cond_len_binary_direct): Ditto. > (cond_len_ternary_direct): Ditto. > (expand_cond_len_unary_optab_fn): Ditto. > (expand_cond_len_binary_optab_fn): Ditto. > (expand_cond_len_ternary_optab_fn): Ditto. > (direct_cond_len_unary_optab_supported_p): Ditto. > (direct_cond_len_binary_optab_supported_p): Ditto. > (direct_cond_len_ternary_optab_supported_p): Ditto. > * internal-fn.def (COND_LEN_ADD): Ditto. > (COND_LEN_SUB): Ditto. > (COND_LEN_MUL): Ditto. > (COND_LEN_DIV): Ditto. > (COND_LEN_MOD): Ditto. > (COND_LEN_RDIV): Ditto. > (COND_LEN_MIN): Ditto. > (COND_LEN_MAX): Ditto. > (COND_LEN_FMIN): Ditto. > (COND_LEN_FMAX): Ditto. > (COND_LEN_AND): Ditto. > (COND_LEN_IOR): Ditto. > (COND_LEN_XOR): Ditto. > (COND_LEN_SHL): Ditto. > (COND_LEN_SHR): Ditto. > (COND_LEN_FMA): Ditto. > (COND_LEN_FMS): Ditto. > (COND_LEN_FNMA): Ditto. > (COND_LEN_FNMS): Ditto. > (COND_LEN_NEG): Ditto. > * optabs.def (OPTAB_D): Ditto. > > --- > gcc/doc/md.texi | 80 +++++++++++++++++++++++++++++++++++++++++++++ > gcc/internal-fn.cc | 15 +++++++++ > gcc/internal-fn.def | 38 +++++++++++++++++++++ > gcc/optabs.def | 24 ++++++++++++++ > 4 files changed, 157 insertions(+) > > diff --git a/gcc/doc/md.texi b/gcc/doc/md.texi > index b30a824488b..cbcb992e5d7 100644 > --- a/gcc/doc/md.texi > +++ b/gcc/doc/md.texi > @@ -7234,6 +7234,86 @@ for (i = 0; i < GET_MODE_NUNITS (@var{m}); i++) > op0[i] = op1[i] ? fma (op2[i], op3[i], op4[i]) : op5[i]; > @end smallexample > > +@cindex @code{cond_len_add@var{mode}} instruction pattern > +@cindex @code{cond_len_sub@var{mode}} instruction pattern > +@cindex @code{cond_len_mul@var{mode}} instruction pattern > +@cindex @code{cond_len_div@var{mode}} instruction pattern > +@cindex @code{cond_len_udiv@var{mode}} instruction pattern > +@cindex @code{cond_len_mod@var{mode}} instruction pattern > +@cindex @code{cond_len_umod@var{mode}} instruction pattern > +@cindex @code{cond_len_and@var{mode}} instruction pattern > +@cindex @code{cond_len_ior@var{mode}} instruction pattern > +@cindex @code{cond_len_xor@var{mode}} instruction pattern > +@cindex @code{cond_len_smin@var{mode}} instruction pattern > +@cindex @code{cond_len_smax@var{mode}} instruction pattern > +@cindex @code{cond_len_umin@var{mode}} instruction pattern > +@cindex @code{cond_len_umax@var{mode}} instruction pattern > +@cindex @code{cond_len_fmin@var{mode}} instruction pattern > +@cindex @code{cond_len_fmax@var{mode}} instruction pattern > +@cindex @code{cond_len_ashl@var{mode}} instruction pattern > +@cindex @code{cond_len_ashr@var{mode}} instruction pattern > +@cindex @code{cond_len_lshr@var{mode}} instruction pattern > +@item @samp{cond_len_add@var{mode}} > +@itemx @samp{cond_len_sub@var{mode}} > +@itemx @samp{cond_len_mul@var{mode}} > +@itemx @samp{cond_len_div@var{mode}} > +@itemx @samp{cond_len_udiv@var{mode}} > +@itemx @samp{cond_len_mod@var{mode}} > +@itemx @samp{cond_len_umod@var{mode}} > +@itemx @samp{cond_len_and@var{mode}} > +@itemx @samp{cond_len_ior@var{mode}} > +@itemx @samp{cond_len_xor@var{mode}} > +@itemx @samp{cond_len_smin@var{mode}} > +@itemx @samp{cond_len_smax@var{mode}} > +@itemx @samp{cond_len_umin@var{mode}} > +@itemx @samp{cond_len_umax@var{mode}} > +@itemx @samp{cond_len_fmin@var{mode}} > +@itemx @samp{cond_len_fmax@var{mode}} > +@itemx @samp{cond_len_ashl@var{mode}} > +@itemx @samp{cond_len_ashr@var{mode}} > +@itemx @samp{cond_len_lshr@var{mode}} > +When operand 1 is true and element index < operand 5 + operand 6, perform an > operation on operands 2 and 3 and > +store the result in operand 0, otherwise store operand 4 in operand 0. > +The operation only works for the operands are vectors. > + > +@smallexample > +for (i = 0; i < ops[5] + ops[6]; i++) > + op0[i] = op1[i] ? op2[i] @var{op} op3[i] : op4[i]; > +@end smallexample > + > +where, for example, @var{op} is @code{+} for @samp{cond_len_add@var{mode}}. > + > +When defined for floating-point modes, the contents of @samp{op3[i]} > +are not interpreted if @samp{op1[i]} is false, just like they would not > +be in a normal C @samp{?:} condition. > + > +Operands 0, 2, 3 and 4 all have mode @var{m}. Operand 1 is a scalar > +integer if @var{m} is scalar, otherwise it has the mode returned by > +@code{TARGET_VECTORIZE_GET_MASK_MODE}. Operand 5 has whichever > +integer mode the target prefers. > + > +@samp{cond_@var{op}@var{mode}} generally corresponds to a conditional > +form of @samp{@var{op}@var{mode}3}. As an exception, the vector forms > +of shifts correspond to patterns like @code{vashl@var{mode}3} rather > +than patterns like @code{ashl@var{mode}3}. > + > +@cindex @code{cond_len_fma@var{mode}} instruction pattern > +@cindex @code{cond_len_fms@var{mode}} instruction pattern > +@cindex @code{cond_len_fnma@var{mode}} instruction pattern > +@cindex @code{cond_len_fnms@var{mode}} instruction pattern > +@item @samp{cond_len_fma@var{mode}} > +@itemx @samp{cond_len_fms@var{mode}} > +@itemx @samp{cond_len_fnma@var{mode}} > +@itemx @samp{cond_len_fnms@var{mode}} > +Like @samp{cond_len_add@var{m}}, except that the conditional operation > +takes 3 operands rather than two. For example, the vector form of > +@samp{cond_len_fma@var{mode}} is equivalent to: > + > +@smallexample > +for (i = 0; i < ops[6] + ops[7]; i++) > + op0[i] = op1[i] ? fma (op2[i], op3[i], op4[i]) : op5[i]; > +@end smallexample > + > @cindex @code{neg@var{mode}cc} instruction pattern > @item @samp{neg@var{mode}cc} > Similar to @samp{mov@var{mode}cc} but for conditional negation. > Conditionally > diff --git a/gcc/internal-fn.cc b/gcc/internal-fn.cc > index 278db7b1805..f9aaf66cf2a 100644 > --- a/gcc/internal-fn.cc > +++ b/gcc/internal-fn.cc > @@ -183,6 +183,9 @@ init_internal_fns () > #define cond_unary_direct { 1, 1, true } > #define cond_binary_direct { 1, 1, true } > #define cond_ternary_direct { 1, 1, true } > +#define cond_len_unary_direct { 1, 1, true } > +#define cond_len_binary_direct { 1, 1, true } > +#define cond_len_ternary_direct { 1, 1, true } > #define while_direct { 0, 2, false } > #define fold_extract_direct { 2, 2, false } > #define fold_left_direct { 1, 1, false } > @@ -3869,6 +3872,15 @@ expand_convert_optab_fn (internal_fn fn, gcall *stmt, > convert_optab optab, > #define expand_cond_ternary_optab_fn(FN, STMT, OPTAB) \ > expand_direct_optab_fn (FN, STMT, OPTAB, 5) > > +#define expand_cond_len_unary_optab_fn(FN, STMT, OPTAB) \ > + expand_direct_optab_fn (FN, STMT, OPTAB, 5) > + > +#define expand_cond_len_binary_optab_fn(FN, STMT, OPTAB) \ > + expand_direct_optab_fn (FN, STMT, OPTAB, 6) > + > +#define expand_cond_len_ternary_optab_fn(FN, STMT, OPTAB) \ > + expand_direct_optab_fn (FN, STMT, OPTAB, 7) > + > #define expand_fold_extract_optab_fn(FN, STMT, OPTAB) \ > expand_direct_optab_fn (FN, STMT, OPTAB, 3) > > @@ -3964,6 +3976,9 @@ multi_vector_optab_supported_p (convert_optab optab, > tree_pair types, > #define direct_cond_unary_optab_supported_p direct_optab_supported_p > #define direct_cond_binary_optab_supported_p direct_optab_supported_p > #define direct_cond_ternary_optab_supported_p direct_optab_supported_p > +#define direct_cond_len_unary_optab_supported_p direct_optab_supported_p > +#define direct_cond_len_binary_optab_supported_p direct_optab_supported_p > +#define direct_cond_len_ternary_optab_supported_p direct_optab_supported_p > #define direct_mask_load_optab_supported_p convert_optab_supported_p > #define direct_load_lanes_optab_supported_p multi_vector_optab_supported_p > #define direct_mask_load_lanes_optab_supported_p > multi_vector_optab_supported_p > diff --git a/gcc/internal-fn.def b/gcc/internal-fn.def > index 238b7ee0bc9..ea750a921ed 100644 > --- a/gcc/internal-fn.def > +++ b/gcc/internal-fn.def > @@ -72,6 +72,10 @@ along with GCC; see the file COPYING3. If not see > - fold_left: for scalar = FN (scalar, vector), keyed off the vector mode > - check_ptrs: used for check_{raw,war}_ptrs > > + - cond_len_unary: a conditional unary optab, such as cond_len_neg<mode> > + - cond_len_binary: a conditional binary optab, such as cond_len_add<mode> > + - cond_len_ternary: a conditional ternary optab, such as > cond_len_fma_rev<mode> > + > DEF_INTERNAL_SIGNED_OPTAB_FN defines an internal function that > maps to one of two optabs, depending on the signedness of an input. > SIGNED_OPTAB and UNSIGNED_OPTAB are the optabs for signed and > @@ -248,6 +252,40 @@ DEF_INTERNAL_OPTAB_FN (COND_FNMS, ECF_CONST, cond_fnms, > cond_ternary) > > DEF_INTERNAL_OPTAB_FN (COND_NEG, ECF_CONST, cond_neg, cond_unary) > > +DEF_INTERNAL_OPTAB_FN (COND_LEN_ADD, ECF_CONST, cond_len_add, > cond_len_binary) > +DEF_INTERNAL_OPTAB_FN (COND_LEN_SUB, ECF_CONST, cond_len_sub, > cond_len_binary) > +DEF_INTERNAL_OPTAB_FN (COND_LEN_MUL, ECF_CONST, cond_len_smul, > cond_len_binary) > +DEF_INTERNAL_SIGNED_OPTAB_FN (COND_LEN_DIV, ECF_CONST, first, cond_len_sdiv, > + cond_len_udiv, cond_len_binary) > +DEF_INTERNAL_SIGNED_OPTAB_FN (COND_LEN_MOD, ECF_CONST, first, cond_len_smod, > + cond_len_umod, cond_len_binary) > +DEF_INTERNAL_OPTAB_FN (COND_LEN_RDIV, ECF_CONST, cond_len_sdiv, > cond_len_binary) > +DEF_INTERNAL_SIGNED_OPTAB_FN (COND_LEN_MIN, ECF_CONST, first, cond_len_smin, > + cond_len_umin, cond_len_binary) > +DEF_INTERNAL_SIGNED_OPTAB_FN (COND_LEN_MAX, ECF_CONST, first, cond_len_smax, > + cond_len_umax, cond_len_binary) > +DEF_INTERNAL_OPTAB_FN (COND_LEN_FMIN, ECF_CONST, cond_len_fmin, > cond_len_binary) > +DEF_INTERNAL_OPTAB_FN (COND_LEN_FMAX, ECF_CONST, cond_len_fmax, > cond_len_binary) > +DEF_INTERNAL_OPTAB_FN (COND_LEN_AND, ECF_CONST | ECF_NOTHROW, cond_len_and, > + cond_len_binary) > +DEF_INTERNAL_OPTAB_FN (COND_LEN_IOR, ECF_CONST | ECF_NOTHROW, cond_len_ior, > + cond_len_binary) > +DEF_INTERNAL_OPTAB_FN (COND_LEN_XOR, ECF_CONST | ECF_NOTHROW, cond_len_xor, > + cond_len_binary) > +DEF_INTERNAL_OPTAB_FN (COND_LEN_SHL, ECF_CONST | ECF_NOTHROW, cond_len_ashl, > + cond_len_binary) > +DEF_INTERNAL_SIGNED_OPTAB_FN (COND_LEN_SHR, ECF_CONST | ECF_NOTHROW, first, > + cond_len_ashr, cond_len_lshr, cond_len_binary) > + > +DEF_INTERNAL_OPTAB_FN (COND_LEN_FMA, ECF_CONST, cond_len_fma, > cond_len_ternary) > +DEF_INTERNAL_OPTAB_FN (COND_LEN_FMS, ECF_CONST, cond_len_fms, > cond_len_ternary) > +DEF_INTERNAL_OPTAB_FN (COND_LEN_FNMA, ECF_CONST, cond_len_fnma, > + cond_len_ternary) > +DEF_INTERNAL_OPTAB_FN (COND_LEN_FNMS, ECF_CONST, cond_len_fnms, > + cond_len_ternary) > + > +DEF_INTERNAL_OPTAB_FN (COND_LEN_NEG, ECF_CONST, cond_len_neg, cond_len_unary) > + > DEF_INTERNAL_OPTAB_FN (RSQRT, ECF_CONST, rsqrt, unary) > > DEF_INTERNAL_OPTAB_FN (REDUC_PLUS, ECF_CONST | ECF_NOTHROW, > diff --git a/gcc/optabs.def b/gcc/optabs.def > index 73c9a0c760f..3dae228fba6 100644 > --- a/gcc/optabs.def > +++ b/gcc/optabs.def > @@ -254,6 +254,30 @@ OPTAB_D (cond_fms_optab, "cond_fms$a") > OPTAB_D (cond_fnma_optab, "cond_fnma$a") > OPTAB_D (cond_fnms_optab, "cond_fnms$a") > OPTAB_D (cond_neg_optab, "cond_neg$a") > +OPTAB_D (cond_len_add_optab, "cond_len_add$a") > +OPTAB_D (cond_len_sub_optab, "cond_len_sub$a") > +OPTAB_D (cond_len_smul_optab, "cond_len_mul$a") > +OPTAB_D (cond_len_sdiv_optab, "cond_len_div$a") > +OPTAB_D (cond_len_smod_optab, "cond_len_mod$a") > +OPTAB_D (cond_len_udiv_optab, "cond_len_udiv$a") > +OPTAB_D (cond_len_umod_optab, "cond_len_umod$a") > +OPTAB_D (cond_len_and_optab, "cond_len_and$a") > +OPTAB_D (cond_len_ior_optab, "cond_len_ior$a") > +OPTAB_D (cond_len_xor_optab, "cond_len_xor$a") > +OPTAB_D (cond_len_ashl_optab, "cond_len_ashl$a") > +OPTAB_D (cond_len_ashr_optab, "cond_len_ashr$a") > +OPTAB_D (cond_len_lshr_optab, "cond_len_lshr$a") > +OPTAB_D (cond_len_smin_optab, "cond_len_smin$a") > +OPTAB_D (cond_len_smax_optab, "cond_len_smax$a") > +OPTAB_D (cond_len_umin_optab, "cond_len_umin$a") > +OPTAB_D (cond_len_umax_optab, "cond_len_umax$a") > +OPTAB_D (cond_len_fmin_optab, "cond_len_fmin$a") > +OPTAB_D (cond_len_fmax_optab, "cond_len_fmax$a") > +OPTAB_D (cond_len_fma_optab, "cond_len_fma$a") > +OPTAB_D (cond_len_fms_optab, "cond_len_fms$a") > +OPTAB_D (cond_len_fnma_optab, "cond_len_fnma$a") > +OPTAB_D (cond_len_fnms_optab, "cond_len_fnms$a") > +OPTAB_D (cond_len_neg_optab, "cond_len_neg$a") > OPTAB_D (cmov_optab, "cmov$a6") > OPTAB_D (cstore_optab, "cstore$a4") > OPTAB_D (ctrap_optab, "ctrap$a4") > -- Richard Biener <rguent...@suse.de> SUSE Software Solutions Germany GmbH, Frankenstrasse 146, 90461 Nuernberg, Germany; GF: Ivo Totev, Andrew Myers, Andrew McDonald, Boudien Moerman; HRB 36809 (AG Nuernberg)