I think it's better to move 'get_non_overloaded_instance' into function_base.

+      /* To avoid API conflicting, we use void return type and void argument
+ for the overloaded function register, like aarch64-sve.  */

Plz rewrite the comments, don't mention aarch64 sve.

Could you run your rvv intrinsic api ci with this patch?
I am worrying that the resolve stuff will destroy the existing APi support.




juzhe.zh...@rivai.ai
 
From: pan2.li
Date: 2023-09-12 15:20
To: gcc-patches
CC: juzhe.zhong; pan2.li; yanzhang.wang; kito.cheng
Subject: [PATCH v2] RISC-V: Implement RESOLVE_OVERLOADED_BUILTIN for RVV 
intrinsic
From: Pan Li <pan2...@intel.com>
 
Update in v2:
 
* Add get_non_overloaded_instance for function instance.
* Fix overload check for policy function.
* Enrich the test cases check.
 
Original log:
 
This patch would like add the framework to support the RVV overloaded
intrinsic API in riscv-xxx-xxx-gcc, like riscv-xxx-xxx-g++ did.
 
However, it almost leverage the hook TARGET_RESOLVE_OVERLOADED_BUILTIN
with below steps.
 
* Register overloaded functions.
* Add function_resolver for overloaded function resolving.
* Add resolve API for function shape with default implementation.
* Implement HOOK for navigating the overloaded API to non-overloaded API.
 
We validated this framework by the vmv_v intrinsic API(s), and we will
add more intrins API support in the underlying patches.
 
gcc/ChangeLog:
 
* config/riscv/riscv-c.cc
(riscv_resolve_overloaded_builtin): New function for the hook.
(riscv_register_pragmas): Register the hook
* config/riscv/riscv-protos.h (resolve_overloaded_builtin): New decl.
* config/riscv/riscv-vector-builtins-shapes.cc (build_one):
Register overloaded function.
(struct overloaded_base): New struct for overloaded shape.
(struct non_overloaded_base): New struct for non overloaded shape.
(struct move_def): Inherit overloaded shape.
* config/riscv/riscv-vector-builtins.cc
(function_instance::get_non_overloaded_instance): New API impl.
(function_builder::add_function): Add overloaded arg.
(function_resolver::function_resolver): New constructor.
(function_builder::add_overloaded_function): New API impl.
(function_resolver::resolve): Ditto.
(function_resolver::lookup): Ditto.
(function_resolver::get_sub_code): Ditto.
(resolve_overloaded_builtin): New function impl.
* config/riscv/riscv-vector-builtins.h:
(class function_resolver): New class.
 
gcc/testsuite/ChangeLog:
 
* gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c: New test.
* gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c: New test.
* gcc.target/riscv/rvv/base/overloaded_vmv_v.h: New test.
 
Signed-off-by: Pan Li <pan2...@intel.com>
---
gcc/config/riscv/riscv-c.cc                   |  36 ++++
gcc/config/riscv/riscv-protos.h               |   1 +
.../riscv/riscv-vector-builtins-shapes.cc     |  20 ++-
gcc/config/riscv/riscv-vector-builtins.cc     | 155 +++++++++++++++++-
gcc/config/riscv/riscv-vector-builtins.h      |  35 +++-
.../riscv/rvv/base/overloaded_rv32_vmv_v.c    |   8 +
.../riscv/rvv/base/overloaded_rv64_vmv_v.c    |   8 +
.../riscv/rvv/base/overloaded_vmv_v.h         |  27 +++
8 files changed, 287 insertions(+), 3 deletions(-)
create mode 100644 
gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c
create mode 100644 
gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c
create mode 100644 gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_vmv_v.h
 
