v3: Modify warning message in riscv.cc v2: Rebase Accroding to the intrinsic doc, the 'Zvfbfmin' and 'Zvfbfwma' intrinsic functions are added by this patch.
Signed-off-by: Feng Wang <wangf...@eswincomputing.com> gcc/ChangeLog: * config/riscv/riscv-vector-builtins-bases.cc (class vfncvtbf16_f): Add 'Zvfbfmin' intrinsic in bases. (class vfwcvtbf16_f): Ditto. (class vfwmaccbf16): Add 'Zvfbfwma' intrinsic in bases. (BASE): Add BASE macro for 'Zvfbfmin' and 'Zvfbfwma'. * config/riscv/riscv-vector-builtins-bases.h: Add declaration for 'Zvfbfmin' and 'Zvfbfwma'. * config/riscv/riscv-vector-builtins-functions.def (REQUIRED_EXTENSIONS): Add builtins def for 'Zvfbfmin' and 'Zvfbfwma'. (vfncvtbf16_f): Ditto. (vfncvtbf16_f_frm): Ditto. (vfwcvtbf16_f): Ditto. (vfwmaccbf16): Ditto. (vfwmaccbf16_frm): Ditto. * config/riscv/riscv-vector-builtins-shapes.cc (supports_vectype_p): Add vector intrinsic build judgment for BFloat16. (build_all): Ditto. (BASE_NAME_MAX_LEN): Adjust max length. * config/riscv/riscv-vector-builtins-types.def (DEF_RVV_F32_OPS): Add new operand type for BFloat16. (vfloat32mf2_t): Ditto. (vfloat32m1_t): Ditto. (vfloat32m2_t): Ditto. (vfloat32m4_t): Ditto. (vfloat32m8_t): Ditto. * config/riscv/riscv-vector-builtins.cc (DEF_RVV_F32_OPS): Ditto. (validate_instance_type_required_extensions): Add required_ext checking for 'Zvfbfmin' and 'Zvfbfwma'. * config/riscv/riscv-vector-builtins.h (enum required_ext): Add required_ext declaration for 'Zvfbfmin' and 'Zvfbfwma'. (reqired_ext_to_isa_name): Ditto. (required_extensions_specified): Ditto. (struct function_group_info): Add match case for 'Zvfbfmin' and 'Zvfbfwma'. * config/riscv/riscv.cc (riscv_validate_vector_type): Add required_ext checking for 'Zvfbfmin' and 'Zvfbfwma'. --- .../riscv/riscv-vector-builtins-bases.cc | 69 +++++++++++++++++++ .../riscv/riscv-vector-builtins-bases.h | 7 ++ .../riscv/riscv-vector-builtins-functions.def | 15 ++++ .../riscv/riscv-vector-builtins-shapes.cc | 31 ++++++++- .../riscv/riscv-vector-builtins-types.def | 13 ++++ gcc/config/riscv/riscv-vector-builtins.cc | 67 ++++++++++++++++++ gcc/config/riscv/riscv-vector-builtins.h | 34 ++++++--- gcc/config/riscv/riscv.cc | 13 ++-- 8 files changed, 232 insertions(+), 17 deletions(-) diff --git a/gcc/config/riscv/riscv-vector-builtins-bases.cc b/gcc/config/riscv/riscv-vector-builtins-bases.cc index 6483faba39c..193392fbcc2 100644 --- a/gcc/config/riscv/riscv-vector-builtins-bases.cc +++ b/gcc/config/riscv/riscv-vector-builtins-bases.cc @@ -2417,6 +2417,60 @@ public: } }; +/* Implements vfncvtbf16_f. */ +template <enum frm_op_type FRM_OP = NO_FRM> +class vfncvtbf16_f : public function_base +{ +public: + bool has_rounding_mode_operand_p () const override + { + return FRM_OP == HAS_FRM; + } + + bool may_require_frm_p () const override { return true; } + + rtx expand (function_expander &e) const override + { + return e.use_exact_insn (code_for_pred_trunc_to_bf16 (e.vector_mode ())); + } +}; + +/* Implements vfwcvtbf16_f. */ +class vfwcvtbf16_f : public function_base +{ +public: + rtx expand (function_expander &e) const override + { + return e.use_exact_insn (code_for_pred_extend_bf16_to (e.vector_mode ())); + } +}; + +/* Implements vfwmaccbf16. */ +template <enum frm_op_type FRM_OP = NO_FRM> +class vfwmaccbf16 : public function_base +{ +public: + bool has_rounding_mode_operand_p () const override + { + return FRM_OP == HAS_FRM; + } + + bool may_require_frm_p () const override { return true; } + + bool has_merge_operand_p () const override { return false; } + + rtx expand (function_expander &e) const override + { + if (e.op_info->op == OP_TYPE_vf) + return e.use_widen_ternop_insn ( + code_for_pred_widen_bf16_mul_scalar (e.vector_mode ())); + if (e.op_info->op == OP_TYPE_vv) + return e.use_widen_ternop_insn ( + code_for_pred_widen_bf16_mul (e.vector_mode ())); + gcc_unreachable (); + } +}; + static CONSTEXPR const vsetvl<false> vsetvl_obj; static CONSTEXPR const vsetvl<true> vsetvlmax_obj; static CONSTEXPR const loadstore<false, LST_UNIT_STRIDE, false> vle_obj; @@ -2734,6 +2788,14 @@ static CONSTEXPR const crypto_vv<UNSPEC_VSM4R> vsm4r_obj; static CONSTEXPR const vsm3me vsm3me_obj; static CONSTEXPR const vaeskf2_vsm3c<UNSPEC_VSM3C> vsm3c_obj; +/* Zvfbfmin */ +static CONSTEXPR const vfncvtbf16_f<NO_FRM> vfncvtbf16_f_obj; +static CONSTEXPR const vfncvtbf16_f<HAS_FRM> vfncvtbf16_f_frm_obj; +static CONSTEXPR const vfwcvtbf16_f vfwcvtbf16_f_obj; +/* Zvfbfwma; */ +static CONSTEXPR const vfwmaccbf16<NO_FRM> vfwmaccbf16_obj; +static CONSTEXPR const vfwmaccbf16<HAS_FRM> vfwmaccbf16_frm_obj; + /* Declare the function base NAME, pointing it to an instance of class <NAME>_obj. */ #define BASE(NAME) \ @@ -3054,4 +3116,11 @@ BASE (vsm4k) BASE (vsm4r) BASE (vsm3me) BASE (vsm3c) +/* Zvfbfmin */ +BASE (vfncvtbf16_f) +BASE (vfncvtbf16_f_frm) +BASE (vfwcvtbf16_f) +/* Zvfbfwma */ +BASE (vfwmaccbf16) +BASE (vfwmaccbf16_frm) } // end namespace riscv_vector diff --git a/gcc/config/riscv/riscv-vector-builtins-bases.h b/gcc/config/riscv/riscv-vector-builtins-bases.h index 1f2c94d3541..af1cb1af50f 100644 --- a/gcc/config/riscv/riscv-vector-builtins-bases.h +++ b/gcc/config/riscv/riscv-vector-builtins-bases.h @@ -339,6 +339,13 @@ extern const function_base *const vsm4k; extern const function_base *const vsm4r; extern const function_base *const vsm3me; extern const function_base *const vsm3c; +/* Zvfbfmin*/ +extern const function_base *const vfncvtbf16_f; +extern const function_base *const vfncvtbf16_f_frm; +extern const function_base *const vfwcvtbf16_f; +/* Zvfbfwma */ +extern const function_base *const vfwmaccbf16; +extern const function_base *const vfwmaccbf16_frm; } } // end namespace riscv_vector diff --git a/gcc/config/riscv/riscv-vector-builtins-functions.def b/gcc/config/riscv/riscv-vector-builtins-functions.def index f742c98be8a..b69cf3cae29 100644 --- a/gcc/config/riscv/riscv-vector-builtins-functions.def +++ b/gcc/config/riscv/riscv-vector-builtins-functions.def @@ -747,4 +747,19 @@ DEF_RVV_FUNCTION (vsm4r, crypto_vv, none_tu_preds, u_vvs_crypto_sew32_lmul_x16_o DEF_RVV_FUNCTION (vsm3me, no_mask_policy, none_tu_preds, u_vvv_crypto_sew32_ops) DEF_RVV_FUNCTION (vsm3c, crypto_vi, none_tu_preds, u_vvv_size_crypto_sew32_ops) #undef REQUIRED_EXTENSIONS + +//Zvfbfmin +#define REQUIRED_EXTENSIONS ZVFBFMIN_EXT +DEF_RVV_FUNCTION (vfncvtbf16_f, narrow_alu, full_preds, f32_to_bf16_f_w_ops) +DEF_RVV_FUNCTION (vfncvtbf16_f_frm, narrow_alu_frm, full_preds, f32_to_bf16_f_w_ops) +DEF_RVV_FUNCTION (vfwcvtbf16_f, alu, full_preds, bf16_to_f32_f_v_ops) +#undef REQUIRED_EXTENSIONS + +/* Zvfbfwma */ +#define REQUIRED_EXTENSIONS ZVFBFWMA_EXT +DEF_RVV_FUNCTION (vfwmaccbf16, alu, full_preds, f32_wwvv_ops) +DEF_RVV_FUNCTION (vfwmaccbf16, alu, full_preds, f32_wwfv_ops) +DEF_RVV_FUNCTION (vfwmaccbf16_frm, alu_frm, full_preds, f32_wwvv_ops) +DEF_RVV_FUNCTION (vfwmaccbf16_frm, alu_frm, full_preds, f32_wwfv_ops) +#undef REQUIRED_EXTENSIONS #undef DEF_RVV_FUNCTION diff --git a/gcc/config/riscv/riscv-vector-builtins-shapes.cc b/gcc/config/riscv/riscv-vector-builtins-shapes.cc index a3ffa92e967..33395414aae 100644 --- a/gcc/config/riscv/riscv-vector-builtins-shapes.cc +++ b/gcc/config/riscv/riscv-vector-builtins-shapes.cc @@ -78,6 +78,30 @@ build_one (function_builder &b, const function_group_info &group, argument_types, group.required_extensions); } +/* Determine whether the intrinsic supports the currently + processed vector type */ +static bool +supports_vectype_p (const function_group_info &group, unsigned int vec_type_idx) +{ + int index = group.ops_infos.types[vec_type_idx].index; + if (index < VECTOR_TYPE_vbfloat16mf4_t || index > VECTOR_TYPE_vbfloat16m8_t) + return true; + /* Only judge for bf16 vector type */ + if (*group.shape == shapes::loadstore + || *group.shape == shapes::indexed_loadstore + || *group.shape == shapes::vundefined + || *group.shape == shapes::misc + || *group.shape == shapes::vset + || *group.shape == shapes::vget + || *group.shape == shapes::vcreate + || *group.shape == shapes::fault_load + || *group.shape == shapes::seg_loadstore + || *group.shape == shapes::seg_indexed_loadstore + || *group.shape == shapes::seg_fault_load) + return true; + return false; +} + /* Add a function instance for every operand && predicate && args combination in GROUP. Take the function base name from GROUP && operand suffix from operand_suffixes && mode suffix from type_suffixes && predication @@ -91,7 +115,10 @@ build_all (function_builder &b, const function_group_info &group) for (unsigned int vec_type_idx = 0; group.ops_infos.types[vec_type_idx].index != NUM_VECTOR_TYPES; ++vec_type_idx) - build_one (b, group, pred_idx, vec_type_idx); + { + if (supports_vectype_p (group, vec_type_idx)) + build_one (b, group, pred_idx, vec_type_idx); + } } /* Declare the function shape NAME, pointing it to an instance @@ -100,7 +127,7 @@ build_all (function_builder &b, const function_group_info &group) static CONSTEXPR const DEF##_def VAR##_obj; \ namespace shapes { const function_shape *const VAR = &VAR##_obj; } -#define BASE_NAME_MAX_LEN 16 +#define BASE_NAME_MAX_LEN 17 /* Base class for build. */ struct build_base : public function_shape diff --git a/gcc/config/riscv/riscv-vector-builtins-types.def b/gcc/config/riscv/riscv-vector-builtins-types.def index e7fca4cca79..e85ca27bcf5 100644 --- a/gcc/config/riscv/riscv-vector-builtins-types.def +++ b/gcc/config/riscv/riscv-vector-builtins-types.def @@ -133,6 +133,12 @@ along with GCC; see the file COPYING3. If not see #define DEF_RVV_WCONVERT_F_OPS(TYPE, REQUIRE) #endif +/* Use "DEF_RVV_F32_OPS" macro include all float32 vector type that will be + used in the bfloat16 intrinsic */ +#ifndef DEF_RVV_F32_OPS +#define DEF_RVV_F32_OPS(TYPE, REQUIRE) +#endif + /* Use "DEF_RVV_WI_OPS" macro include all signed integer can be widened which will be iterated and registered as intrinsic functions. */ #ifndef DEF_RVV_WI_OPS @@ -615,6 +621,12 @@ DEF_RVV_WCONVERT_F_OPS (vfloat64m2_t, RVV_REQUIRE_ELEN_FP_64) DEF_RVV_WCONVERT_F_OPS (vfloat64m4_t, RVV_REQUIRE_ELEN_FP_64) DEF_RVV_WCONVERT_F_OPS (vfloat64m8_t, RVV_REQUIRE_ELEN_FP_64) +DEF_RVV_F32_OPS (vfloat32mf2_t, RVV_REQUIRE_ELEN_FP_32 | RVV_REQUIRE_MIN_VLEN_64) +DEF_RVV_F32_OPS (vfloat32m1_t, RVV_REQUIRE_ELEN_FP_32) +DEF_RVV_F32_OPS (vfloat32m2_t, RVV_REQUIRE_ELEN_FP_32) +DEF_RVV_F32_OPS (vfloat32m4_t, RVV_REQUIRE_ELEN_FP_32) +DEF_RVV_F32_OPS (vfloat32m8_t, RVV_REQUIRE_ELEN_FP_32) + DEF_RVV_WI_OPS (vint8mf8_t, RVV_REQUIRE_MIN_VLEN_64) DEF_RVV_WI_OPS (vint8mf4_t, 0) DEF_RVV_WI_OPS (vint8mf2_t, 0) @@ -1481,3 +1493,4 @@ DEF_RVV_CRYPTO_SEW64_OPS (vuint64m8_t, RVV_REQUIRE_ELEN_64) #undef DEF_RVV_TUPLE_OPS #undef DEF_RVV_CRYPTO_SEW32_OPS #undef DEF_RVV_CRYPTO_SEW64_OPS +#undef DEF_RVV_F32_OPS diff --git a/gcc/config/riscv/riscv-vector-builtins.cc b/gcc/config/riscv/riscv-vector-builtins.cc index 720436dfbc9..9b375127bbb 100644 --- a/gcc/config/riscv/riscv-vector-builtins.cc +++ b/gcc/config/riscv/riscv-vector-builtins.cc @@ -242,6 +242,12 @@ static const rvv_type_info wconvert_f_ops[] = { #include "riscv-vector-builtins-types.def" {NUM_VECTOR_TYPES, 0}}; +/* A list of all floating-point will be registered for intrinsic functions. */ +static const rvv_type_info f32_ops[] = { +#define DEF_RVV_F32_OPS(TYPE, REQUIRE) {VECTOR_TYPE_##TYPE, REQUIRE}, +#include "riscv-vector-builtins-types.def" + {NUM_VECTOR_TYPES, 0}}; + /* A list of all integer will be registered for intrinsic functions. */ static const rvv_type_info iu_ops[] = { #define DEF_RVV_I_OPS(TYPE, REQUIRE) {VECTOR_TYPE_##TYPE, REQUIRE}, @@ -757,6 +763,25 @@ static CONSTEXPR const rvv_arg_type_info trunc_f_v_args[] static CONSTEXPR const rvv_arg_type_info w_v_args[] = {rvv_arg_type_info (RVV_BASE_double_trunc_vector), rvv_arg_type_info_end}; +/* A list of args for vector_type func (vector_type) function. */ +static CONSTEXPR const rvv_arg_type_info bf_w_v_args[] + = {rvv_arg_type_info (RVV_BASE_double_trunc_bfloat_vector), + rvv_arg_type_info_end}; + +/* A list of args for vector_type func (vector_type) function. */ +static CONSTEXPR const rvv_arg_type_info bf_wwvv_args[] + = {rvv_arg_type_info (RVV_BASE_vector), + rvv_arg_type_info (RVV_BASE_double_trunc_bfloat_vector), + rvv_arg_type_info (RVV_BASE_double_trunc_bfloat_vector), + rvv_arg_type_info_end}; + +/* A list of args for vector_type func (vector_type) function. */ +static CONSTEXPR const rvv_arg_type_info bf_wwxv_args[] + = {rvv_arg_type_info (RVV_BASE_vector), + rvv_arg_type_info (RVV_BASE_double_trunc_bfloat_scalar), + rvv_arg_type_info (RVV_BASE_double_trunc_bfloat_vector), + rvv_arg_type_info_end}; + /* A list of args for vector_type func (vector_type) function. */ static CONSTEXPR const rvv_arg_type_info m_args[] = {rvv_arg_type_info (RVV_BASE_mask), rvv_arg_type_info_end}; @@ -1749,6 +1774,38 @@ static CONSTEXPR const rvv_op_info f_to_nf_f_w_ops rvv_arg_type_info (RVV_BASE_double_trunc_float_vector), /* Return type */ v_args /* Args */}; +/* A static operand information for vector_type func (vector_type) + * function registration. */ +static CONSTEXPR const rvv_op_info f32_to_bf16_f_w_ops + = {f32_ops, /* Types */ + OP_TYPE_f_w, /* Suffix */ + rvv_arg_type_info (RVV_BASE_double_trunc_bfloat_vector), /* Return type */ + v_args /* Args */}; + +/* A static operand information for vector_type func (vector_type) + * function registration. */ +static CONSTEXPR const rvv_op_info bf16_to_f32_f_v_ops + = {f32_ops, /* Types */ + OP_TYPE_f_v, /* Suffix */ + rvv_arg_type_info (RVV_BASE_vector), /* Return type */ + bf_w_v_args /* Args */}; + +/* A static operand information for vector_type func (vector_type, double demote + * type, double demote type) function registration. */ +static CONSTEXPR const rvv_op_info f32_wwvv_ops + = {f32_ops, /* Types */ + OP_TYPE_vv, /* Suffix */ + rvv_arg_type_info (RVV_BASE_vector), /* Return type */ + bf_wwvv_args /* Args */}; + +/* A static operand information for vector_type func (vector_type, double demote + * scalar_type, double demote type) function registration. */ +static CONSTEXPR const rvv_op_info f32_wwfv_ops + = {f32_ops, /* Types */ + OP_TYPE_vf, /* Suffix */ + rvv_arg_type_info (RVV_BASE_vector), /* Return type */ + bf_wwxv_args /* Args */}; + /* A static operand information for vector_type func (vector_type) * function registration. */ static CONSTEXPR const rvv_op_info all_v_ops @@ -4643,6 +4700,16 @@ validate_instance_type_required_extensions (const rvv_type_info type, { uint64_t exts = type.required_extensions; + if ((exts & RVV_REQUIRE_ELEN_BF_16) + && !TARGET_VECTOR_ELEN_BF_16_P (riscv_vector_elen_flags)) + { + error_at (EXPR_LOCATION (exp), + "built-in function %qE requires the " + "zvfbfmin or zvfbfwma ISA extension", + exp); + return false; + } + if ((exts & RVV_REQUIRE_ELEN_FP_16) && !TARGET_VECTOR_ELEN_FP_16_P (riscv_vector_elen_flags)) { diff --git a/gcc/config/riscv/riscv-vector-builtins.h b/gcc/config/riscv/riscv-vector-builtins.h index 56dbe2cf0e2..ef4148380c2 100644 --- a/gcc/config/riscv/riscv-vector-builtins.h +++ b/gcc/config/riscv/riscv-vector-builtins.h @@ -114,17 +114,19 @@ static const unsigned int CP_WRITE_CSR = 1U << 5; /* Enumerates the required extensions. */ enum required_ext { - VECTOR_EXT, /* Vector extension */ - ZVBB_EXT, /* Cryto vector Zvbb sub-ext */ - ZVBB_OR_ZVKB_EXT, /* Cryto vector Zvbb or zvkb sub-ext */ - ZVBC_EXT, /* Crypto vector Zvbc sub-ext */ - ZVKG_EXT, /* Crypto vector Zvkg sub-ext */ - ZVKNED_EXT, /* Crypto vector Zvkned sub-ext */ + VECTOR_EXT, /* Vector extension */ + ZVBB_EXT, /* Cryto vector Zvbb sub-ext */ + ZVBB_OR_ZVKB_EXT, /* Cryto vector Zvbb or zvkb sub-ext */ + ZVBC_EXT, /* Crypto vector Zvbc sub-ext */ + ZVKG_EXT, /* Crypto vector Zvkg sub-ext */ + ZVKNED_EXT, /* Crypto vector Zvkned sub-ext */ ZVKNHA_OR_ZVKNHB_EXT, /* Crypto vector Zvknh[ab] sub-ext */ - ZVKNHB_EXT, /* Crypto vector Zvknhb sub-ext */ - ZVKSED_EXT, /* Crypto vector Zvksed sub-ext */ - ZVKSH_EXT, /* Crypto vector Zvksh sub-ext */ - XTHEADVECTOR_EXT, /* XTheadVector extension */ + ZVKNHB_EXT, /* Crypto vector Zvknhb sub-ext */ + ZVKSED_EXT, /* Crypto vector Zvksed sub-ext */ + ZVKSH_EXT, /* Crypto vector Zvksh sub-ext */ + XTHEADVECTOR_EXT, /* XTheadVector extension */ + ZVFBFMIN_EXT, /* Zvfbfmin externsion */ + ZVFBFWMA_EXT, /* Zvfbfwma extension */ /* Please update below to isa_name func when add or remove enum type(s). */ }; @@ -154,6 +156,10 @@ static inline const char * reqired_ext_to_isa_name (enum required_ext required) return "zvksh"; case XTHEADVECTOR_EXT: return "xthreadvector"; + case ZVFBFMIN_EXT: + return "zvfbfmin"; + case ZVFBFWMA_EXT: + return "zvfbfwma"; default: gcc_unreachable (); } @@ -187,6 +193,10 @@ static inline bool required_extensions_specified (enum required_ext required) return TARGET_ZVKSH; case XTHEADVECTOR_EXT: return TARGET_XTHEADVECTOR; + case ZVFBFMIN_EXT: + return TARGET_ZVFBFMIN; + case ZVFBFWMA_EXT: + return TARGET_ZVFBFWMA; default: gcc_unreachable (); } @@ -323,6 +333,10 @@ struct function_group_info return TARGET_ZVKSH; case XTHEADVECTOR_EXT: return TARGET_XTHEADVECTOR; + case ZVFBFMIN_EXT: + return TARGET_ZVFBFMIN; + case ZVFBFWMA_EXT: + return TARGET_ZVFBFWMA; default: gcc_unreachable (); } diff --git a/gcc/config/riscv/riscv.cc b/gcc/config/riscv/riscv.cc index 38ed773c222..87e9ca817db 100644 --- a/gcc/config/riscv/riscv.cc +++ b/gcc/config/riscv/riscv.cc @@ -6014,11 +6014,14 @@ riscv_validate_vector_type (const_tree type, const char *hint) bool float_type_p = riscv_vector_float_type_p (type); if (float_type_p && element_bitsize == 16 - && !TARGET_VECTOR_ELEN_FP_16_P (riscv_vector_elen_flags)) - { - error_at (input_location, - "%s %qT requires the zvfhmin or zvfh ISA extension", - hint, type); + && (!TARGET_VECTOR_ELEN_FP_16_P (riscv_vector_elen_flags) + && !TARGET_VECTOR_ELEN_BF_16_P (riscv_vector_elen_flags))) + { + const char *name = IDENTIFIER_POINTER (DECL_NAME (TYPE_NAME (type))); + if (strstr (name, "vfloat")) + error_at (input_location, + "%s %qT requires the zvfhmin or zvfh ISA extension", + hint, type); return; } -- 2.17.1