This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e159487b0e [REFACTOR][IR] attrs.h follow-up cleanup: drop legacy
vtable / rename / phase out AttrFieldInfo (#19615)
e159487b0e is described below
commit e159487b0e4131b6874622bf03c546e837ae84c6
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue May 26 22:09:35 2026 -0400
[REFACTOR][IR] attrs.h follow-up cleanup: drop legacy vtable / rename /
phase out AttrFieldInfo (#19615)
## Summary
Follow-up to #19607 that continues trimming `attrs.h` and adjacent
files. The six commits land independently and each builds clean.
- Phase out `OpNode::arguments` and `AttrFieldInfo` — the field stored
metadata that no Python tooling, test, or C++ caller (beyond internal
sanity checks) read; removing it deletes `AttrFieldInfo` plus ~335
chained `.add_argument(...)` calls. The remaining 12 internal consumers
now read `op->num_inputs` and report indexed inputs (`input[i]`).
- Drop the (unused) virtual destructor on `BaseAttrsNode` (ffi::Object
uses a captured-typed deleter, no virtual dispatch needed) and inline
the trivial 3-line `DictAttrs(Map)` constructor into the header.
- Rename `BaseAttrsNode` → `AttrsNode`; the `Base` prefix existed only
to distinguish from the `AttrsNodeReflAdapter` shim that #19607
removed. The `"ir.Attrs"` FFI registry key is unchanged.
- Promote `DictAttrs` to NOTNULLABLE
(`TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE` + COW macro). The
no-arg `DictAttrs()` constructor already created an empty backing,
so every existing call site already produced a defined object;
~15 defensive `attrs.defined()` checks (and a defensive Python `None`
fallback in `Function`) are now redundant.
- Inline the `WithAttr(DictAttrs, ...)` / `WithAttrs(DictAttrs, ...)`
free-function overloads into the TFunc-template wrappers — those
overloads had no external callers (no TVM_DLL, no Python binding).
- Rename `AttrsWithDefaultValues<T>` → `PassConfigWithDefaults<T>` and
move from `attrs.h` to `transform.h`; all 9 consumers are pass-config
classes registered via `TVM_REGISTER_PASS_CONFIG_OPTION`.
`attrs.h` shrinks from 363 → 262 lines.
---
include/tvm/ir/attrs.h | 213 +++++++++++---------------
include/tvm/ir/op.h | 41 ++++-
include/tvm/ir/transform.h | 21 +++
include/tvm/relax/attrs/ccl.h | 12 +-
include/tvm/relax/attrs/create.h | 8 +-
include/tvm/relax/attrs/datatype.h | 8 +-
include/tvm/relax/attrs/distributed.h | 5 +-
include/tvm/relax/attrs/image.h | 12 +-
include/tvm/relax/attrs/index.h | 9 +-
include/tvm/relax/attrs/linear_algebra.h | 8 +-
include/tvm/relax/attrs/manipulate.h | 74 +++++----
include/tvm/relax/attrs/nn.h | 106 +++++++------
include/tvm/relax/attrs/op.h | 21 ++-
include/tvm/relax/attrs/qdq.h | 4 +-
include/tvm/relax/attrs/sampling.h | 4 +-
include/tvm/relax/attrs/search.h | 9 +-
include/tvm/relax/attrs/sorting.h | 12 +-
include/tvm/relax/attrs/statistical.h | 9 +-
include/tvm/relax/attrs/vision.h | 24 +--
include/tvm/target/virtual_device.h | 4 +-
python/tvm/relax/expr.py | 4 +
src/ir/attrs.cc | 35 +----
src/ir/op.cc | 5 +-
src/relax/backend/contrib/clml/codegen.cc | 2 +-
src/relax/backend/contrib/tensorrt/codegen.cc | 2 +-
src/relax/ir/dataflow_matcher.cc | 2 +-
src/relax/ir/expr.cc | 4 -
src/relax/script/printer/function.cc | 5 +-
src/s_tir/transform/hoist_expression.cc | 4 +-
src/s_tir/transform/inject_double_buffer.cc | 2 +-
src/s_tir/transform/loop_partition.cc | 2 +-
src/script/printer/ir/ir.cc | 2 +-
src/target/cuda/codegen_cuda.cc | 2 +-
src/tirx/analysis/verify_tirx_well_formed.cc | 3 +-
src/tirx/ir/function.cc | 4 -
src/tirx/script/printer/buffer.cc | 2 +-
src/tirx/script/printer/function.cc | 10 +-
src/tirx/transform/ir_utils.cc | 4 -
src/tirx/transform/remove_no_op.cc | 5 +-
src/tirx/transform/split_host_device.cc | 2 +-
src/tirx/transform/stmt_simplify.cc | 2 +-
src/tirx/transform/unroll_loop.cc | 2 +-
42 files changed, 341 insertions(+), 368 deletions(-)
diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h
index c549fcdbc1..96eec4616b 100644
--- a/include/tvm/ir/attrs.h
+++ b/include/tvm/ir/attrs.h
@@ -23,7 +23,7 @@
* This module enables declaration of named attributes
* which support default value setup and bound checking.
*
- * \sa BaseAttrsNode, AttrsWithDefaultValues
+ * \sa AttrsNode
*/
#ifndef TVM_IR_ATTRS_H_
#define TVM_IR_ATTRS_H_
@@ -43,59 +43,23 @@
namespace tvm {
-/*!
- * \brief Information about attribute fields in string representations.
- */
-class AttrFieldInfoNode : public ffi::Object {
- public:
- /*! \brief name of the field */
- ffi::String name;
- /*! \brief type docstring information in str. */
- ffi::String type_info;
- /*! \brief detailed description of the type */
- ffi::String description;
-
- static void RegisterReflection() {
- namespace rfl = ffi::reflection;
- rfl::ObjectDef<AttrFieldInfoNode>()
- .def_ro("name", &AttrFieldInfoNode::name)
- .def_ro("type_info", &AttrFieldInfoNode::type_info)
- .def_ro("description", &AttrFieldInfoNode::description);
- }
-
- static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
-
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.AttrFieldInfo", AttrFieldInfoNode,
ffi::Object);
-};
-
-/*! \brief AttrFieldInfo */
-class AttrFieldInfo : public ffi::ObjectRef {
- public:
- TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AttrFieldInfo, ffi::ObjectRef,
AttrFieldInfoNode);
-};
-
/*!
* \brief Base class of all attribute class
- * \note Do not subclass AttrBaseNode directly,
- * subclass AttrsNode instead.
- * \sa AttrsNode
+ * \sa Attrs
*/
-class BaseAttrsNode : public ffi::Object {
+class AttrsNode : public ffi::Object {
public:
- /*! \brief virtual destructor */
- virtual ~BaseAttrsNode() {}
-
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
- TVM_FFI_DECLARE_OBJECT_INFO("ir.Attrs", BaseAttrsNode, ffi::Object);
+ TVM_FFI_DECLARE_OBJECT_INFO("ir.Attrs", AttrsNode, ffi::Object);
};
/*!
- * \brief Managed reference to BaseAttrsNode.
- * \sa AttrsNode, BaseAttrsNode
+ * \brief Managed reference to AttrsNode.
+ * \sa AttrsNode
*/
class Attrs : public ffi::ObjectRef {
public:
- TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Attrs, ffi::ObjectRef,
BaseAttrsNode);
+ TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Attrs, ffi::ObjectRef, AttrsNode);
};
/*!
@@ -104,7 +68,7 @@ class Attrs : public ffi::ObjectRef {
* its fields are directly accessible via object.field_name
* like other normal nodes.
*/
-class DictAttrsNode : public BaseAttrsNode {
+class DictAttrsNode : public AttrsNode {
public:
/*! \brief internal attrs map */
ffi::Map<ffi::String, ffi::Any> dict;
@@ -115,28 +79,70 @@ class DictAttrsNode : public BaseAttrsNode {
}
// type info
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.DictAttrs", DictAttrsNode,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.DictAttrs", DictAttrsNode, AttrsNode);
};
/*!
* \brief Managed reference to DictAttrsNode
* \sa DictAttrsNode.
+ *
+ * \note DictAttrs is NOTNULLABLE: every instance must hold a backing
+ * DictAttrsNode. The class enforces this end-to-end by:
+ * - the default constructor (no args) allocating an empty backing,
+ * - the copy/move ctors and assignments leaving the moved-from
+ * instance in a defined-but-empty state rather than null,
+ * - the FFI type traits rejecting None at deserialization boundaries
+ * (since `_type_is_nullable == false`), and
+ * - the FFI lambda for ``ir.IRModule`` explicitly normalizing a
+ * missing/None attrs argument to ``DictAttrs()`` before forwarding
+ * to the C++ constructor.
+ * Callers (including third-party code via templates like ``WithAttr``)
+ * can therefore rely on ``attrs->dict`` being safe to dereference
+ * without a ``.defined()`` guard.
*/
class DictAttrs : public Attrs {
public:
/*!
- * \brief constructor with UnsafeInit
+ * \brief Construct a DictAttrs backed by DictAttrsNode.
+ *
+ * The no-argument form constructs an empty (but always defined) DictAttrs.
+ * \param dict The attributes.
+ */
+ explicit DictAttrs(ffi::Map<ffi::String, Any> dict = {}) {
+ ffi::ObjectPtr<DictAttrsNode> n = ffi::make_object<DictAttrsNode>();
+ n->dict = std::move(dict);
+ data_ = std::move(n);
+ }
+
+ /*!
+ * \brief Move constructor that leaves the source in a defined-but-empty
+ * state rather than null, preserving the NOTNULLABLE invariant
+ * even after `std::move`.
*/
- explicit DictAttrs(ffi::UnsafeInit tag) : Attrs(tag) {}
+ DictAttrs(DictAttrs&& other) noexcept : Attrs(ffi::UnsafeInit{}) {
+ data_ = std::move(other.data_);
+ other.data_ = ffi::make_object<DictAttrsNode>();
+ }
+
/*!
- * \brief Consruct a Attrs backed by DictAttrsNode.
- * \param dict The attributes.
+ * \brief Move assignment that leaves the source in a defined-but-empty
+ * state rather than null, preserving the NOTNULLABLE invariant
+ * even after `std::move`.
*/
- TVM_DLL explicit DictAttrs(ffi::Map<ffi::String, Any> dict = {});
+ DictAttrs& operator=(DictAttrs&& other) noexcept {
+ if (this != &other) {
+ data_ = std::move(other.data_);
+ other.data_ = ffi::make_object<DictAttrsNode>();
+ }
+ return *this;
+ }
+
+ // Explicit copy ctor/assign defaults. Declaring the move members above
+ // would otherwise suppress the implicit copy members.
+ DictAttrs(const DictAttrs& other) = default;
+ DictAttrs& operator=(const DictAttrs& other) = default;
// Utils for accessing attributes
- // This needs to be on DictAttrs, not DictAttrsNode because we return the
default
- // value if DictAttrsNode is not defined.
/*!
* \brief Get a function attribute.
*
@@ -160,8 +166,7 @@ class DictAttrs : public Attrs {
ffi::Optional<TObjectRef> GetAttr(
const std::string& attr_key,
ffi::Optional<TObjectRef> default_value =
ffi::Optional<TObjectRef>(std::nullopt)) const {
- if (!defined()) return default_value;
- const DictAttrsNode* node = this->as<DictAttrsNode>();
+ const DictAttrsNode* node = get();
auto it = node->dict.find(attr_key);
if (it != node->dict.end()) {
return (*it).second.cast<TObjectRef>();
@@ -197,57 +202,19 @@ class DictAttrs : public Attrs {
return GetAttr<int64_t>(attr_key, 0).value_or(0) != 0;
}
- explicit DictAttrs(::tvm::ffi::ObjectPtr<DictAttrsNode> n) : Attrs(n) {}
- DictAttrs(const DictAttrs&) = default;
- DictAttrs(DictAttrs&&) = default;
- DictAttrs& operator=(const DictAttrs&) = default;
- DictAttrs& operator=(DictAttrs&&) = default;
- const DictAttrsNode* operator->() const { return static_cast<const
DictAttrsNode*>(data_.get()); }
- const DictAttrsNode* get() const { return operator->(); }
+ // Inline-expand TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE here, minus
+ // the default copy/move it normally injects (we define our own move members
+ // above so the moved-from instance stays defined-but-empty).
+ explicit DictAttrs(::tvm::ffi::UnsafeInit tag) : Attrs(tag) {}
+ using __PtrType =
+ std::conditional_t<DictAttrsNode::_type_mutable, DictAttrsNode*, const
DictAttrsNode*>;
+ __PtrType operator->() const { return static_cast<__PtrType>(data_.get()); }
+ __PtrType get() const { return static_cast<__PtrType>(data_.get()); }
+ static constexpr bool _type_is_nullable = false;
using ContainerType = DictAttrsNode;
TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode);
};
-/*!
- * \brief Copy the DictAttrs, but overrides attributes with the
- * entries from \p attrs.
- *
- * \param attrs The DictAttrs to update
- *
- * \param new_attrs Key/values attributes to add to \p attrs.
- *
- * \returns The new DictAttrs with updated attributes.
- */
-DictAttrs WithAttrs(DictAttrs attrs, ffi::Map<ffi::String, Any> new_attrs);
-
-/*!
- * \brief Copy the DictAttrs, but overrides a single attribute.
- *
- * \param attrs The DictAttrs to update
- *
- * \param key The update to insert or update.
- *
- * \param value The new value of the attribute
- *
- * \returns The new DictAttrs with updated attributes.
- */
-DictAttrs WithAttr(DictAttrs attrs, ffi::String key, Any value);
-
-inline DictAttrs WithAttr(DictAttrs attrs, const std::string& key, Any value) {
- return WithAttr(std::move(attrs), ffi::String(key), std::move(value));
-}
-
-/*!
- * \brief Copy the DictAttrs, but without a specific attribute.
- *
- * \param attrs The DictAttrs to update
- *
- * \param key The key to remove
- *
- * \returns The new DictAttrs with updated attributes.
- */
-DictAttrs WithoutAttr(DictAttrs attrs, const std::string& key);
-
/*!
* \brief Copy the function or module, but overrides
* the attribute value key with the value.
@@ -280,7 +247,10 @@ inline TFunc WithAttr(TFunc input, const std::string&
attr_key, Any attr_value)
using TNode = typename TFunc::ContainerType;
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
TNode* node = input.CopyOnWrite();
- node->attrs = WithAttr(std::move(node->attrs), attr_key, attr_value);
+ // node->attrs is NOTNULLABLE by contract, but defend against a caller
+ // that left a moved-from DictAttrs in place by re-initializing here.
+ if (!node->attrs.defined()) node->attrs = DictAttrs();
+ node->attrs.CopyOnWrite()->dict.Set(attr_key, std::move(attr_value));
return input;
}
@@ -298,10 +268,15 @@ template <typename TFunc>
inline TFunc WithAttrs(TFunc input, ffi::Map<ffi::String, Any> attrs) {
using TNode = typename TFunc::ContainerType;
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
+ if (attrs.empty()) return input;
TNode* node = input.CopyOnWrite();
-
- node->attrs = WithAttrs(std::move(node->attrs), attrs);
-
+ // node->attrs is NOTNULLABLE by contract, but defend against a caller
+ // that left a moved-from DictAttrs in place by re-initializing here.
+ if (!node->attrs.defined()) node->attrs = DictAttrs();
+ auto* dict_node = node->attrs.CopyOnWrite();
+ for (const auto& [k, v] : attrs) {
+ dict_node->dict.Set(k, v);
+ }
return input;
}
@@ -335,29 +310,17 @@ template <typename TFunc>
inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) {
using TNode = typename TFunc::ContainerType;
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
-
TNode* node = input.CopyOnWrite();
- node->attrs = WithoutAttr(std::move(node->attrs), attr_key);
-
+ // node->attrs is NOTNULLABLE by contract, but defend against a caller
+ // that left a moved-from DictAttrs in place; nothing to erase from an
+ // empty dict.
+ if (!node->attrs.defined()) {
+ node->attrs = DictAttrs();
+ return input;
+ }
+ node->attrs.CopyOnWrite()->dict.erase(attr_key);
return input;
}
-/*!
- * \brief Create an object with all default values, using the reflection
defaults.
- * \tparam TObj the ObjectRef type to be created.
- * \return An instance with all reflection-defined default values applied.
- */
-template <typename TObj>
-inline TObj AttrsWithDefaultValues() {
- static_assert(std::is_base_of_v<ffi::ObjectRef, TObj>, "Can only create
ObjectRef-derived types");
- using ContainerType = typename TObj::ContainerType;
- static auto finit_object =
ffi::Function::GetGlobalRequired("ffi.MakeObjectFromPackedArgs");
- AnyView packed_args[1];
- packed_args[0] = ContainerType::RuntimeTypeIndex();
- ffi::Any rv;
- finit_object.CallPacked(ffi::PackedArgs(packed_args, 1), &rv);
- return rv.cast<TObj>();
-}
-
} // namespace tvm
#endif // TVM_IR_ATTRS_H_
diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h
index dc8f99cd47..3fd39c1060 100644
--- a/include/tvm/ir/op.h
+++ b/include/tvm/ir/op.h
@@ -44,6 +44,41 @@ namespace tvm {
template <typename>
class OpAttrMap;
+/*!
+ * \brief Information about an input field of an Op (name, type, description).
+ *
+ * Populated via OpRegEntry::add_argument and consumed both by
+ * internal sanity checks / error messages and by external tooling
+ * that wants to introspect an Op's argument schema.
+ */
+class ArgumentInfoNode : public ffi::Object {
+ public:
+ /*! \brief name of the field */
+ ffi::String name;
+ /*! \brief type docstring information in str. */
+ ffi::String type_info;
+ /*! \brief detailed description of the type */
+ ffi::String description;
+
+ static void RegisterReflection() {
+ namespace rfl = ffi::reflection;
+ rfl::ObjectDef<ArgumentInfoNode>()
+ .def_ro("name", &ArgumentInfoNode::name)
+ .def_ro("type_info", &ArgumentInfoNode::type_info)
+ .def_ro("description", &ArgumentInfoNode::description);
+ }
+
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
+
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.ArgumentInfo", ArgumentInfoNode,
ffi::Object);
+};
+
+/*! \brief Managed reference to ArgumentInfoNode. */
+class ArgumentInfo : public ffi::ObjectRef {
+ public:
+ TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ArgumentInfo, ffi::ObjectRef,
ArgumentInfoNode);
+};
+
// TODO(tvm-team): migrate low-level intrinsics to use Op
/*!
* \brief Primitive Op(builtin intrinsics)
@@ -68,7 +103,7 @@ class OpNode : public RelaxExprNode {
*/
ffi::String description;
/* \brief Information of input arguments to the operator */
- ffi::Array<AttrFieldInfo> arguments;
+ ffi::Array<ArgumentInfo> arguments;
/*!
* \brief The type key of the attribute field
* This can be empty, in which case it defaults to anything.
@@ -330,11 +365,11 @@ inline OpRegEntry& OpRegEntry::describe(const
std::string& descr) { // NOLINT(*
inline OpRegEntry& OpRegEntry::add_argument(const std::string& name, const
std::string& type,
const std::string& description) {
- auto n = ffi::make_object<AttrFieldInfoNode>();
+ auto n = ffi::make_object<ArgumentInfoNode>();
n->name = name;
n->type_info = type;
n->description = description;
- get()->arguments.push_back(AttrFieldInfo(n));
+ get()->arguments.push_back(ArgumentInfo(n));
return *this;
}
diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h
index 436987ae78..f929f1654b 100644
--- a/include/tvm/ir/transform.h
+++ b/include/tvm/ir/transform.h
@@ -57,6 +57,7 @@
#define TVM_IR_TRANSFORM_H_
#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/creator.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/string.h>
@@ -66,6 +67,7 @@
#include <tvm/ir/with_context.h>
#include <string>
+#include <type_traits>
#include <utility>
namespace tvm {
@@ -300,6 +302,25 @@ class PassContext : public ffi::ObjectRef {
friend class With<PassContext>;
};
+/*!
+ * \brief Create a pass-config object with all default values, using the
+ * reflection defaults.
+ * \tparam TConfig the ObjectRef type to be created.
+ * \return An instance with all reflection-defined default values applied.
+ */
+template <typename TConfig>
+inline TConfig PassConfigWithDefaults() {
+ static_assert(std::is_base_of_v<ffi::ObjectRef, TConfig>,
+ "Can only create ObjectRef-derived types");
+ using ContainerType = typename TConfig::ContainerType;
+ static auto finit_object =
ffi::Function::GetGlobalRequired("ffi.MakeObjectFromPackedArgs");
+ ffi::AnyView packed_args[1];
+ packed_args[0] = ContainerType::RuntimeTypeIndex();
+ ffi::Any rv;
+ finit_object.CallPacked(ffi::PackedArgs(packed_args, 1), &rv);
+ return rv.cast<TConfig>();
+}
+
#define TVM_PASS_CTX_CONFIG_VAR_DEF [[maybe_unused]] static uint32_t
__make_PassContext_tid
/*!
diff --git a/include/tvm/relax/attrs/ccl.h b/include/tvm/relax/attrs/ccl.h
index 7e0624706b..031a1de493 100644
--- a/include/tvm/relax/attrs/ccl.h
+++ b/include/tvm/relax/attrs/ccl.h
@@ -31,7 +31,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes used in allreduce operators */
-struct AllReduceAttrs : public tvm::BaseAttrsNode {
+struct AllReduceAttrs : public tvm::AttrsNode {
ffi::String op_type;
bool in_group;
@@ -45,11 +45,11 @@ struct AllReduceAttrs : public tvm::BaseAttrsNode {
"Whether the reduction operation performs in group or globally
or in group as "
"default.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AllReduceAttrs",
AllReduceAttrs, BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AllReduceAttrs",
AllReduceAttrs, AttrsNode);
}; // struct AllReduceAttrs
/*! \brief Attributes used in allgather operators */
-struct AllGatherAttrs : public tvm::BaseAttrsNode {
+struct AllGatherAttrs : public tvm::AttrsNode {
int num_workers;
bool in_group;
@@ -63,11 +63,11 @@ struct AllGatherAttrs : public tvm::BaseAttrsNode {
"Whether the allgather operation performs in group or globally
or in group as "
"default.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AllGatherAttrs",
AllGatherAttrs, BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AllGatherAttrs",
AllGatherAttrs, AttrsNode);
}; // struct AllGatherAttrs
/*! \brief Attributes used in scatter operators */
-struct ScatterCollectiveAttrs : public tvm::BaseAttrsNode {
+struct ScatterCollectiveAttrs : public tvm::AttrsNode {
int num_workers;
int axis;
@@ -82,7 +82,7 @@ struct ScatterCollectiveAttrs : public tvm::BaseAttrsNode {
"this axis.");
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ScatterCollectiveAttrs",
ScatterCollectiveAttrs,
- BaseAttrsNode);
+ AttrsNode);
}; // struct ScatterCollectiveAttrs
} // namespace relax
diff --git a/include/tvm/relax/attrs/create.h b/include/tvm/relax/attrs/create.h
index 9a9e453263..14a3402f25 100644
--- a/include/tvm/relax/attrs/create.h
+++ b/include/tvm/relax/attrs/create.h
@@ -30,7 +30,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes used in full/full_like, ones/ones_like, and
zeros/zeros_like operators */
-struct InitAttrs : public BaseAttrsNode {
+struct InitAttrs : public AttrsNode {
DataType dtype;
static void RegisterReflection() {
@@ -38,11 +38,11 @@ struct InitAttrs : public BaseAttrsNode {
refl::ObjectDef<InitAttrs>().def_ro("dtype", &InitAttrs::dtype,
"The data type of the created
tensor.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.InitAttrs", InitAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.InitAttrs", InitAttrs,
AttrsNode);
}; // struct InitAttrs
/*! \brief Attributes used in tril and triu operator */
-struct TriluAttrs : public BaseAttrsNode {
+struct TriluAttrs : public AttrsNode {
int k;
static void RegisterReflection() {
@@ -51,7 +51,7 @@ struct TriluAttrs : public BaseAttrsNode {
"k", &TriluAttrs::k,
"The number of diagonals above or below the main diagonal to exclude
or include.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.TriluAttrs", TriluAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.TriluAttrs", TriluAttrs,
AttrsNode);
}; // struct TriluAttrs
} // namespace relax
diff --git a/include/tvm/relax/attrs/datatype.h
b/include/tvm/relax/attrs/datatype.h
index a187059703..f67223edb5 100644
--- a/include/tvm/relax/attrs/datatype.h
+++ b/include/tvm/relax/attrs/datatype.h
@@ -30,25 +30,25 @@ namespace tvm {
namespace relax {
/*! \brief Attributes used in astype operator */
-struct AstypeAttrs : public BaseAttrsNode {
+struct AstypeAttrs : public AttrsNode {
DataType dtype;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<AstypeAttrs>().def_ro("dtype", &AstypeAttrs::dtype,
"Target data type");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AstypeAttrs", AstypeAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AstypeAttrs", AstypeAttrs,
AttrsNode);
}; // struct AstypeAttrs.
/*! \brief Attributes used in wrap_param operator */
-struct WrapParamAttrs : public BaseAttrsNode {
+struct WrapParamAttrs : public AttrsNode {
DataType dtype;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<WrapParamAttrs>().def_ro("dtype", &WrapParamAttrs::dtype,
"Target data type");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.WrapParamAttrs",
WrapParamAttrs, BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.WrapParamAttrs",
WrapParamAttrs, AttrsNode);
}; // struct WrapParamAttrs.
} // namespace relax
diff --git a/include/tvm/relax/attrs/distributed.h
b/include/tvm/relax/attrs/distributed.h
index cce508ef1d..23b698eb36 100644
--- a/include/tvm/relax/attrs/distributed.h
+++ b/include/tvm/relax/attrs/distributed.h
@@ -32,7 +32,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes for redistribute and annotate_sharding operator */
-struct DistributionAttrs : public BaseAttrsNode {
+struct DistributionAttrs : public AttrsNode {
distributed::DeviceMesh device_mesh;
distributed::Placement placement;
@@ -44,8 +44,7 @@ struct DistributionAttrs : public BaseAttrsNode {
.def_ro("placement", &DistributionAttrs::placement,
"The placement of a tensor's distribution plan");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.DistributionAttrs",
DistributionAttrs,
- BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.DistributionAttrs",
DistributionAttrs, AttrsNode);
}; // struct DistributionAttrs
} // namespace relax
diff --git a/include/tvm/relax/attrs/image.h b/include/tvm/relax/attrs/image.h
index 8cc5e36734..eacbea7180 100644
--- a/include/tvm/relax/attrs/image.h
+++ b/include/tvm/relax/attrs/image.h
@@ -30,7 +30,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes used in image resize2d operator */
-struct Resize2DAttrs : public BaseAttrsNode {
+struct Resize2DAttrs : public AttrsNode {
ffi::Array<FloatImm> roi;
ffi::String layout;
ffi::String method;
@@ -75,11 +75,11 @@ struct Resize2DAttrs : public BaseAttrsNode {
"The dtype of the output tensor. It it is not specified, the
output will have the same "
"dtype as input if not specified.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Resize2DAttrs",
Resize2DAttrs, BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Resize2DAttrs",
Resize2DAttrs, AttrsNode);
}; // struct Resize2dAttrs
/*! \brief Attributes used in image resize3d operator */
-struct Resize3DAttrs : public BaseAttrsNode {
+struct Resize3DAttrs : public AttrsNode {
ffi::Array<FloatImm> roi;
ffi::String layout;
ffi::String method;
@@ -124,11 +124,11 @@ struct Resize3DAttrs : public BaseAttrsNode {
"The dtype of the output tensor. It it is not specified, the
output will have the same "
"dtype as input if not specified.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Resize3DAttrs",
Resize3DAttrs, BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Resize3DAttrs",
Resize3DAttrs, AttrsNode);
}; // struct Resize3DAttrs
/*! \brief Attributes used in image grid_sample operator */
-struct GridSampleAttrs : public BaseAttrsNode {
+struct GridSampleAttrs : public AttrsNode {
ffi::String method;
ffi::String layout;
ffi::String padding_mode;
@@ -146,7 +146,7 @@ struct GridSampleAttrs : public BaseAttrsNode {
.def_ro("align_corners", &GridSampleAttrs::align_corners,
"If True, the corner pixels of the input and output tensors
are aligned.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GridSampleAttrs",
GridSampleAttrs, BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GridSampleAttrs",
GridSampleAttrs, AttrsNode);
}; // struct GridSampleAttrs
} // namespace relax
diff --git a/include/tvm/relax/attrs/index.h b/include/tvm/relax/attrs/index.h
index 7b4c446bb8..6133a6f580 100644
--- a/include/tvm/relax/attrs/index.h
+++ b/include/tvm/relax/attrs/index.h
@@ -30,7 +30,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes used in take operator */
-struct TakeAttrs : public BaseAttrsNode {
+struct TakeAttrs : public AttrsNode {
ffi::Optional<int64_t> axis;
ffi::String mode;
@@ -41,11 +41,11 @@ struct TakeAttrs : public BaseAttrsNode {
.def_ro("mode", &TakeAttrs::mode, "The mode for handling out-of-bounds
indices.",
refl::DefaultValue("fast"));
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.TakeAttrs", TakeAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.TakeAttrs", TakeAttrs,
AttrsNode);
}; // struct TakeAttrs
/*! \brief Attributes used in strided_slice operator */
-struct StridedSliceAttrs : public BaseAttrsNode {
+struct StridedSliceAttrs : public AttrsNode {
bool assume_inbound;
static void RegisterReflection() {
@@ -56,8 +56,7 @@ struct StridedSliceAttrs : public BaseAttrsNode {
"out of bound indices will be clipped to the bound.",
refl::DefaultValue(true));
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.StridedSliceAttrs",
StridedSliceAttrs,
- BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.StridedSliceAttrs",
StridedSliceAttrs, AttrsNode);
}; // struct StridedSliceAttrs
} // namespace relax
diff --git a/include/tvm/relax/attrs/linear_algebra.h
b/include/tvm/relax/attrs/linear_algebra.h
index 2627dafcf6..817885edb8 100644
--- a/include/tvm/relax/attrs/linear_algebra.h
+++ b/include/tvm/relax/attrs/linear_algebra.h
@@ -30,7 +30,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes for matmul operator */
-struct MatmulAttrs : public BaseAttrsNode {
+struct MatmulAttrs : public AttrsNode {
DataType out_dtype;
static void RegisterReflection() {
@@ -38,11 +38,11 @@ struct MatmulAttrs : public BaseAttrsNode {
refl::ObjectDef<MatmulAttrs>().def_ro("out_dtype", &MatmulAttrs::out_dtype,
"The data type of the output
tensor");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.MatmulAttrs", MatmulAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.MatmulAttrs", MatmulAttrs,
AttrsNode);
}; // struct MatmulAttrs
/*! \brief Attributes used in einsum operator */
-struct EinsumAttrs : public BaseAttrsNode {
+struct EinsumAttrs : public AttrsNode {
ffi::String subscripts;
static void RegisterReflection() {
@@ -50,7 +50,7 @@ struct EinsumAttrs : public BaseAttrsNode {
refl::ObjectDef<EinsumAttrs>().def_ro("subscripts",
&EinsumAttrs::subscripts,
"The einsum expression string");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.EinsumAttrs", EinsumAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.EinsumAttrs", EinsumAttrs,
AttrsNode);
}; // struct EinsumAttrs
} // namespace relax
diff --git a/include/tvm/relax/attrs/manipulate.h
b/include/tvm/relax/attrs/manipulate.h
index cc651207fa..7897b860e1 100644
--- a/include/tvm/relax/attrs/manipulate.h
+++ b/include/tvm/relax/attrs/manipulate.h
@@ -31,7 +31,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes used in concat operators */
-struct ConcatAttrs : public BaseAttrsNode {
+struct ConcatAttrs : public AttrsNode {
ffi::Optional<int64_t> axis;
static void RegisterReflection() {
@@ -40,11 +40,11 @@ struct ConcatAttrs : public BaseAttrsNode {
"The axis at which the input arrays
are concatenated."
"Should lie in range `[-ndim,
ndim)`.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ConcatAttrs", ConcatAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ConcatAttrs", ConcatAttrs,
AttrsNode);
}; // struct ConcatAttrs
/*! \brief Attributes used in expand_dims operators */
-struct ExpandDimsAttrs : public BaseAttrsNode {
+struct ExpandDimsAttrs : public AttrsNode {
ffi::Array<int64_t> axis;
static void RegisterReflection() {
@@ -55,11 +55,11 @@ struct ExpandDimsAttrs : public BaseAttrsNode {
"All values are required to lie in range `[-data.ndim - 1,
data.ndim]`, "
"with the convention of negative indexing.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ExpandDimsAttrs",
ExpandDimsAttrs, BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ExpandDimsAttrs",
ExpandDimsAttrs, AttrsNode);
}; // struct ExpandDimsAttrs
/*! \brief Attributes used in layout_transform operator */
-struct LayoutTransformAttrs : public BaseAttrsNode {
+struct LayoutTransformAttrs : public AttrsNode {
tirx::IndexMap index_map;
// pad_value is chosen to be of PrimValue type, as it represents constant
TIR POD expression. This
// needs to be revisited in case PrimValue is evolved to represent symbolic
expression in future.
@@ -93,11 +93,11 @@ struct LayoutTransformAttrs : public BaseAttrsNode {
"The separators between axes to regenerate output");
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.LayoutTransformAttrs",
LayoutTransformAttrs,
- BaseAttrsNode);
+ AttrsNode);
}; // struct LayoutTransformAttrs
/*! \brief Attributes used in permute_dims operator */
-struct PermuteDimsAttrs : public BaseAttrsNode {
+struct PermuteDimsAttrs : public AttrsNode {
ffi::Optional<ffi::Array<int64_t>> axes;
static void RegisterReflection() {
@@ -105,12 +105,11 @@ struct PermuteDimsAttrs : public BaseAttrsNode {
refl::ObjectDef<PermuteDimsAttrs>().def_ro(
"axes", &PermuteDimsAttrs::axes, "The target axes order, reverse order
if not specified.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.PermuteDimsAttrs",
PermuteDimsAttrs,
- BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.PermuteDimsAttrs",
PermuteDimsAttrs, AttrsNode);
}; // struct PermuteDimsAttrs
/*! \brief Attributes used in split operator */
-struct SplitAttrs : public BaseAttrsNode {
+struct SplitAttrs : public AttrsNode {
ffi::ObjectRef indices_or_sections;
int axis;
@@ -121,11 +120,11 @@ struct SplitAttrs : public BaseAttrsNode {
"The input array of indices or the number of split sections.")
.def_ro("axis", &SplitAttrs::axis, "The axis to be splitted");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SplitAttrs", SplitAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SplitAttrs", SplitAttrs,
AttrsNode);
}; // struct SplitAttrs
/*! \brief Attributes used in squeeze operators */
-struct SqueezeAttrs : public BaseAttrsNode {
+struct SqueezeAttrs : public AttrsNode {
ffi::Optional<ffi::Array<int64_t>> axis;
static void RegisterReflection() {
@@ -136,11 +135,11 @@ struct SqueezeAttrs : public BaseAttrsNode {
"Else, the dimension in axes get
squeezed."
"It is an error if an axis does not
has dimension 1.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SqueezeAttrs", SqueezeAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SqueezeAttrs", SqueezeAttrs,
AttrsNode);
}; // struct SqueezeAttrs
/*! \brief Attributes used in stack operators */
-struct StackAttrs : public BaseAttrsNode {
+struct StackAttrs : public AttrsNode {
ffi::Optional<int64_t> axis;
static void RegisterReflection() {
@@ -152,11 +151,11 @@ struct StackAttrs : public BaseAttrsNode {
"so it must be in range [-ndim-1, ndim] where ndim is the "
"number of dimensions of the input tensors.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.StackAttrs", StackAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.StackAttrs", StackAttrs,
AttrsNode);
}; // struct StackAttrs
/*! \brief Attributes used in repeat operators */
-struct RepeatAttrs : public BaseAttrsNode {
+struct RepeatAttrs : public AttrsNode {
int repeats;
ffi::Optional<int64_t> axis;
@@ -169,11 +168,11 @@ struct RepeatAttrs : public BaseAttrsNode {
"counting from the backward. By default, use the flattened
input array, and "
"return a flat output array.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.RepeatAttrs", RepeatAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.RepeatAttrs", RepeatAttrs,
AttrsNode);
}; // struct RepeatAttrs
/*! \brief Attributes used in tile operators */
-struct TileAttrs : public BaseAttrsNode {
+struct TileAttrs : public AttrsNode {
ffi::Array<int64_t> repeats;
static void RegisterReflection() {
@@ -181,11 +180,11 @@ struct TileAttrs : public BaseAttrsNode {
refl::ObjectDef<TileAttrs>().def_ro("repeats", &TileAttrs::repeats,
"The number of repetitions of data
along each axis.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.TileAttrs", TileAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.TileAttrs", TileAttrs,
AttrsNode);
}; // struct TileAttrs
/*! \brief Attributes used in flip operators */
-struct FlipAttrs : public BaseAttrsNode {
+struct FlipAttrs : public AttrsNode {
int64_t axis;
static void RegisterReflection() {
@@ -193,11 +192,11 @@ struct FlipAttrs : public BaseAttrsNode {
refl::ObjectDef<FlipAttrs>().def_ro("axis", &FlipAttrs::axis,
"The axis along which to flip over.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.FlipAttrs", FlipAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.FlipAttrs", FlipAttrs,
AttrsNode);
}; // struct FlipAttrs
/*! \brief Attributes used in gather_elements operators */
-struct GatherElementsAttrs : public BaseAttrsNode {
+struct GatherElementsAttrs : public AttrsNode {
int64_t axis;
static void RegisterReflection() {
@@ -207,11 +206,11 @@ struct GatherElementsAttrs : public BaseAttrsNode {
refl::DefaultValue(0));
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GatherElementsAttrs",
GatherElementsAttrs,
- BaseAttrsNode);
+ AttrsNode);
}; // struct GatherElementsAttrs
/*! \brief Attributes used in gather_nd operators */
-struct GatherNDAttrs : public BaseAttrsNode {
+struct GatherNDAttrs : public AttrsNode {
int64_t batch_dims;
static void RegisterReflection() {
@@ -219,11 +218,11 @@ struct GatherNDAttrs : public BaseAttrsNode {
refl::ObjectDef<GatherNDAttrs>().def_ro("batch_dims",
&GatherNDAttrs::batch_dims,
"The number of batch dims.",
refl::DefaultValue(0));
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GatherNDAttrs",
GatherNDAttrs, BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GatherNDAttrs",
GatherNDAttrs, AttrsNode);
}; // struct GatherNDAttrs
/*! \brief Attributes used in index_put operator */
-struct IndexPutAttrs : public BaseAttrsNode {
+struct IndexPutAttrs : public AttrsNode {
bool accumulate;
static void RegisterReflection() {
@@ -235,11 +234,11 @@ struct IndexPutAttrs : public BaseAttrsNode {
"otherwise performs tensor[indices] = values.",
refl::DefaultValue(false));
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.IndexPutAttrs",
IndexPutAttrs, BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.IndexPutAttrs",
IndexPutAttrs, AttrsNode);
}; // struct IndexPutAttrs
/*! \brief Attribute used in meshgrid operator */
-struct MeshgridAttrs : public BaseAttrsNode {
+struct MeshgridAttrs : public AttrsNode {
ffi::Optional<ffi::String> indexing;
static void RegisterReflection() {
@@ -247,11 +246,11 @@ struct MeshgridAttrs : public BaseAttrsNode {
refl::ObjectDef<MeshgridAttrs>().def_ro("indexing",
&MeshgridAttrs::indexing,
"Specifies how the grid dimensions
are ordered.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.MeshgridAttrs",
MeshgridAttrs, BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.MeshgridAttrs",
MeshgridAttrs, AttrsNode);
};
/*! \brief Attributes used in scatter_elements operators */
-struct ScatterElementsAttrs : public BaseAttrsNode {
+struct ScatterElementsAttrs : public AttrsNode {
int64_t axis;
ffi::String reduction;
@@ -266,11 +265,11 @@ struct ScatterElementsAttrs : public BaseAttrsNode {
refl::DefaultValue("update"));
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ScatterElementsAttrs",
ScatterElementsAttrs,
- BaseAttrsNode);
+ AttrsNode);
}; // struct ScatterElementsAttrs
/*! \brief Attributes used in scatter_nd operators */
-struct ScatterNDAttrs : public BaseAttrsNode {
+struct ScatterNDAttrs : public AttrsNode {
ffi::String reduction;
static void RegisterReflection() {
@@ -281,11 +280,11 @@ struct ScatterNDAttrs : public BaseAttrsNode {
"either \"update\", \"add\", \"mul\", \"min\" or \"max\".",
refl::DefaultValue("update"));
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ScatterNDAttrs",
ScatterNDAttrs, BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ScatterNDAttrs",
ScatterNDAttrs, AttrsNode);
}; // struct ScatterNDAttrs
/*! \brief Attributes used in slice_scatter operator */
-struct SliceScatterAttrs : public BaseAttrsNode {
+struct SliceScatterAttrs : public AttrsNode {
int axis;
static void RegisterReflection() {
@@ -294,12 +293,11 @@ struct SliceScatterAttrs : public BaseAttrsNode {
"the dimension to insert the
slice into ",
refl::DefaultValue(0));
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SliceScatterAttrs",
SliceScatterAttrs,
- BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SliceScatterAttrs",
SliceScatterAttrs, AttrsNode);
}; // struct SliceScatterAttrs
/*! \brief Attributes used in one_hot operator */
-struct OneHotAttrs : public BaseAttrsNode {
+struct OneHotAttrs : public AttrsNode {
int depth;
int axis;
@@ -309,7 +307,7 @@ struct OneHotAttrs : public BaseAttrsNode {
.def_ro("depth", &OneHotAttrs::depth, "Depth of the one hot
dimension.")
.def_ro("axis", &OneHotAttrs::axis, "Axis to fill.",
refl::DefaultValue(-1));
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.OneHotAttrs", OneHotAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.OneHotAttrs", OneHotAttrs,
AttrsNode);
}; // struct OneHotAttrs
} // namespace relax
diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
index b483d3e233..52d9c40d74 100644
--- a/include/tvm/relax/attrs/nn.h
+++ b/include/tvm/relax/attrs/nn.h
@@ -30,7 +30,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes used in Conv1d operator */
-struct Conv1DAttrs : public BaseAttrsNode {
+struct Conv1DAttrs : public AttrsNode {
ffi::Array<int64_t> strides;
ffi::Array<int64_t> padding;
ffi::Array<int64_t> dilation;
@@ -70,11 +70,11 @@ struct Conv1DAttrs : public BaseAttrsNode {
.def_ro("out_dtype", &Conv1DAttrs::out_dtype,
"Output data type, set to explicit type under mixed precision
setting");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv1DAttrs", Conv1DAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv1DAttrs", Conv1DAttrs,
AttrsNode);
}; // struct Conv1dAttrs
/*! \brief Attributes used in Conv2d operator */
-struct Conv2DAttrs : public BaseAttrsNode {
+struct Conv2DAttrs : public AttrsNode {
ffi::Array<int64_t> strides;
ffi::Array<int64_t> padding;
ffi::Array<int64_t> dilation;
@@ -116,11 +116,11 @@ struct Conv2DAttrs : public BaseAttrsNode {
.def_ro("out_dtype", &Conv2DAttrs::out_dtype,
"Output data type, set to explicit type under mixed precision
setting");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv2DAttrs", Conv2DAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv2DAttrs", Conv2DAttrs,
AttrsNode);
}; // struct Conv2dAttrs
/*! \brief Attributes used in Conv3d operator */
-struct Conv3DAttrs : public BaseAttrsNode {
+struct Conv3DAttrs : public AttrsNode {
ffi::Array<int64_t> strides;
ffi::Array<int64_t> padding;
ffi::Array<int64_t> dilation;
@@ -164,11 +164,11 @@ struct Conv3DAttrs : public BaseAttrsNode {
.def_ro("out_dtype", &Conv3DAttrs::out_dtype,
"Output data type, set to explicit type under mixed precision
setting");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv3DAttrs", Conv3DAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv3DAttrs", Conv3DAttrs,
AttrsNode);
}; // struct Conv3dAttrs
/*! \brief Attributes used in Conv1DTranspose operator */
-struct Conv1DTransposeAttrs : public BaseAttrsNode {
+struct Conv1DTransposeAttrs : public AttrsNode {
ffi::Array<int64_t> strides;
ffi::Array<int64_t> padding;
ffi::Array<int64_t> output_padding;
@@ -213,11 +213,11 @@ struct Conv1DTransposeAttrs : public BaseAttrsNode {
"Output data type, set to explicit type under mixed precision
setting");
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv1DTransposeAttrs",
Conv1DTransposeAttrs,
- BaseAttrsNode);
+ AttrsNode);
}; // struct Conv1DTransposeAttrs
/*! \brief Attributes used in Conv2d operator */
-struct Conv2DTransposeAttrs : public BaseAttrsNode {
+struct Conv2DTransposeAttrs : public AttrsNode {
ffi::Array<int64_t> strides;
ffi::Array<int64_t> padding;
ffi::Array<int64_t> output_padding;
@@ -264,11 +264,11 @@ struct Conv2DTransposeAttrs : public BaseAttrsNode {
"Output data type, set to explicit type under mixed precision
setting");
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv2DTransposeAttrs",
Conv2DTransposeAttrs,
- BaseAttrsNode);
+ AttrsNode);
}; // struct Conv2DTransposeAttrs
/*! \brief Attributes used in Conv3dTranspose operator */
-struct Conv3DTransposeAttrs : public BaseAttrsNode {
+struct Conv3DTransposeAttrs : public AttrsNode {
ffi::Array<int64_t> strides;
ffi::Array<int64_t> padding;
ffi::Array<int64_t> output_padding;
@@ -317,11 +317,11 @@ struct Conv3DTransposeAttrs : public BaseAttrsNode {
"Output data type, set to explicit type under mixed precision
setting");
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv3DTransposeAttrs",
Conv3DTransposeAttrs,
- BaseAttrsNode);
+ AttrsNode);
}; // struct Conv3DTransposeAttrs
/*! \brief Attributes used in max_pool1d and avg_pool1d operator */
-struct Pool1DAttrs : public BaseAttrsNode {
+struct Pool1DAttrs : public AttrsNode {
ffi::Array<int64_t> pool_size;
ffi::Array<int64_t> strides;
ffi::Array<int64_t> padding;
@@ -358,11 +358,11 @@ struct Pool1DAttrs : public BaseAttrsNode {
"'N', 'C', 'W' stands for batch, channel, and width"
"dimensions respectively. Pooling is applied on the 'W'
dimensions.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Pool1DAttrs", Pool1DAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Pool1DAttrs", Pool1DAttrs,
AttrsNode);
}; // struct Pool1dAttrs
/*! \brief Attributes used in max_pool2d and avg_pool2d operator */
-struct Pool2DAttrs : public BaseAttrsNode {
+struct Pool2DAttrs : public AttrsNode {
ffi::Array<int64_t> pool_size;
ffi::Array<int64_t> strides;
ffi::Array<int64_t> padding;
@@ -401,11 +401,11 @@ struct Pool2DAttrs : public BaseAttrsNode {
"dimensions respectively. Pooling is applied on the 'H' and"
"'W' dimensions.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Pool2DAttrs", Pool2DAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Pool2DAttrs", Pool2DAttrs,
AttrsNode);
}; // struct Pool2dAttrs
/*! \brief Attributes used in max_pool3d and avg_pool3d operator */
-struct Pool3DAttrs : public BaseAttrsNode {
+struct Pool3DAttrs : public AttrsNode {
ffi::Array<int64_t> pool_size;
ffi::Array<int64_t> strides;
ffi::Array<int64_t> padding;
@@ -444,11 +444,11 @@ struct Pool3DAttrs : public BaseAttrsNode {
"dimensions respectively. Pooling is applied on the 'D', 'H'
and"
"'W' dimensions.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Pool3DAttrs", Pool3DAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Pool3DAttrs", Pool3DAttrs,
AttrsNode);
}; // struct Pool3dAttrs
/*! \brief Attributes for 1d adaptive pool operator */
-struct AdaptivePool1DAttrs : public BaseAttrsNode {
+struct AdaptivePool1DAttrs : public AttrsNode {
ffi::Optional<ffi::Array<int64_t>> output_size;
ffi::String layout;
ffi::String out_layout;
@@ -469,11 +469,11 @@ struct AdaptivePool1DAttrs : public BaseAttrsNode {
"'W' dimensions.");
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AdaptivePool1DAttrs",
AdaptivePool1DAttrs,
- BaseAttrsNode);
+ AttrsNode);
}; // struct AdaptivePool1DAttrs
/*! \brief Attributes for 2d adaptive pool operator */
-struct AdaptivePool2DAttrs : public BaseAttrsNode {
+struct AdaptivePool2DAttrs : public AttrsNode {
ffi::Optional<ffi::Array<int64_t>> output_size;
ffi::String layout;
ffi::String out_layout;
@@ -494,11 +494,11 @@ struct AdaptivePool2DAttrs : public BaseAttrsNode {
"'W' dimensions.");
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AdaptivePool2DAttrs",
AdaptivePool2DAttrs,
- BaseAttrsNode);
+ AttrsNode);
}; // struct AdaptivePool2DAttrs
/*! \brief Attributes for 3d adaptive pool operator */
-struct AdaptivePool3DAttrs : public BaseAttrsNode {
+struct AdaptivePool3DAttrs : public AttrsNode {
ffi::Optional<ffi::Array<int64_t>> output_size;
ffi::String layout;
ffi::String out_layout;
@@ -519,11 +519,11 @@ struct AdaptivePool3DAttrs : public BaseAttrsNode {
"'W' dimensions.");
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AdaptivePool3DAttrs",
AdaptivePool3DAttrs,
- BaseAttrsNode);
+ AttrsNode);
}; // struct AdaptivePool3DAttrs
/*! \brief Attributes used in softmax operators */
-struct SoftmaxAttrs : public BaseAttrsNode {
+struct SoftmaxAttrs : public AttrsNode {
int axis;
static void RegisterReflection() {
@@ -531,11 +531,11 @@ struct SoftmaxAttrs : public BaseAttrsNode {
refl::ObjectDef<SoftmaxAttrs>().def_ro("axis", &SoftmaxAttrs::axis,
"The axis to sum over when
computing softmax.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SoftmaxAttrs", SoftmaxAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SoftmaxAttrs", SoftmaxAttrs,
AttrsNode);
};
/*! \brief Attributes used in softmax operators */
-struct LeakyReluAttrs : public BaseAttrsNode {
+struct LeakyReluAttrs : public AttrsNode {
double alpha;
static void RegisterReflection() {
@@ -543,11 +543,11 @@ struct LeakyReluAttrs : public BaseAttrsNode {
refl::ObjectDef<LeakyReluAttrs>().def_ro("alpha", &LeakyReluAttrs::alpha,
"The slope of the negative
part.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.LeakyReluAttrs",
LeakyReluAttrs, BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.LeakyReluAttrs",
LeakyReluAttrs, AttrsNode);
};
/*! \brief Attributes used in softplus operators */
-struct SoftplusAttrs : public BaseAttrsNode {
+struct SoftplusAttrs : public AttrsNode {
double beta;
double threshold;
@@ -559,11 +559,11 @@ struct SoftplusAttrs : public BaseAttrsNode {
.def_ro("threshold", &SoftplusAttrs::threshold,
"Value determining when to use linear approximation for
numerical stability.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SoftplusAttrs",
SoftplusAttrs, BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SoftplusAttrs",
SoftplusAttrs, AttrsNode);
};
/*! \brief Attributes used in PReLU operator */
-struct PReluAttrs : public BaseAttrsNode {
+struct PReluAttrs : public AttrsNode {
int axis;
static void RegisterReflection() {
@@ -571,11 +571,11 @@ struct PReluAttrs : public BaseAttrsNode {
refl::ObjectDef<PReluAttrs>().def_ro("axis", &PReluAttrs::axis,
"The axis along which the alpha
values are applied.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.PReluAttrs", PReluAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.PReluAttrs", PReluAttrs,
AttrsNode);
};
/*! \brief Attributes used in batch_norm operator */
-struct BatchNormAttrs : public BaseAttrsNode {
+struct BatchNormAttrs : public AttrsNode {
int axis;
double epsilon;
bool center;
@@ -598,11 +598,11 @@ struct BatchNormAttrs : public BaseAttrsNode {
.def_ro("training", &BatchNormAttrs::training,
"Whether we are training (i.e., not in eval mode).");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.BatchNormAttrs",
BatchNormAttrs, BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.BatchNormAttrs",
BatchNormAttrs, AttrsNode);
}; // struct BatchNormAttrs
/*! \brief Attributes used in layer_norm operator */
-struct LayerNormAttrs : public BaseAttrsNode {
+struct LayerNormAttrs : public AttrsNode {
ffi::Array<int64_t> axes;
double epsilon;
bool center;
@@ -620,11 +620,11 @@ struct LayerNormAttrs : public BaseAttrsNode {
.def_ro("scale", &LayerNormAttrs::scale,
"Indicating if the gamma scale will be multiplied.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.LayerNormAttrs",
LayerNormAttrs, BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.LayerNormAttrs",
LayerNormAttrs, AttrsNode);
}; // struct LayerNormAttrs
/*! \brief Attributes used in group_norm operator */
-struct GroupNormAttrs : public BaseAttrsNode {
+struct GroupNormAttrs : public AttrsNode {
int num_groups;
int channel_axis;
ffi::Array<int64_t> axes;
@@ -649,11 +649,11 @@ struct GroupNormAttrs : public BaseAttrsNode {
.def_ro("scale", &GroupNormAttrs::scale,
"Indicating if the gamma scale will be multiplied.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GroupNormAttrs",
GroupNormAttrs, BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GroupNormAttrs",
GroupNormAttrs, AttrsNode);
}; // struct GroupNormAttrs
/*! \brief Attributes used in instance_norm operator */
-struct InstanceNormAttrs : public BaseAttrsNode {
+struct InstanceNormAttrs : public AttrsNode {
int channel_axis;
ffi::Array<int64_t> axes;
double epsilon;
@@ -674,12 +674,11 @@ struct InstanceNormAttrs : public BaseAttrsNode {
.def_ro("scale", &InstanceNormAttrs::scale,
"Indicating if the gamma scale will be multiplied.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.InstanceNormAttrs",
InstanceNormAttrs,
- BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.InstanceNormAttrs",
InstanceNormAttrs, AttrsNode);
}; // struct InstanceNormAttrs
/*! \brief Attributes used in rms_norm operator */
-struct RMSNormAttrs : public BaseAttrsNode {
+struct RMSNormAttrs : public AttrsNode {
ffi::Array<int64_t> axes;
double epsilon;
@@ -691,11 +690,11 @@ struct RMSNormAttrs : public BaseAttrsNode {
.def_ro("epsilon", &RMSNormAttrs::epsilon,
"Small float added to variance to avoid dividing by zero");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.RMSNormAttrs", RMSNormAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.RMSNormAttrs", RMSNormAttrs,
AttrsNode);
}; // struct RMSNormAttrs
/*! \brief Attributes used in nll_loss operator */
-struct NLLLossAttrs : public BaseAttrsNode {
+struct NLLLossAttrs : public AttrsNode {
ffi::String reduction;
int ignore_index;
@@ -708,11 +707,11 @@ struct NLLLossAttrs : public BaseAttrsNode {
refl::DefaultValue("mean"))
.def_ro("ignore_index", &NLLLossAttrs::ignore_index, "The target value
to ignore.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.NLLLossAttrs", NLLLossAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.NLLLossAttrs", NLLLossAttrs,
AttrsNode);
}; // struct NLLLossAttrs
/*! \brief Attributes used in dropout operator */
-struct DropoutAttrs : public BaseAttrsNode {
+struct DropoutAttrs : public AttrsNode {
double rate;
static void RegisterReflection() {
@@ -721,11 +720,11 @@ struct DropoutAttrs : public BaseAttrsNode {
"rate", &DropoutAttrs::rate,
"Fraction of the input that gets dropped out during training time");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.DropoutAttrs", DropoutAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.DropoutAttrs", DropoutAttrs,
AttrsNode);
}; // struct DropoutAttrs
/*! \brief Attributes used in Attention operator */
-struct AttentionAttrs : public BaseAttrsNode {
+struct AttentionAttrs : public AttrsNode {
ffi::Optional<FloatImm> scale;
ffi::Optional<ffi::String> causal_mask;
ffi::Optional<IntImm> window_size;
@@ -741,11 +740,11 @@ struct AttentionAttrs : public BaseAttrsNode {
.def_ro("window_size", &AttentionAttrs::window_size,
"The size of the window for sliding-window attention.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AttentionAttrs",
AttentionAttrs, BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AttentionAttrs",
AttentionAttrs, AttrsNode);
}; // struct AttentionAttrs
/*! \brief Attributes used for the padding operator */
-struct PadAttrs : public BaseAttrsNode {
+struct PadAttrs : public AttrsNode {
ffi::Array<int64_t> pad_width;
double pad_value = 0.0;
tvm::ffi::String pad_mode;
@@ -764,11 +763,11 @@ struct PadAttrs : public BaseAttrsNode {
"\"reflect\" pads by reflecting values with respect to the
edges.",
refl::DefaultValue("constant"));
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.PadAttrs", PadAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.PadAttrs", PadAttrs,
AttrsNode);
};
/*! \brief Attributes used for the pixel shuffle operator */
-struct PixelShuffleAttrs : public BaseAttrsNode {
+struct PixelShuffleAttrs : public AttrsNode {
int upscale_factor;
static void RegisterReflection() {
@@ -777,8 +776,7 @@ struct PixelShuffleAttrs : public BaseAttrsNode {
&PixelShuffleAttrs::upscale_factor,
"Scale factor for spatial
upsampling.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.PixelShuffleAttrs",
PixelShuffleAttrs,
- BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.PixelShuffleAttrs",
PixelShuffleAttrs, AttrsNode);
};
} // namespace relax
diff --git a/include/tvm/relax/attrs/op.h b/include/tvm/relax/attrs/op.h
index 54970e0eab..4c1451c3dc 100644
--- a/include/tvm/relax/attrs/op.h
+++ b/include/tvm/relax/attrs/op.h
@@ -31,7 +31,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes used in call_tir_with_grad */
-struct CallTIRWithGradAttrs : public BaseAttrsNode {
+struct CallTIRWithGradAttrs : public AttrsNode {
ffi::String te_grad_name;
ffi::Map<ffi::String, Any> te_grad_kwargs;
@@ -45,11 +45,11 @@ struct CallTIRWithGradAttrs : public BaseAttrsNode {
"The keyword arguments passed to the te gradient function.");
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.CallTIRWithGradAttrs",
CallTIRWithGradAttrs,
- BaseAttrsNode);
+ AttrsNode);
}; // struct CallTIRAttrs
/*! \brief Attributes used in call_tir_inplace */
-struct CallTIRInplaceAttrs : public BaseAttrsNode {
+struct CallTIRInplaceAttrs : public AttrsNode {
/*!
* \brief Indices that describe which input corresponds to which output.
*
@@ -65,11 +65,11 @@ struct CallTIRInplaceAttrs : public BaseAttrsNode {
&CallTIRInplaceAttrs::inplace_indices);
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.CallTIRInplaceAttrs",
CallTIRInplaceAttrs,
- BaseAttrsNode);
+ AttrsNode);
}; // struct CallTIRInplaceAttrs
/*! \brief Attributes used in call_inplace_packed */
-struct CallInplacePackedAttrs : public BaseAttrsNode {
+struct CallInplacePackedAttrs : public AttrsNode {
/*!
* \brief Indices that describe which input corresponds to which output.
*
@@ -85,11 +85,11 @@ struct CallInplacePackedAttrs : public BaseAttrsNode {
&CallInplacePackedAttrs::inplace_indices);
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.CallInplacePackedAttrs",
CallInplacePackedAttrs,
- BaseAttrsNode);
+ AttrsNode);
}; // struct CallInplacePackedAttrs
/*! \brief Attributes used in to_vdevice */
-struct ToVDeviceAttrs : public BaseAttrsNode {
+struct ToVDeviceAttrs : public AttrsNode {
VDevice dst_vdevice;
static void RegisterReflection() {
@@ -97,11 +97,11 @@ struct ToVDeviceAttrs : public BaseAttrsNode {
refl::ObjectDef<ToVDeviceAttrs>().def_ro("dst_vdevice",
&ToVDeviceAttrs::dst_vdevice,
"The destination device where the
data is copied to.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ToVDeviceAttrs",
ToVDeviceAttrs, BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ToVDeviceAttrs",
ToVDeviceAttrs, AttrsNode);
}; // struct ToVDeviceAttrs
/*! \brief Attributes used in hint_on_device */
-struct HintOnDeviceAttrs : public BaseAttrsNode {
+struct HintOnDeviceAttrs : public AttrsNode {
int32_t device_type;
int32_t index;
MemoryScope memory_scope;
@@ -114,8 +114,7 @@ struct HintOnDeviceAttrs : public BaseAttrsNode {
.def_ro("index", &HintOnDeviceAttrs::index, "The device id.")
.def_ro("memory_scope", &HintOnDeviceAttrs::memory_scope, "The device
memory scope.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.HintOnDeviceAttrs",
HintOnDeviceAttrs,
- BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.HintOnDeviceAttrs",
HintOnDeviceAttrs, AttrsNode);
}; // struct HintOnDeviceAttrs
} // namespace relax
diff --git a/include/tvm/relax/attrs/qdq.h b/include/tvm/relax/attrs/qdq.h
index 08bc054dc5..83ec2223c3 100644
--- a/include/tvm/relax/attrs/qdq.h
+++ b/include/tvm/relax/attrs/qdq.h
@@ -30,7 +30,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes for relax.quantize/relax.dequantize operator */
-struct QuantizeAttrs : public BaseAttrsNode {
+struct QuantizeAttrs : public AttrsNode {
DataType out_dtype;
int axis;
@@ -43,7 +43,7 @@ struct QuantizeAttrs : public BaseAttrsNode {
"Default value is -1, which corresponds to the last axis.",
refl::DefaultValue(-1));
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.QuantizeAttrs",
QuantizeAttrs, BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.QuantizeAttrs",
QuantizeAttrs, AttrsNode);
}; // QuantizeAttrs
} // namespace relax
diff --git a/include/tvm/relax/attrs/sampling.h
b/include/tvm/relax/attrs/sampling.h
index 2d7421cc20..11bbfb6eba 100644
--- a/include/tvm/relax/attrs/sampling.h
+++ b/include/tvm/relax/attrs/sampling.h
@@ -30,7 +30,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes used in multinomial_from_uniform operator */
-struct MultinomialFromUniformAttrs : public BaseAttrsNode {
+struct MultinomialFromUniformAttrs : public AttrsNode {
DataType dtype;
static void RegisterReflection() {
@@ -40,7 +40,7 @@ struct MultinomialFromUniformAttrs : public BaseAttrsNode {
refl::DefaultValue(DataType::Int(64)));
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.MultinomialFromUniformAttrs",
- MultinomialFromUniformAttrs,
BaseAttrsNode);
+ MultinomialFromUniformAttrs, AttrsNode);
}; // struct MultinomialFromUniformAttrs
} // namespace relax
diff --git a/include/tvm/relax/attrs/search.h b/include/tvm/relax/attrs/search.h
index 015e5d8edc..6b3ee4860a 100644
--- a/include/tvm/relax/attrs/search.h
+++ b/include/tvm/relax/attrs/search.h
@@ -30,7 +30,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes for search operators */
-struct ArgmaxArgminAttrs : public BaseAttrsNode {
+struct ArgmaxArgminAttrs : public AttrsNode {
ffi::Optional<int64_t> axis;
bool keepdims;
@@ -44,12 +44,11 @@ struct ArgmaxArgminAttrs : public BaseAttrsNode {
"with size "
"one.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ArgmaxArgminAttrs",
ArgmaxArgminAttrs,
- BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ArgmaxArgminAttrs",
ArgmaxArgminAttrs, AttrsNode);
}; // struct ArgmaxArgminAttrs
/*! \brief Attributes for bucketize operator */
-struct BucketizeAttrs : public tvm::BaseAttrsNode {
+struct BucketizeAttrs : public tvm::AttrsNode {
bool out_int32;
bool right;
@@ -61,7 +60,7 @@ struct BucketizeAttrs : public tvm::BaseAttrsNode {
.def_ro("right", &BucketizeAttrs::right,
"Determines the behavior for values in boundaries");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.BucketizeAttrs",
BucketizeAttrs, BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.BucketizeAttrs",
BucketizeAttrs, AttrsNode);
}; // struct BucketizeAttrs
} // namespace relax
diff --git a/include/tvm/relax/attrs/sorting.h
b/include/tvm/relax/attrs/sorting.h
index e32d47239f..e8bf65d55a 100644
--- a/include/tvm/relax/attrs/sorting.h
+++ b/include/tvm/relax/attrs/sorting.h
@@ -31,7 +31,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes used in sort operator */
-struct SortAttrs : public BaseAttrsNode {
+struct SortAttrs : public AttrsNode {
int axis;
bool descending;
@@ -47,11 +47,11 @@ struct SortAttrs : public BaseAttrsNode {
"If it is not specified, it defaults to the ascending order.",
refl::DefaultValue(false));
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SortAttrs", SortAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SortAttrs", SortAttrs,
AttrsNode);
}; // struct SortAttrs
/*! \brief Attributes used in argsort operator */
-struct ArgsortAttrs : public BaseAttrsNode {
+struct ArgsortAttrs : public AttrsNode {
int axis;
bool descending;
DataType dtype;
@@ -70,11 +70,11 @@ struct ArgsortAttrs : public BaseAttrsNode {
.def_ro("dtype", &ArgsortAttrs::dtype, "DType of the output indices.",
refl::DefaultValue(DataType::Void()));
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ArgsortAttrs", ArgsortAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ArgsortAttrs", ArgsortAttrs,
AttrsNode);
}; // struct ArgsortAttrs
/*! \brief Attributes used in topk operator */
-struct TopKAttrs : public BaseAttrsNode {
+struct TopKAttrs : public AttrsNode {
int k;
int axis;
bool largest;
@@ -100,7 +100,7 @@ struct TopKAttrs : public BaseAttrsNode {
.def_ro("dtype", &TopKAttrs::dtype, "Data type of the output indices.",
refl::DefaultValue(DataType::Void()));
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.TopKAttrs", TopKAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.TopKAttrs", TopKAttrs,
AttrsNode);
}; // struct TopKAttrs
} // namespace relax
diff --git a/include/tvm/relax/attrs/statistical.h
b/include/tvm/relax/attrs/statistical.h
index 884946402a..66996c802c 100644
--- a/include/tvm/relax/attrs/statistical.h
+++ b/include/tvm/relax/attrs/statistical.h
@@ -30,7 +30,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes for statistical operators */
-struct StatisticalAttrs : public BaseAttrsNode {
+struct StatisticalAttrs : public AttrsNode {
ffi::Optional<ffi::Array<int64_t>> axis;
bool keepdims;
@@ -44,12 +44,11 @@ struct StatisticalAttrs : public BaseAttrsNode {
"with size "
"one.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.StatisticalAttrs",
StatisticalAttrs,
- BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.StatisticalAttrs",
StatisticalAttrs, AttrsNode);
}; // struct StatisticalAttrs
/*! \brief Attributes used in scan operators like cumsum, cumprod */
-struct ScanopAttrs : public BaseAttrsNode {
+struct ScanopAttrs : public AttrsNode {
ffi::Optional<int64_t> axis;
DataType dtype;
bool exclusive = false;
@@ -66,7 +65,7 @@ struct ScanopAttrs : public BaseAttrsNode {
.def_ro("exclusive", &ScanopAttrs::exclusive, "The first element is
not included",
refl::DefaultValue(false));
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ScanopAttrs", ScanopAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ScanopAttrs", ScanopAttrs,
AttrsNode);
}; // struct ScanopAttrs
} // namespace relax
diff --git a/include/tvm/relax/attrs/vision.h b/include/tvm/relax/attrs/vision.h
index 37ec77cbbf..f4b1830669 100644
--- a/include/tvm/relax/attrs/vision.h
+++ b/include/tvm/relax/attrs/vision.h
@@ -32,7 +32,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes used in AllClassNonMaximumSuppression operator */
-struct AllClassNonMaximumSuppressionAttrs : public BaseAttrsNode {
+struct AllClassNonMaximumSuppressionAttrs : public AttrsNode {
ffi::String output_format;
static void RegisterReflection() {
@@ -43,11 +43,11 @@ struct AllClassNonMaximumSuppressionAttrs : public
BaseAttrsNode {
"consumed by each frontend.");
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AllClassNonMaximumSuppressionAttrs",
- AllClassNonMaximumSuppressionAttrs,
BaseAttrsNode);
+ AllClassNonMaximumSuppressionAttrs,
AttrsNode);
}; // struct AllClassNonMaximumSuppressionAttrs
/*! \brief Attributes used in ROIAlign operator */
-struct ROIAlignAttrs : public BaseAttrsNode {
+struct ROIAlignAttrs : public AttrsNode {
ffi::Array<int64_t> pooled_size;
double spatial_scale;
int sample_ratio;
@@ -68,11 +68,11 @@ struct ROIAlignAttrs : public BaseAttrsNode {
.def_ro("layout", &ROIAlignAttrs::layout, "Dimension ordering of the
input data.")
.def_ro("mode", &ROIAlignAttrs::mode, "Mode for ROI Align. Can be
'avg' or 'max'.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ROIAlignAttrs",
ROIAlignAttrs, BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ROIAlignAttrs",
ROIAlignAttrs, AttrsNode);
}; // struct ROIAlignAttrs
/*! \brief Attributes used in ROIPool operator */
-struct ROIPoolAttrs : public BaseAttrsNode {
+struct ROIPoolAttrs : public AttrsNode {
ffi::Array<int64_t> pooled_size;
double spatial_scale;
ffi::String layout;
@@ -85,11 +85,11 @@ struct ROIPoolAttrs : public BaseAttrsNode {
"Ratio of input feature map height (or width) to raw image
height (or width).")
.def_ro("layout", &ROIPoolAttrs::layout, "Dimension ordering of the
input data.");
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ROIPoolAttrs", ROIPoolAttrs,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ROIPoolAttrs", ROIPoolAttrs,
AttrsNode);
}; // struct ROIPoolAttrs
/*! \brief Attributes used in GetValidCounts operator */
-struct GetValidCountsAttrs : public BaseAttrsNode {
+struct GetValidCountsAttrs : public AttrsNode {
double score_threshold;
int id_index;
int score_index;
@@ -105,11 +105,11 @@ struct GetValidCountsAttrs : public BaseAttrsNode {
"Index of the scores/confidence of boxes.");
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GetValidCountsAttrs",
GetValidCountsAttrs,
- BaseAttrsNode);
+ AttrsNode);
}; // struct GetValidCountsAttrs
/*! \brief Attributes used in NonMaximumSuppression operator */
-struct NonMaximumSuppressionAttrs : public BaseAttrsNode {
+struct NonMaximumSuppressionAttrs : public AttrsNode {
int max_output_size;
double iou_threshold;
bool force_suppress;
@@ -149,11 +149,11 @@ struct NonMaximumSuppressionAttrs : public BaseAttrsNode {
"Score threshold for soft-NMS validity check; 0.0 when
unused.");
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.NonMaximumSuppressionAttrs",
- NonMaximumSuppressionAttrs, BaseAttrsNode);
+ NonMaximumSuppressionAttrs, AttrsNode);
}; // struct NonMaximumSuppressionAttrs
/*! \brief Attributes for multibox_transform_loc (SSD / TFLite-style box
decode). */
-struct MultiboxTransformLocAttrs : public BaseAttrsNode {
+struct MultiboxTransformLocAttrs : public AttrsNode {
bool clip;
double threshold;
ffi::Array<double> variances;
@@ -173,7 +173,7 @@ struct MultiboxTransformLocAttrs : public BaseAttrsNode {
"If false, force output scores[:,0,:] to 0 (background
class).");
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.MultiboxTransformLocAttrs",
- MultiboxTransformLocAttrs, BaseAttrsNode);
+ MultiboxTransformLocAttrs, AttrsNode);
}; // struct MultiboxTransformLocAttrs
} // namespace relax
diff --git a/include/tvm/target/virtual_device.h
b/include/tvm/target/virtual_device.h
index b791387306..83c7f5655a 100644
--- a/include/tvm/target/virtual_device.h
+++ b/include/tvm/target/virtual_device.h
@@ -169,7 +169,7 @@ constexpr int kInvalidDeviceType = -1;
* These operations are needed during device planning.
*/
-class VirtualDeviceNode : public BaseAttrsNode {
+class VirtualDeviceNode : public AttrsNode {
private:
/*!
* \brief The \p DLDeviceType (represented as an int) of the virtual device.
If \p target is
@@ -257,7 +257,7 @@ class VirtualDeviceNode : public BaseAttrsNode {
"The area of memory w.r.t. the virtual device where data is
stored.",
refl::DefaultValue(""));
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("target.VirtualDevice", VirtualDeviceNode,
BaseAttrsNode);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("target.VirtualDevice", VirtualDeviceNode,
AttrsNode);
friend class VirtualDevice;
};
diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py
index 5a75e43b12..6dffaab8f4 100644
--- a/python/tvm/relax/expr.py
+++ b/python/tvm/relax/expr.py
@@ -1010,6 +1010,8 @@ class Function(BaseFunc, Scriptable):
attrs: tvm.ir.DictAttrs | None = None,
span: Span | None = None,
) -> None:
+ if attrs is None:
+ attrs = tvm.ir.DictAttrs({})
self.__init_handle_by_constructor__(
_ffi_api.Function,
params,
@@ -1029,6 +1031,8 @@ class Function(BaseFunc, Scriptable):
span: Span | None = None,
):
"""Construct a relax.Function but without body"""
+ if attrs is None:
+ attrs = tvm.ir.DictAttrs({})
return _ffi_api.FunctionCreateEmpty(params, ret_struct_info, is_pure,
attrs, span) # type: ignore
def __call__(self, *args):
diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc
index e7d9b90828..b58c183c7a 100644
--- a/src/ir/attrs.cc
+++ b/src/ir/attrs.cc
@@ -26,40 +26,9 @@
namespace tvm {
-TVM_FFI_STATIC_INIT_BLOCK() {
- AttrFieldInfoNode::RegisterReflection();
- DictAttrsNode::RegisterReflection();
-}
-
-DictAttrs WithAttrs(DictAttrs attrs, ffi::Map<ffi::String, ffi::Any>
new_attrs) {
- if (new_attrs.empty()) {
- return attrs;
- }
-
- auto* write_ptr = attrs.CopyOnWrite();
- for (const auto& [key, value] : new_attrs) {
- write_ptr->dict.Set(key, value);
- }
- return attrs;
-}
-
-DictAttrs WithAttr(DictAttrs attrs, ffi::String key, ffi::Any value) {
- attrs.CopyOnWrite()->dict.Set(key, value);
- return attrs;
-}
-
-DictAttrs WithoutAttr(DictAttrs attrs, const std::string& key) {
- attrs.CopyOnWrite()->dict.erase(key);
- return attrs;
-}
-
-DictAttrs::DictAttrs(ffi::Map<ffi::String, Any> dict) {
- ffi::ObjectPtr<DictAttrsNode> n = ffi::make_object<DictAttrsNode>();
- n->dict = std::move(dict);
- data_ = std::move(n);
-}
+TVM_FFI_STATIC_INIT_BLOCK() { DictAttrsNode::RegisterReflection(); }
-TVM_FFI_STATIC_INIT_BLOCK() {
tvm::ffi::reflection::ObjectDef<BaseAttrsNode>(); }
+TVM_FFI_STATIC_INIT_BLOCK() { tvm::ffi::reflection::ObjectDef<AttrsNode>(); }
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
diff --git a/src/ir/op.cc b/src/ir/op.cc
index f6078e30d9..3684298e4a 100644
--- a/src/ir/op.cc
+++ b/src/ir/op.cc
@@ -33,7 +33,10 @@
namespace tvm {
-TVM_FFI_STATIC_INIT_BLOCK() { OpNode::RegisterReflection(); }
+TVM_FFI_STATIC_INIT_BLOCK() {
+ ArgumentInfoNode::RegisterReflection();
+ OpNode::RegisterReflection();
+}
using ffi::Any;
using ffi::Function;
diff --git a/src/relax/backend/contrib/clml/codegen.cc
b/src/relax/backend/contrib/clml/codegen.cc
index 5fd04c05bf..c58c2ee9aa 100644
--- a/src/relax/backend/contrib/clml/codegen.cc
+++ b/src/relax/backend/contrib/clml/codegen.cc
@@ -267,7 +267,7 @@ class OpenCLMLJSONSerializer : public JSONSerializer {
auto ctx = transform::PassContext::Current();
auto cfg =
ctx->GetConfig<OpenCLMLCompilerConfig>("relax.ext.clml.options");
if (!cfg.defined()) {
- cfg = AttrsWithDefaultValues<OpenCLMLCompilerConfig>();
+ cfg = transform::PassConfigWithDefaults<OpenCLMLCompilerConfig>();
}
node->SetAttr("clml_version",
static_cast<int64_t>(cfg.value()->clml_version.IntValue()));
}
diff --git a/src/relax/backend/contrib/tensorrt/codegen.cc
b/src/relax/backend/contrib/tensorrt/codegen.cc
index 8720c77b43..7fa6d48bdc 100644
--- a/src/relax/backend/contrib/tensorrt/codegen.cc
+++ b/src/relax/backend/contrib/tensorrt/codegen.cc
@@ -180,7 +180,7 @@ class TensorRTJSONSerializer : public JSONSerializer {
auto ctx = transform::PassContext::Current();
auto cfg =
ctx->GetConfig<TensorRTCompilerConfig>("relax.ext.tensorrt.options");
if (!cfg.defined()) {
- cfg = AttrsWithDefaultValues<TensorRTCompilerConfig>();
+ cfg = transform::PassConfigWithDefaults<TensorRTCompilerConfig>();
}
TVM_FFI_ICHECK_EQ(cfg.value()->tensorrt_version.size(), 3);
ffi::Array<int64_t> tensorrt_version = {cfg.value()->tensorrt_version[0],
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index 22e3a7bbc3..e8eafde317 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -209,7 +209,7 @@ bool DFPatternMatcher::VisitDFPattern_(const
AttrPatternNode* attr_pattern, cons
} else if (auto* op = expr.as<FunctionNode>()) {
matches = true;
for (auto kv : attributes) {
- if (matches && op->attrs.defined() && op->attrs->dict.count(kv.first)) {
+ if (matches && op->attrs->dict.count(kv.first)) {
matches &= ffi::StructuralEqual()(kv.second,
op->attrs->dict[kv.first]);
} else {
matches = false;
diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc
index c2f404d41f..5c2419209b 100644
--- a/src/relax/ir/expr.cc
+++ b/src/relax/ir/expr.cc
@@ -542,10 +542,6 @@ TVM_FFI_STATIC_INIT_BLOCK() {
Function::Function(ffi::Array<Var> params, Expr body,
ffi::Optional<StructInfo> ret_struct_info,
bool is_pure, DictAttrs attrs, Span span) {
- if (!attrs.defined()) {
- attrs = DictAttrs();
- }
-
// Set the function type.
// For function, we take a conservative approach and require the function
type
// to be known at construction time.
diff --git a/src/relax/script/printer/function.cc
b/src/relax/script/printer/function.cc
index e30a2b0bf4..4c0d84f9f6 100644
--- a/src/relax/script/printer/function.cc
+++ b/src/relax/script/printer/function.cc
@@ -84,7 +84,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
// Step 3. Clean up func variables
(*f)->func_vars = nullptr;
// Step 4. Print attributes
- if (n->attrs.defined() && !n->attrs->dict.empty()) {
+ if (!n->attrs->dict.empty()) {
// If the function is a global function and has a global symbol,
// then don't print the global symbol (it will be implicit from not
being private).
// For a function without an IR module whose global symbol
@@ -119,8 +119,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
}
// if the function is global or is not in a module and does not have a
global symbol,
// indicate that it's private
- if (AtTopLevelFunction(d) &&
- (!n->attrs.defined() ||
!n->attrs->dict.count(tvm::attr::kGlobalSymbol))) {
+ if (AtTopLevelFunction(d) &&
!n->attrs->dict.count(tvm::attr::kGlobalSymbol)) {
dec_keys.push_back("private");
dec_values.push_back(LiteralDoc::Boolean(true,
ffi::Optional<AccessPath>()));
}
diff --git a/src/s_tir/transform/hoist_expression.cc
b/src/s_tir/transform/hoist_expression.cc
index 8fe1845029..5cb851ca2a 100644
--- a/src/s_tir/transform/hoist_expression.cc
+++ b/src/s_tir/transform/hoist_expression.cc
@@ -568,7 +568,7 @@ Pass HoistExpression() {
auto cfg = ctx->GetConfig<HoistExpressionConfig>("s_tir.HoistExpression");
if (!cfg.defined()) {
- cfg = AttrsWithDefaultValues<HoistExpressionConfig>();
+ cfg = tvm::transform::PassConfigWithDefaults<HoistExpressionConfig>();
}
n->body = ExpressionHoister::Hoist(std::move(n->body), cfg.value());
return f;
@@ -602,7 +602,7 @@ static Pass HoistIfThenElseImpl() {
return f;
}
if (!cfg.defined()) {
- cfg = AttrsWithDefaultValues<HoistIfThenElseConfig>();
+ cfg = tvm::transform::PassConfigWithDefaults<HoistIfThenElseConfig>();
}
int block_var = static_cast<int>(cfg.value()->support_block_scope_hoisting
? HoistedConditionals::kUsingBlockVar
diff --git a/src/s_tir/transform/inject_double_buffer.cc
b/src/s_tir/transform/inject_double_buffer.cc
index 0c934ddbcd..ac2f25a629 100644
--- a/src/s_tir/transform/inject_double_buffer.cc
+++ b/src/s_tir/transform/inject_double_buffer.cc
@@ -332,7 +332,7 @@ Pass InjectDoubleBuffer() {
auto* n = f.CopyOnWrite();
auto cfg =
ctx->GetConfig<InjectDoubleBufferConfig>("s_tir.InjectDoubleBuffer");
if (!cfg.defined()) {
- cfg = AttrsWithDefaultValues<InjectDoubleBufferConfig>();
+ cfg = tvm::transform::PassConfigWithDefaults<InjectDoubleBufferConfig>();
}
n->body =
DoubleBufferInjector(cfg.value()->split_loop).Inject(std::move(n->body));
return f;
diff --git a/src/s_tir/transform/loop_partition.cc
b/src/s_tir/transform/loop_partition.cc
index bf2dca776c..8eb444dcfd 100644
--- a/src/s_tir/transform/loop_partition.cc
+++ b/src/s_tir/transform/loop_partition.cc
@@ -817,7 +817,7 @@ Pass LoopPartition() {
auto* n = f.CopyOnWrite();
auto cfg = ctx->GetConfig<LoopPartitionConfig>("s_tir.LoopPartition");
if (!cfg.defined()) {
- cfg = AttrsWithDefaultValues<LoopPartitionConfig>();
+ cfg = tvm::transform::PassConfigWithDefaults<LoopPartitionConfig>();
}
n->body = s_tir::LoopPartition(std::move(n->body),
cfg.value()->partition_const_loop,
cfg.value()->no_unroll_loop_with_extent_one,
diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc
index d49f3123d9..640bc6c57e 100644
--- a/src/script/printer/ir/ir.cc
+++ b/src/script/printer/ir/ir.cc
@@ -75,7 +75,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
(*f)->AddDispatchToken(d, "ir");
IdDoc module_doc = d->Define(mod, f(),
GetBindingName(d).value_or("Module"));
(*f)->global_infos = &mod->global_infos;
- if (mod->attrs.defined() && !mod->attrs->dict.empty()) {
+ if (!mod->attrs->dict.empty()) {
(*f)->stmts.push_back(
ExprStmtDoc(IR(d, "module_attrs") //
->Call({d->AsDoc<ExprDoc>(mod->attrs,
p->Attr("attrs"))})));
diff --git a/src/target/cuda/codegen_cuda.cc b/src/target/cuda/codegen_cuda.cc
index 863cc4eb20..27cad36735 100644
--- a/src/target/cuda/codegen_cuda.cc
+++ b/src/target/cuda/codegen_cuda.cc
@@ -211,7 +211,7 @@ void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f,
std::ostream& os) {
extractor(f->body);
// Also check PrimFunc attrs for persistent kernel (decorator-level)
bool is_persistent = extractor.is_persistent_kernel;
- if (!is_persistent && f->attrs.defined() &&
f->attrs->dict.count(tirx::attr::kPersistentKernel)) {
+ if (!is_persistent && f->attrs->dict.count(tirx::attr::kPersistentKernel)) {
is_persistent = true;
}
arith::Analyzer analyzer;
diff --git a/src/tirx/analysis/verify_tirx_well_formed.cc
b/src/tirx/analysis/verify_tirx_well_formed.cc
index 64ede04f20..f9063bd2d2 100644
--- a/src/tirx/analysis/verify_tirx_well_formed.cc
+++ b/src/tirx/analysis/verify_tirx_well_formed.cc
@@ -251,8 +251,7 @@ bool VerifyTIRxWellFormed(const IRModule& mod, bool
assert_mode, bool device_fun
for (const auto& [gvar, base_func] : mod->functions) {
if (auto prim_func = base_func.as<PrimFunc>()) {
// s_tir=True PrimFuncs use s_tir semantics — defer to VerifyWellFormed.
- if (prim_func.value()->attrs.defined() &&
- prim_func.value()->attrs->dict.count(tvm::attr::kSTir)) {
+ if (prim_func.value()->attrs->dict.count(tvm::attr::kSTir)) {
if (!VerifyWellFormed(prim_func.value(), assert_mode)) return false;
continue;
}
diff --git a/src/tirx/ir/function.cc b/src/tirx/ir/function.cc
index a92767c85a..273ed1ae3c 100644
--- a/src/tirx/ir/function.cc
+++ b/src/tirx/ir/function.cc
@@ -77,10 +77,6 @@ relax::StructInfo InferStructInfo(const PrimFunc& prim_func)
{
// Get the function type of a PrimFunc
PrimFunc::PrimFunc(ffi::Array<tirx::Var> params, Stmt body, Type ret_type,
ffi::Map<tirx::Var, Buffer> buffer_map, DictAttrs attrs,
Span span) {
- if (!attrs.defined()) {
- attrs = DictAttrs();
- }
-
if (!ret_type.defined()) {
ret_type = VoidType();
}
diff --git a/src/tirx/script/printer/buffer.cc
b/src/tirx/script/printer/buffer.cc
index 72f3f9f9df..32d50a8f8d 100644
--- a/src/tirx/script/printer/buffer.cc
+++ b/src/tirx/script/printer/buffer.cc
@@ -193,7 +193,7 @@ ffi::Map<ffi::String, ExprDoc> BufferAttrs(tirx::Buffer
buffer, const AccessPath
for (const auto& f : d->frames) {
if (const auto* tir_f = f.as<TIRFrameNode>()) {
if (auto func = tir_f->tirx.as<tirx::PrimFuncNode>()) {
- if (func->attrs.defined() &&
func->attrs->dict.count(tvm::attr::kSTir)) {
+ if (func->attrs->dict.count(tvm::attr::kSTir)) {
enclosing_s_tir = true;
}
break;
diff --git a/src/tirx/script/printer/function.cc
b/src/tirx/script/printer/function.cc
index 41b561e739..30912034da 100644
--- a/src/tirx/script/printer/function.cc
+++ b/src/tirx/script/printer/function.cc
@@ -106,7 +106,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
if (d->cfg->syntax_sugar && CountVarOccurrence(func, var) == 2 &&
func->buffer_map.count(var)) {
tirx::Buffer buffer = func->buffer_map[var];
- bool s_tir = func->attrs.defined() &&
func->attrs->dict.count(tvm::attr::kSTir);
+ bool s_tir = func->attrs->dict.count(tvm::attr::kSTir);
if (IsSimpleBuffer(buffer, s_tir) &&
buffer_data_counter.at(buffer->data.get()) == 1) {
AccessPath buffer_p = p->Attr("buffer_map")->MapItem(var);
IdDoc lhs = DefineBuffer(buffer, *f, d);
@@ -120,7 +120,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
args.push_back(AssignDoc(DefineVar(var, *f, d), std::nullopt, a));
}
// Step 2. Handle `func->attrs`
- if (func->attrs.defined() && !func->attrs->dict.empty()) {
+ if (!func->attrs->dict.empty()) {
// for global symbol, don't display it if it matches the func name
std::unordered_set<ffi::String> keys_to_remove;
if (func->attrs->dict.count(tvm::attr::kGlobalSymbol) &&
@@ -214,15 +214,15 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
ffi::Array<ffi::String, void> kwargs_keys;
ffi::Array<ExprDoc, void> kwargs_values;
// mark private if there is no global symbol
- if (!func->attrs.defined() ||
!func->attrs->dict.count(tvm::attr::kGlobalSymbol)) {
+ if (!func->attrs->dict.count(tvm::attr::kGlobalSymbol)) {
kwargs_keys.push_back("private");
kwargs_values.push_back(LiteralDoc::Boolean(true,
ffi::Optional<AccessPath>()));
}
- if (func->attrs.defined() && func->attrs->dict.count(tvm::attr::kSTir)) {
+ if (func->attrs->dict.count(tvm::attr::kSTir)) {
kwargs_keys.push_back("s_tir");
kwargs_values.push_back(LiteralDoc::Boolean(true,
ffi::Optional<AccessPath>()));
}
- if (func->attrs.defined() &&
func->attrs->dict.count(tirx::attr::kPersistentKernel)) {
+ if (func->attrs->dict.count(tirx::attr::kPersistentKernel)) {
kwargs_keys.push_back("persistent");
kwargs_values.push_back(LiteralDoc::Boolean(true,
ffi::Optional<AccessPath>()));
}
diff --git a/src/tirx/transform/ir_utils.cc b/src/tirx/transform/ir_utils.cc
index 8582968f0e..281e53d76c 100644
--- a/src/tirx/transform/ir_utils.cc
+++ b/src/tirx/transform/ir_utils.cc
@@ -158,10 +158,6 @@ class IRConvertSSA final : public StmtExprMutator {
}();
auto attrs = [&]() -> DictAttrs {
- if (!func->attrs.defined()) {
- return DictAttrs();
- }
-
ffi::Map<ffi::String, ffi::Any> dict;
bool made_change = false;
diff --git a/src/tirx/transform/remove_no_op.cc
b/src/tirx/transform/remove_no_op.cc
index aa22802154..133cfa9d9a 100644
--- a/src/tirx/transform/remove_no_op.cc
+++ b/src/tirx/transform/remove_no_op.cc
@@ -271,8 +271,9 @@ namespace transform {
Pass RemoveNoOp() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
- RemoveNoOpConfig config =
ctx->GetConfig<RemoveNoOpConfig>("tirx.RemoveNoOp")
-
.value_or(AttrsWithDefaultValues<RemoveNoOpConfig>());
+ RemoveNoOpConfig config =
+ ctx->GetConfig<RemoveNoOpConfig>("tirx.RemoveNoOp")
+
.value_or(tvm::transform::PassConfigWithDefaults<RemoveNoOpConfig>());
arith::Analyzer analyzer;
analyzer.rewrite_simplify.SetMaximumRewriteSteps(config->max_simplification_steps);
diff --git a/src/tirx/transform/split_host_device.cc
b/src/tirx/transform/split_host_device.cc
index 6a07306b38..70c44ba66c 100644
--- a/src/tirx/transform/split_host_device.cc
+++ b/src/tirx/transform/split_host_device.cc
@@ -112,7 +112,7 @@ class HostDeviceSplitter : public StmtMutator {
device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget,
device_target},
{tirx::attr::kNoAlias,
true},
{tirx::attr::kIsGlobalFunc, true}});
- if (cur_func_->attrs.defined() &&
cur_func_->attrs->dict.count(tvm::attr::kSTir)) {
+ if (cur_func_->attrs->dict.count(tvm::attr::kSTir)) {
device_func = WithAttr(std::move(device_func), tvm::attr::kSTir, true);
}
auto num_inputs = cur_func_->GetAttr<int64_t>(tvm::attr::kNumInputs);
diff --git a/src/tirx/transform/stmt_simplify.cc
b/src/tirx/transform/stmt_simplify.cc
index 2238625255..9ebbcab9e1 100644
--- a/src/tirx/transform/stmt_simplify.cc
+++ b/src/tirx/transform/stmt_simplify.cc
@@ -89,7 +89,7 @@ class StmtSimplifyConfig : public ffi::ObjectRef {
};
static StmtSimplifyConfig MakeDefaultStmtSimplifyConfig() {
- return AttrsWithDefaultValues<StmtSimplifyConfig>();
+ return tvm::transform::PassConfigWithDefaults<StmtSimplifyConfig>();
}
TVM_FFI_STATIC_INIT_BLOCK() { StmtSimplifyConfigNode::RegisterReflection(); }
diff --git a/src/tirx/transform/unroll_loop.cc
b/src/tirx/transform/unroll_loop.cc
index faf1ec2d67..ae99410cee 100644
--- a/src/tirx/transform/unroll_loop.cc
+++ b/src/tirx/transform/unroll_loop.cc
@@ -285,7 +285,7 @@ Pass UnrollLoop() {
auto* n = f.CopyOnWrite();
auto cfg = ctx->GetConfig<UnrollLoopConfig>("tirx.UnrollLoop");
if (!cfg.defined()) {
- cfg = AttrsWithDefaultValues<UnrollLoopConfig>();
+ cfg = tvm::transform::PassConfigWithDefaults<UnrollLoopConfig>();
}
n->body = UnrollLoop(std::move(f->body), cfg.value());
return f;