diff --git a/gcc/config/riscv/riscv-c.cc b/gcc/config/riscv/riscv-c.cc
index 283052ae313..060edd3129d 100644
--- a/gcc/config/riscv/riscv-c.cc
+++ b/gcc/config/riscv/riscv-c.cc
@@ -220,11 +220,47 @@ riscv_check_builtin_call (location_t loc, vec<location_t> 
arg_loc, tree fndecl,
   gcc_unreachable ();
}
+/* Implement TARGET_RESOLVE_OVERLOADED_BUILTIN.  */
+static tree
+riscv_resolve_overloaded_builtin (unsigned int uncast_location, tree fndecl,
+   void *uncast_arglist)
+{
+  vec<tree, va_gc> empty = {};
+  location_t loc = (location_t) uncast_location;
+  vec<tree, va_gc> *arglist = (vec<tree, va_gc> *) uncast_arglist;
+  unsigned int code = DECL_MD_FUNCTION_CODE (fndecl);
+  unsigned int subcode = code >> RISCV_BUILTIN_SHIFT;
+  tree new_fndecl = NULL_TREE;
+
+  if (!arglist)
+    arglist = &empty;
+
+  switch (code & RISCV_BUILTIN_CLASS)
+    {
+    case RISCV_BUILTIN_GENERAL:
+      break;
+    case RISCV_BUILTIN_VECTOR:
+      new_fndecl = riscv_vector::resolve_overloaded_builtin (loc, subcode,
+      arglist);
+      break;
+    default:
+      gcc_unreachable ();
+    }
+
+  if (new_fndecl == NULL_TREE)
+    return new_fndecl;
+
+  return build_function_call_vec (loc, vNULL, new_fndecl, arglist, NULL,
+   fndecl);
+}
+
/* Implement REGISTER_TARGET_PRAGMAS.  */
void
riscv_register_pragmas (void)
{
+  targetm.resolve_overloaded_builtin = riscv_resolve_overloaded_builtin;
   targetm.check_builtin_call = riscv_check_builtin_call;
+
   c_register_pragma ("riscv", "intrinsic", riscv_pragma_intrinsic);
}
diff --git a/gcc/config/riscv/riscv-protos.h b/gcc/config/riscv/riscv-protos.h
index 6dbf6b9f943..5d2492dd031 100644
--- a/gcc/config/riscv/riscv-protos.h
+++ b/gcc/config/riscv/riscv-protos.h
@@ -381,6 +381,7 @@ gimple *gimple_fold_builtin (unsigned int, 
gimple_stmt_iterator *, gcall *);
rtx expand_builtin (unsigned int, tree, rtx);
bool check_builtin_call (location_t, vec<location_t>, unsigned int,
   tree, unsigned int, tree *);
+tree resolve_overloaded_builtin (location_t, unsigned int, vec<tree, va_gc> *);
bool const_vec_all_same_in_range_p (rtx, HOST_WIDE_INT, HOST_WIDE_INT);
bool legitimize_move (rtx, rtx);
void emit_vlmax_vsetvl (machine_mode, rtx);
diff --git a/gcc/config/riscv/riscv-vector-builtins-shapes.cc 
b/gcc/config/riscv/riscv-vector-builtins-shapes.cc
index f8fdec863e6..1c1a2cc9488 100644
--- a/gcc/config/riscv/riscv-vector-builtins-shapes.cc
+++ b/gcc/config/riscv/riscv-vector-builtins-shapes.cc
@@ -49,6 +49,8 @@ build_one (function_builder &b, const function_group_info 
&group,
     group.ops_infos.types[vec_type_idx].index);
   b.allocate_argument_types (function_instance, argument_types);
   b.apply_predication (function_instance, return_type, argument_types);
