More information: For PRED_TYPE_tumu, it's easy to analyze, just need to count how many arguments in the arglist. If arglist has 5 arguments (mask, merge, op1, op2, len) Then it must be TUMU.
What I mean is that we should be able to quickly to compute the arguments of the construction of the function_instance. Then we can get the non-overloaeded function. juzhe.zh...@rivai.ai From: juzhe.zh...@rivai.ai Date: 2023-09-15 10:02 To: pan2.li; gcc-patches CC: pan2.li; yanzhang.wang; kito.cheng Subject: Re: [PATCH v3] RISC-V: Implement RESOLVE_OVERLOADED_BUILTIN for RVV intrinsic Sorry for comment again. I am not happy with current get_non_overloaeded_instance function. I think the searching approach is very in-effective: +function_instance * +function_base::get_non_overloaded_instance (unsigned int code, + vec<tree, va_gc> &arglist) const +{ + unsigned int code_limit = vec_safe_length (registered_functions); + + for (unsigned fun_code = code; fun_code < code_limit; fun_code++) + { + registered_function *rfun = (*registered_functions)[fun_code]; + function_instance instance = rfun->instance; + + if (rfun->overloaded_p) + continue; + + unsigned k; + const rvv_arg_type_info *args = instance.op_info->args; + + for (k = 0; args[k].base_type != NUM_BASE_TYPES; k++) + { + if (k >= arglist.length ()) + break; + + if (TYPE_MODE (instance.get_arg_type (k)) + != TYPE_MODE (TREE_TYPE (arglist[k]))) + break; + } + + if (args[k].base_type == NUM_BASE_TYPES) + return &rfun->instance; + } + + return NULL; +} Instead, I think we should build up a table which map non-overloaded function according to the arguments so that we could get the "instance" effectively. E.g. For vint8mf8_t tumu vadd intrinsic the instance is like this: function_instance ("vadd", bases::vadd, shapes::alu, iu_ops[VECTOR_TYPE_vuint8mf8_t], PRED_TYPE_tumu, &iu_vvv_ops); Since the get_nonoverloaed_instance is already the function of the class BASE. So, The first 3 arguments "vadd", bases::vadd, shapes::alu should already known since it is a known function_base. The last 3 arguments may need some elegant analysis or map table to quickly grep. So, I think we should consider this framework seriously. juzhe.zh...@rivai.ai From: pan2.li Date: 2023-09-12 16:46 To: gcc-patches CC: juzhe.zhong; pan2.li; yanzhang.wang; kito.cheng Subject: [PATCH v3] RISC-V: Implement RESOLVE_OVERLOADED_BUILTIN for RVV intrinsic From: Pan Li <pan2...@intel.com> Update in v3: * Rewrite comment for overloaded function add. * Move get_non_overloaded_instance to function_base. Update in v2: * Add get_non_overloaded_instance for function instance. * Fix overload check for policy function. * Enrich the test cases check. Original log: This patch would like add the framework to support the RVV overloaded intrinsic API in riscv-xxx-xxx-gcc, like riscv-xxx-xxx-g++ did. However, it almost leverage the hook TARGET_RESOLVE_OVERLOADED_BUILTIN with below steps. * Register overloaded functions. * Add function_resolver for overloaded function resolving. * Add resolve API for function shape with default implementation. * Implement HOOK for navigating the overloaded API to non-overloaded API. We validated this framework by the vmv_v intrinsic API(s), and we will add more intrins API support in the underlying patches. gcc/ChangeLog: * config/riscv/riscv-c.cc (riscv_resolve_overloaded_builtin): New function for the hook. (riscv_register_pragmas): Register the hook * config/riscv/riscv-protos.h (resolve_overloaded_builtin): New decl. * config/riscv/riscv-vector-builtins-shapes.cc (build_one): Register overloaded function. (struct overloaded_base): New struct for overloaded shape. (struct non_overloaded_base): New struct for non overloaded shape. (struct move_def): Inherit overloaded shape. * config/riscv/riscv-vector-builtins.cc (function_base::get_non_overloaded_instance): New API impl. (function_builder::add_function): Add overloaded arg. (function_resolver::function_resolver): New constructor. (function_builder::add_overloaded_function): New API impl. (function_resolver::resolve): Ditto. (function_resolver::lookup): Ditto. (function_resolver::get_sub_code): Ditto. (resolve_overloaded_builtin): New function impl. * config/riscv/riscv-vector-builtins.h: (class function_resolver): New class. gcc/testsuite/ChangeLog: * gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c: New test. * gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c: New test. * gcc.target/riscv/rvv/base/overloaded_vmv_v.h: New test. Signed-off-by: Pan Li <pan2...@intel.com> --- gcc/config/riscv/riscv-c.cc | 36 ++++ gcc/config/riscv/riscv-protos.h | 1 + .../riscv/riscv-vector-builtins-shapes.cc | 20 ++- gcc/config/riscv/riscv-vector-builtins.cc | 155 +++++++++++++++++- gcc/config/riscv/riscv-vector-builtins.h | 36 +++- .../riscv/rvv/base/overloaded_rv32_vmv_v.c | 8 + .../riscv/rvv/base/overloaded_rv64_vmv_v.c | 8 + .../riscv/rvv/base/overloaded_vmv_v.h | 27 +++ 8 files changed, 288 insertions(+), 3 deletions(-) create mode 100644 gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c create mode 100644 gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c create mode 100644 gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_vmv_v.h diff --git a/gcc/config/riscv/riscv-c.cc b/gcc/config/riscv/riscv-c.cc index 283052ae313..060edd3129d 100644 --- a/gcc/config/riscv/riscv-c.cc +++ b/gcc/config/riscv/riscv-c.cc @@ -220,11 +220,47 @@ riscv_check_builtin_call (location_t loc, vec<location_t> arg_loc, tree fndecl, gcc_unreachable (); } +/* Implement TARGET_RESOLVE_OVERLOADED_BUILTIN. */ +static tree +riscv_resolve_overloaded_builtin (unsigned int uncast_location, tree fndecl, + void *uncast_arglist) +{ + vec<tree, va_gc> empty = {}; + location_t loc = (location_t) uncast_location; + vec<tree, va_gc> *arglist = (vec<tree, va_gc> *) uncast_arglist; + unsigned int code = DECL_MD_FUNCTION_CODE (fndecl); + unsigned int subcode = code >> RISCV_BUILTIN_SHIFT; + tree new_fndecl = NULL_TREE; + + if (!arglist) + arglist = ∅ + + switch (code & RISCV_BUILTIN_CLASS) + { + case RISCV_BUILTIN_GENERAL: + break; + case RISCV_BUILTIN_VECTOR: + new_fndecl = riscv_vector::resolve_overloaded_builtin (loc, subcode, + arglist); + break; + default: + gcc_unreachable (); + } + + if (new_fndecl == NULL_TREE) + return new_fndecl; + + return build_function_call_vec (loc, vNULL, new_fndecl, arglist, NULL, + fndecl); +} + /* Implement REGISTER_TARGET_PRAGMAS. */ void riscv_register_pragmas (void) { + targetm.resolve_overloaded_builtin = riscv_resolve_overloaded_builtin; targetm.check_builtin_call = riscv_check_builtin_call; + c_register_pragma ("riscv", "intrinsic", riscv_pragma_intrinsic); } diff --git a/gcc/config/riscv/riscv-protos.h b/gcc/config/riscv/riscv-protos.h index 6dbf6b9f943..5d2492dd031 100644 --- a/gcc/config/riscv/riscv-protos.h +++ b/gcc/config/riscv/riscv-protos.h @@ -381,6 +381,7 @@ gimple *gimple_fold_builtin (unsigned int, gimple_stmt_iterator *, gcall *); rtx expand_builtin (unsigned int, tree, rtx); bool check_builtin_call (location_t, vec<location_t>, unsigned int, tree, unsigned int, tree *); +tree resolve_overloaded_builtin (location_t, unsigned int, vec<tree, va_gc> *); bool const_vec_all_same_in_range_p (rtx, HOST_WIDE_INT, HOST_WIDE_INT); bool legitimize_move (rtx, rtx); void emit_vlmax_vsetvl (machine_mode, rtx); diff --git a/gcc/config/riscv/riscv-vector-builtins-shapes.cc b/gcc/config/riscv/riscv-vector-builtins-shapes.cc index f8fdec863e6..1c1a2cc9488 100644 --- a/gcc/config/riscv/riscv-vector-builtins-shapes.cc +++ b/gcc/config/riscv/riscv-vector-builtins-shapes.cc @@ -49,6 +49,8 @@ build_one (function_builder &b, const function_group_info &group, group.ops_infos.types[vec_type_idx].index); b.allocate_argument_types (function_instance, argument_types); b.apply_predication (function_instance, return_type, argument_types); + + b.add_overloaded_function (function_instance, *group.shape); b.add_unique_function (function_instance, (*group.shape), return_type, argument_types); } @@ -87,6 +89,22 @@ struct build_base : public function_shape } }; +struct overloaded_base : public build_base +{ + tree resolve (function_resolver &r) const override + { + return r.lookup (); + } +}; + +struct non_overloaded_base : public build_base +{ + tree resolve (function_resolver &) const override + { + gcc_unreachable (); + } +}; + /* vsetvl_def class. */ struct vsetvl_def : public build_base { @@ -525,7 +543,7 @@ struct narrow_alu_def : public build_base }; /* move_def class. Handle vmv.v.v/vmv.v.x. */ -struct move_def : public build_base +struct move_def : public overloaded_base { char *get_name (function_builder &b, const function_instance &instance, bool overloaded_p) const override diff --git a/gcc/config/riscv/riscv-vector-builtins.cc b/gcc/config/riscv/riscv-vector-builtins.cc index 6d99f970ead..4f6fbdc3e28 100644 --- a/gcc/config/riscv/riscv-vector-builtins.cc +++ b/gcc/config/riscv/riscv-vector-builtins.cc @@ -80,6 +80,10 @@ public: /* The decl itself. */ tree GTY ((skip)) decl; + + /* True if the decl represents an overloaded function that needs to be + resolved by function_resolver. */ + bool overloaded_p; }; /* Hash traits for registered_function. */ @@ -3196,6 +3200,77 @@ function_instance::could_trap_p () const return false; } +/* Try to get the non-overloaded function instance. + After we register the overloaded the functions, the registered functions + table may look like: + + +--------+---------------------------+-------------------+ + | index | name | kind | + +--------+---------------------------+-------------------+ + | 124733 | __riscv_vmv_v | Overloaded | <- Hook fun code + +--------+---------------------------+-------------------+ + | 124735 | __riscv_vmv_v_v_i8mf8 | Non-overloaded | + +--------+---------------------------+-------------------+ + | 124737 | __riscv_vmv_v | Placeholder | + +--------+---------------------------+-------------------+ + | 124739 | __riscv_vmv_v | Overloaded | + +--------+---------------------------+-------------------+ + | 124741 | __riscv_vmv_v_v_i8mf4 | Non-overloaded | + +--------+---------------------------+-------------------+ + | 124743 | __riscv_vmv_v | Placeholder | + +--------+---------------------------+-------------------+ + | 124745 | __riscv_vmv_v | Overloaded | + +--------+---------------------------+-------------------+ + | 124747 | __riscv_vmv_v_v_i8mf2 | Non-overloaded | + +--------+---------------------------+-------------------+ + | 124749 | __riscv_vmv_v | Placeholder | + +--------+---------------------------+-------------------+ + | 124751 | __riscv_vmv_v | Overloaded | + +--------+---------------------------+-------------------+ + | 124753 | __riscv_vmv_v_v_i8m1 | Non-overloaded | + +--------+---------------------------+-------------------+ + | 124755 | __riscv_vmv_v | Placeholder | + +--------+---------------------------+-------------------+ + + When we resolve the overloaded API from the hook, we always get the first + function code of one API group (aka vmv_v as above table). We will search + start from that index to find the only one non-overloaded API with exactly + the same arglist. Or NULL instance will be returned. + */ +function_instance * +function_base::get_non_overloaded_instance (unsigned int code, + vec<tree, va_gc> &arglist) const +{ + unsigned int code_limit = vec_safe_length (registered_functions); + + for (unsigned fun_code = code; fun_code < code_limit; fun_code++) + { + registered_function *rfun = (*registered_functions)[fun_code]; + function_instance instance = rfun->instance; + + if (rfun->overloaded_p) + continue; + + unsigned k; + const rvv_arg_type_info *args = instance.op_info->args; + + for (k = 0; args[k].base_type != NUM_BASE_TYPES; k++) + { + if (k >= arglist.length ()) + break; + + if (TYPE_MODE (instance.get_arg_type (k)) + != TYPE_MODE (TREE_TYPE (arglist[k]))) + break; + } + + if (args[k].base_type == NUM_BASE_TYPES) + return &rfun->instance; + } + + return NULL; +} + function_builder::function_builder () { m_direct_overloads = lang_GNU_CXX (); @@ -3357,7 +3432,8 @@ function_builder::get_attributes (const function_instance &instance) registered_function & function_builder::add_function (const function_instance &instance, const char *name, tree fntype, tree attrs, - bool placeholder_p) + bool placeholder_p, + bool overloaded_p = false) { unsigned int code = vec_safe_length (registered_functions); code = (code << RISCV_BUILTIN_SHIFT) + RISCV_BUILTIN_VECTOR; @@ -3383,6 +3459,7 @@ function_builder::add_function (const function_instance &instance, registered_function &rfn = *ggc_alloc<registered_function> (); rfn.instance = instance; rfn.decl = decl; + rfn.overloaded_p = overloaded_p; vec_safe_push (registered_functions, &rfn); return rfn; @@ -3432,6 +3509,26 @@ function_builder::add_unique_function (const function_instance &instance, obstack_free (&m_string_obstack, name); } +void +function_builder::add_overloaded_function (const function_instance &instance, + const function_shape *shape) +{ + if (!check_required_extensions (instance)) + return; + + char *name = shape->get_name (*this, instance, true); + + if (name) + { + /* To avoid API conflicting, take void return type and void argument + for the overloaded function. */ + tree fntype = build_function_type (void_type_node, void_list_node); + add_function (instance, name, fntype, NULL_TREE, m_direct_overloads, + true); + obstack_free (&m_string_obstack, name); + } +} + function_call_info::function_call_info (location_t location_in, const function_instance &instance_in, tree fndecl_in) @@ -3852,6 +3949,13 @@ function_checker::function_checker (location_t location, m_nargs (nargs), m_args (args) {} +function_resolver::function_resolver (location_t location, + const function_instance &instance, + tree fndecl, + vec<tree, va_gc> &arglist) + : function_call_info (location, instance, fndecl), m_arglist (arglist) +{} + /* Report that LOCATION has a call to FNDECL in which argument ARGNO was not an integer constant expression. ARGNO counts from zero. */ void @@ -3967,6 +4071,39 @@ function_checker::check () return shape->check (*this); } +unsigned int +function_resolver::get_sub_code () +{ + unsigned int fun_code = DECL_MD_FUNCTION_CODE (fndecl); + + return fun_code >> RISCV_BUILTIN_SHIFT; +} + +tree +function_resolver::resolve () +{ + return shape->resolve (*this); +} + +tree +function_resolver::lookup () +{ + unsigned int fun_code = get_sub_code (); + function_instance *instance + = base->get_non_overloaded_instance (fun_code, m_arglist); + + if (!instance) + return NULL_TREE; + + hashval_t hash = instance->hash (); + registered_function *rfun = function_table->find_with_hash (*instance, hash); + + if (!rfun) + return NULL_TREE; + + return rfun->decl; +} + inline hashval_t registered_function_hasher::hash (value_type value) { @@ -4196,6 +4333,22 @@ check_builtin_call (location_t location, vec<location_t>, unsigned int code, TREE_TYPE (rfn.decl), nargs, args).check (); } +tree +resolve_overloaded_builtin (location_t loc, unsigned int code, + vec<tree, va_gc> *arglist) +{ + if (code >= vec_safe_length (registered_functions)) + return NULL_TREE; + + const registered_function *rfun = (*registered_functions)[code]; + + if (!rfun || !rfun->overloaded_p) + return NULL_TREE; + + return function_resolver (loc, rfun->instance, rfun->decl, *arglist) + .resolve (); +} + function_instance get_read_vl_instance (void) { diff --git a/gcc/config/riscv/riscv-vector-builtins.h b/gcc/config/riscv/riscv-vector-builtins.h index e358a8e4d91..e20f0f14ce4 100644 --- a/gcc/config/riscv/riscv-vector-builtins.h +++ b/gcc/config/riscv/riscv-vector-builtins.h @@ -277,6 +277,8 @@ public: void apply_predication (const function_instance &, tree, vec<tree> &) const; void add_unique_function (const function_instance &, const function_shape *, tree, vec<tree> &); + void add_overloaded_function (const function_instance &, + const function_shape *); void register_function_group (const function_group_info &); void append_name (const char *); void append_base_name (const char *); @@ -288,7 +290,7 @@ private: tree get_attributes (const function_instance &); registered_function &add_function (const function_instance &, const char *, - tree, tree, bool); + tree, tree, bool, bool); /* True if we should create a separate decl for each instance of an overloaded function, instead of using function_builder. */ @@ -424,6 +426,11 @@ public: /* Expand the given call into rtl. Return the result of the function, or an arbitrary value if the function doesn't return a result. */ virtual rtx expand (function_expander &) const = 0; + + /* Return the non-overloaded function instance from the registered + function table if success, or NULL will be returned. */ + virtual function_instance * get_non_overloaded_instance ( + unsigned int, vec<tree, va_gc> &arglist) const; }; /* A class for checking that the semantic constraints on a function call are @@ -462,6 +469,29 @@ private: tree *m_args; }; +/* A class for resolving an overloaded function call. */ +class function_resolver : public function_call_info +{ +public: + function_resolver (location_t, const function_instance &, tree, + vec<tree, va_gc> &); + + /* Resolve the correlated non-overloaded function from the + the registered_functions table. */ + tree resolve (); + + /* Lookup the non-overloaded function from the registered + function table. */ + tree lookup (); + + /* Return the sub code of the fndecl. */ + unsigned int get_sub_code (); + +private: + /* The arguments to the overloaded function. */ + vec<tree, va_gc> &m_arglist; +}; + /* Classifies functions into "shapes" base on: - Base name of the intrinsic function. @@ -486,6 +516,10 @@ public: /* Check whether the given call is semantically valid. Return true if it is, otherwise report an error and return false. */ virtual bool check (function_checker &) const { return true; } + + /* Try to resolve the overloaded call. Return the non-overloaded + function decl on success and NULL_TREE on failure. */ + virtual tree resolve (function_resolver &) const { return NULL_TREE; }; }; extern const char *const operand_suffixes[NUM_OP_TYPES]; diff --git a/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c new file mode 100644 index 00000000000..56154da155b --- /dev/null +++ b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c @@ -0,0 +1,8 @@ +/* { dg-do compile } */ +/* { dg-options "-march=rv32gcv_zvfh -mabi=ilp32 -O3 -Wno-psabi" } */ + +#include "overloaded_vmv_v.h" + +/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e32,\s*m1,\s*ta,\s*ma} 2 } } */ +/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e16,\s*m1,\s*ta,\s*ma} 2 } } */ +/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e8,\s*m4,\s*tu,\s*ma} 2 } } */ diff --git a/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c new file mode 100644 index 00000000000..f4a63c9585d --- /dev/null +++ b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c @@ -0,0 +1,8 @@ +/* { dg-do compile } */ +/* { dg-options "-march=rv64gcv_zvfh -mabi=lp64 -O3 -Wno-psabi" } */ + +#include "overloaded_vmv_v.h" + +/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e32,\s*m1,\s*ta,\s*ma} 2 } } */ +/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e16,\s*m1,\s*ta,\s*ma} 2 } } */ +/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e8,\s*m4,\s*tu,\s*ma} 2 } } */ diff --git a/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_vmv_v.h b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_vmv_v.h new file mode 100644 index 00000000000..8756c5e17b7 --- /dev/null +++ b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_vmv_v.h @@ -0,0 +1,27 @@ +#include "riscv_vector.h" + +vint32m1_t test_vmv_overloaded_0 (vint32m1_t src, size_t vl) { + return __riscv_vmv_v (src, vl); +} + +vfloat16m1_t test_vmv_overloaded_1 (vfloat16m1_t src, size_t vl) { + return __riscv_vmv_v (src, vl); +} + +vint8m4_t test_vmv_overloaded_2 (vint8m4_t maskedoff, vint8m4_t src, + size_t vl) { + return __riscv_vmv_v_tu (maskedoff, src, vl); +} + +vint32m1_t test_vmv_non_overloaded_0 (vint32m1_t src, size_t vl) { + return __riscv_vmv_v_v_i32m1 (src, vl); +} + +vfloat16m1_t test_vmv_non_overloaded_1 (vfloat16m1_t src, size_t vl) { + return __riscv_vmv_v_v_f16m1 (src, vl); +} + +vint8m4_t test_vmv_non_overloaded_2 (vint8m4_t maskedoff, vint8m4_t src, + size_t vl) { + return __riscv_vmv_v_v_i8m4_tu (maskedoff, src, vl); +} -- 2.34.1