This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 6066f60 feat(core): Introduce Attribute-Carrying Language-Agnostic
Enums (#554)
6066f60 is described below
commit 6066f605558ec4bbc9d477e46894439edef60022
Author: Junru Shao <[email protected]>
AuthorDate: Sat Apr 18 14:41:32 2026 -0700
feat(core): Introduce Attribute-Carrying Language-Agnostic Enums (#554)
RFC: https://github.com/apache/tvm-ffi/issues/553
Add first-class cross-language enum support to TVM-FFI. An enum is a
registered Object type whose instances are named, frozen singletons —
the same model as `tvm::Op`, generalised into an `Enum` base class
usable from Python and C++ and converging on a single shared registry
per `type_key`.
## At a glance
### Python
```python
from __future__ import annotations
from typing import ClassVar
from tvm_ffi.dataclasses import Enum, auto, entry
# Pure-Python enum — fresh type_key, no C++ involvement.
class Priority(Enum, type_key="my.Priority"):
low = auto()
medium = auto()
high = auto()
# Attribute-carrying enum.
class Activation(Enum, type_key="nn.Activation"):
output_zero: bool
is_monotonic: bool
relu: ClassVar[Activation] = entry(output_zero=True, is_monotonic=True)
gelu: ClassVar[Activation] = entry(output_zero=False,
is_monotonic=False)
silu: ClassVar[Activation] = entry(output_zero=False, is_monotonic=True)
# Python class binding C++-registered entries.
class Variant(Enum, type_key="testing.TestEnumVariant"):
Alpha: ClassVar[Variant] # bound to the C++-registered "Alpha"
Beta: ClassVar[Variant] # bound to the C++-registered "Beta"
assert Activation.relu.value == 0 # auto-assigned ordinal
assert Activation.relu.name == "relu" # auto-populated
assert Activation.relu.output_zero is True # user field
assert Activation.get("relu") is Activation.relu
# Extensible per-variant attributes, writable from anywhere.
cost = Activation.def_attr("cost", default=0)
cost[Activation.relu] = 1
cost[Activation.gelu] = 4
assert cost[Activation.silu] == 0 # default — silu was never
assigned
assert Activation.silu not in cost # distinguishes default-hit vs.
set
```
### C++
```cpp
#include <tvm/ffi/enum.h>
#include <tvm/ffi/reflection/enum_def.h>
class ActivationObj : public tvm::ffi::EnumObj {
public:
bool output_zero;
bool is_monotonic;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("nn.Activation", ActivationObj,
tvm::ffi::EnumObj);
};
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<ActivationObj>(refl::init(false))
.def_ro("output_zero", &ActivationObj::output_zero)
.def_ro("is_monotonic", &ActivationObj::is_monotonic);
refl::EnumDef<ActivationObj>("relu")
.set_attr("output_zero", true)
.set_attr("is_monotonic", true);
refl::EnumDef<ActivationObj>("gelu")
.set_attr("output_zero", false)
.set_attr("is_monotonic", false);
}
```
The Python and C++ halves write to the same two `type_index`-keyed
TypeAttr columns, so a Python subclass that binds
`type_key="nn.Activation"` sees every C++-registered entry, and any
later `auto()`/`entry(...)` from Python becomes visible to C++ readers
of the same columns. Entries cross FFI as ordinary `ObjectRef`s — no
wire-format work.
## Design
- Enum instances are `EnumObj` subclasses. Each carries a dense
auto-assigned `int64_t value` (0-indexed per class, declaration-order
ordinal) and a `String name`. Both are populated at registration;
neither is user-supplied.
- Two per-class TypeAttr columns, shared across all call sites:
- `__ffi_enum_entries__` — `Dict<String, Enum>` mapping instance name →
frozen singleton.
- `__ffi_enum_attrs__` — `Dict<String, List<Any>>` mapping attribute
name → ordinal-indexed list.
- **Register-once-then-mutate.** Each column is registered exactly once
via `TVMFFITypeRegisterAttr`; every subsequent writer fetches the live
container with `TVMFFIGetTypeAttrColumn` and mutates it in place.
Distributed registration across TUs or Python modules converges on one
set of containers.
- **Python variants** are declared in one of four shapes, processed in
`Enum.__init_subclass__`:
1. `name: ClassVar[Cls] = entry(**kwargs)` — registers a Python-side
entry and forwards kwargs to `__init__`.
2. `name = entry(**kwargs)` (no annotation) — same as 1, for
attribute-carrying enums where `ClassVar` is noise.
3. `name = auto()` (or `name: ClassVar[Cls] = auto()`) — registers a
variant with no extra fields; the preferred form for simple enums.
4. Bare `name: ClassVar[Cls]` — binds to a C++-registered entry of the
same name, or registers a blank Python entry if none exists.
Within one class body, bare `ClassVar` binders resolve first (annotation
order), then sentinel assignments (class-body order); auto-ordinals
follow that combined order. Mixing all four forms on a single class is
supported.
- **Auto-detected backend.** `Enum.__init_subclass__(type_key=...)`
routes the subclass through `@c_class` if the type is already registered
in the FFI type system, otherwise through `@py_class`. There is no
separate `py_enum`/`c_enum` opt-in.
- **Integer literals are rejected** on the RHS. The auto-ordinal policy
owns `value`, so `ok = 0` and `entry(0)` would either duplicate or
conflict with the auto-ordinal. `auto()` is the intended replacement.
`entry(value=...)` / `entry(name=...)` raise `TypeError` at class-body
time.
## New public interfaces
### C++ headers
- `include/tvm/ffi/enum.h` — `EnumObj` (`int64_t value`, `String name`,
both `def_ro`-reflected) and `Enum` (nullable `ObjectRef` wrapper),
registered under type key `ffi.Enum`. Plus two column-name constants
`kEnumEntriesAttrName` (= `"__ffi_enum_entries__"`) and
`kEnumAttrsAttrName` (= `"__ffi_enum_attrs__"`).
- `include/tvm/ffi/reflection/enum_def.h` —
`refl::EnumDef<T>("name").set_attr("key", value)...`. Each call
allocates a fresh ordinal, constructs the instance, and writes it into
the per-class registry. Duplicate names for the same `T` raise
`RuntimeError`. Exposes `.instance()` / `.ordinal()` for tests /
advanced callers.
- `include/tvm/ffi/tvm_ffi.h` transitively includes both new headers.
### Python surface (`tvm_ffi.dataclasses`)
- `Enum` — base class, decorated
`@dataclass_transform(field_specifiers=(Field, field, entry, auto))` so
type checkers recognise `entry()` / `auto()` as dataclass-field
specifiers.
- `entry(**kwargs)`, `auto()` — variant-declaration sentinels.
- `EnumAttrMap` — view over the shared `__ffi_enum_attrs__` column;
`__getitem__` / `__setitem__` / `__contains__` / `get(default=...)`.
- Per-subclass surface: `Cls.get(name)`, `Cls.entries()`,
`Cls.def_attr(name, *, default=...)`, and three live class-level
properties `Cls.by_name` (`Dict[str, Enum]`), `Cls.by_value`
(`List[Enum]` indexed by ordinal), `Cls.attr_dict` (`Dict[str,
List[Any]]`). The class-level properties are backed by an internal
`_ClassProperty` descriptor so they work without a metaclass.
## Other user-visible changes
- **`TVMFFITypeRegisterAttr` rejects duplicate `(type_index, attr_name)`
writes.** Reverses a previously relaxed "silent overwrite" behaviour.
The enforced invariant is load-bearing for the register-once-then-mutate
protocol; the error message points callers at that protocol.
- **Default repr for `EnumObj` subclasses** is `<type_key>.<name>`
instead of the generic `type_key(field1=..., field2=...)` form. Rendered
by `ReprPrinter` after the `__ffi_repr__` hook check, so explicit
overrides still take precedence.
- **Built-in sentinels `MISSING` / `KWARGS`** now render as `<MISSING>`
/ `<KWARGS>` via pointer-identity dispatch, replacing the generic
`ffi.Object` fallback.
- **C++ test-support type `testing.TestEnumVariant`** (in
`src/ffi/testing/testing.cc`) now extends `EnumObj` and registers
`Alpha` / `Beta` entries with a `code` attribute via `refl::EnumDef`.
This is the canonical end-to-end demonstration of the builder and is
exercised by the Python test suite.
## Testing
- `uv run pytest tests/python/test_dataclass_enum.py -q` — 38/38
passing. Covers all four declaration forms, auto-ordinal assignment,
frozen-singleton identity, rejection of `entry(value=...)` /
`entry(name=...)`, `get` / `entries` / `by_name` / `by_value` /
`attr_dict`, `def_attr` round-trips through the unified column, direct
TypeAttr verification, the C++-backed happy path against
`testing.TestEnumVariant`, mixed C++/Python entry registration, and the
repr / sentinel behaviour.
- `uv run pytest tests/python -q` — 2246 passed, 16 skipped, 3 xfailed.
No regressions.
- `pre-commit run --all-files` — clean.
C++ GoogleTest and Rust suites were not re-run; the enum builder is
exercised end-to-end from the Python tests against
`testing.TestEnumVariant`, and no Rust bindings were touched.
---
CMakeLists.txt | 3 +-
cmake/Utils/Library.cmake | 4 +
include/tvm/ffi/enum.h | 156 +++++++
include/tvm/ffi/reflection/accessor.h | 22 +
include/tvm/ffi/reflection/enum_def.h | 174 +++++++
include/tvm/ffi/tvm_ffi.h | 2 +
python/tvm_ffi/core.pyi | 1 +
python/tvm_ffi/cython/object.pxi | 28 ++
python/tvm_ffi/dataclasses/__init__.py | 5 +
python/tvm_ffi/dataclasses/enum.py | 820 +++++++++++++++++++++++++++++++++
src/ffi/extra/dataclass.cc | 30 ++
src/ffi/object.cc | 18 +-
src/ffi/testing/testing.cc | 28 ++
tests/python/test_dataclass_enum.py | 689 +++++++++++++++++++++++++++
14 files changed, 1974 insertions(+), 6 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 7b09ee4..6f07d4a 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -150,8 +150,7 @@ if (MSVC)
target_link_libraries(tvm_ffi_objs PRIVATE DbgHelp.lib)
target_link_libraries(tvm_ffi_shared PRIVATE DbgHelp.lib)
target_link_libraries(tvm_ffi_static PRIVATE DbgHelp.lib)
- # /bigobj: printer.cc exceeds default section limit
- target_compile_options(tvm_ffi_objs PRIVATE /bigobj)
+
# produce pdb file
target_link_options(tvm_ffi_shared PRIVATE /DEBUG)
endif ()
diff --git a/cmake/Utils/Library.cmake b/cmake/Utils/Library.cmake
index d5bc2c8..0b12efd 100644
--- a/cmake/Utils/Library.cmake
+++ b/cmake/Utils/Library.cmake
@@ -70,6 +70,10 @@ function (tvm_ffi_add_msvc_flags target_name)
target_compile_definitions(${target_name} PUBLIC
-D_ENABLE_EXTENDED_ALIGNED_STORAGE)
target_compile_definitions(${target_name} PUBLIC -DNOMINMAX)
target_compile_options(${target_name} PRIVATE "/Zi")
+ # Heavy template instantiations in reflection/creator/object.h can exceed
MSVC's default
+ # per-object section limit (C1128). Apply /bigobj to every target that
uses these flags so
+ # growth in any TU doesn't break the Windows build.
+ target_compile_options(${target_name} PRIVATE "/bigobj")
endif ()
endfunction ()
diff --git a/include/tvm/ffi/enum.h b/include/tvm/ffi/enum.h
new file mode 100644
index 0000000..5a840cb
--- /dev/null
+++ b/include/tvm/ffi/enum.h
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/ffi/enum.h
+ * \brief Base class for FFI-registered enum types.
+ */
+#ifndef TVM_FFI_ENUM_H_
+#define TVM_FFI_ENUM_H_
+
+#include <tvm/ffi/any.h>
+#include <tvm/ffi/c_api.h>
+#include <tvm/ffi/container/dict.h>
+#include <tvm/ffi/error.h>
+#include <tvm/ffi/object.h>
+#include <tvm/ffi/reflection/accessor.h>
+#include <tvm/ffi/string.h>
+
+#include <cstdint>
+#include <type_traits>
+#include <utility>
+
+namespace tvm {
+namespace ffi {
+
+class Enum;
+
+/*!
+ * \brief Base class for FFI-registered enums.
+ *
+ * Each registered variant is a unique, process-wide singleton with a
+ * dense ordinal (``value``) and string ``name``. Subclasses may add
+ * *declared fields* — part of the variant's schema, set at registration
+ * time via ``reflection::EnumDef``. Separately, any consumer may
+ * attach *extensible attributes* (per-variant metadata stored outside
+ * the variant's fields) via ``EnumDef::set_attr`` or the Python
+ * ``Enum.def_attr`` surface, without modifying ``EnumClsObj``.
+ *
+ * \sa reflection::EnumDef
+ */
+class EnumObj : public Object {
+ public:
+ /*! \brief Declared field: dense ordinal assigned at registration time
(0-indexed per class). */
+ int64_t value;
+ /*! \brief Declared field: instance name (e.g., ``"Add"`` for ``Op.Add``). */
+ String name;
+
+ EnumObj() = default;
+ /*!
+ * \brief Construct an EnumObj with an explicit ordinal and name.
+ * \param value The dense ordinal (0-indexed per enum class).
+ * \param name The instance name key.
+ */
+ EnumObj(int64_t value, String name) : value(value), name(std::move(name)) {}
+
+ /*!
+ * \brief Look up the registered singleton for ``EnumClsObj`` by name.
+ *
+ * Reads from the per-class ``reflection::type_attr::kEnumEntries``
+ * registry populated by ``reflection::EnumDef<EnumClsObj>``. Instances
+ * are unique per ``(type_key, name)`` pair for the life of the process,
+ * so the returned ``Enum`` compares equal (by pointer) to every other
+ * lookup of the same name. Throws ``RuntimeError`` if no instance with
+ * the given name is registered for ``EnumClsObj``.
+ *
+ * \tparam EnumClsObj An ``Object`` subclass deriving from ``EnumObj``.
+ * \param name The instance name to look up (e.g., ``"Add"``).
+ * \return The registered ``Enum`` singleton.
+ */
+ template <typename EnumClsObj>
+ static Enum Get(const String& name);
+
+ /// \cond Doxygen_Suppress
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindUniqueInstance;
+ TVM_FFI_DECLARE_OBJECT_INFO("ffi.Enum", EnumObj, Object);
+ /// \endcond
+
+ private:
+ /*!
+ * \brief Return the process-wide ``__ffi_enum_entries__`` column pointer.
+ *
+ * The column is registered at library init via ``EnsureTypeAttrColumn``
+ * and the struct its pointer refers to is stable for the lifetime of the
+ * process, so we cache the lookup in a function-local static.
+ */
+ static const TVMFFITypeAttrColumn* GetEnumEntriesColumn() {
+ constexpr TVMFFIByteArray kAttrName =
+ reflection::AsByteArray(reflection::type_attr::kEnumEntries);
+ static const TVMFFITypeAttrColumn* column =
TVMFFIGetTypeAttrColumn(&kAttrName);
+ return column;
+ }
+};
+
+/*!
+ * \brief ObjectRef wrapper for ``EnumObj``.
+ *
+ * Holds a shared reference to a registered singleton. Two ``Enum``
+ * values compare structurally equal if and only if they point at the
+ * same underlying object (see ``kTVMFFISEqHashKindUniqueInstance``),
+ * which — given the register-once registry — is equivalent to sharing
+ * the same ``(type_key, name)`` pair.
+ *
+ * \sa EnumObj
+ * \sa reflection::EnumDef
+ */
+class Enum : public ObjectRef {
+ public:
+ /// \cond Doxygen_Suppress
+ TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Enum, ObjectRef, EnumObj);
+ /// \endcond
+};
+
+template <typename EnumClsObj>
+inline Enum EnumObj::Get(const String& name) {
+ static_assert(std::is_base_of_v<EnumObj, EnumClsObj>,
+ "EnumObj::Get<T> requires T to be a subclass of EnumObj");
+ const TVMFFITypeAttrColumn* column = GetEnumEntriesColumn();
+ int32_t type_index = EnumClsObj::RuntimeTypeIndex();
+ if (column != nullptr) {
+ int32_t offset = type_index - column->begin_index;
+ if (offset >= 0 && offset < column->size) {
+ const TVMFFIAny* stored = &column->data[offset];
+ if (stored->type_index != kTVMFFINone) {
+ Dict<String, Enum> entries =
AnyView::CopyFromTVMFFIAny(*stored).cast<Dict<String, Enum>>();
+ auto it = entries.find(name);
+ if (it != entries.end()) {
+ return (*it).second;
+ }
+ }
+ }
+ }
+ TVM_FFI_THROW(RuntimeError) << "Enum `" << EnumClsObj::_type_key << "` has
no instance named `"
+ << name << "`";
+ TVM_FFI_UNREACHABLE();
+}
+
+} // namespace ffi
+} // namespace tvm
+
+#endif // TVM_FFI_ENUM_H_
diff --git a/include/tvm/ffi/reflection/accessor.h
b/include/tvm/ffi/reflection/accessor.h
index a403417..9b7950b 100644
--- a/include/tvm/ffi/reflection/accessor.h
+++ b/include/tvm/ffi/reflection/accessor.h
@@ -489,6 +489,28 @@ inline constexpr const char* kDataToJson =
"__data_to_json__";
* ``ObjectRef``.
*/
inline constexpr const char* kDataFromJson = "__data_from_json__";
+/*!
+ * \brief Per-class enum entry registry.
+ *
+ * Maps each variant's name to its registered singleton for an
+ * ``EnumObj`` subclass. Populated by ``refl::EnumDef<T>("Name")`` on
+ * the C++ side and by ``Enum`` subclass declarations on the Python
+ * side; both languages share the same underlying storage.
+ *
+ * Value type: ``Dict<String, Enum>``.
+ */
+inline constexpr const char* kEnumEntries = "__ffi_enum_entries__";
+/*!
+ * \brief Per-class column holding extensible attributes for enum variants.
+ *
+ * The outer dict is keyed by extensible-attribute name; each value is a
+ * list indexed by the variant's ordinal (``EnumObj::value``). Written
+ * by ``refl::EnumDef<T>::set_attr(...)`` on the C++ side and by the
+ * ``EnumAttrMap`` returned from Python ``Enum.def_attr(...)``.
+ *
+ * Value type: ``Dict<String, List<Any>>``.
+ */
+inline constexpr const char* kEnumAttrs = "__ffi_enum_attrs__";
} // namespace type_attr
/*!
diff --git a/include/tvm/ffi/reflection/enum_def.h
b/include/tvm/ffi/reflection/enum_def.h
new file mode 100644
index 0000000..f0a20b0
--- /dev/null
+++ b/include/tvm/ffi/reflection/enum_def.h
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/ffi/reflection/enum_def.h
+ * \brief Builder for registering enum instances on ``EnumObj`` subclasses.
+ */
+#ifndef TVM_FFI_REFLECTION_ENUM_DEF_H_
+#define TVM_FFI_REFLECTION_ENUM_DEF_H_
+
+#include <tvm/ffi/any.h>
+#include <tvm/ffi/c_api.h>
+#include <tvm/ffi/container/dict.h>
+#include <tvm/ffi/container/list.h>
+#include <tvm/ffi/enum.h>
+#include <tvm/ffi/error.h>
+#include <tvm/ffi/memory.h>
+#include <tvm/ffi/object.h>
+#include <tvm/ffi/reflection/accessor.h>
+#include <tvm/ffi/reflection/registry.h>
+#include <tvm/ffi/string.h>
+
+#include <cstdint>
+#include <string>
+#include <type_traits>
+#include <utility>
+
+namespace tvm {
+namespace ffi {
+namespace reflection {
+
+/*!
+ * \brief Builder that registers a single enum instance on ``EnumClsObj``.
+ *
+ * Each ``EnumDef<EnumClsObj>("Name")`` call allocates a fresh dense ordinal
+ * (``= len(existing entries)``), constructs a variant with ``value`` and
+ * ``name`` populated, and writes it into the per-class registry stored in
+ * the ``type_attr::kEnumEntries`` TypeAttr column. Subsequent
+ * ``.set_attr(...)`` calls write *extensible attributes* — per-variant
+ * metadata attached outside the variant's declared fields — into the
+ * per-class ``type_attr::kEnumAttrs`` column. Python bindings of the
+ * same ``type_key`` see every C++-registered variant and every extensible
+ * attribute through the matching ``Enum.def_attr`` surface.
+ *
+ * \tparam EnumClsObj An ``Object`` subclass deriving from ``EnumObj``.
+ *
+ * \code{.cpp}
+ * namespace refl = ::tvm::ffi::reflection;
+ * refl::EnumDef<OpObj>("Add").set_attr("has_side_effects", false);
+ * refl::EnumDef<OpObj>("Mul").set_attr("has_side_effects", false);
+ * \endcode
+ */
+template <typename EnumClsObj, typename =
std::enable_if_t<std::is_base_of_v<EnumObj, EnumClsObj>>>
+class EnumDef : public ReflectionDefBase {
+ public:
+ /*!
+ * \brief Register a new instance named ``instance_name`` on ``EnumClsObj``.
+ * \param instance_name The instance's string name (e.g., ``"Add"``).
+ */
+ explicit EnumDef(const char* instance_name)
+ : type_index_(EnumClsObj::RuntimeTypeIndex()), name_(instance_name) {
+ Dict<String, Enum> entries = EnsureEntriesDict();
+ String name_str(name_);
+ if (entries.count(name_str) != 0) {
+ TVM_FFI_THROW(RuntimeError) << "Duplicate enum entry `" << name_ << "`
for type `"
+ << EnumClsObj::_type_key << "`";
+ }
+ ordinal_ = static_cast<int64_t>(entries.size());
+ ObjectPtr<EnumClsObj> obj = make_object<EnumClsObj>();
+ obj->value = ordinal_;
+ obj->name = name_str;
+ instance_ = Enum(ObjectPtr<EnumObj>(std::move(obj)));
+ entries.Set(name_str, instance_);
+ // Ensure the attrs dict exists so later ``set_attr`` calls can mutate it.
+ EnsureAttrsDict();
+ }
+
+ /*!
+ * \brief Write an *extensible attribute* for this enum variant.
+ *
+ * Writes land in the per-class ``type_attr::kEnumAttrs`` column and
+ * are visible to every binder of the same ``type_key`` — including
+ * Python readers via ``Enum.def_attr`` / ``Enum.attr_dict``. Distinct
+ * from declared fields on ``EnumClsObj``: declared fields are part of
+ * the variant's schema and set during construction, whereas
+ * extensible attributes live outside the variant object and may be
+ * attached by any consumer at any time.
+ *
+ * \tparam T The value type.
+ * \param attr_name The extensible-attribute name (e.g.,
+ * ``"has_side_effects"``).
+ * \param value The value to store for this variant's ordinal.
+ * \return Reference to this builder for chaining.
+ */
+ template <typename T>
+ EnumDef& set_attr(const char* attr_name, T value) {
+ Dict<String, List<Any>> attrs = EnsureAttrsDict();
+ String attr_key(attr_name);
+ List<Any> column;
+ auto it = attrs.find(attr_key);
+ if (it == attrs.end()) {
+ column = List<Any>();
+ attrs.Set(attr_key, column);
+ } else {
+ column = (*it).second;
+ }
+ while (static_cast<int64_t>(column.size()) <= ordinal_) {
+ column.push_back(Any(nullptr));
+ }
+ column.Set(ordinal_, Any(std::move(value)));
+ return *this;
+ }
+
+ /*! \brief Return the registered instance (for tests / advanced callers). */
+ Enum instance() const { return instance_; }
+
+ /*! \brief Return the ordinal assigned to this instance. */
+ int64_t ordinal() const { return ordinal_; }
+
+ private:
+ Dict<String, Enum> EnsureEntriesDict() {
+ return EnsureDict<Dict<String, Enum>>(type_attr::kEnumEntries);
+ }
+
+ Dict<String, List<Any>> EnsureAttrsDict() {
+ return EnsureDict<Dict<String, List<Any>>>(type_attr::kEnumAttrs);
+ }
+
+ template <typename DictT>
+ DictT EnsureDict(const char* attr_name) {
+ TVMFFIByteArray name_array = {attr_name,
std::char_traits<char>::length(attr_name)};
+ const TVMFFITypeAttrColumn* column = TVMFFIGetTypeAttrColumn(&name_array);
+ if (column != nullptr) {
+ int32_t offset = type_index_ - column->begin_index;
+ if (offset >= 0 && offset < column->size) {
+ const TVMFFIAny* stored = &column->data[offset];
+ if (stored->type_index != kTVMFFINone) {
+ return AnyView::CopyFromTVMFFIAny(*stored).cast<DictT>();
+ }
+ }
+ }
+ DictT fresh;
+ TVMFFIAny value_any = AnyView(fresh).CopyToTVMFFIAny();
+ TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index_, &name_array,
&value_any));
+ return fresh;
+ }
+
+ int32_t type_index_;
+ const char* name_;
+ int64_t ordinal_;
+ Enum instance_;
+};
+
+} // namespace reflection
+} // namespace ffi
+} // namespace tvm
+
+#endif // TVM_FFI_REFLECTION_ENUM_DEF_H_
diff --git a/include/tvm/ffi/tvm_ffi.h b/include/tvm/ffi/tvm_ffi.h
index b55350d..57af97a 100644
--- a/include/tvm/ffi/tvm_ffi.h
+++ b/include/tvm/ffi/tvm_ffi.h
@@ -41,6 +41,7 @@
#include <tvm/ffi/container/variant.h>
#include <tvm/ffi/dtype.h>
#include <tvm/ffi/endian.h>
+#include <tvm/ffi/enum.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/expected.h>
#include <tvm/ffi/function.h>
@@ -51,6 +52,7 @@
#include <tvm/ffi/reflection/access_path.h>
#include <tvm/ffi/reflection/accessor.h>
#include <tvm/ffi/reflection/creator.h>
+#include <tvm/ffi/reflection/enum_def.h>
#include <tvm/ffi/reflection/overload.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/rvalue_ref.h>
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index e9b117e..90071c7 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -76,6 +76,7 @@ def _object_type_key_to_index(type_key: str) -> int | None:
...
def _set_type_cls(type_info: TypeInfo, type_cls: type) -> None: ...
def _lookup_or_register_type_info_from_type_key(type_key: str) -> TypeInfo: ...
def _lookup_type_attr(type_index: int, attr_key: str) -> Any: ...
+def _register_type_attr(type_index: int, attr_key: str, value: Any) -> None:
...
def _type_cls_to_type_info(type_cls: type) -> TypeInfo | None: ...
class Error(Object):
diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi
index 6770a9d..4b8c52d 100644
--- a/python/tvm_ffi/cython/object.pxi
+++ b/python/tvm_ffi/cython/object.pxi
@@ -739,6 +739,34 @@ def _lookup_type_attr(type_index: int32_t, attr_key: str)
-> Any:
return make_ret(data)
+def _register_type_attr(type_index: int32_t, attr_key: str, value: object) ->
None:
+ """Register a value for the ``(type_index, attr_key)`` slot.
+
+ Wraps :c:func:`TVMFFITypeRegisterAttr`, which raises :class:`RuntimeError`
+ if a value is already registered for the slot. To update the stored
+ value, register a mutable container (e.g. ``Dict``/``List``) once and
+ mutate it in place on subsequent calls.
+
+ ``TVMFFIPyPyObjectToFFIAny`` produces a non-owning :c:type:`TVMFFIAny`
+ view of *value*; ``TVMFFITypeRegisterAttr`` incref's the underlying
+ object when it stores the slot, so no explicit refcount management is
+ needed here.
+ """
+ cdef ByteArrayArg attr_key_bytes = ByteArrayArg(c_str(attr_key))
+ cdef TVMFFIAny temp
+ cdef int c_api_ret_code
+ temp.type_index = kTVMFFINone
+ temp.v_int64 = 0
+ TVMFFIPyPyObjectToFFIAny(
+ TVMFFIPyArgSetterFactory_,
+ <PyObject*>value,
+ &temp,
+ &c_api_ret_code,
+ )
+ CHECK_CALL(c_api_ret_code)
+ CHECK_CALL(TVMFFITypeRegisterAttr(type_index, &attr_key_bytes.cdata,
&temp))
+
+
def _type_cls_to_type_info(type_cls: type) -> TypeInfo | None:
return TYPE_CLS_TO_INFO.get(type_cls, None)
diff --git a/python/tvm_ffi/dataclasses/__init__.py
b/python/tvm_ffi/dataclasses/__init__.py
index 850e9d4..43e5cc1 100644
--- a/python/tvm_ffi/dataclasses/__init__.py
+++ b/python/tvm_ffi/dataclasses/__init__.py
@@ -20,17 +20,22 @@ from tvm_ffi.core import MISSING, Object
from .c_class import c_class
from .common import asdict, astuple, fields, is_dataclass, replace
+from .enum import Enum, EnumAttrMap, auto, entry
from .field import KW_ONLY, Field, field
from .py_class import py_class
__all__ = [
"KW_ONLY",
"MISSING",
+ "Enum",
+ "EnumAttrMap",
"Field",
"Object",
"asdict",
"astuple",
+ "auto",
"c_class",
+ "entry",
"field",
"fields",
"is_dataclass",
diff --git a/python/tvm_ffi/dataclasses/enum.py
b/python/tvm_ffi/dataclasses/enum.py
new file mode 100644
index 0000000..cc130ea
--- /dev/null
+++ b/python/tvm_ffi/dataclasses/enum.py
@@ -0,0 +1,820 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Cross-language enum types: named, frozen, ordinal-indexed singletons.
+
+An ``Enum`` subclass has one of two usage modes, distinguished by whether
+its ``type_key`` is already registered in the FFI type system:
+
+* **Closed Python enum** — fresh ``type_key``, variants declared once in
+ the class body. Behavior matches ``enum.Enum``.
+* **Cross-language registry** — ``type_key`` also registered in C++ (or
+ another Python module). Python and C++ both contribute variants to
+ the same per-class registry, and consumers attach *extensible
+ attributes* to variants from any module at any time.
+
+See :class:`Enum` for declaration forms and :meth:`Enum.def_attr` for
+extensible attributes.
+
+Storage layout (mirrors ``include/tvm/ffi/enum.h``):
+
+* ``__ffi_enum_entries__`` — ``Dict[str, Enum]``, name → variant.
+* ``__ffi_enum_attrs__`` — ``Dict[str, List[Any]]``, extensible-attr
+ name → column indexed by each variant's ordinal.
+"""
+
+from __future__ import annotations
+
+import sys
+import typing
+from collections.abc import Callable, Iterator
+from typing import Any, ClassVar
+
+from typing_extensions import dataclass_transform
+
+from .. import core
+from ..container import Dict, List
+from ..core import Object
+from .c_class import c_class
+from .field import Field, field
+from .py_class import py_class
+
+__all__ = [
+ "ENUM_ATTRS_ATTR",
+ "ENUM_ENTRIES_ATTR",
+ "Enum",
+ "EnumAttrMap",
+ "auto",
+ "entry",
+]
+
+#: TypeAttr column storing ``Dict[str, Enum]`` (instance name → singleton).
+ENUM_ENTRIES_ATTR = "__ffi_enum_entries__"
+
+#: TypeAttr column storing ``Dict[str, List[Any]]`` of per-variant attrs.
+ENUM_ATTRS_ATTR = "__ffi_enum_attrs__"
+
+
+# ---------------------------------------------------------------------------
+# entry() sentinel
+# ---------------------------------------------------------------------------
+
+
+class _EnumEntry:
+ """Sentinel produced by :func:`entry`; consumed by
``Enum.__init_subclass__``.
+
+ Holds the positional and keyword arguments forwarded to the subclass's
+ ``__init__`` when the variant is materialized. ``value`` and ``name``
+ are auto-assigned (dense ordinal and class-body name) and must not
+ appear in the captured arguments.
+ """
+
+ __slots__ = ("args", "kwargs")
+
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ self.args: tuple[Any, ...] = args
+ self.kwargs: dict[str, Any] = kwargs
+
+ def __repr__(self) -> str:
+ parts = [repr(a) for a in self.args]
+ parts.extend(f"{k}={v!r}" for k, v in self.kwargs.items())
+ return f"entry({', '.join(parts)})"
+
+
+def entry(*args: Any, **kwargs: Any) -> Any:
+ """Declare a new enum variant with values for its declared fields.
+
+ ``entry(...)`` is a class-body sentinel; it never produces a real
+ instance. At class creation, :meth:`Enum.__init_subclass__` scans
+ for these sentinels and, for each one, constructs a singleton variant
+ by forwarding the captured positional and keyword arguments to the
+ subclass's ``__init__``, together with an auto-assigned
+ :attr:`~Enum.value` (dense ordinal) and :attr:`~Enum.name`
+ (class-body name).
+
+ Prefer :func:`auto` when a variant has no declared fields beyond the
+ auto-assigned ordinal and name — it expresses intent without the
+ empty-arg-list noise.
+
+ When the enum's ``type_key`` is C++-backed (registered via
+ ``refl::ObjectDef``), only keyword arguments are supported — field
+ values are assigned via reflected setters keyed by name. Passing
+ positional arguments in that case raises :class:`TypeError`.
+ ``entry(value=...)`` and ``entry(name=...)`` always raise
+ :class:`TypeError` because those fields are auto-assigned.
+
+ Examples
+ --------
+ Variant with declared fields:
+
+ .. code-block:: python
+
+ from typing import ClassVar
+
+
+ class Activation(Enum, type_key="my.Activation"):
+ output_zero: bool
+ is_monotonic: bool
+
+ relu: ClassVar[Activation] = entry(output_zero=True,
is_monotonic=True)
+ gelu: ClassVar[Activation] = entry(output_zero=False,
is_monotonic=False)
+
+ Returns
+ -------
+ object
+ An opaque sentinel. The declared return type is ``Any`` so that
+ ``ClassVar[Cls] = entry(...)`` type-checks even though the sentinel
+ is not a real ``Cls``.
+
+ """
+ return _EnumEntry(*args, **kwargs)
+
+
+def auto() -> Any:
+ """Declare a new enum variant with no declared fields.
+
+ Semantically equivalent to :func:`entry` called with no arguments but
+ reads more clearly for the common case where a variant differs from
+ its siblings only by name and ordinal. The resulting singleton has
+ only the auto-assigned :attr:`~Enum.value` and :attr:`~Enum.name`.
+
+ ``auto()`` registers a *new* Python-side variant; it is not the right
+ tool for binding to a pre-existing C++-registered entry (use a bare
+ ``ClassVar[Cls]`` annotation for that — see :class:`Enum`).
+
+ Examples
+ --------
+ .. code-block:: python
+
+ class Status(Enum, type_key="my.Status"):
+ ok = auto()
+ err = auto()
+ retry = auto()
+
+
+ assert Status.ok.value == 0
+ assert Status.err.name == "err"
+
+ Returns
+ -------
+ object
+ An opaque sentinel, the same kind returned by :func:`entry`. The
+ declared return type is ``Any`` so that both ``name = auto()`` and
+ ``name: ClassVar[Cls] = auto()`` type-check.
+
+ """
+ return _EnumEntry()
+
+
+# ---------------------------------------------------------------------------
+# Class-level helpers
+# ---------------------------------------------------------------------------
+
+
+class _ClassProperty:
+ """Read-only descriptor whose getter receives the owning class.
+
+ Used for ``by_name``/``by_value``/``attr_dict`` so they work as class-level
+ attribute access (e.g., ``Op.attr_dict["has_side_effects"]``) without
+ needing a metaclass.
+ """
+
+ __slots__ = ("_fget",)
+
+ def __init__(self, fget: Callable[[type], Any]) -> None:
+ self._fget = fget
+
+ def __get__(self, instance: Any, owner: type | None = None) -> Any:
+ cls = owner if owner is not None else type(instance)
+ return self._fget(cls)
+
+
+# ---------------------------------------------------------------------------
+# Enum base + EnumAttrMap
+# ---------------------------------------------------------------------------
+
+
+@dataclass_transform(
+ eq_default=False,
+ order_default=False,
+ frozen_default=True,
+ field_specifiers=(Field, field, entry, auto),
+)
+@c_class("ffi.Enum", init=False)
+class Enum(Object):
+ """A named-singleton registry with cross-language identity.
+
+ Subclasses declare variants: frozen, named, ordinal-indexed
+ singletons — the familiar enum pattern. Unlike ``enum.Enum``, an
+ ``Enum`` subclass bound to an FFI-registered ``type_key`` has an
+ **open variant set**: C++ translation units and other Python modules
+ binding the same ``type_key`` can contribute variants to the shared
+ registry. Per-variant metadata can also be attached post-hoc via
+ :meth:`def_attr` as an *extensible attribute*, outside the class
+ definition.
+
+ For **closed, Python-only enums**, use a fresh ``type_key`` with
+ :func:`auto` / :func:`entry` — behavior matches ``enum.Enum``.
+
+ Attributes
+ ----------
+ value : int
+ Dense ordinal assigned at registration (0-indexed per class).
+ name : str
+ The variant's string name key (e.g., ``"Add"`` for ``Op.Add``).
+
+ Closed Python enum
+ ------------------
+ Pick a fresh ``type_key`` and list variants with :func:`auto` or
+ :func:`entry`. The variant set is fixed at class-definition time.
+
+ .. code-block:: python
+
+ class Priority(Enum, type_key="my.Priority"):
+ low = auto()
+ medium = auto()
+ high = auto()
+
+
+ # Variants with declared fields — values supplied via entry(...).
+ class Activation(Enum, type_key="my.Activation"):
+ output_zero: bool
+ is_monotonic: bool
+
+ relu: ClassVar[Activation] = entry(output_zero=True,
is_monotonic=True)
+ gelu: ClassVar[Activation] = entry(output_zero=False,
is_monotonic=False)
+
+ Cross-language registry
+ -----------------------
+ When ``type_key`` is already registered (typically by C++), the
+ Python class *binds* to the existing type rather than creating a
+ new one. Bare ``ClassVar[Cls]`` annotations bind to variants
+ already registered on the C++ side; :func:`entry` / :func:`auto`
+ still register fresh Python variants whose ordinals extend past the
+ C++ ones. All variants — regardless of origin — land in the same
+ per-class registry and are visible to every binder of the same
+ ``type_key``.
+
+ .. code-block:: python
+
+ # Registered in C++ via refl::EnumDef<VariantObj>("Alpha")... .
+ class Variant(Enum, type_key="testing.TestEnumVariant"):
+ Alpha: ClassVar[Variant] # binds to C++-registered "Alpha"
+ Beta: ClassVar[Variant] # binds to C++-registered "Beta"
+
+ Declaration forms
+ -----------------
+ Four shapes are supported in the class body:
+
+ 1. ``name = auto()`` — new variant with no declared fields.
+ 2. ``name: ClassVar[Cls] = entry(**kwargs)`` — new variant; ``kwargs``
+ populate declared fields.
+ 3. ``name = entry(**kwargs)`` — same as (2), without the ``ClassVar``
+ annotation.
+ 4. ``name: ClassVar[Cls]`` — in cross-language mode, binds to an
+ existing C++-registered variant (error if unknown); otherwise
+ registers a new Python variant with only the auto-assigned
+ :attr:`value` and :attr:`name` (equivalent to ``name = auto()``).
+
+ Integer literals (``ok = 0``) are rejected: :attr:`value` is
+ auto-assigned, so a user-supplied ordinal would either silently
+ duplicate or conflict. ``entry(value=...)`` and ``entry(name=...)``
+ raise :class:`TypeError` at class-body time.
+
+ Differences from ``enum.Enum``
+ ------------------------------
+ * **Same**: :attr:`name`, :attr:`value`, iteration, identity
+ comparison; closed-set behavior when ``type_key`` is fresh.
+ * **Extended**: ``entry(**kwargs)`` replaces the tuple-RHS idiom;
+ ``dataclass_transform`` gives native type-checker support; open
+ registry when ``type_key`` is shared across languages.
+ * **Different**: :attr:`value` is always the ordinal (no
+ user-supplied integer values); :meth:`def_attr` adds extensible
+ attributes outside the class schema.
+ * **Not provided**: ``Flag`` / ``IntFlag``, member aliasing,
+ ``_missing_`` hook.
+
+ Subclasses inherit :meth:`get`, :meth:`entries`, :meth:`def_attr`,
+ and the ``by_name`` / ``by_value`` / ``attr_dict`` class-level
+ views.
+
+ """
+
+ __slots__ = ()
+
+ value: int
+ name: str
+
+ def __init_subclass__(
+ cls,
+ *,
+ type_key: str | None = None,
+ frozen: bool = True,
+ init: bool = True,
+ repr: bool = True,
+ **kwargs: Any,
+ ) -> None:
+ super().__init_subclass__(**kwargs)
+ if type_key is None:
+ return
+
+ binders, python_entries = _collect_entry_declarations(cls)
+
+ cxx_backed = core._object_type_key_to_index(type_key) is not None
+ if cxx_backed:
+ c_class(type_key, init=init, repr=repr)(cls)
+ else:
+ py_class(type_key, frozen=frozen)(cls)
+
+ _resolve_entries(cls, binders, python_entries, type_key=type_key,
cxx_backed=cxx_backed)
+
+ @classmethod
+ def get(cls, name: str) -> Enum:
+ """Return the variant named *name*, or raise :class:`KeyError`."""
+ entries = _entries_dict(cls)
+ if entries is not None and name in entries:
+ return entries[name]
+ raise KeyError(f"{cls.__name__} has no variant named {name!r}")
+
+ @classmethod
+ def entries(cls) -> Iterator[Enum]:
+ """Iterate over all variants, in ordinal (value) order."""
+ return iter(cls.by_value)
+
+ @_ClassProperty
+ def by_name(cls: type) -> Any:
+ """Live ``Dict[str, Enum]`` mapping each variant's name to the variant
singleton."""
+ return _entries_dict(cls) or Dict({})
+
+ @_ClassProperty
+ def by_value(cls: type) -> list[Any]:
+ """Return the variants as a list indexed by ordinal
(``variant.value``)."""
+ entries = _entries_dict(cls)
+ if entries is None:
+ return []
+ ordered: list[Any] = [None] * len(entries)
+ for inst in entries.values():
+ ordered[int(inst.value)] = inst
+ return ordered
+
+ @_ClassProperty
+ def attr_dict(cls: type) -> Any:
+ """Live ``Dict[str, List[Any]]`` backing every extensible attribute.
+
+ The outer dict is keyed by extensible-attribute name; each value
+ is a list indexed by variant ordinal. Prefer :meth:`def_attr`
+ for normal per-variant reads and writes; this property is for
+ bulk inspection and for reading values written by C++
+ ``EnumDef::set_attr``.
+ """
+ return _attrs_dict(cls) or Dict({})
+
+ @classmethod
+ def def_attr(
+ cls,
+ name: str,
+ *,
+ default: Any = core.MISSING,
+ ) -> EnumAttrMap:
+ """Declare an *extensible attribute* column on this enum.
+
+ Extensible attributes let any consumer associate per-variant data
+ outside the enum's class-body schema — a lowering function
+ attached to an operator by a code generator, a cost model
+ registered only on some targets, a documentation string added
+ after the fact. Writes are last-write-wins for the same
+ ``(variant, name)`` pair and visible to every consumer that
+ calls :meth:`def_attr` with the same name on the same enum,
+ including C++ code writing via ``EnumDef::set_attr``.
+
+ Extensible attributes differ from **declared fields**:
+
+ ==================== ========================
==========================
+ Concept Lives on Added by
+ ==================== ========================
==========================
+ Declared field The variant object Enum author, in class
body
+ Extensible attribute ``__ffi_enum_attrs__`` Any consumer, any time
+ ==================== ========================
==========================
+
+ Rule of thumb: if the data is part of *what a variant is*,
+ declare a field in the class body; if it's part of *what a
+ consumer wants to know*, attach it with :meth:`def_attr`.
+
+ Parameters
+ ----------
+ name
+ The extensible-attribute name (e.g., ``"has_side_effects"``).
+ Writes go to ``attr_dict[name]`` as a list indexed by each
+ variant's ordinal.
+ default
+ Value returned by ``attr[variant]`` when nothing was
+ registered for that variant. Left as ``MISSING`` to raise
+ :class:`KeyError` on unset variants instead. The default
+ is a property of *this* :class:`EnumAttrMap` view, not of
+ the underlying column: calling :meth:`def_attr` twice with
+ the same ``name`` but different defaults creates two views
+ that share every explicit write but may disagree on unset
+ variants — e.g., ``Op.def_attr("cost", default=0)`` and
+ ``Op.def_attr("cost", default=-1)`` return ``0`` and ``-1``
+ respectively for a variant that was never written to.
+
+ Returns
+ -------
+ EnumAttrMap
+ Mutable view over the column. Use ``variant in attr`` to
+ distinguish an explicit write from a default-hit. ``None``
+ is reserved as the "unset" sentinel (matching C++
+ ``EnumDef::set_attr`` padding), so ``attr[variant] = None``
+ raises :class:`TypeError` — store a typed wrapper (e.g. a
+ ``0``/``False`` flag) when you need a falsy-but-present
+ value.
+
+ Notes
+ -----
+ :meth:`def_attr` is not a way to add fields to the enum's
+ schema, subclass frozen variants, or bypass the frozen-instance
+ invariant via ``setattr`` — for that, declare a field in the
+ class body instead.
+
+ """
+ return EnumAttrMap(cls, name, default=default)
+
+
+class EnumAttrMap:
+ """Mutable per-variant view over an extensible-attribute column.
+
+ Returned by :meth:`Enum.def_attr`. Writes go to a ``List[Any]``
+ column keyed by extensible-attribute name inside the per-class
+ ``__ffi_enum_attrs__`` dict. The list is indexed by each variant's
+ ordinal (``variant.value``) and padded with ``None`` as new variants
+ are registered. The column is shared across every consumer —
+ including C++ code writing via ``EnumDef::set_attr`` — and the data
+ is not a field on the variant object. See :meth:`Enum.def_attr` for
+ full semantics.
+
+ ``None`` is reserved as the column's "unset" sentinel (matching the
+ C++ ``Any(nullptr)`` padding used by ``EnumDef::set_attr``), so
+ :meth:`__setitem__` rejects ``None`` with :class:`TypeError` — an
+ explicit ``attr[variant] = None`` would otherwise be
+ indistinguishable from never-written and surprise ``variant in attr``
+ / :meth:`__getitem__` readers. To "clear" a previously written
+ value, register a mutable container once and mutate it in place.
+ """
+
+ __slots__ = ("_default", "_enum_cls", "_name")
+
+ def __init__(self, enum_cls: type, name: str, *, default: Any =
core.MISSING) -> None:
+ self._enum_cls = enum_cls
+ self._name = name
+ self._default = default
+
+ def _ordinal_of(self, variant: object) -> int:
+ if not isinstance(variant, self._enum_cls):
+ raise TypeError(
+ f"{self._enum_cls.__name__}.def_attr({self._name!r}) expects a
"
+ f"{self._enum_cls.__name__} variant, got
{type(variant).__name__}"
+ )
+ return int(variant.value) # type: ignore[attr-defined]
+
+ def _column(self, *, create: bool) -> Any | None:
+ """Return the ``List[Any]`` column for this attribute; create if
missing.
+
+ Returns ``None`` when ``create`` is false and the column doesn't exist.
+ """
+ attrs = _attrs_dict(self._enum_cls) if not create else
_ensure_attrs_dict(self._enum_cls)
+ if attrs is None:
+ return None
+ if self._name in attrs:
+ return attrs[self._name]
+ if not create:
+ return None
+ col = List([])
+ attrs[self._name] = col
+ return col
+
+ def __setitem__(self, variant: object, value: Any) -> None:
+ if value is None:
+ raise TypeError(
+ f"{self._enum_cls.__name__}.def_attr({self._name!r}): "
+ f"None is reserved as the 'unset' sentinel for extensible "
+ f"attributes and cannot be written explicitly."
+ )
+ ordinal = self._ordinal_of(variant)
+ col = self._column(create=True)
+ assert col is not None # create=True always materialises the column.
+ while len(col) <= ordinal:
+ col.append(None)
+ col[ordinal] = value
+
+ def __getitem__(self, variant: object) -> Any:
+ ordinal = self._ordinal_of(variant)
+ col = self._column(create=False)
+ if col is not None and ordinal < len(col):
+ v = col[ordinal]
+ if v is not None:
+ return v
+ if self._default is core.MISSING:
+ raise KeyError(
+ f"{self._enum_cls.__name__}.{variant.name} has no " # type:
ignore[attr-defined]
+ f"extensible attribute {self._name!r} set"
+ )
+ return self._default
+
+ def __contains__(self, variant: object) -> bool:
+ if not isinstance(variant, self._enum_cls):
+ return False
+ try:
+ ordinal = self._ordinal_of(variant)
+ except TypeError:
+ return False
+ col = self._column(create=False)
+ return col is not None and ordinal < len(col) and col[ordinal] is not
None
+
+ def get(self, variant: object, default: Any = None) -> Any:
+ """Return the value for *variant*, or *default* if unset or foreign."""
+ if not isinstance(variant, self._enum_cls):
+ return default
+ try:
+ return self[variant]
+ except KeyError:
+ return default
+
+ @property
+ def name(self) -> str:
+ """The extensible-attribute name passed to :meth:`Enum.def_attr`."""
+ return self._name
+
+
+# ---------------------------------------------------------------------------
+# TypeAttr accessors
+# ---------------------------------------------------------------------------
+
+
+def _entries_dict(cls: type) -> Any:
+ type_info = getattr(cls, "__tvm_ffi_type_info__", None)
+ if type_info is None:
+ return None
+ return core._lookup_type_attr(type_info.type_index, ENUM_ENTRIES_ATTR)
+
+
+def _attrs_dict(cls: type) -> Any:
+ type_info = getattr(cls, "__tvm_ffi_type_info__", None)
+ if type_info is None:
+ return None
+ return core._lookup_type_attr(type_info.type_index, ENUM_ATTRS_ATTR)
+
+
+def _ensure_entries_dict(cls: type) -> Any:
+ """Return the live ``__ffi_enum_entries__`` dict, registering it if
absent."""
+ type_info = cls.__tvm_ffi_type_info__ # ty: ignore[unresolved-attribute]
+ entries = core._lookup_type_attr(type_info.type_index, ENUM_ENTRIES_ATTR)
+ if entries is not None:
+ return entries
+ entries = Dict({})
+ core._register_type_attr(type_info.type_index, ENUM_ENTRIES_ATTR, entries)
+ # Re-read so mutations go through the ref owned by the registry.
+ return core._lookup_type_attr(type_info.type_index, ENUM_ENTRIES_ATTR)
+
+
+def _ensure_attrs_dict(cls: type) -> Any:
+ """Return the live ``__ffi_enum_attrs__`` dict, registering it if
absent."""
+ type_info = cls.__tvm_ffi_type_info__ # ty: ignore[unresolved-attribute]
+ attrs = core._lookup_type_attr(type_info.type_index, ENUM_ATTRS_ATTR)
+ if attrs is not None:
+ return attrs
+ attrs = Dict({})
+ core._register_type_attr(type_info.type_index, ENUM_ATTRS_ATTR, attrs)
+ return core._lookup_type_attr(type_info.type_index, ENUM_ATTRS_ATTR)
+
+
+# ---------------------------------------------------------------------------
+# Class-body scanning + entry materialisation
+# ---------------------------------------------------------------------------
+
+
+def _collect_entry_declarations(
+ cls: type,
+) -> tuple[list[str], dict[str, _EnumEntry]]:
+ """Scan *cls.__dict__* for variant declarations.
+
+ Returns ``(binders, python_entries)`` in declaration order:
+
+ * *binders* — names annotated as ``ClassVar[Cls]`` with no assigned value.
+ Each either binds to an existing C++-registered entry with the same
+ name or registers a new blank Python entry.
+ * *python_entries* — names assigned an ``entry(...)`` sentinel (with or
+ without a ``ClassVar`` annotation). Each registers a new Python entry
+ using the captured args/kwargs.
+
+ Matched assignments are removed from ``cls.__dict__`` so that
+ ``@c_class`` / ``@py_class`` don't misinterpret them as field defaults or
+ class constants.
+ """
+ annotations = _own_annotations(cls)
+ dict_keys = set(cls.__dict__.keys())
+
+ binders: list[str] = []
+ for name, ann in annotations.items():
+ if name.startswith("_"):
+ continue
+ if _is_class_var(ann) and name not in dict_keys:
+ binders.append(name)
+
+ python_entries: dict[str, _EnumEntry] = {}
+ for name, value in list(cls.__dict__.items()):
+ if name.startswith("_"):
+ continue
+ if isinstance(value, _EnumEntry):
+ python_entries[name] = value
+ try:
+ delattr(cls, name)
+ except AttributeError:
+ pass
+
+ return binders, python_entries
+
+
+def _resolve_entries(
+ cls: type,
+ binders: list[str],
+ python_entries: dict[str, _EnumEntry],
+ *,
+ type_key: str,
+ cxx_backed: bool,
+) -> None:
+ """Materialise *binders* and *python_entries* into class-attribute
singletons.
+
+ Processing order matches declaration order: ``binders`` first (because
+ their annotations appear before any class-body assignments), then
+ ``python_entries`` in their class-body order. Each newly registered
+ entry gets a dense ordinal equal to the current entries-dict size, so
+ ordinals stay compact and stable across registrations.
+
+ A cxx-backed enum (``type_key`` was already registered in the FFI type
+ system before this Python subclass was created) supports mixing C++ and
+ Python entries: bare ``ClassVar[Cls]`` binders must name an existing
+ C++-registered entry, but ``entry(...)``/``auto()`` sentinels may add
+ fresh Python-side entries whose ordinals extend past the C++ entries.
+ """
+ entries = _ensure_entries_dict(cls)
+
+ for name in binders:
+ if name in entries:
+ # Already materialised — either C++-registered or previously bound.
+ setattr(cls, name, entries[name])
+ continue
+ if cxx_backed:
+ raise _cxx_backed_unknown_binder_error(cls, name, type_key,
entries)
+ ordinal = len(entries)
+ instance = _instantiate(cls, args=(), kwargs={}, ordinal=ordinal,
name=name)
+ entries[name] = instance
+ setattr(cls, name, instance)
+
+ for name, e in python_entries.items():
+ if name in entries:
+ raise RuntimeError(
+ f"Duplicate enum entry {name!r} for {cls.__name__}: already "
+ f"registered as ordinal {int(entries[name].value)}."
+ )
+ if "value" in e.kwargs or "name" in e.kwargs:
+ raise TypeError(
+ f"{cls.__name__}.{name}: `value` and `name` are auto-assigned "
+ f"and must not appear in entry(...) arguments."
+ )
+ ordinal = len(entries)
+ if cxx_backed:
+ instance = _instantiate_cxx_backed(
+ cls, args=e.args, kwargs=e.kwargs, ordinal=ordinal, name=name
+ )
+ else:
+ instance = _instantiate(cls, args=e.args, kwargs=e.kwargs,
ordinal=ordinal, name=name)
+ entries[name] = instance
+ setattr(cls, name, instance)
+
+
+def _cxx_backed_unknown_binder_error(
+ cls: type,
+ name: str,
+ type_key: str,
+ entries: Any,
+) -> RuntimeError:
+ """Build a descriptive error for an unbindable bare ``ClassVar`` binder.
+
+ A bare ``ClassVar[Cls]`` annotation on a cxx-backed enum means "bind to
+ an existing C++ entry with this name" — if the C++ registry has no such
+ entry, the declaration is almost always a typo. For adding a *new*
+ Python-side variant on a cxx-backed enum, use ``entry(...)`` or
+ ``auto()`` instead.
+ """
+ known = list(entries.keys()) if entries is not None else []
+ known_str = ", ".join(repr(k) for k in known) if known else "<none>"
+ return RuntimeError(
+ f"Cannot bind enum variant {name!r} on {cls.__name__}: the FFI "
+ f"type {type_key!r} is already registered in C++ with entries "
+ f"[{known_str}], but has no C++ entry named {name!r}. "
+ f"Bare ``ClassVar[{cls.__name__}]`` binders on a C++-backed enum "
+ f"must name an entry already registered in C++; they cannot "
+ f"introduce new variants from Python. "
+ f"If this was a typo, double-check the spelling against the known "
+ f"entries above (`{name}: ClassVar[{cls.__name__}]`); if you meant "
+ f"to add a new Python-side variant, use `{name} = auto()` or "
+ f"`{name} = entry(...)` instead."
+ )
+
+
+def _instantiate(
+ cls: type,
+ *,
+ args: tuple[Any, ...],
+ kwargs: dict[str, Any],
+ ordinal: int,
+ name: str,
+) -> Any:
+ """Construct a subclass instance with auto-assigned ``value``/``name``."""
+ merged = dict(kwargs)
+ merged["value"] = ordinal
+ merged["name"] = name
+ return cls(*args, **merged)
+
+
+def _instantiate_cxx_backed(
+ cls: type,
+ *,
+ args: tuple[Any, ...],
+ kwargs: dict[str, Any],
+ ordinal: int,
+ name: str,
+) -> Any:
+ """Construct a new variant of a cxx-backed enum without going through
``__init__``.
+
+ C++-backed enums whose underlying type is registered with
+ ``refl::init(false)`` (e.g. any subclass of ``EnumObj`` in C++) have no
+ ``__ffi_init__``, so the usual ``cls(value=..., name=...)`` path is not
+ available. Mirror ``reflection::EnumDef`` by allocating a blank instance
+ via ``__ffi_new__`` and populating fields through the frozen-setter
+ escape hatch exposed on the reflected property descriptors.
+ """
+ if args:
+ raise TypeError(
+ f"{cls.__name__}.{name}: positional `entry(...)` args are not "
+ f"supported when extending a C++-backed enum; use keyword "
+ f"arguments naming reflected fields."
+ )
+ type_info = cls.__tvm_ffi_type_info__ # ty: ignore[unresolved-attribute]
+ ffi_new = core._lookup_type_attr(type_info.type_index, "__ffi_new__")
+ if ffi_new is None:
+ raise RuntimeError(
+ f"Cannot add Python enum variant {name!r} on {cls.__name__}: "
+ f"its C++ type has no ``__ffi_new__`` allocator registered, so "
+ f"blank instances cannot be created from Python."
+ )
+ instance = ffi_new()
+ for key in ("value", "name", *kwargs.keys()):
+ descriptor = getattr(cls, key, None)
+ if descriptor is None or not hasattr(descriptor, "set"):
+ raise TypeError(
+ f"{cls.__name__}.{name}: cannot set field {key!r} on a "
+ f"C++-backed enum — no reflected setter is available."
+ )
+ getattr(cls, "value").set(instance, ordinal)
+ getattr(cls, "name").set(instance, name)
+ for k, v in kwargs.items():
+ getattr(cls, k).set(instance, v)
+ return instance
+
+
+# ---------------------------------------------------------------------------
+# Annotation helpers
+# ---------------------------------------------------------------------------
+
+
+def _own_annotations(cls: type) -> dict[str, Any]:
+ """Return *cls*'s own annotations dict (not inherited)."""
+ if sys.version_info >= (3, 14):
+ return dict(getattr(cls, "__annotations__", {}) or {})
+ return dict(cls.__dict__.get("__annotations__", {}))
+
+
+def _is_class_var(annotation: Any) -> bool:
+ """Return True if *annotation* is ``ClassVar`` or ``ClassVar[...]``."""
+ if annotation is ClassVar:
+ return True
+ if typing.get_origin(annotation) is ClassVar:
+ return True
+ if isinstance(annotation, str):
+ stripped = annotation.replace(" ", "")
+ return stripped.startswith("ClassVar") or
stripped.startswith("typing.ClassVar")
+ return False
diff --git a/src/ffi/extra/dataclass.cc b/src/ffi/extra/dataclass.cc
index 2fcb702..c76117b 100644
--- a/src/ffi/extra/dataclass.cc
+++ b/src/ffi/extra/dataclass.cc
@@ -29,6 +29,7 @@
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/container/shape.h>
#include <tvm/ffi/container/tensor.h>
+#include <tvm/ffi/enum.h>
#include <tvm/ffi/extra/dataclass.h>
#include <tvm/ffi/reflection/accessor.h>
#include <tvm/ffi/reflection/creator.h>
@@ -51,6 +52,8 @@
#include <utility>
#include <vector>
+#include "../object_internal.h"
+
namespace tvm {
namespace ffi {
@@ -809,6 +812,18 @@ class ReprPrinter : public ObjectGraphDFS<ReprPrinter,
ReprFrame, std::string> {
return true;
}
}
+ // Built-in sentinel singletons — pointer-identity dispatch ahead of any
+ // type-keyed lookups so the generic ``ffi.Object`` framing is skipped.
+ static const Object* missing_ptr = GetMissingObject().get();
+ static const Object* kwargs_ptr = GetKwargsObject().get();
+ if (obj == missing_ptr) {
+ *out = "<MISSING>";
+ return true;
+ }
+ if (obj == kwargs_ptr) {
+ *out = "<KWARGS>";
+ return true;
+ }
// String/Bytes on heap
if (ti == TypeIndex::kTVMFFIStr) {
String s = details::AnyUnsafe::CopyFromAnyViewAfterCheck<String>(value);
@@ -866,6 +881,21 @@ class ReprPrinter : public ObjectGraphDFS<ReprPrinter,
ReprFrame, std::string> {
*out = result;
return true;
}
+ // Default repr for EnumObj subclasses: ``<type_key>.<name>``. Reached
only
+ // when no user-registered ``__ffi_repr__`` hook has claimed this type, so
+ // explicit repr overrides on specific enum subclasses still take
precedence.
+ if (obj->IsInstance<EnumObj>()) {
+ const EnumObj* enum_obj = static_cast<const EnumObj*>(obj);
+ const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(ti);
+ std::string result(type_info->type_key.data, type_info->type_key.size);
+ result += '.';
+ result.append(enum_obj->name.data(), enum_obj->name.size());
+ if (show_addr_) result += "@" + AddressStr(obj);
+ state_[obj] = State::kDone;
+ repr_cache_[obj] = result;
+ *out = result;
+ return true;
+ }
// Needs a frame
return false;
}
diff --git a/src/ffi/object.cc b/src/ffi/object.cc
index b3ed6b5..30d61a4 100644
--- a/src/ffi/object.cc
+++ b/src/ffi/object.cc
@@ -27,6 +27,7 @@
#include <tvm/ffi/container/list.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/container/tensor.h>
+#include <tvm/ffi/enum.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/memory.h>
@@ -319,11 +320,14 @@ class TypeTable {
column->data = reinterpret_cast<const TVMFFIAny*>(column->data_.data());
column->size = static_cast<int32_t>(column->data_.size());
column->begin_index = 0;
- if (column->data_[type_index - column->begin_index] != nullptr) {
- TVM_FFI_THROW(RuntimeError) << "Type attribute `" << name_str << "` is
already set for type `"
- << TypeIndexToTypeKey(type_index) << "`";
+ Any& slot = column->data_[type_index - column->begin_index];
+ if (slot.type_index() != kTVMFFINone) {
+ TVM_FFI_THROW(RuntimeError)
+ << "TypeAttr `" << name_str << "` is already registered for type
index " << type_index
+ << ". To update the stored value, register a mutable container
(e.g., Dict/List) "
+ << "once and mutate it in place on subsequent calls.";
}
- column->data_[type_index - column->begin_index] = value_view;
+ slot = value_view;
}
const TVMFFITypeAttrColumn* GetTypeAttrColumn(const TVMFFIByteArray* name) {
String name_str(*name);
@@ -645,6 +649,12 @@ TVM_FFI_STATIC_INIT_BLOCK() {
refl::TypeAttrDef<ffi::DictObj>().def(
refl::type_attr::kConvert,
&refl::details::FFIConvertFromAnyViewToObjectRef<ffi::Dict<ffi::Any,
ffi::Any>>);
+ refl::ObjectDef<ffi::EnumObj>(refl::init(false))
+ .def_ro("value", &ffi::EnumObj::value, "Ordinal assigned at
registration.",
+ refl::AttachFieldFlag::SEqHashIgnore())
+ .def_ro("name", &ffi::EnumObj::name, "Instance name.");
+ refl::EnsureTypeAttrColumn(refl::type_attr::kEnumEntries);
+ refl::EnsureTypeAttrColumn(refl::type_attr::kEnumAttrs);
refl::GlobalDef()
.def_method("ffi.GetRegisteredTypeKeys",
[]() -> ffi::Array<ffi::String> {
diff --git a/src/ffi/testing/testing.cc b/src/ffi/testing/testing.cc
index 9418580..4447900 100644
--- a/src/ffi/testing/testing.cc
+++ b/src/ffi/testing/testing.cc
@@ -27,10 +27,12 @@
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/container/variant.h>
#include <tvm/ffi/dtype.h>
+#include <tvm/ffi/enum.h>
#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/optional.h>
#include <tvm/ffi/reflection/accessor.h>
+#include <tvm/ffi/reflection/enum_def.h>
#include <tvm/ffi/reflection/registry.h>
#include <chrono>
@@ -79,6 +81,30 @@ TVM_FFI_STATIC_INIT_BLOCK() {
refl::type_attr::kConvert,
&refl::details::FFIConvertFromAnyViewToObjectRef<TestIntPair>);
}
+// C++-backed enum used by the Python ``Enum`` tests to exercise both
+// ``EnumDef``-registered entries and the Python ``ClassVar``-based binding.
+class TestEnumVariantObj : public tvm::ffi::EnumObj {
+ public:
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("testing.TestEnumVariant",
TestEnumVariantObj,
+ tvm::ffi::EnumObj);
+};
+
+class TestEnumVariant : public tvm::ffi::Enum {
+ public:
+ TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TestEnumVariant, tvm::ffi::Enum,
TestEnumVariantObj);
+};
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+ namespace refl = tvm::ffi::reflection;
+ // ObjectDef registers the type on destruction, so the temporary is
intentional;
+ // silence clang-tidy's bugprone-unused-raii since the RAII finalisation is
the point.
+ refl::ObjectDef<TestEnumVariantObj>(refl::init(false)); //
NOLINT(bugprone-unused-raii)
+ refl::TypeAttrDef<TestEnumVariantObj>().def(
+ refl::type_attr::kConvert,
&refl::details::FFIConvertFromAnyViewToObjectRef<TestEnumVariant>);
+ refl::EnumDef<TestEnumVariantObj>("Alpha").set_attr("code", int64_t{10});
+ refl::EnumDef<TestEnumVariantObj>("Beta").set_attr("code", int64_t{20});
+}
+
class TestObjectBase : public Object {
public:
int64_t v_i64;
@@ -541,6 +567,8 @@ TVM_FFI_STATIC_INIT_BLOCK() {
})
.def("testing.optional_tensor_view_has_value",
[](const Optional<TensorView>& t) { return t.has_value(); })
+ .def("testing.enum_variant_get",
+ [](const String& name) -> Enum { return
EnumObj::Get<TestEnumVariantObj>(name); })
.def_method("testing.TestIntPairSum", &TestIntPair::Sum, "Get sum of the
pair")
// Container-with-tensor test helpers for DLPack container conversion
// NOLINTBEGIN(performance-unnecessary-value-param)
diff --git a/tests/python/test_dataclass_enum.py
b/tests/python/test_dataclass_enum.py
new file mode 100644
index 0000000..af16601
--- /dev/null
+++ b/tests/python/test_dataclass_enum.py
@@ -0,0 +1,689 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Tests for ``Enum`` subclasses with ``type_key``-parameterized
inheritance."""
+
+from __future__ import annotations
+
+import itertools
+from typing import ClassVar
+
+import pytest
+import tvm_ffi
+from tvm_ffi import core
+from tvm_ffi.core import Object
+from tvm_ffi.dataclasses import Enum, EnumAttrMap, auto, entry
+from tvm_ffi.dataclasses.enum import (
+ ENUM_ATTRS_ATTR,
+ ENUM_ENTRIES_ATTR,
+ _EnumEntry,
+)
+
+# ---------------------------------------------------------------------------
+# Unique type-key generator — ensures no collisions across tests.
+# ---------------------------------------------------------------------------
+_counter = itertools.count()
+
+
+def _unique_key(base: str) -> str:
+ return f"testing.py_enum_sub.{base}_{next(_counter)}"
+
+
+# ---------------------------------------------------------------------------
+# Attribute-carrying form — entry(...) with extra field kwargs
+# ---------------------------------------------------------------------------
+
+
+def test_attribute_carrying_basic() -> None:
+ class Activation(Enum, type_key=_unique_key("Activation")):
+ output_zero: bool
+ is_monotonic: bool
+
+ relu: ClassVar[Activation] = entry(output_zero=True, is_monotonic=True)
+ gelu: ClassVar[Activation] = entry(output_zero=False,
is_monotonic=False)
+ silu: ClassVar[Activation] = entry(output_zero=False,
is_monotonic=True)
+
+ assert isinstance(Activation.relu, Activation)
+ assert isinstance(Activation.gelu, Activation)
+ assert isinstance(Activation.silu, Activation)
+
+ assert Activation.relu.output_zero is True # ty:
ignore[unresolved-attribute]
+ assert Activation.relu.is_monotonic is True # ty:
ignore[unresolved-attribute]
+ assert Activation.gelu.output_zero is False # ty:
ignore[unresolved-attribute]
+ assert Activation.silu.is_monotonic is True # ty:
ignore[unresolved-attribute]
+
+ # Ordinals auto-assigned in declaration order.
+ assert Activation.relu.value == 0
+ assert Activation.gelu.value == 1
+ assert Activation.silu.value == 2
+ assert Activation.relu.name == "relu"
+ assert Activation.gelu.name == "gelu"
+
+ assert Activation.get("relu").same_as(Activation.relu)
+ assert Activation.get("gelu").same_as(Activation.gelu)
+ assert Activation.get("silu").same_as(Activation.silu)
+
+
+def test_entry_rejects_value_kwarg() -> None:
+ """``entry(value=...)`` conflicts with the auto-assigned ordinal."""
+ with pytest.raises(TypeError):
+
+ class _Bad(Enum, type_key=_unique_key("BadValue")):
+ flag: bool
+ a: ClassVar[_Bad] = entry(flag=True, value=5)
+
+
+def test_entry_rejects_name_kwarg() -> None:
+ """``entry(name=...)`` conflicts with the auto-assigned declaration key."""
+ with pytest.raises(TypeError):
+
+ class _Bad(Enum, type_key=_unique_key("BadName")):
+ flag: bool
+ a: ClassVar[_Bad] = entry(flag=True, name="other")
+
+
+def test_get_missing_raises() -> None:
+ class Missing(Enum, type_key=_unique_key("Missing")):
+ flag: bool
+ yes: ClassVar[Missing] = entry(flag=True)
+
+ with pytest.raises(KeyError):
+ Missing.get("no-such-entry")
+
+
+def test_entries_iteration_order() -> None:
+ class Ordered(Enum, type_key=_unique_key("Ordered")):
+ tag: str
+ a: ClassVar[Ordered] = entry(tag="first")
+ b: ClassVar[Ordered] = entry(tag="second")
+ c: ClassVar[Ordered] = entry(tag="third")
+
+ values = list(Ordered.entries())
+ assert len(values) == 3
+ assert values[0].same_as(Ordered.a)
+ assert values[1].same_as(Ordered.b)
+ assert values[2].same_as(Ordered.c)
+
+
+def test_frozen_variants() -> None:
+ class Frozen(Enum, type_key=_unique_key("Frozen")):
+ flag: bool
+ yes: ClassVar[Frozen] = entry(flag=True)
+
+ with pytest.raises(AttributeError):
+ Frozen.yes.flag = False # ty: ignore[invalid-assignment]
+
+
+# ---------------------------------------------------------------------------
+# Bare ClassVar[Cls] (no assignment) — Python-side blank entries
+# ---------------------------------------------------------------------------
+
+
+def test_bare_classvar_without_cxx_entries() -> None:
+ """``ClassVar[Cls]`` with no value registers a new blank Python entry
+ when the type key has no C++ backing.
+ """
+
+ class Status(Enum, type_key=_unique_key("Status")):
+ ok: ClassVar[Status]
+ err: ClassVar[Status]
+ retry: ClassVar[Status]
+
+ assert isinstance(Status.ok, Status)
+ assert Status.ok.value == 0
+ assert Status.err.value == 1
+ assert Status.retry.value == 2
+ assert Status.ok.name == "ok"
+ assert Status.err.name == "err"
+ assert list(Status.entries()) == [Status.ok, Status.err, Status.retry]
+ assert Status.get("ok").same_as(Status.ok)
+
+
+def test_bare_classvar_mixed_with_entry() -> None:
+ """Bare ``ClassVar`` and ``ClassVar = entry(...)`` may mix in one class."""
+
+ class Kind(Enum, type_key=_unique_key("Kind")):
+ blank: ClassVar[Kind]
+ tag: str = "" # ordinary field; extra fields follow
+ named: ClassVar[Kind] = entry(tag="hi")
+
+ # ``ClassVar`` binders are processed before ``entry(...)`` assignments.
+ assert Kind.blank.value == 0
+ assert Kind.named.value == 1
+ assert Kind.blank.name == "blank"
+ assert Kind.named.name == "named"
+ assert Kind.named.tag == "hi" # ty: ignore[unresolved-attribute]
+
+
+# ---------------------------------------------------------------------------
+# Bare ``name = entry(...)`` sugar (no ClassVar annotation)
+# ---------------------------------------------------------------------------
+
+
+def test_bare_entry_sugar_form() -> None:
+ """``name = entry(...)`` without a ``ClassVar`` annotation is picked up."""
+
+ class Activation(Enum, type_key=_unique_key("ActivationBare")):
+ output_zero: bool
+ is_monotonic: bool
+
+ relu = entry(output_zero=True, is_monotonic=True)
+ gelu = entry(output_zero=False, is_monotonic=False)
+
+ assert isinstance(Activation.relu, Activation)
+ assert Activation.relu.output_zero is True # ty:
ignore[unresolved-attribute]
+ assert Activation.gelu.output_zero is False # ty:
ignore[unresolved-attribute]
+ assert list(Activation.entries()) == [Activation.relu, Activation.gelu]
+
+
+# ---------------------------------------------------------------------------
+# ``auto()`` — simple Python-side entries without ClassVar annotation
+# ---------------------------------------------------------------------------
+
+
+def test_auto_basic_no_annotation() -> None:
+ """``name = auto()`` registers a py-side entry with dense auto-ordinals."""
+
+ class Priority(Enum, type_key=_unique_key("Priority")):
+ low = auto()
+ medium = auto()
+ high = auto()
+
+ assert isinstance(Priority.low, Priority)
+ assert Priority.low.value == 0
+ assert Priority.medium.value == 1
+ assert Priority.high.value == 2
+ assert Priority.low.name == "low"
+ assert Priority.high.name == "high"
+ assert list(Priority.entries()) == [Priority.low, Priority.medium,
Priority.high]
+
+
+def test_auto_with_classvar_annotation() -> None:
+ """``name: ClassVar[Cls] = auto()`` is equivalent to the annotation-less
form."""
+
+ class Stage(Enum, type_key=_unique_key("Stage")):
+ init: ClassVar[Stage] = auto()
+ run: ClassVar[Stage] = auto()
+ done: ClassVar[Stage] = auto()
+
+ assert Stage.init.value == 0
+ assert Stage.run.value == 1
+ assert Stage.done.value == 2
+
+
+def test_auto_mixed_with_bare_classvar() -> None:
+ """``auto()`` may coexist with bare ``ClassVar`` binders in one class.
+
+ Bare ``ClassVar`` binders are processed first (in annotation order),
+ then ``auto()`` / ``entry(...)`` sentinels in class-body order — so
+ ordinals reflect that deterministic two-phase order.
+ """
+
+ class Mixed(Enum, type_key=_unique_key("Mixed")):
+ alpha: ClassVar[Mixed]
+ beta = auto()
+ gamma: ClassVar[Mixed]
+
+ # Binders (alpha, gamma) come first in annotation order, then sentinels.
+ assert Mixed.alpha.value == 0
+ assert Mixed.gamma.value == 1
+ assert Mixed.beta.value == 2
+ assert {v.name for v in Mixed.entries()} == {"alpha", "beta", "gamma"}
+
+
+def test_auto_mixed_with_entry() -> None:
+ """``auto()`` and ``entry(...)`` compose on an attribute-carrying enum."""
+
+ class Op(Enum, type_key=_unique_key("OpMixedAuto")):
+ arity: int = 0
+ noop = auto()
+ add = entry(arity=2)
+ neg = entry(arity=1)
+
+ assert Op.noop.value == 0
+ assert Op.add.value == 1
+ assert Op.neg.value == 2
+ assert Op.noop.arity == 0 # ty: ignore[unresolved-attribute]
+ assert Op.add.arity == 2 # ty: ignore[unresolved-attribute]
+
+
+def test_auto_rejects_already_registered_name() -> None:
+ """``auto()`` on a name already in the entries dict is rejected.
+
+ ``testing.TestEnumVariant`` pre-registers ``Alpha`` / ``Beta`` from C++,
+ so attempting to *register* (rather than bind) ``Alpha`` via ``auto()``
+ must fail — bare ``ClassVar[Cls]`` is the way to bind to an existing
+ entry.
+ """
+ with pytest.raises(RuntimeError):
+
+ class _Shadow(Enum, type_key="testing.TestEnumVariant"):
+ Alpha = auto()
+
+
+def test_auto_returns_fresh_sentinels() -> None:
+ """Each ``auto()`` call returns a distinct sentinel instance."""
+ a, b = auto(), auto()
+ assert isinstance(a, _EnumEntry)
+ assert isinstance(b, _EnumEntry)
+ assert a is not b
+ assert a.args == ()
+ assert a.kwargs == {}
+
+
+# ---------------------------------------------------------------------------
+# by_name / by_value / attr_dict
+# ---------------------------------------------------------------------------
+
+
+def test_by_name_is_live_dict() -> None:
+ class K(Enum, type_key=_unique_key("ByName")):
+ a: ClassVar[K]
+ b: ClassVar[K]
+
+ assert set(K.by_name.keys()) == {"a", "b"}
+ assert K.by_name["a"].same_as(K.a)
+
+
+def test_by_value_indexed_by_ordinal() -> None:
+ class K(Enum, type_key=_unique_key("ByValue")):
+ a: ClassVar[K]
+ b: ClassVar[K]
+ c: ClassVar[K]
+
+ by_val = K.by_value
+ assert len(by_val) == 3
+ assert by_val[0].same_as(K.a)
+ assert by_val[1].same_as(K.b)
+ assert by_val[2].same_as(K.c)
+
+
+def test_attr_dict_direct_access() -> None:
+ """The ``attr_dict`` class property returns the live per-variant attrs
map."""
+
+ class Op(Enum, type_key=_unique_key("OpDirect")):
+ arity: int
+ add: ClassVar[Op] = entry(arity=2)
+ neg: ClassVar[Op] = entry(arity=1)
+
+ has_side_effects = Op.def_attr("has_side_effects", default=False)
+ has_side_effects[Op.add] = False
+ has_side_effects[Op.neg] = True
+
+ # Direct read via class-level property.
+ column = Op.attr_dict["has_side_effects"]
+ assert column[Op.add.value] is False
+ assert column[Op.neg.value] is True
+
+
+# ---------------------------------------------------------------------------
+# EnumAttrMap / def_attr
+# ---------------------------------------------------------------------------
+
+
+def test_def_attr_basic_get_set() -> None:
+ class Op(Enum, type_key=_unique_key("Op")):
+ arity: int
+ add: ClassVar[Op] = entry(arity=2)
+ neg: ClassVar[Op] = entry(arity=1)
+
+ cost = Op.def_attr("cost", default=0)
+ assert isinstance(cost, EnumAttrMap)
+
+ assert cost[Op.add] == 0
+ assert cost[Op.neg] == 0
+
+ cost[Op.add] = 5
+ cost[Op.neg] = 2
+ assert cost[Op.add] == 5
+ assert cost[Op.neg] == 2
+
+
+def test_def_attr_missing_raises_without_default() -> None:
+ class OpStrict(Enum, type_key=_unique_key("OpStrict")):
+ arity: int
+ add: ClassVar[OpStrict] = entry(arity=2)
+
+ attr = OpStrict.def_attr("strict_attr")
+ with pytest.raises(KeyError):
+ _ = attr[OpStrict.add]
+
+
+def test_def_attr_get_method_default() -> None:
+ class Op(Enum, type_key=_unique_key("OpGetDefault")):
+ arity: int
+ add: ClassVar[Op] = entry(arity=2)
+
+ attr = Op.def_attr("cost")
+ assert attr.get(Op.add, -1) == -1
+ attr[Op.add] = 9
+ assert attr.get(Op.add, -1) == 9
+
+
+def test_def_attr_rejects_foreign_variant() -> None:
+ class Left(Enum, type_key=_unique_key("LeftEnum")):
+ v: int
+ one: ClassVar[Left] = entry(v=1)
+
+ class Right(Enum, type_key=_unique_key("RightEnum")):
+ v: int
+ one: ClassVar[Right] = entry(v=1)
+
+ attr = Left.def_attr("x", default=0)
+ with pytest.raises(TypeError):
+ attr[Right.one] = 1
+
+
+def test_def_attr_contains() -> None:
+ class Op(Enum, type_key=_unique_key("OpC")):
+ arity: int
+ add: ClassVar[Op] = entry(arity=2)
+ neg: ClassVar[Op] = entry(arity=1)
+
+ cost = Op.def_attr("cost", default=0)
+ assert Op.add not in cost
+ cost[Op.add] = 5
+ assert Op.add in cost
+ assert Op.neg not in cost
+
+
+def test_def_attr_rejects_none_write() -> None:
+ """``None`` is reserved as the column's "unset" sentinel."""
+
+ class Op(Enum, type_key=_unique_key("OpNone")):
+ arity: int
+ add: ClassVar[Op] = entry(arity=2)
+
+ cost = Op.def_attr("cost", default=0)
+ with pytest.raises(TypeError, match="reserved as the 'unset' sentinel"):
+ cost[Op.add] = None
+ assert Op.add not in cost
+
+
+def test_def_attr_accepts_fresh_wrapper_from_get() -> None:
+ """Variants returned by ``Cls.get(...)`` may be fresh Python wrappers
+ whose ``id`` differs from the cached class attribute. Ordinal-indexed
+ lookup must still resolve correctly.
+ """
+
+ class Op(Enum, type_key=_unique_key("OpFresh")):
+ arity: int
+ add: ClassVar[Op] = entry(arity=2)
+
+ cost = Op.def_attr("cost", default=0)
+ cost[Op.add] = 7
+
+ fresh = Op.get("add")
+ assert fresh.same_as(Op.add)
+ assert cost[fresh] == 7
+ assert fresh in cost
+
+
+# ---------------------------------------------------------------------------
+# `entry` sentinel sanity checks
+# ---------------------------------------------------------------------------
+
+
+def test_entry_sentinel_reprs() -> None:
+ e = entry(1, 2, name_key="x")
+ assert isinstance(e, _EnumEntry)
+ assert e.args == (1, 2)
+ assert e.kwargs == {"name_key": "x"}
+ assert "entry(" in repr(e)
+
+
+def test_entry_attribute_access_outside_class_body() -> None:
+ """A naked ``entry()`` call returns the sentinel — never a real
instance."""
+ e = entry(output_zero=True)
+ assert not isinstance(e, Object)
+
+
+# ---------------------------------------------------------------------------
+# TypeAttr-level verification
+# ---------------------------------------------------------------------------
+
+
+def test_enum_entries_typeattr_is_mapping() -> None:
+ class WithAttr(Enum, type_key=_unique_key("WithAttr")):
+ v: int
+ one: ClassVar[WithAttr] = entry(v=1)
+
+ tinfo = WithAttr.__tvm_ffi_type_info__ # ty: ignore[unresolved-attribute]
+ stored = core._lookup_type_attr(tinfo.type_index, ENUM_ENTRIES_ATTR)
+ assert stored is not None
+ assert "one" in stored
+
+
+def test_enum_attrs_typeattr_stored_under_unified_column() -> None:
+ """``def_attr`` writes into the ``__ffi_enum_attrs__`` Dict column."""
+
+ class WithAttr(Enum, type_key=_unique_key("UnifiedAttrs")):
+ v: int
+ one: ClassVar[WithAttr] = entry(v=1)
+
+ attr = WithAttr.def_attr("color", default="?")
+ attr[WithAttr.one] = "red"
+
+ tinfo = WithAttr.__tvm_ffi_type_info__ # ty: ignore[unresolved-attribute]
+ stored = core._lookup_type_attr(tinfo.type_index, ENUM_ATTRS_ATTR)
+ assert stored is not None
+ assert "color" in stored
+ assert stored["color"][WithAttr.one.value] == "red"
+
+
+# ---------------------------------------------------------------------------
+# C++-backed enum — auto-detected when type_key is already registered
+# ---------------------------------------------------------------------------
+
+
+def test_cxx_enum_obj_get_returns_singleton() -> None:
+ """``EnumObj::Get<TestEnumVariantObj>`` (wired as
``testing.enum_variant_get``)
+ returns the same singleton as ``Enum.get`` and the Python-side binder.
+ """
+ cxx_get = tvm_ffi.get_global_func("testing.enum_variant_get")
+
+ class Variant(Enum, type_key="testing.TestEnumVariant"):
+ Alpha: ClassVar[Variant]
+ Beta: ClassVar[Variant]
+
+ assert cxx_get("Alpha").same_as(Variant.Alpha)
+ assert cxx_get("Beta").same_as(Variant.Beta)
+
+ with pytest.raises(RuntimeError, match="no instance named"):
+ cxx_get("Nope")
+
+
+def test_cxx_backed_classvar_binds_to_existing_entries() -> None:
+ """``ClassVar[Cls]`` (no assignment) binds to entries registered on the
+ C++ side via ``refl::EnumDef``.
+
+ ``testing.TestEnumVariant`` is registered in C++ with two entries,
+ ``Alpha`` and ``Beta``, each with a ``code`` attr attached via
+ ``set_attr``. The Python subclass picks these up without declaring any
+ new entries of its own.
+ """
+
+ class Variant(Enum, type_key="testing.TestEnumVariant"):
+ Alpha: ClassVar[Variant]
+ Beta: ClassVar[Variant]
+
+ assert isinstance(Variant.Alpha, Variant)
+ assert Variant.Alpha.name == "Alpha"
+ assert Variant.Beta.name == "Beta"
+ # Ordinals come from C++ (registered in Alpha, Beta declaration order).
+ assert Variant.Alpha.value == 0
+ assert Variant.Beta.value == 1
+ assert Variant.get("Alpha").same_as(Variant.Alpha)
+
+ # C++-stored `code` attr is visible through attr_dict.
+ code_col = Variant.attr_dict["code"]
+ assert code_col[Variant.Alpha.value] == 10
+ assert code_col[Variant.Beta.value] == 20
+
+
+def test_cxx_backed_reads_entries_typeattr() -> None:
+ class Variant2(Enum, type_key="testing.TestEnumVariant"):
+ Alpha: ClassVar[Variant2]
+
+ tinfo = Variant2.__tvm_ffi_type_info__ # ty: ignore[unresolved-attribute]
+ stored = core._lookup_type_attr(tinfo.type_index, ENUM_ENTRIES_ATTR)
+ assert stored is not None
+ assert "Alpha" in stored
+
+
+def test_cxx_backed_binder_typo_raises_descriptive_error() -> None:
+ """A ``ClassVar[Cls]`` binder naming an entry that isn't in the C++
+ registry raises a ``RuntimeError`` that names both the typo and the
+ known C++-registered entries, rather than falling through to the
+ ``Enum``-base ``init=False`` guard.
+ """
+ with pytest.raises(RuntimeError) as excinfo:
+
+ class _Typo(Enum, type_key="testing.TestEnumVariant"):
+ Alpha: ClassVar[_Typo]
+ Neta: ClassVar[_Typo] # intended to be "Beta"
+
+ msg = str(excinfo.value)
+ assert "'Neta'" in msg
+ assert "'testing.TestEnumVariant'" in msg
+ assert "'Alpha'" in msg
+ assert "'Beta'" in msg
+ assert "C++" in msg
+ assert "ClassVar" in msg
+
+
+def test_cxx_backed_mixed_entries_via_auto() -> None:
+ """A Python subclass of a cxx-backed enum may add new Python-side entries
+ via ``auto()`` alongside bare ``ClassVar`` binders for existing C++
entries.
+
+ Ordinals for the new Python entries continue from the count of existing
+ entries, preserving the dense-ordinal invariant across the mixed set.
+ The ``__ffi_enum_entries__`` dict lives at the type-index level and is
+ shared with every other Python subclass of the same ``type_key`` — so
+ we pick unique variant names (``Mixed*``) to avoid collisions with
+ other tests that also bind to ``testing.TestEnumVariant``.
+ """
+
+ class Mixed(Enum, type_key="testing.TestEnumVariant"):
+ Alpha: ClassVar[Mixed]
+ Beta: ClassVar[Mixed]
+ MixedOne = auto()
+ MixedTwo = auto()
+
+ # Alpha/Beta come from C++ with ordinals 0 and 1.
+ assert Mixed.Alpha.value == 0
+ assert Mixed.Beta.value == 1
+ assert Mixed.Alpha.name == "Alpha"
+ assert Mixed.Beta.name == "Beta"
+
+ # Python-side entries extend the dense ordinal sequence from the C++ count.
+ assert Mixed.MixedOne.name == "MixedOne"
+ assert Mixed.MixedTwo.name == "MixedTwo"
+ assert Mixed.MixedOne.value == Mixed.Beta.value + 1
+ assert Mixed.MixedTwo.value == Mixed.Beta.value + 2
+
+ # Round-trip through ``get`` / ``by_name`` / ``entries``.
+ assert Mixed.get("MixedOne").same_as(Mixed.MixedOne)
+ assert Mixed.get("MixedTwo").same_as(Mixed.MixedTwo)
+ assert {"Alpha", "Beta", "MixedOne",
"MixedTwo"}.issubset(Mixed.by_name.keys())
+
+ # Python-side variants are real subclass instances of the cxx-backed type.
+ assert isinstance(Mixed.MixedOne, Mixed)
+ assert isinstance(Mixed.MixedTwo, Mixed)
+
+ # Existing C++ attrs remain unaffected; new Python variants have no attrs
yet.
+ code = Mixed.attr_dict["code"]
+ assert code[Mixed.Alpha.value] == 10
+ assert code[Mixed.Beta.value] == 20
+
+
+def test_cxx_backed_python_entry_accepts_def_attr() -> None:
+ """``def_attr`` writes still work for Python-side variants on a cxx-backed
enum."""
+
+ class WithPy(Enum, type_key="testing.TestEnumVariant"):
+ Alpha: ClassVar[WithPy]
+ AttrOne = auto()
+
+ tag = WithPy.def_attr("tag", default=None)
+ tag[WithPy.AttrOne] = "py-side"
+ assert tag[WithPy.AttrOne] == "py-side"
+ # Column was widened to the new ordinal; C++-registered entries retain
default.
+ assert tag.get(WithPy.Alpha) is None
+
+
+# ---------------------------------------------------------------------------
+# Default ReprPrint for EnumObj subclasses + MISSING/KWARGS sentinels
+# ---------------------------------------------------------------------------
+
+
+def test_default_repr_python_backed() -> None:
+ """Python-only enum subclasses format each variant as
``<type_key>.<name>``."""
+ key = _unique_key("ReprPy")
+
+ class Priority(Enum, type_key=key):
+ low = auto()
+ medium = auto()
+ high = auto()
+
+ assert repr(Priority.low) == f"{key}.low"
+ assert repr(Priority.medium) == f"{key}.medium"
+ assert repr(Priority.high) == f"{key}.high"
+
+
+def test_default_repr_cxx_backed() -> None:
+ """C++-registered enum subclasses format with the C++ type_key."""
+
+ class Variant(Enum, type_key="testing.TestEnumVariant"):
+ Alpha: ClassVar[Variant]
+ Beta: ClassVar[Variant]
+
+ assert repr(Variant.Alpha) == "testing.TestEnumVariant.Alpha"
+ assert repr(Variant.Beta) == "testing.TestEnumVariant.Beta"
+
+
+def test_default_repr_in_nested_container() -> None:
+ """Enum repr applies recursively when a variant is nested inside a
Dict/List."""
+ key = _unique_key("ReprNested")
+
+ class Color(Enum, type_key=key):
+ red = auto()
+ green = auto()
+
+ by_name_repr = repr(Color.by_name)
+ assert f"{key}.red" in by_name_repr
+ assert f"{key}.green" in by_name_repr
+
+ by_value_entries = [repr(v) for v in Color.by_value]
+ assert by_value_entries == [f"{key}.red", f"{key}.green"]
+
+
+def test_default_repr_with_attribute_carrying_variant() -> None:
+ """Attribute-carrying entries still render with the ``<type_key>.<name>``
form."""
+ key = _unique_key("ReprWithAttrs")
+
+ class Op(Enum, type_key=key):
+ arity: int
+ add: ClassVar[Op] = entry(arity=2)
+ neg: ClassVar[Op] = entry(arity=1)
+
+ assert repr(Op.add) == f"{key}.add"
+ assert repr(Op.neg) == f"{key}.neg"
+
+
+def test_missing_and_kwargs_sentinel_repr() -> None:
+ """The built-in MISSING and KWARGS singletons render with angle-bracket
tags."""
+ assert repr(core.MISSING) == "<MISSING>"
+ assert repr(core.KWARGS) == "<KWARGS>"