+
+  b.add_overloaded_function (function_instance, *group.shape);
   b.add_unique_function (function_instance, (*group.shape), return_type,
argument_types);
}
@@ -87,6 +89,22 @@ struct build_base : public function_shape
   }
};
+struct overloaded_base : public build_base
+{
+  tree resolve (function_resolver &r) const override
+  {
+    return r.lookup ();
+  }
+};
+
+struct non_overloaded_base : public build_base
+{
+  tree resolve (function_resolver &) const override
+  {
+    gcc_unreachable ();
+  }
+};
+
/* vsetvl_def class.  */
struct vsetvl_def : public build_base
{
@@ -525,7 +543,7 @@ struct narrow_alu_def : public build_base
};
/* move_def class. Handle vmv.v.v/vmv.v.x.  */
-struct move_def : public build_base
+struct move_def : public overloaded_base
{
   char *get_name (function_builder &b, const function_instance &instance,
  bool overloaded_p) const override
diff --git a/gcc/config/riscv/riscv-vector-builtins.cc 
b/gcc/config/riscv/riscv-vector-builtins.cc
index 6d99f970ead..41ecbb48461 100644
--- a/gcc/config/riscv/riscv-vector-builtins.cc
+++ b/gcc/config/riscv/riscv-vector-builtins.cc
@@ -80,6 +80,10 @@ public:
   /* The decl itself.  */
   tree GTY ((skip)) decl;
+
+  /* True if the decl represents an overloaded function that needs to be
+     resolved by function_resolver.  */
+  bool overloaded_p;
};
/* Hash traits for registered_function.  */
@@ -3196,6 +3200,77 @@ function_instance::could_trap_p () const
   return false;
}
+/* Try to get the non-overloaded function instance.
+   After we register the overloaded the functions, the registered functions
+   table may look like:
+
+   +--------+---------------------------+-------------------+
+   | index  | name                      | kind              |
+   +--------+---------------------------+-------------------+
+   | 124733 | __riscv_vmv_v             | Overloaded        | <- Hook fun code
+   +--------+---------------------------+-------------------+
+   | 124735 | __riscv_vmv_v_v_i8mf8     | Non-overloaded    |
+   +--------+---------------------------+-------------------+
+   | 124737 | __riscv_vmv_v             | Placeholder       |
+   +--------+---------------------------+-------------------+
+   | 124739 | __riscv_vmv_v             | Overloaded        |
+   +--------+---------------------------+-------------------+
+   | 124741 | __riscv_vmv_v_v_i8mf4     | Non-overloaded    |
+   +--------+---------------------------+-------------------+
+   | 124743 | __riscv_vmv_v             | Placeholder       |
+   +--------+---------------------------+-------------------+
+   | 124745 | __riscv_vmv_v             | Overloaded        |
+   +--------+---------------------------+-------------------+
+   | 124747 | __riscv_vmv_v_v_i8mf2     | Non-overloaded    |
+   +--------+---------------------------+-------------------+
+   | 124749 | __riscv_vmv_v             | Placeholder       |
+   +--------+---------------------------+-------------------+
+   | 124751 | __riscv_vmv_v             | Overloaded        |
+   +--------+---------------------------+-------------------+
+   | 124753 | __riscv_vmv_v_v_i8m1      | Non-overloaded    |
+   +--------+---------------------------+-------------------+
+   | 124755 | __riscv_vmv_v             | Placeholder       |
+   +--------+---------------------------+-------------------+
+
+   When we resolve the overloaded API from the hook, we always get the first
+   function code of one API group (aka vmv_v as above table). We will search
+   start from that index to find the only one non-overloaded API with exactly
+   the same arglist. Or NULL instance will be returned.
+ */
+function_instance *
+function_instance::get_non_overloaded_instance (unsigned int code,
+ vec<tree, va_gc> &arglist) const
+{
+  unsigned int code_limit = vec_safe_length (registered_functions);
+
+  for (unsigned fun_code = code; fun_code < code_limit; fun_code++)
+    {
+      registered_function *rfun = (*registered_functions)[fun_code];
+      function_instance instance = rfun->instance;
+
+      if (rfun->overloaded_p)
+ continue;
+
+      unsigned k;
+      const rvv_arg_type_info *args = instance.op_info->args;
+
+      for (k = 0; args[k].base_type != NUM_BASE_TYPES; k++)
+ {
+   if (k >= arglist.length ())
+     break;
+
+   if (TYPE_MODE (instance.get_arg_type (k))
+     != TYPE_MODE (TREE_TYPE (arglist[k])))
+     break;
+ }
+
+ if (args[k].base_type == NUM_BASE_TYPES)
+   return &rfun->instance;
+    }
+
+  return NULL;
+}
+
function_builder::function_builder ()
{
   m_direct_overloads = lang_GNU_CXX ();
@@ -3357,7 +3432,8 @@ function_builder::get_attributes (const function_instance 
&instance)
registered_function &
function_builder::add_function (const function_instance &instance,
const char *name, tree fntype, tree attrs,
- bool placeholder_p)
+ bool placeholder_p,
+ bool overloaded_p = false)
{
   unsigned int code = vec_safe_length (registered_functions);
   code = (code << RISCV_BUILTIN_SHIFT) + RISCV_BUILTIN_VECTOR;
@@ -3383,6 +3459,7 @@ function_builder::add_function (const function_instance 
&instance,
   registered_function &rfn = *ggc_alloc<registered_function> ();
   rfn.instance = instance;
   rfn.decl = decl;
+  rfn.overloaded_p = overloaded_p;
   vec_safe_push (registered_functions, &rfn);
   return rfn;
@@ -3432,6 +3509,26 @@ function_builder::add_unique_function (const 
function_instance &instance,
   obstack_free (&m_string_obstack, name);
}
+void
+function_builder::add_overloaded_function (const function_instance &instance,
+    const function_shape *shape)
+{
+  if (!check_required_extensions (instance))
+    return;
+
+  char *name = shape->get_name (*this, instance, true);
+
+  if (name)
+    {
+      /* To avoid API conflicting, we use void return type and void argument
+ for the overloaded function register, like aarch64-sve.  */
+      tree fntype = build_function_type (void_type_node, void_list_node);
+      add_function (instance, name, fntype, NULL_TREE, m_direct_overloads,
+     true);
+      obstack_free (&m_string_obstack, name);
+    }
+}
+
function_call_info::function_call_info (location_t location_in,
const function_instance &instance_in,
tree fndecl_in)
@@ -3852,6 +3949,13 @@ function_checker::function_checker (location_t location,
     m_nargs (nargs), m_args (args)
{}
+function_resolver::function_resolver (location_t location,
+       const function_instance &instance,
+       tree fndecl,
+       vec<tree, va_gc> &arglist)
+  : function_call_info (location, instance, fndecl), m_arglist (arglist)
+{}
+
/* Report that LOCATION has a call to FNDECL in which argument ARGNO
    was not an integer constant expression.  ARGNO counts from zero.  */
void
@@ -3967,6 +4071,39 @@ function_checker::check ()
   return shape->check (*this);
}
+unsigned int
+function_resolver::get_sub_code ()
+{
+  unsigned int fun_code = DECL_MD_FUNCTION_CODE (fndecl);
+
+  return fun_code >> RISCV_BUILTIN_SHIFT;
+}
+
+tree
+function_resolver::resolve ()
+{
+  return shape->resolve (*this);
+}
+
+tree
+function_resolver::lookup ()
+{
+  unsigned int fun_code = get_sub_code ();
+  function_instance *instance = get_non_overloaded_instance (fun_code,
+      m_arglist);
+
+  if (!instance)
+    return NULL_TREE;
+
+  hashval_t hash = instance->hash ();
+  registered_function *rfun = function_table->find_with_hash (*instance, hash);
+
+  if (!rfun)
+    return NULL_TREE;
+
+  return rfun->decl;
+}
+
inline hashval_t
registered_function_hasher::hash (value_type value)
{
@@ -4196,6 +4333,22 @@ check_builtin_call (location_t location, 
vec<location_t>, unsigned int code,
   TREE_TYPE (rfn.decl), nargs, args).check ();
}
+tree
+resolve_overloaded_builtin (location_t loc, unsigned int code,
+     vec<tree, va_gc> *arglist)
+{
+  if (code >= vec_safe_length (registered_functions))
+    return NULL_TREE;
+
+  const registered_function *rfun = (*registered_functions)[code];
+
+  if (!rfun || !rfun->overloaded_p)
+    return NULL_TREE;
+
+  return function_resolver (loc, rfun->instance, rfun->decl, *arglist)
+    .resolve ();
+}
+
function_instance
get_read_vl_instance (void)
{
diff --git a/gcc/config/riscv/riscv-vector-builtins.h 
b/gcc/config/riscv/riscv-vector-builtins.h
index e358a8e4d91..5f8f7a97315 100644
--- a/gcc/config/riscv/riscv-vector-builtins.h
+++ b/gcc/config/riscv/riscv-vector-builtins.h
@@ -256,6 +256,10 @@ public:
   tree get_return_type () const;
   tree get_arg_type (unsigned opno) const;
+  function_instance * get_non_overloaded_instance (unsigned int,
+    vec<tree, va_gc> &arglist)
+    const;
+
   /* The properties of the function.  (The explicit "enum"s are required
      for gengtype.)  */
   const char *base_name;
@@ -277,6 +281,8 @@ public:
   void apply_predication (const function_instance &, tree, vec<tree> &) const;
   void add_unique_function (const function_instance &, const function_shape *,
    tree, vec<tree> &);
+  void add_overloaded_function (const function_instance &,
+ const function_shape *);
   void register_function_group (const function_group_info &);
   void append_name (const char *);
   void append_base_name (const char *);
@@ -288,7 +294,7 @@ private:
   tree get_attributes (const function_instance &);
   registered_function &add_function (const function_instance &, const char *,
-      tree, tree, bool);
+      tree, tree, bool, bool);
   /* True if we should create a separate decl for each instance of an
      overloaded function, instead of using function_builder.  */
@@ -462,6 +468,29 @@ private:
   tree *m_args;
};
+/* A class for resolving an overloaded function call.  */
+class function_resolver : public function_call_info
+{
+public:
+  function_resolver (location_t, const function_instance &, tree,
+      vec<tree, va_gc> &);
+
+  /* Resolve the correlated non-overloaded function from the
+     the registered_functions table.  */
+  tree resolve ();
+
+  /* Lookup the non-overloaded function from the registered
+     function table.  */
+  tree lookup ();
+
+  /* Return the sub code of the fndecl.  */
+  unsigned int get_sub_code ();
+
+private:
+  /* The arguments to the overloaded function.  */
+  vec<tree, va_gc> &m_arglist;
+};
+
/* Classifies functions into "shapes" base on:
    - Base name of the intrinsic function.
@@ -486,6 +515,10 @@ public:
   /* Check whether the given call is semantically valid.  Return true
    if it is, otherwise report an error and return false.  */
   virtual bool check (function_checker &) const { return true; }
+
+  /* Try to resolve the overloaded call.  Return the non-overloaded
+     function decl on success and NULL_TREE on failure.  */
+  virtual tree resolve (function_resolver &) const { return NULL_TREE; };
};
extern const char *const operand_suffixes[NUM_OP_TYPES];
diff --git a/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c 
b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c
new file mode 100644
index 00000000000..56154da155b
--- /dev/null
+++ b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c
@@ -0,0 +1,8 @@
+/* { dg-do compile } */
+/* { dg-options "-march=rv32gcv_zvfh -mabi=ilp32 -O3 -Wno-psabi" } */
+
+#include "overloaded_vmv_v.h"
+
+/* { dg-final { scan-assembler-times 
{vsetvli\s+zero,\s*[ax][0-9]+,\s*e32,\s*m1,\s*ta,\s*ma} 2 } } */
+/* { dg-final { scan-assembler-times 
{vsetvli\s+zero,\s*[ax][0-9]+,\s*e16,\s*m1,\s*ta,\s*ma} 2 } } */
+/* { dg-final { scan-assembler-times 
{vsetvli\s+zero,\s*[ax][0-9]+,\s*e8,\s*m4,\s*tu,\s*ma} 2 } } */
diff --git a/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c 
b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c
new file mode 100644
index 00000000000..f4a63c9585d
--- /dev/null
+++ b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c
@@ -0,0 +1,8 @@
+/* { dg-do compile } */
+/* { dg-options "-march=rv64gcv_zvfh -mabi=lp64 -O3 -Wno-psabi" } */
+
+#include "overloaded_vmv_v.h"
+
+/* { dg-final { scan-assembler-times 
{vsetvli\s+zero,\s*[ax][0-9]+,\s*e32,\s*m1,\s*ta,\s*ma} 2 } } */
+/* { dg-final { scan-assembler-times 
{vsetvli\s+zero,\s*[ax][0-9]+,\s*e16,\s*m1,\s*ta,\s*ma} 2 } } */
+/* { dg-final { scan-assembler-times 
{vsetvli\s+zero,\s*[ax][0-9]+,\s*e8,\s*m4,\s*tu,\s*ma} 2 } } */
diff --git a/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_vmv_v.h 
b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_vmv_v.h
new file mode 100644
index 00000000000..8756c5e17b7
--- /dev/null
+++ b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_vmv_v.h
@@ -0,0 +1,27 @@
+#include "riscv_vector.h"
+
+vint32m1_t test_vmv_overloaded_0 (vint32m1_t src, size_t vl) {
+  return __riscv_vmv_v (src, vl);
+}
+
+vfloat16m1_t test_vmv_overloaded_1 (vfloat16m1_t src, size_t vl) {
+  return __riscv_vmv_v (src, vl);
+}
+
+vint8m4_t test_vmv_overloaded_2 (vint8m4_t maskedoff, vint8m4_t src,
+ size_t vl) {
+  return __riscv_vmv_v_tu (maskedoff, src, vl);
+}
+
+vint32m1_t test_vmv_non_overloaded_0 (vint32m1_t src, size_t vl) {
+  return __riscv_vmv_v_v_i32m1 (src, vl);
+}
+
+vfloat16m1_t test_vmv_non_overloaded_1 (vfloat16m1_t src, size_t vl) {
+  return __riscv_vmv_v_v_f16m1 (src, vl);
+}
+
+vint8m4_t test_vmv_non_overloaded_2 (vint8m4_t maskedoff, vint8m4_t src,
+      size_t vl) {
+  return __riscv_vmv_v_v_i8m4_tu (maskedoff, src, vl);
+}
-- 
2.34.1
 
 

Reply via email to