This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new d70c905 refactor(enum)!: hide identity fields, add IntEnum/StrEnum,
consolidate accessors (#564)
d70c905 is described below
commit d70c905d16d16a93fedd7417e1500aaf11dfb530
Author: Junru Shao <[email protected]>
AuthorDate: Tue Apr 21 09:35:03 2026 -0700
refactor(enum)!: hide identity fields, add IntEnum/StrEnum, consolidate
accessors (#564)
## Summary
This PR completes the enum identity-vs-payload separation: it hides the
intrinsic `_value` / `_name` fields, introduces `IntEnum` / `StrEnum`
payload bases, consolidates the class-level entry accessors into a
single `all_entries()` iterator, and adds a payload-aware fast path in
the Cython type converter so `py_class` fields typed as a payload-enum
accept raw `int` / `str` values.
## Architecture
- **Rename `EnumObj::value` / `name` to `_value` / `_name`** across the
C++ layout, reflection schema (`ObjectDef<EnumObj>::def_ro`), repr
printer, and `reflection::EnumDef`. The leading underscore frees the
public `value` symbol for user payload fields while keeping the dense
ordinal and declaration-key as the canonical variant identity.
- **Add `IntEnum` / `StrEnum` base classes** in
`tvm_ffi.dataclasses.enum` that reserve a declared `value: int` (resp.
`str`) payload field via `_prepare_payload_enum_subclass`. The helper
injects or validates the `value` annotation *before*
`Enum.__init_subclass__` runs, so reflection sees the payload as a
first-class declared field rather than an extensible attribute, and
stamps `__ffi_enum_payload_value_type__` onto the class for downstream
dispatch.
- **Tighten `py_class` parent-info lookup** to fall back to
`base.__tvm_ffi_type_info__` when the core registry lookup misses, so
Python subclasses descending through `IntEnum` / `StrEnum` register with
the correct parent type info.
- **Consolidate class-level entry accessors** into a single
`all_entries()` iterator backed by `_ordered_entries`. The `by_name` /
`by_value` descriptors and the older `entries()` method are removed;
`_EnumMeta.__iter__` / `__len__` and the public accessor now share one
canonical ordering path.
- **Accept payload-literal variant declarations** on `IntEnum` /
`StrEnum` (`red = 10`, `add = "+"`) as sugar for `entry(value=...)`.
`_collect_entry_declarations` treats bare assignments on payload enums
as `_EnumEntry` values, and invalid payload types surface as a
`TypeError` via the new `_instantiate_entry` error-normalizing wrapper.
- **Payload-aware Cython dispatch**: new `_tc_convert_int_enum` /
`_tc_convert_str_enum` inspect the registered target class's MRO at
schema-build time and, at call time, coerce raw `int` / `str` payloads
into the matching variant via `all_entries()`. The object-passthrough
and pack/convert fallback are preserved by factoring the marshaled path
into `_tc_convert_object_marshaled`.
## Public Interfaces (BREAKING)
- **C++ ABI break**: `EnumObj::value` / `EnumObj::name` no longer exist;
consumers must read `EnumObj::_value` / `EnumObj::_name`.
- **Python API break**: on plain `Enum` subclasses, `instance.value` /
`instance.name` are no longer surfaced (they raise `AttributeError`).
Use `_value` / `_name` for the intrinsic ordinal and key, or migrate to
`IntEnum` / `StrEnum` to obtain a user-visible `value` payload.
- **Python API break**: `Enum.entries()` is renamed to
`Enum.all_entries()`; `Enum.by_name` and `Enum.by_value` class
properties are removed. Replace `Cls.entries()` with
`Cls.all_entries()`, `Cls.by_value` with `list(Cls.all_entries())`, and
`Cls.by_name` with `{e._name: e for e in Cls.all_entries()}`.
- **New exports**: `tvm_ffi.dataclasses.IntEnum` and
`tvm_ffi.dataclasses.StrEnum` (added to `__all__`).
- `entry(...)` now rejects `_value=` / `_name=` kwargs (previously
`value=` / `name=`).
- `IntEnum` / `StrEnum` subclasses additionally accept `red = 10` / `add
= "+"` as sugar for `entry(value=...)`.
- `py_class` fields typed as `IntEnum` / `StrEnum` subclasses now accept
raw `int` / `str` payloads at construction and assignment; the converter
resolves them to the singleton variant and rejects unknown payloads with
the usual `expected <Type>` error.
## Behavioral Changes
- Reading `.value` / `.name` on a plain `Enum` variant raises
`AttributeError`.
- Repr still produces `<type_key>.<_name>` (wire-identical).
- `iter(EnumCls)` and `len(EnumCls)` now route through
`_ordered_entries` instead of the removed `by_value` descriptor;
observable ordering is unchanged (dense ordinal order).
- Constructing an enum subclass with an invalid literal payload (e.g.,
`class Bad(IntEnum): nope = "x"`) now raises `TypeError` with a
`<Cls>.<name>: invalid enum entry: ...` prefix via `_instantiate_entry`,
rather than leaking the raw converter error.
- `py_class` field assignments of the form `inst.priority = 20` (where
`priority: Priority` is an `IntEnum`) now succeed, binding to the
corresponding enum variant.
- **Python 3.8/3.9 compatibility**: `_prepare_payload_enum_subclass`
uses `_own_annotations(cls)` rather than `getattr(cls,
"__annotations__", {})`. The previous lookup followed MRO when the
subclass had no annotations of its own and pulled in the parent `Enum`'s
`_value` / `_name`, which then re-registered those fields on the
subclass and produced `ValueError: duplicate parameter name: '_value'`
when constructing the `inspect.Signature` for `__init__`. Python 3.10+
was unaffected because the interpreter always materializes an empty
`__annotations__` dict in every class body.
## Docs
- Doxygen and Python docstrings updated in-place in
`include/tvm/ffi/enum.h`, `include/tvm/ffi/reflection/enum_def.h`,
`include/tvm/ffi/reflection/accessor.h`, and
`python/tvm_ffi/dataclasses/enum.py`.
- No doc-site page referenced the old public `EnumObj::value` / `name`
or the removed `entries()` / `by_name` / `by_value` accessors (verified
via `rg` over `docs/`), so no further changes in `docs/` were required.
## Test plan
- [x] `uv pip install --force-reinstall --verbose -e .` rebuild after
Cython changes in `pyclass_type_converter.pxi`.
- [x] `uv run pytest -vvs tests/python/test_dataclass_enum.py
tests/python/test_dataclass_py_class.py
tests/python/test_type_converter.py` — 1008 passed, 2 xfailed.
- [x] New tests:
- `test_int_enum_payload_literal_sugar`,
`test_str_enum_payload_literal_sugar`,
`test_payload_literal_sugar_preserves_annotated_field_defaults`,
`test_int_enum_payload_literal_sugar_rejects_invalid_payload`,
`test_str_enum_payload_literal_sugar_rejects_invalid_payload`.
- `test_all_entries_iteration_order`,
`test_all_entries_indexed_by_ordinal` (replacing the removed
`test_by_name_is_live_dict` / `test_by_value_indexed_by_ordinal`).
- `test_public_value_and_name_are_hidden` (identity-hidden assertion on
plain `Enum`).
- `test_int_enum_payload_value` / `test_str_enum_payload_value`
(end-to-end payload).
- `TestPayloadEnums` group in `test_type_converter.py` covering int/str
payload conversion, object passthrough, and rejection of unknown
payloads.
- `test_payload_enum_fields_end_to_end` in `test_dataclass_py_class.py`
covering construction, assignment, and error behavior for `IntEnum` /
`StrEnum`-typed `py_class` fields.
- [x] Python 3.8 MRO-lookup behavior reproduced on CPython 3.8.20 and
the `_own_annotations` fix validated against it.
- [ ] CI C++ tests for `refl::EnumDef` (Linux x86_64/aarch64, macOS
arm64, Windows) — not rerun locally; relying on CI to exercise the
renamed `_value` / `_name` compile path.
- [ ] CI Rust job — no local Rust coverage for downstream `EnumObj`
field consumers.
- [ ] Cross-language (C++-backed) `IntEnum` / `StrEnum` payload coercion
is not specifically exercised; existing `test_cxx_backed_*` cases cover
the object-passthrough path, and the new literal-sugar path targets
pure-Python subclasses.
## Migration
- Replace `enum_instance.value` / `.name` with `._value` / `._name`,
**or** switch the base class to `IntEnum` / `StrEnum` to keep a
user-visible `value` field.
- Replace `entry(value=..., name=...)` with `entry(_value=...,
_name=...)`, or drop them (both are auto-assigned).
- Replace `Cls.entries()` with `Cls.all_entries()`.
- Replace `Cls.by_value` with `list(Cls.all_entries())`.
- Replace `Cls.by_name` with `{e._name: e for e in Cls.all_entries()}`.
---
include/tvm/ffi/enum.h | 8 +-
include/tvm/ffi/reflection/accessor.h | 17 +-
include/tvm/ffi/reflection/enum_def.h | 8 +-
python/tvm_ffi/cython/pyclass_type_converter.pxi | 170 +++++++++++-
python/tvm_ffi/dataclasses/__init__.py | 4 +-
python/tvm_ffi/dataclasses/enum.py | 298 ++++++++++++++++-----
python/tvm_ffi/dataclasses/py_class.py | 2 +
src/ffi/extra/dataclass.cc | 4 +-
src/ffi/object.cc | 5 +-
tests/python/test_dataclass_enum.py | 315 +++++++++++++++++------
tests/python/test_dataclass_py_class.py | 40 ++-
tests/python/test_type_converter.py | 67 ++++-
12 files changed, 755 insertions(+), 183 deletions(-)
diff --git a/include/tvm/ffi/enum.h b/include/tvm/ffi/enum.h
index 5a840cb..0cd97d9 100644
--- a/include/tvm/ffi/enum.h
+++ b/include/tvm/ffi/enum.h
@@ -45,7 +45,7 @@ class Enum;
* \brief Base class for FFI-registered enums.
*
* Each registered variant is a unique, process-wide singleton with a
- * dense ordinal (``value``) and string ``name``. Subclasses may add
+ * dense ordinal (``_value``) and string ``_name``. Subclasses may add
* *declared fields* — part of the variant's schema, set at registration
* time via ``reflection::EnumDef``. Separately, any consumer may
* attach *extensible attributes* (per-variant metadata stored outside
@@ -57,9 +57,9 @@ class Enum;
class EnumObj : public Object {
public:
/*! \brief Declared field: dense ordinal assigned at registration time
(0-indexed per class). */
- int64_t value;
+ int64_t _value;
/*! \brief Declared field: instance name (e.g., ``"Add"`` for ``Op.Add``). */
- String name;
+ String _name;
EnumObj() = default;
/*!
@@ -67,7 +67,7 @@ class EnumObj : public Object {
* \param value The dense ordinal (0-indexed per enum class).
* \param name The instance name key.
*/
- EnumObj(int64_t value, String name) : value(value), name(std::move(name)) {}
+ EnumObj(int64_t value, String name) : _value(value), _name(std::move(name))
{}
/*!
* \brief Look up the registered singleton for ``EnumClsObj`` by name.
diff --git a/include/tvm/ffi/reflection/accessor.h
b/include/tvm/ffi/reflection/accessor.h
index 9b7950b..ee43379 100644
--- a/include/tvm/ffi/reflection/accessor.h
+++ b/include/tvm/ffi/reflection/accessor.h
@@ -504,13 +504,28 @@ inline constexpr const char* kEnumEntries =
"__ffi_enum_entries__";
* \brief Per-class column holding extensible attributes for enum variants.
*
* The outer dict is keyed by extensible-attribute name; each value is a
- * list indexed by the variant's ordinal (``EnumObj::value``). Written
+ * list indexed by the variant's ordinal (``EnumObj::_value``). Written
* by ``refl::EnumDef<T>::set_attr(...)`` on the C++ side and by the
* ``EnumAttrMap`` returned from Python ``Enum.def_attr(...)``.
*
* Value type: ``Dict<String, List<Any>>``.
*/
inline constexpr const char* kEnumAttrs = "__ffi_enum_attrs__";
+/*!
+ * \brief Per-class payload-to-variant index for enums.
+ *
+ * Parallel to ``kEnumEntries`` (name → variant) but keyed by the
+ * user-visible payload carried on each variant — i.e. the ``value``
+ * field on Python ``IntEnum`` / ``StrEnum`` subclasses (``int`` or
+ * ``str``) or the equivalent payload field on a C++ ``EnumObj``
+ * subclass. Populated by the creator of each variant (Python or C++)
+ * when the variant has a payload; absent or partially populated
+ * otherwise. Consumed by FFI converters to resolve a raw payload
+ * (``int``/``str``) to its singleton variant in O(1).
+ *
+ * Value type: ``Dict<Any, Enum>``.
+ */
+inline constexpr const char* kEnumValueEntries = "__ffi_enum_value_entries__";
} // namespace type_attr
/*!
diff --git a/include/tvm/ffi/reflection/enum_def.h
b/include/tvm/ffi/reflection/enum_def.h
index f0a20b0..6e1e91c 100644
--- a/include/tvm/ffi/reflection/enum_def.h
+++ b/include/tvm/ffi/reflection/enum_def.h
@@ -49,8 +49,8 @@ namespace reflection {
* \brief Builder that registers a single enum instance on ``EnumClsObj``.
*
* Each ``EnumDef<EnumClsObj>("Name")`` call allocates a fresh dense ordinal
- * (``= len(existing entries)``), constructs a variant with ``value`` and
- * ``name`` populated, and writes it into the per-class registry stored in
+ * (``= len(existing entries)``), constructs a variant with ``_value`` and
+ * ``_name`` populated, and writes it into the per-class registry stored in
* the ``type_attr::kEnumEntries`` TypeAttr column. Subsequent
* ``.set_attr(...)`` calls write *extensible attributes* — per-variant
* metadata attached outside the variant's declared fields — into the
@@ -83,8 +83,8 @@ class EnumDef : public ReflectionDefBase {
}
ordinal_ = static_cast<int64_t>(entries.size());
ObjectPtr<EnumClsObj> obj = make_object<EnumClsObj>();
- obj->value = ordinal_;
- obj->name = name_str;
+ obj->_value = ordinal_;
+ obj->_name = name_str;
instance_ = Enum(ObjectPtr<EnumObj>(std::move(obj)));
entries.Set(name_str, instance_);
// Ensure the attrs dict exists so later ``set_attr`` calls can mutate it.
diff --git a/python/tvm_ffi/cython/pyclass_type_converter.pxi
b/python/tvm_ffi/cython/pyclass_type_converter.pxi
index fad1970..af216d1 100644
--- a/python/tvm_ffi/cython/pyclass_type_converter.pxi
+++ b/python/tvm_ffi/cython/pyclass_type_converter.pxi
@@ -36,6 +36,7 @@ cdef object _INT64_MIN = -(1 << 63)
cdef object _INT64_MAX = (1 << 63) - 1
cdef int _VALUE_PROTOCOL_MAX_DEPTH = 64
cdef str _TYPE_ATTR_FFI_CONVERT = "__ffi_convert__"
+cdef str _TYPE_ATTR_ENUM_VALUE_ENTRIES = "__ffi_enum_value_entries__"
cdef class _TypeConverter
ctypedef CAny (*_dispatch_fn_t)(_TypeConverter, object, bint*) except *
@@ -499,24 +500,88 @@ cdef CAny _tc_convert_union(_TypeConverter conv, object
value, bint* changed) ex
# Converters (4/N): Object Types
# ---------------------------------------------------------------------------
-cdef CAny _tc_convert_object(_TypeConverter conv, object value, bint* changed)
except *:
- """Convert *value* to an object compatible with ``conv.type_index``."""
- # TODO: SmallStr and SmallBytes => ObjectRef conversion is not supported
yet
+cdef inline object _tc_get_registered_cls(int32_t type_index):
+ if 0 <= type_index < len(TYPE_INDEX_TO_CLS):
+ return TYPE_INDEX_TO_CLS[type_index]
+ return None
+
+
+cdef inline bint _tc_registered_cls_has_base(
+ int32_t type_index,
+ str module_name,
+ str base_name,
+) except *:
+ cdef object cls = _tc_get_registered_cls(type_index)
+ cdef object base
+ if cls is None:
+ return False
+ for base in cls.__mro__:
+ if (
+ getattr(base, "__module__", None) == module_name
+ and getattr(base, "__name__", None) == base_name
+ ):
+ return True
+ return False
+
+
+cdef object _tc_find_payload_enum_variant(
+ int32_t type_index, object enum_cls, object payload
+) except *:
+ """Resolve *payload* to its singleton variant (``None`` if no match).
+
+ Primary path: O(1) lookup in the cross-language value-entries column
+ (``__ffi_enum_value_entries__``). Falls back to an O(n) linear scan
+ over ``enum_cls.all_entries()`` when the column has no entry for
+ *payload* — needed so correctness is preserved for variants whose
+ creators haven't populated the column.
+ """
+ cdef object value_entries
+ cdef object variant
+ value_entries = _lookup_type_attr(type_index,
_TYPE_ATTR_ENUM_VALUE_ENTRIES)
+ if value_entries is not None:
+ variant = value_entries.get(payload)
+ if variant is not None:
+ return variant
+ for variant in enum_cls.all_entries():
+ if variant.value == payload:
+ return variant
+ return None
+
+
+cdef object _tc_normalize_int_enum_payload(object value, bint* matched) except
*:
+ cdef object ivalue
+ matched[0] = False
+ if isinstance(value, bool):
+ matched[0] = True
+ return int(value)
+ if isinstance(value, int):
+ if not (_INT64_MIN <= value <= _INT64_MAX):
+ raise _ConvertError(
+ f"integer {value} out of int64 range [{_INT64_MIN},
{_INT64_MAX}]"
+ )
+ matched[0] = True
+ return value
+ if isinstance(value, Integral):
+ try:
+ ivalue = int(value)
+ except Exception as err:
+ raise _ConvertError(f"int() failed for {type(value).__qualname__}:
{err}") from None
+ if not (_INT64_MIN <= ivalue <= _INT64_MAX):
+ raise _ConvertError(
+ f"integer {ivalue} out of int64 range [{_INT64_MIN},
{_INT64_MAX}]"
+ )
+ matched[0] = True
+ return ivalue
+ return None
+
+
+cdef CAny _tc_convert_object_marshaled(_TypeConverter conv, object value)
except *:
cdef int32_t actual_type_index = kTVMFFINone
cdef CAny packed
cdef CAny converted
cdef Function fn_convert
cdef object err = None
- # Step 1: existing FFI objects that already satisfy the target schema are
passthrough.
- assert conv.type_index >= kTVMFFIStaticObjectBegin
- if isinstance(value, CObject):
- actual_type_index = TVMFFIObjectGetTypeIndex((<CObject>value).chandle)
- if _tc_type_index_is_instance(actual_type_index, conv.type_index):
- return CAny(value)
- changed[0] = True
-
- # Step 2: pack, and convert to the target type.
packed = CAnyChecked(value, conv.err_hint, value)
fn_convert = conv.fn_convert
try:
@@ -535,6 +600,73 @@ cdef CAny _tc_convert_object(_TypeConverter conv, object
value, bint* changed) e
raise _ConvertError(f"expected {conv.err_hint}, got
{_tc_describe_value_type(value)}") from err
+cdef CAny _tc_convert_object(_TypeConverter conv, object value, bint* changed)
except *:
+ """Convert *value* to an object compatible with ``conv.type_index``."""
+ # TODO: SmallStr and SmallBytes => ObjectRef conversion is not supported
yet
+ cdef int32_t actual_type_index = kTVMFFINone
+
+ # Step 1: existing FFI objects that already satisfy the target schema are
passthrough.
+ assert conv.type_index >= kTVMFFIStaticObjectBegin
+ if isinstance(value, CObject):
+ actual_type_index = TVMFFIObjectGetTypeIndex((<CObject>value).chandle)
+ if _tc_type_index_is_instance(actual_type_index, conv.type_index):
+ return CAny(value)
+ changed[0] = True
+
+ # Step 2: pack, and convert to the target type.
+ return _tc_convert_object_marshaled(conv, value)
+
+
+cdef CAny _tc_convert_int_enum(_TypeConverter conv, object value, bint*
changed) except *:
+ """Convert *value* to an IntEnum-compatible object."""
+ cdef int32_t actual_type_index = kTVMFFINone
+ cdef object target_cls
+ cdef object ivalue
+ cdef object variant
+ cdef bint is_int_like = False
+
+ assert conv.type_index >= kTVMFFIStaticObjectBegin
+ if isinstance(value, CObject):
+ actual_type_index = TVMFFIObjectGetTypeIndex((<CObject>value).chandle)
+ if _tc_type_index_is_instance(actual_type_index, conv.type_index):
+ return CAny(value)
+
+ target_cls = _tc_get_registered_cls(conv.type_index)
+ if target_cls is not None:
+ ivalue = _tc_normalize_int_enum_payload(value, &is_int_like)
+ if is_int_like:
+ changed[0] = True
+ variant = _tc_find_payload_enum_variant(conv.type_index,
target_cls, ivalue)
+ if variant is not None:
+ return CAny(variant)
+
+ changed[0] = True
+ return _tc_convert_object_marshaled(conv, value)
+
+
+cdef CAny _tc_convert_str_enum(_TypeConverter conv, object value, bint*
changed) except *:
+ """Convert *value* to a StrEnum-compatible object."""
+ cdef int32_t actual_type_index = kTVMFFINone
+ cdef object target_cls
+ cdef object variant
+
+ assert conv.type_index >= kTVMFFIStaticObjectBegin
+ if isinstance(value, CObject):
+ actual_type_index = TVMFFIObjectGetTypeIndex((<CObject>value).chandle)
+ if _tc_type_index_is_instance(actual_type_index, conv.type_index):
+ return CAny(value)
+
+ target_cls = _tc_get_registered_cls(conv.type_index)
+ if target_cls is not None and isinstance(value, str):
+ changed[0] = True
+ variant = _tc_find_payload_enum_variant(conv.type_index, target_cls,
value)
+ if variant is not None:
+ return CAny(variant)
+
+ changed[0] = True
+ return _tc_convert_object_marshaled(conv, value)
+
+
cdef inline bint _tc_type_index_is_instance(int32_t actual_tindex, int32_t
target_tindex) noexcept:
"""Check if *actual_tindex* is *target_tindex* or a subclass thereof."""
# TODO: this can be optimized by looking up `TYPE_INDEX_TO_INFO`
@@ -703,14 +835,24 @@ def _build_converter(schema):
conv.err_hint = "Object"
return conv
if origin_tindex >= kTVMFFIStaticObjectBegin:
- conv.dispatch = _tc_convert_object
+ if _tc_registered_cls_has_base(origin_tindex,
"tvm_ffi.dataclasses.enum", "IntEnum"):
+ conv.dispatch = _tc_convert_int_enum
+ elif _tc_registered_cls_has_base(origin_tindex,
"tvm_ffi.dataclasses.enum", "StrEnum"):
+ conv.dispatch = _tc_convert_str_enum
+ else:
+ conv.dispatch = _tc_convert_object
conv.type_index = origin_tindex
conv.err_hint = origin
return conv
tindex = _object_type_key_to_index(origin)
if tindex is not None:
- conv.dispatch = _tc_convert_object
+ if _tc_registered_cls_has_base(tindex, "tvm_ffi.dataclasses.enum",
"IntEnum"):
+ conv.dispatch = _tc_convert_int_enum
+ elif _tc_registered_cls_has_base(tindex, "tvm_ffi.dataclasses.enum",
"StrEnum"):
+ conv.dispatch = _tc_convert_str_enum
+ else:
+ conv.dispatch = _tc_convert_object
conv.type_index = tindex
conv.err_hint = origin
return conv
diff --git a/python/tvm_ffi/dataclasses/__init__.py
b/python/tvm_ffi/dataclasses/__init__.py
index 43e5cc1..0fb322a 100644
--- a/python/tvm_ffi/dataclasses/__init__.py
+++ b/python/tvm_ffi/dataclasses/__init__.py
@@ -20,7 +20,7 @@ from tvm_ffi.core import MISSING, Object
from .c_class import c_class
from .common import asdict, astuple, fields, is_dataclass, replace
-from .enum import Enum, EnumAttrMap, auto, entry
+from .enum import Enum, EnumAttrMap, IntEnum, StrEnum, auto, entry
from .field import KW_ONLY, Field, field
from .py_class import py_class
@@ -30,7 +30,9 @@ __all__ = [
"Enum",
"EnumAttrMap",
"Field",
+ "IntEnum",
"Object",
+ "StrEnum",
"asdict",
"astuple",
"auto",
diff --git a/python/tvm_ffi/dataclasses/enum.py
b/python/tvm_ffi/dataclasses/enum.py
index 610c6c3..3182b7c 100644
--- a/python/tvm_ffi/dataclasses/enum.py
+++ b/python/tvm_ffi/dataclasses/enum.py
@@ -55,8 +55,11 @@ from .py_class import py_class
__all__ = [
"ENUM_ATTRS_ATTR",
"ENUM_ENTRIES_ATTR",
+ "ENUM_VALUE_ENTRIES_ATTR",
"Enum",
"EnumAttrMap",
+ "IntEnum",
+ "StrEnum",
"auto",
"entry",
]
@@ -67,6 +70,13 @@ ENUM_ENTRIES_ATTR = "__ffi_enum_entries__"
#: TypeAttr column storing ``Dict[str, List[Any]]`` of per-variant attrs.
ENUM_ATTRS_ATTR = "__ffi_enum_attrs__"
+#: TypeAttr column storing ``Dict[Any, Enum]`` (payload → singleton) —
+#: parallel to :data:`ENUM_ENTRIES_ATTR` but keyed by the user-visible
+#: payload (``IntEnum.value`` / ``StrEnum.value``). Populated by every
+#: creator of a payload-carrying variant; consumed by FFI converters for
+#: O(1) payload → variant resolution.
+ENUM_VALUE_ENTRIES_ATTR = "__ffi_enum_value_entries__"
+
# ---------------------------------------------------------------------------
# entry() sentinel
@@ -77,7 +87,7 @@ class _EnumEntry:
"""Sentinel produced by :func:`entry`; consumed by
``Enum.__init_subclass__``.
Holds the positional and keyword arguments forwarded to the subclass's
- ``__init__`` when the variant is materialized. ``value`` and ``name``
+ ``__init__`` when the variant is materialized. ``_value`` and ``_name``
are auto-assigned (dense ordinal and class-body name) and must not
appear in the captured arguments.
"""
@@ -102,7 +112,7 @@ def entry(*args: Any, **kwargs: Any) -> Any:
for these sentinels and, for each one, constructs a singleton variant
by forwarding the captured positional and keyword arguments to the
subclass's ``__init__``, together with an auto-assigned
- :attr:`~Enum.value` (dense ordinal) and :attr:`~Enum.name`
+ :attr:`~Enum._value` (dense ordinal) and :attr:`~Enum._name`
(class-body name).
Prefer :func:`auto` when a variant has no declared fields beyond the
@@ -113,7 +123,7 @@ def entry(*args: Any, **kwargs: Any) -> Any:
``refl::ObjectDef``), only keyword arguments are supported — field
values are assigned via reflected setters keyed by name. Passing
positional arguments in that case raises :class:`TypeError`.
- ``entry(value=...)`` and ``entry(name=...)`` always raise
+ ``entry(_value=...)`` and ``entry(_name=...)`` always raise
:class:`TypeError` because those fields are auto-assigned.
Examples
@@ -149,7 +159,7 @@ def auto() -> Any:
Semantically equivalent to :func:`entry` called with no arguments but
reads more clearly for the common case where a variant differs from
its siblings only by name and ordinal. The resulting singleton has
- only the auto-assigned :attr:`~Enum.value` and :attr:`~Enum.name`.
+ only the auto-assigned :attr:`~Enum._value` and :attr:`~Enum._name`.
``auto()`` registers a *new* Python-side variant; it is not the right
tool for binding to a pre-existing C++-registered entry (use a bare
@@ -165,8 +175,8 @@ def auto() -> Any:
retry = auto()
- assert Status.ok.value == 0
- assert Status.err.name == "err"
+ assert Status.ok._value == 0
+ assert Status.err._name == "err"
Returns
-------
@@ -187,9 +197,8 @@ def auto() -> Any:
class _ClassProperty:
"""Read-only descriptor whose getter receives the owning class.
- Used for ``by_name``/``by_value``/``attr_dict`` so they work as class-level
- attribute access (e.g., ``Op.attr_dict["has_side_effects"]``) without
- needing a metaclass.
+ Used for ``attr_dict`` so it works as a class-level attribute access
+ (e.g., ``Op.attr_dict["has_side_effects"]``) without needing a metaclass.
"""
__slots__ = ("_fget",)
@@ -216,10 +225,10 @@ class _EnumMeta(type(Object)):
"""
def __iter__(cls) -> Iterator[Any]:
- return iter(cls.by_value) # ty: ignore[unresolved-attribute]
+ return iter(_ordered_entries(cls)) # ty: ignore[unresolved-attribute]
def __len__(cls) -> int:
- return len(cls.by_value) # ty: ignore[unresolved-attribute]
+ return len(_ordered_entries(cls)) # ty: ignore[unresolved-attribute]
@dataclass_transform(
@@ -246,9 +255,9 @@ class Enum(Object, metaclass=_EnumMeta):
Attributes
----------
- value : int
+ _value : int
Dense ordinal assigned at registration (0-indexed per class).
- name : str
+ _name : str
The variant's string name key (e.g., ``"Add"`` for ``Op.Add``).
Closed Python enum
@@ -302,36 +311,38 @@ class Enum(Object, metaclass=_EnumMeta):
4. ``name: ClassVar[Cls]`` — in cross-language mode, binds to an
existing C++-registered variant (error if unknown); otherwise
registers a new Python variant with only the auto-assigned
- :attr:`value` and :attr:`name` (equivalent to ``name = auto()``).
+ :attr:`_value` and :attr:`_name` (equivalent to ``name = auto()``).
- Integer literals (``ok = 0``) are rejected: :attr:`value` is
+ Integer literals (``ok = 0``) are rejected: :attr:`_value` is
auto-assigned, so a user-supplied ordinal would either silently
- duplicate or conflict. ``entry(value=...)`` and ``entry(name=...)``
+ duplicate or conflict. ``entry(_value=...)`` and ``entry(_name=...)``
raise :class:`TypeError` at class-body time.
Differences from ``enum.Enum``
------------------------------
- * **Same**: :attr:`name`, :attr:`value`, iteration, identity
+ * **Same**: hidden identity fields :attr:`_name`, :attr:`_value`,
+ iteration, identity
comparison; closed-set behavior when ``type_key`` is fresh.
* **Extended**: ``entry(**kwargs)`` replaces the tuple-RHS idiom;
``dataclass_transform`` gives native type-checker support; open
registry when ``type_key`` is shared across languages.
- * **Different**: :attr:`value` is always the ordinal (no
- user-supplied integer values); :meth:`def_attr` adds extensible
- attributes outside the class schema.
+ * **Different**: :attr:`_value` is always the ordinal; :meth:`def_attr`
+ adds extensible attributes outside the class schema. Use
+ :class:`IntEnum` / :class:`StrEnum` when you need a user-visible
+ ``value`` field.
* **Not provided**: ``Flag`` / ``IntFlag``, member aliasing,
``_missing_`` hook.
- Subclasses inherit :meth:`get`, :meth:`entries`, :meth:`def_attr`,
- and the ``by_name`` / ``by_value`` / ``attr_dict`` class-level
- views.
+ Subclasses inherit :meth:`get`, :meth:`all_entries`, :meth:`def_attr`,
+ and the ``attr_dict`` class-level view.
"""
__slots__ = ()
- value: int
- name: str
+ _value: int
+ _name: str
+ __ffi_enum_payload_value_type__: object = None
def __init_subclass__(
cls,
@@ -365,25 +376,9 @@ class Enum(Object, metaclass=_EnumMeta):
raise KeyError(f"{cls.__name__} has no variant named {name!r}")
@classmethod
- def entries(cls) -> Iterator[Enum]:
- """Iterate over all variants, in ordinal (value) order."""
- return iter(cls.by_value)
-
- @_ClassProperty
- def by_name(cls: type) -> Any:
- """Live ``Dict[str, Enum]`` mapping each variant's name to the variant
singleton."""
- return _entries_dict(cls) or Dict({})
-
- @_ClassProperty
- def by_value(cls: type) -> list[Any]:
- """Return the variants as a list indexed by ordinal
(``variant.value``)."""
- entries = _entries_dict(cls)
- if entries is None:
- return []
- ordered: list[Any] = [None] * len(entries)
- for inst in entries.values():
- ordered[int(inst.value)] = inst
- return ordered
+ def all_entries(cls) -> Iterator[Enum]:
+ """Iterate over all variants, in ordinal (``_value``) order."""
+ return iter(_ordered_entries(cls))
@_ClassProperty
def attr_dict(cls: type) -> Any:
@@ -474,7 +469,7 @@ class EnumAttrMap:
Returned by :meth:`Enum.def_attr`. Writes go to a ``List[Any]``
column keyed by extensible-attribute name inside the per-class
``__ffi_enum_attrs__`` dict. The list is indexed by each variant's
- ordinal (``variant.value``) and padded with ``None`` as new variants
+ ordinal (``variant._value``) and padded with ``None`` as new variants
are registered. The column is shared across every consumer —
including C++ code writing via ``EnumDef::set_attr`` — and the data
is not a field on the variant object. See :meth:`Enum.def_attr` for
@@ -502,7 +497,7 @@ class EnumAttrMap:
f"{self._enum_cls.__name__}.def_attr({self._name!r}) expects a
"
f"{self._enum_cls.__name__} variant, got
{type(variant).__name__}"
)
- return int(variant.value) # type: ignore[attr-defined]
+ return int(variant._value) # type: ignore[attr-defined]
def _column(self, *, create: bool) -> Any | None:
"""Return the ``List[Any]`` column for this attribute; create if
missing.
@@ -543,7 +538,7 @@ class EnumAttrMap:
return v
if self._default is core.MISSING:
raise KeyError(
- f"{self._enum_cls.__name__}.{variant.name} has no " # type:
ignore[attr-defined]
+ f"{self._enum_cls.__name__}.{variant._name} has no " # type:
ignore[attr-defined]
f"extensible attribute {self._name!r} set"
)
return self._default
@@ -573,6 +568,56 @@ class EnumAttrMap:
return self._name
+class IntEnum(Enum):
+ """Enum variant base with a user-visible ``value: int`` field."""
+
+ __slots__ = ()
+ value: int
+
+ def __init_subclass__(
+ cls,
+ *,
+ type_key: str | None = None,
+ frozen: bool = True,
+ init: bool = True,
+ repr: bool = True,
+ **kwargs: Any,
+ ) -> None:
+ _prepare_payload_enum_subclass(cls, value_type=int,
base_name="IntEnum")
+ super().__init_subclass__(
+ type_key=type_key,
+ frozen=frozen,
+ init=init,
+ repr=repr,
+ **kwargs,
+ )
+
+
+class StrEnum(Enum):
+ """Enum variant base with a user-visible ``value: str`` field."""
+
+ __slots__ = ()
+ value: str
+
+ def __init_subclass__(
+ cls,
+ *,
+ type_key: str | None = None,
+ frozen: bool = True,
+ init: bool = True,
+ repr: bool = True,
+ **kwargs: Any,
+ ) -> None:
+ _prepare_payload_enum_subclass(cls, value_type=str,
base_name="StrEnum")
+ super().__init_subclass__(
+ type_key=type_key,
+ frozen=frozen,
+ init=init,
+ repr=repr,
+ **kwargs,
+ )
+
+
# ---------------------------------------------------------------------------
# TypeAttr accessors
# ---------------------------------------------------------------------------
@@ -592,6 +637,25 @@ def _attrs_dict(cls: type) -> Any:
return core._lookup_type_attr(type_info.type_index, ENUM_ATTRS_ATTR)
+def _value_entries_dict(cls: type) -> Any:
+ """Return the live ``Dict[Any, Enum]`` payload-to-variant index, or
None."""
+ type_info = getattr(cls, "__tvm_ffi_type_info__", None)
+ if type_info is None:
+ return None
+ return core._lookup_type_attr(type_info.type_index,
ENUM_VALUE_ENTRIES_ATTR)
+
+
+def _ordered_entries(cls: type) -> list[Any]:
+ """Return all variants ordered by ordinal."""
+ entries = _entries_dict(cls)
+ if entries is None:
+ return []
+ ordered: list[Any] = [None] * len(entries)
+ for inst in entries.values():
+ ordered[int(inst._value)] = inst
+ return ordered
+
+
def _ensure_entries_dict(cls: type) -> Any:
"""Return the live ``__ffi_enum_entries__`` dict, registering it if
absent."""
type_info = cls.__tvm_ffi_type_info__ # ty: ignore[unresolved-attribute]
@@ -615,6 +679,17 @@ def _ensure_attrs_dict(cls: type) -> Any:
return core._lookup_type_attr(type_info.type_index, ENUM_ATTRS_ATTR)
+def _ensure_value_entries_dict(cls: type) -> Any:
+ """Return the live ``__ffi_enum_value_entries__`` dict, registering it if
absent."""
+ type_info = cls.__tvm_ffi_type_info__ # ty: ignore[unresolved-attribute]
+ entries = core._lookup_type_attr(type_info.type_index,
ENUM_VALUE_ENTRIES_ATTR)
+ if entries is not None:
+ return entries
+ entries = Dict({})
+ core._register_type_attr(type_info.type_index, ENUM_VALUE_ENTRIES_ATTR,
entries)
+ return core._lookup_type_attr(type_info.type_index,
ENUM_VALUE_ENTRIES_ATTR)
+
+
# ---------------------------------------------------------------------------
# Class-body scanning + entry materialisation
# ---------------------------------------------------------------------------
@@ -642,11 +717,16 @@ def _collect_entry_declarations(
dict_keys = set(cls.__dict__.keys())
binders: list[str] = []
+ ordinary_fields: set[str] = set()
+ payload_value_type = getattr(cls, "__ffi_enum_payload_value_type__", None)
for name, ann in annotations.items():
if name.startswith("_"):
continue
- if _is_class_var(ann) and name not in dict_keys:
- binders.append(name)
+ if _is_class_var(ann):
+ if name not in dict_keys:
+ binders.append(name)
+ else:
+ ordinary_fields.add(name)
python_entries: dict[str, _EnumEntry] = {}
for name, value in list(cls.__dict__.items()):
@@ -658,6 +738,14 @@ def _collect_entry_declarations(
delattr(cls, name)
except AttributeError:
pass
+ elif payload_value_type is not None and name not in ordinary_fields:
+ if isinstance(value, (staticmethod, classmethod, property)) or
callable(value):
+ continue
+ python_entries[name] = _EnumEntry(value=value)
+ try:
+ delattr(cls, name)
+ except AttributeError:
+ pass
return binders, python_entries
@@ -685,11 +773,14 @@ def _resolve_entries(
fresh Python-side entries whose ordinals extend past the C++ entries.
"""
entries = _ensure_entries_dict(cls)
+ payload_value_type = getattr(cls, "__ffi_enum_payload_value_type__", None)
for name in binders:
if name in entries:
# Already materialised — either C++-registered or previously bound.
- setattr(cls, name, entries[name])
+ instance = entries[name]
+ setattr(cls, name, instance)
+ _index_payload(cls, instance, payload_value_type)
continue
if cxx_backed:
raise _cxx_backed_unknown_binder_error(cls, name, type_key,
entries)
@@ -697,27 +788,43 @@ def _resolve_entries(
instance = _instantiate(cls, args=(), kwargs={}, ordinal=ordinal,
name=name)
entries[name] = instance
setattr(cls, name, instance)
+ _index_payload(cls, instance, payload_value_type)
for name, e in python_entries.items():
if name in entries:
raise RuntimeError(
f"Duplicate enum entry {name!r} for {cls.__name__}: already "
- f"registered as ordinal {int(entries[name].value)}."
+ f"registered as ordinal {int(entries[name]._value)}."
)
- if "value" in e.kwargs or "name" in e.kwargs:
+ if "_value" in e.kwargs or "_name" in e.kwargs:
raise TypeError(
- f"{cls.__name__}.{name}: `value` and `name` are auto-assigned "
+ f"{cls.__name__}.{name}: `_value` and `_name` are
auto-assigned "
f"and must not appear in entry(...) arguments."
)
ordinal = len(entries)
- if cxx_backed:
- instance = _instantiate_cxx_backed(
- cls, args=e.args, kwargs=e.kwargs, ordinal=ordinal, name=name
- )
- else:
- instance = _instantiate(cls, args=e.args, kwargs=e.kwargs,
ordinal=ordinal, name=name)
+ instance = _instantiate_entry(
+ cls, entry=e, ordinal=ordinal, name=name, cxx_backed=cxx_backed
+ )
entries[name] = instance
setattr(cls, name, instance)
+ _index_payload(cls, instance, payload_value_type)
+
+
+def _index_payload(cls: type, instance: Any, payload_value_type: type | None)
-> None:
+ """Record ``(instance.value → instance)`` in the value-entries column.
+
+ No-op for non-payload enums. For payload enums, the first writer of a
+ given payload wins — matches the "first-match" semantics of the linear
+ scan that FFI converters fall back to when this column is incomplete.
+ """
+ if payload_value_type is None:
+ return
+ payload = getattr(instance, "value", None)
+ if payload is None:
+ return
+ value_entries = _ensure_value_entries_dict(cls)
+ if payload not in value_entries:
+ value_entries[payload] = instance
def _cxx_backed_unknown_binder_error(
@@ -758,13 +865,32 @@ def _instantiate(
ordinal: int,
name: str,
) -> Any:
- """Construct a subclass instance with auto-assigned ``value``/``name``."""
+ """Construct a subclass instance with auto-assigned
``_value``/``_name``."""
merged = dict(kwargs)
- merged["value"] = ordinal
- merged["name"] = name
+ merged["_value"] = ordinal
+ merged["_name"] = name
return cls(*args, **merged)
+def _instantiate_entry(
+ cls: type,
+ *,
+ entry: _EnumEntry,
+ ordinal: int,
+ name: str,
+ cxx_backed: bool,
+) -> Any:
+ """Instantiate an enum entry and normalize construction errors."""
+ try:
+ if cxx_backed:
+ return _instantiate_cxx_backed(
+ cls, args=entry.args, kwargs=entry.kwargs, ordinal=ordinal,
name=name
+ )
+ return _instantiate(cls, args=entry.args, kwargs=entry.kwargs,
ordinal=ordinal, name=name)
+ except Exception as err:
+ raise TypeError(f"{cls.__name__}.{name}: invalid enum entry: {err}")
from None
+
+
def _instantiate_cxx_backed(
cls: type,
*,
@@ -777,7 +903,7 @@ def _instantiate_cxx_backed(
C++-backed enums whose underlying type is registered with
``refl::init(false)`` (e.g. any subclass of ``EnumObj`` in C++) have no
- ``__ffi_init__``, so the usual ``cls(value=..., name=...)`` path is not
+ ``__ffi_init__``, so the usual ``cls(_value=..., _name=...)`` path is not
available. Mirror ``reflection::EnumDef`` by allocating a blank instance
via ``__ffi_new__`` and populating fields through the frozen-setter
escape hatch exposed on the reflected property descriptors.
@@ -797,15 +923,15 @@ def _instantiate_cxx_backed(
f"blank instances cannot be created from Python."
)
instance = ffi_new()
- for key in ("value", "name", *kwargs.keys()):
+ for key in ("_value", "_name", *kwargs.keys()):
descriptor = getattr(cls, key, None)
if descriptor is None or not hasattr(descriptor, "set"):
raise TypeError(
f"{cls.__name__}.{name}: cannot set field {key!r} on a "
f"C++-backed enum — no reflected setter is available."
)
- getattr(cls, "value").set(instance, ordinal)
- getattr(cls, "name").set(instance, name)
+ getattr(cls, "_value").set(instance, ordinal)
+ getattr(cls, "_name").set(instance, name)
for k, v in kwargs.items():
getattr(cls, k).set(instance, v)
return instance
@@ -833,3 +959,43 @@ def _is_class_var(annotation: Any) -> bool:
stripped = annotation.replace(" ", "")
return stripped.startswith("ClassVar") or
stripped.startswith("typing.ClassVar")
return False
+
+
+def _annotation_matches_expected(annotation: Any, expected_type: type) -> bool:
+ """Return True if *annotation* matches the required payload-field type."""
+ if annotation is expected_type:
+ return True
+ if isinstance(annotation, str):
+ stripped = annotation.replace(" ", "")
+ expected = expected_type.__name__
+ return stripped in {expected, f"builtins.{expected}"}
+ return False
+
+
+def _prepare_payload_enum_subclass(
+ cls: type[Enum], *, value_type: type[Any], base_name: str
+) -> None:
+ """Inject and validate the user-visible ``value`` field for payload
enums."""
+ if "value" in cls.__dict__:
+ raise TypeError(
+ f"{base_name} reserves `value` as a declared field; use "
+ "`entry(value=...)` or `<variant> = <payload>` to assign each "
+ "variant's payload."
+ )
+
+ annotations = _own_annotations(cls)
+ existing = annotations.get("value")
+ if existing is not None:
+ if _is_class_var(existing):
+ raise TypeError(f"{base_name} reserves `value` as a declared
field, not a ClassVar.")
+ if not _annotation_matches_expected(existing, value_type):
+ raise TypeError(
+ f"{base_name} requires `value: {value_type.__name__}` on
subclasses, "
+ f"got {existing!r}."
+ )
+ return
+
+ updated = _own_annotations(cls)
+ updated["value"] = value_type
+ cls.__annotations__ = updated
+ cls.__ffi_enum_payload_value_type__ = value_type
diff --git a/python/tvm_ffi/dataclasses/py_class.py
b/python/tvm_ffi/dataclasses/py_class.py
index e10e502..a3e249f 100644
--- a/python/tvm_ffi/dataclasses/py_class.py
+++ b/python/tvm_ffi/dataclasses/py_class.py
@@ -95,6 +95,8 @@ def _register_type_without_fields(cls: type, type_key: str |
None) -> Any:
parent_info: core.TypeInfo | None = None
for base in cls.__bases__:
parent_info = core._type_cls_to_type_info(base)
+ if parent_info is None:
+ parent_info = getattr(base, "__tvm_ffi_type_info__", None)
if parent_info is not None:
break
if parent_info is None:
diff --git a/src/ffi/extra/dataclass.cc b/src/ffi/extra/dataclass.cc
index c76117b..5032d13 100644
--- a/src/ffi/extra/dataclass.cc
+++ b/src/ffi/extra/dataclass.cc
@@ -881,7 +881,7 @@ class ReprPrinter : public ObjectGraphDFS<ReprPrinter,
ReprFrame, std::string> {
*out = result;
return true;
}
- // Default repr for EnumObj subclasses: ``<type_key>.<name>``. Reached
only
+ // Default repr for EnumObj subclasses: ``<type_key>.<_name>``. Reached
only
// when no user-registered ``__ffi_repr__`` hook has claimed this type, so
// explicit repr overrides on specific enum subclasses still take
precedence.
if (obj->IsInstance<EnumObj>()) {
@@ -889,7 +889,7 @@ class ReprPrinter : public ObjectGraphDFS<ReprPrinter,
ReprFrame, std::string> {
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(ti);
std::string result(type_info->type_key.data, type_info->type_key.size);
result += '.';
- result.append(enum_obj->name.data(), enum_obj->name.size());
+ result.append(enum_obj->_name.data(), enum_obj->_name.size());
if (show_addr_) result += "@" + AddressStr(obj);
state_[obj] = State::kDone;
repr_cache_[obj] = result;
diff --git a/src/ffi/object.cc b/src/ffi/object.cc
index 30d61a4..dcfb189 100644
--- a/src/ffi/object.cc
+++ b/src/ffi/object.cc
@@ -650,11 +650,12 @@ TVM_FFI_STATIC_INIT_BLOCK() {
refl::type_attr::kConvert,
&refl::details::FFIConvertFromAnyViewToObjectRef<ffi::Dict<ffi::Any,
ffi::Any>>);
refl::ObjectDef<ffi::EnumObj>(refl::init(false))
- .def_ro("value", &ffi::EnumObj::value, "Ordinal assigned at
registration.",
+ .def_ro("_value", &ffi::EnumObj::_value, "Ordinal assigned at
registration.",
refl::AttachFieldFlag::SEqHashIgnore())
- .def_ro("name", &ffi::EnumObj::name, "Instance name.");
+ .def_ro("_name", &ffi::EnumObj::_name, "Instance name.");
refl::EnsureTypeAttrColumn(refl::type_attr::kEnumEntries);
refl::EnsureTypeAttrColumn(refl::type_attr::kEnumAttrs);
+ refl::EnsureTypeAttrColumn(refl::type_attr::kEnumValueEntries);
refl::GlobalDef()
.def_method("ffi.GetRegisteredTypeKeys",
[]() -> ffi::Array<ffi::String> {
diff --git a/tests/python/test_dataclass_enum.py
b/tests/python/test_dataclass_enum.py
index d5a7925..05e95de 100644
--- a/tests/python/test_dataclass_enum.py
+++ b/tests/python/test_dataclass_enum.py
@@ -25,10 +25,11 @@ import pytest
import tvm_ffi
from tvm_ffi import core
from tvm_ffi.core import Object
-from tvm_ffi.dataclasses import Enum, EnumAttrMap, auto, entry
+from tvm_ffi.dataclasses import Enum, EnumAttrMap, IntEnum, StrEnum, auto,
entry
from tvm_ffi.dataclasses.enum import (
ENUM_ATTRS_ATTR,
ENUM_ENTRIES_ATTR,
+ ENUM_VALUE_ENTRIES_ATTR,
_EnumEntry,
)
@@ -66,33 +67,33 @@ def test_attribute_carrying_basic() -> None:
assert Activation.silu.is_monotonic is True # ty:
ignore[unresolved-attribute]
# Ordinals auto-assigned in declaration order.
- assert Activation.relu.value == 0
- assert Activation.gelu.value == 1
- assert Activation.silu.value == 2
- assert Activation.relu.name == "relu"
- assert Activation.gelu.name == "gelu"
+ assert Activation.relu._value == 0
+ assert Activation.gelu._value == 1
+ assert Activation.silu._value == 2
+ assert Activation.relu._name == "relu"
+ assert Activation.gelu._name == "gelu"
assert Activation.get("relu").same_as(Activation.relu)
assert Activation.get("gelu").same_as(Activation.gelu)
assert Activation.get("silu").same_as(Activation.silu)
-def test_entry_rejects_value_kwarg() -> None:
- """``entry(value=...)`` conflicts with the auto-assigned ordinal."""
+def test_entry_rejects__value_kwarg() -> None:
+ """``entry(_value=...)`` conflicts with the auto-assigned ordinal."""
with pytest.raises(TypeError):
class _Bad(Enum, type_key=_unique_key("BadValue")):
flag: bool
- a: ClassVar[_Bad] = entry(flag=True, value=5)
+ a: ClassVar[_Bad] = entry(flag=True, _value=5)
-def test_entry_rejects_name_kwarg() -> None:
- """``entry(name=...)`` conflicts with the auto-assigned declaration key."""
+def test_entry_rejects__name_kwarg() -> None:
+ """``entry(_name=...)`` conflicts with the auto-assigned declaration
key."""
with pytest.raises(TypeError):
class _Bad(Enum, type_key=_unique_key("BadName")):
flag: bool
- a: ClassVar[_Bad] = entry(flag=True, name="other")
+ a: ClassVar[_Bad] = entry(flag=True, _name="other")
def test_get_missing_raises() -> None:
@@ -104,14 +105,167 @@ def test_get_missing_raises() -> None:
Missing.get("no-such-entry")
-def test_entries_iteration_order() -> None:
+def test_public_value_and_name_are_hidden() -> None:
+ class Hidden(Enum, type_key=_unique_key("Hidden")):
+ yes = auto()
+
+ assert Hidden.yes._value == 0
+ assert Hidden.yes._name == "yes"
+ with pytest.raises(AttributeError):
+ _ = Hidden.yes.value
+ with pytest.raises(AttributeError):
+ _ = Hidden.yes.name
+
+
+def test_int_enum_payload_value() -> None:
+ class Colors(IntEnum, type_key=_unique_key("IntEnum")):
+ red = entry(value=10)
+ blue = entry(value=20)
+
+ assert Colors.red.value == 10
+ assert Colors.blue.value == 20
+ assert Colors.red._value == 0
+ assert Colors.blue._value == 1
+ assert Colors.red._name == "red"
+ assert Colors.get("red").same_as(Colors.red)
+ assert list(Colors.all_entries()) == [Colors.red, Colors.blue]
+
+
+def test_int_enum_payload_literal_sugar() -> None:
+ class Priority(IntEnum, type_key=_unique_key("IntEnumLiteral")):
+ low = 10
+ high = 20
+
+ assert isinstance(Priority.low, Priority)
+ assert isinstance(Priority.high, Priority)
+ assert not isinstance(Priority.low, int)
+ assert Priority.low.value == 10
+ assert Priority.high.value == 20
+ assert Priority.low._name == "low"
+ assert Priority.high._name == "high"
+ assert list(Priority.all_entries()) == [Priority.low, Priority.high]
+
+
+def test_str_enum_payload_value() -> None:
+ class Tokens(StrEnum, type_key=_unique_key("StrEnum")):
+ add = entry(value="+")
+ mul = entry(value="*")
+
+ assert Tokens.add.value == "+"
+ assert Tokens.mul.value == "*"
+ assert Tokens.add._value == 0
+ assert Tokens.mul._value == 1
+ assert Tokens.add._name == "add"
+ assert Tokens.get("mul").same_as(Tokens.mul)
+
+
+def test_str_enum_payload_literal_sugar() -> None:
+ class Opcode(StrEnum, type_key=_unique_key("StrEnumLiteral")):
+ add = "+"
+ mul = "*"
+
+ assert isinstance(Opcode.add, Opcode)
+ assert isinstance(Opcode.mul, Opcode)
+ assert not isinstance(Opcode.add, str)
+ assert Opcode.add.value == "+"
+ assert Opcode.mul.value == "*"
+ assert Opcode.add._name == "add"
+ assert Opcode.mul._name == "mul"
+ assert list(Opcode.all_entries()) == [Opcode.add, Opcode.mul]
+
+
+def test_payload_literal_sugar_preserves_annotated_field_defaults() -> None:
+ class Opcode(StrEnum, type_key=_unique_key("StrEnumLiteralDefault")):
+ arity: int = 0
+ add = "+"
+ mul = "*"
+
+ assert isinstance(Opcode.add, Opcode)
+ assert isinstance(Opcode.mul, Opcode)
+ assert Opcode.add.arity == 0 # ty: ignore[unresolved-attribute]
+ assert Opcode.mul.arity == 0 # ty: ignore[unresolved-attribute]
+ assert Opcode.add.value == "+"
+ assert Opcode.mul.value == "*"
+
+
+def test_int_enum_payload_literal_sugar_rejects_invalid_payload() -> None:
+ with pytest.raises(TypeError):
+
+ class _Bad(IntEnum, type_key=_unique_key("IntEnumLiteralBad")):
+ nope = "x"
+
+
+def test_str_enum_payload_literal_sugar_rejects_invalid_payload() -> None:
+ with pytest.raises(TypeError):
+
+ class _Bad(StrEnum, type_key=_unique_key("StrEnumLiteralBad")):
+ nope = 42
+
+
+def test_int_enum_populates_value_entries_typeattr() -> None:
+ class Priority(IntEnum, type_key=_unique_key("IntEnumValueEntries")):
+ low = entry(value=10)
+ high = 20 # literal sugar
+
+ type_info = Priority.__tvm_ffi_type_info__ # ty:
ignore[unresolved-attribute]
+ value_entries = core._lookup_type_attr(type_info.type_index,
ENUM_VALUE_ENTRIES_ATTR)
+ assert value_entries is not None
+ assert value_entries[10].same_as(Priority.low)
+ assert value_entries[20].same_as(Priority.high)
+
+
+def test_str_enum_populates_value_entries_typeattr() -> None:
+ class Opcode(StrEnum, type_key=_unique_key("StrEnumValueEntries")):
+ add = entry(value="+")
+ mul = "*" # literal sugar
+
+ type_info = Opcode.__tvm_ffi_type_info__ # ty:
ignore[unresolved-attribute]
+ value_entries = core._lookup_type_attr(type_info.type_index,
ENUM_VALUE_ENTRIES_ATTR)
+ assert value_entries is not None
+ assert value_entries["+"].same_as(Opcode.add)
+ assert value_entries["*"].same_as(Opcode.mul)
+
+
+def test_plain_enum_does_not_create_value_entries_typeattr() -> None:
+ class Status(Enum, type_key=_unique_key("PlainEnumNoValue")):
+ ok: ClassVar[Status] = auto()
+
+ type_info = Status.__tvm_ffi_type_info__ # ty:
ignore[unresolved-attribute]
+ value_entries = core._lookup_type_attr(type_info.type_index,
ENUM_VALUE_ENTRIES_ATTR)
+ assert value_entries is None
+
+
+def test_payload_literal_sugar_preserves_methods_and_properties() -> None:
+ class Priority(IntEnum, type_key=_unique_key("IntEnumWithMethods")):
+ low = 1
+ high = 10
+
+ def is_high(self) -> bool:
+ return self.value >= 5
+
+ @classmethod
+ def default(cls) -> Priority:
+ return cls.low # ty: ignore[invalid-return-type]
+
+ @property
+ def doubled(self) -> int:
+ return self.value * 2
+
+ assert Priority.low.is_high() is False # ty:
ignore[possibly-missing-attribute]
+ assert Priority.high.is_high() is True # ty:
ignore[possibly-missing-attribute]
+ assert Priority.default().same_as(Priority.low)
+ assert Priority.high.doubled == 20 # ty:
ignore[possibly-missing-attribute]
+ assert list(Priority.all_entries()) == [Priority.low, Priority.high]
+
+
+def test_all_entries_iteration_order() -> None:
class Ordered(Enum, type_key=_unique_key("Ordered")):
tag: str
a: ClassVar[Ordered] = entry(tag="first")
b: ClassVar[Ordered] = entry(tag="second")
c: ClassVar[Ordered] = entry(tag="third")
- values = list(Ordered.entries())
+ values = list(Ordered.all_entries())
assert len(values) == 3
assert values[0].same_as(Ordered.a)
assert values[1].same_as(Ordered.b)
@@ -153,12 +307,12 @@ def test_bare_classvar_without_cxx_entries() -> None:
retry: ClassVar[Status]
assert isinstance(Status.ok, Status)
- assert Status.ok.value == 0
- assert Status.err.value == 1
- assert Status.retry.value == 2
- assert Status.ok.name == "ok"
- assert Status.err.name == "err"
- assert list(Status.entries()) == [Status.ok, Status.err, Status.retry]
+ assert Status.ok._value == 0
+ assert Status.err._value == 1
+ assert Status.retry._value == 2
+ assert Status.ok._name == "ok"
+ assert Status.err._name == "err"
+ assert list(Status.all_entries()) == [Status.ok, Status.err, Status.retry]
assert Status.get("ok").same_as(Status.ok)
@@ -171,10 +325,10 @@ def test_bare_classvar_mixed_with_entry() -> None:
named: ClassVar[Kind] = entry(tag="hi")
# ``ClassVar`` binders are processed before ``entry(...)`` assignments.
- assert Kind.blank.value == 0
- assert Kind.named.value == 1
- assert Kind.blank.name == "blank"
- assert Kind.named.name == "named"
+ assert Kind.blank._value == 0
+ assert Kind.named._value == 1
+ assert Kind.blank._name == "blank"
+ assert Kind.named._name == "named"
assert Kind.named.tag == "hi" # ty: ignore[unresolved-attribute]
@@ -196,7 +350,7 @@ def test_bare_entry_sugar_form() -> None:
assert isinstance(Activation.relu, Activation)
assert Activation.relu.output_zero is True # ty:
ignore[unresolved-attribute]
assert Activation.gelu.output_zero is False # ty:
ignore[unresolved-attribute]
- assert list(Activation.entries()) == [Activation.relu, Activation.gelu]
+ assert list(Activation.all_entries()) == [Activation.relu, Activation.gelu]
# ---------------------------------------------------------------------------
@@ -213,12 +367,12 @@ def test_auto_basic_no_annotation() -> None:
high = auto()
assert isinstance(Priority.low, Priority)
- assert Priority.low.value == 0
- assert Priority.medium.value == 1
- assert Priority.high.value == 2
- assert Priority.low.name == "low"
- assert Priority.high.name == "high"
- assert list(Priority.entries()) == [Priority.low, Priority.medium,
Priority.high]
+ assert Priority.low._value == 0
+ assert Priority.medium._value == 1
+ assert Priority.high._value == 2
+ assert Priority.low._name == "low"
+ assert Priority.high._name == "high"
+ assert list(Priority.all_entries()) == [Priority.low, Priority.medium,
Priority.high]
def test_auto_with_classvar_annotation() -> None:
@@ -229,9 +383,9 @@ def test_auto_with_classvar_annotation() -> None:
run: ClassVar[Stage] = auto()
done: ClassVar[Stage] = auto()
- assert Stage.init.value == 0
- assert Stage.run.value == 1
- assert Stage.done.value == 2
+ assert Stage.init._value == 0
+ assert Stage.run._value == 1
+ assert Stage.done._value == 2
def test_auto_mixed_with_bare_classvar() -> None:
@@ -248,10 +402,10 @@ def test_auto_mixed_with_bare_classvar() -> None:
gamma: ClassVar[Mixed]
# Binders (alpha, gamma) come first in annotation order, then sentinels.
- assert Mixed.alpha.value == 0
- assert Mixed.gamma.value == 1
- assert Mixed.beta.value == 2
- assert {v.name for v in Mixed.entries()} == {"alpha", "beta", "gamma"}
+ assert Mixed.alpha._value == 0
+ assert Mixed.gamma._value == 1
+ assert Mixed.beta._value == 2
+ assert {v._name for v in Mixed.all_entries()} == {"alpha", "beta", "gamma"}
def test_auto_mixed_with_entry() -> None:
@@ -263,9 +417,9 @@ def test_auto_mixed_with_entry() -> None:
add = entry(arity=2)
neg = entry(arity=1)
- assert Op.noop.value == 0
- assert Op.add.value == 1
- assert Op.neg.value == 2
+ assert Op.noop._value == 0
+ assert Op.add._value == 1
+ assert Op.neg._value == 2
assert Op.noop.arity == 0 # ty: ignore[unresolved-attribute]
assert Op.add.arity == 2 # ty: ignore[unresolved-attribute]
@@ -295,30 +449,21 @@ def test_auto_returns_fresh_sentinels() -> None:
# ---------------------------------------------------------------------------
-# by_name / by_value / attr_dict
+# all_entries / attr_dict
# ---------------------------------------------------------------------------
-def test_by_name_is_live_dict() -> None:
- class K(Enum, type_key=_unique_key("ByName")):
- a: ClassVar[K]
- b: ClassVar[K]
-
- assert set(K.by_name.keys()) == {"a", "b"}
- assert K.by_name["a"].same_as(K.a)
-
-
-def test_by_value_indexed_by_ordinal() -> None:
- class K(Enum, type_key=_unique_key("ByValue")):
+def test_all_entries_indexed_by_ordinal() -> None:
+ class K(Enum, type_key=_unique_key("AllEntries")):
a: ClassVar[K]
b: ClassVar[K]
c: ClassVar[K]
- by_val = K.by_value
- assert len(by_val) == 3
- assert by_val[0].same_as(K.a)
- assert by_val[1].same_as(K.b)
- assert by_val[2].same_as(K.c)
+ entries = list(K.all_entries())
+ assert len(entries) == 3
+ assert entries[0].same_as(K.a)
+ assert entries[1].same_as(K.b)
+ assert entries[2].same_as(K.c)
def test_attr_dict_direct_access() -> None:
@@ -335,8 +480,8 @@ def test_attr_dict_direct_access() -> None:
# Direct read via class-level property.
column = Op.attr_dict["has_side_effects"]
- assert column[Op.add.value] is False
- assert column[Op.neg.value] is True
+ assert column[Op.add._value] is False
+ assert column[Op.neg._value] is True
# ---------------------------------------------------------------------------
@@ -491,7 +636,7 @@ def test_enum_attrs_typeattr_stored_under_unified_column()
-> None:
stored = core._lookup_type_attr(tinfo.type_index, ENUM_ATTRS_ATTR)
assert stored is not None
assert "color" in stored
- assert stored["color"][WithAttr.one.value] == "red"
+ assert stored["color"][WithAttr.one._value] == "red"
# ---------------------------------------------------------------------------
@@ -531,17 +676,17 @@ def test_cxx_backed_classvar_binds_to_existing_entries()
-> None:
Beta: ClassVar[Variant]
assert isinstance(Variant.Alpha, Variant)
- assert Variant.Alpha.name == "Alpha"
- assert Variant.Beta.name == "Beta"
+ assert Variant.Alpha._name == "Alpha"
+ assert Variant.Beta._name == "Beta"
# Ordinals come from C++ (registered in Alpha, Beta declaration order).
- assert Variant.Alpha.value == 0
- assert Variant.Beta.value == 1
+ assert Variant.Alpha._value == 0
+ assert Variant.Beta._value == 1
assert Variant.get("Alpha").same_as(Variant.Alpha)
# C++-stored `code` attr is visible through attr_dict.
code_col = Variant.attr_dict["code"]
- assert code_col[Variant.Alpha.value] == 10
- assert code_col[Variant.Beta.value] == 20
+ assert code_col[Variant.Alpha._value] == 10
+ assert code_col[Variant.Beta._value] == 20
def test_cxx_backed_reads_entries_typeattr() -> None:
@@ -594,21 +739,23 @@ def test_cxx_backed_mixed_entries_via_auto() -> None:
MixedTwo = auto()
# Alpha/Beta come from C++ with ordinals 0 and 1.
- assert Mixed.Alpha.value == 0
- assert Mixed.Beta.value == 1
- assert Mixed.Alpha.name == "Alpha"
- assert Mixed.Beta.name == "Beta"
+ assert Mixed.Alpha._value == 0
+ assert Mixed.Beta._value == 1
+ assert Mixed.Alpha._name == "Alpha"
+ assert Mixed.Beta._name == "Beta"
# Python-side entries extend the dense ordinal sequence from the C++ count.
- assert Mixed.MixedOne.name == "MixedOne"
- assert Mixed.MixedTwo.name == "MixedTwo"
- assert Mixed.MixedOne.value == Mixed.Beta.value + 1
- assert Mixed.MixedTwo.value == Mixed.Beta.value + 2
+ assert Mixed.MixedOne._name == "MixedOne"
+ assert Mixed.MixedTwo._name == "MixedTwo"
+ assert Mixed.MixedOne._value == Mixed.Beta._value + 1
+ assert Mixed.MixedTwo._value == Mixed.Beta._value + 2
- # Round-trip through ``get`` / ``by_name`` / ``entries``.
+ # Round-trip through ``get`` / ``all_entries``.
assert Mixed.get("MixedOne").same_as(Mixed.MixedOne)
assert Mixed.get("MixedTwo").same_as(Mixed.MixedTwo)
- assert {"Alpha", "Beta", "MixedOne",
"MixedTwo"}.issubset(Mixed.by_name.keys())
+ assert {"Alpha", "Beta", "MixedOne", "MixedTwo"}.issubset(
+ {entry._name for entry in Mixed.all_entries()}
+ )
# Python-side variants are real subclass instances of the cxx-backed type.
assert isinstance(Mixed.MixedOne, Mixed)
@@ -616,8 +763,8 @@ def test_cxx_backed_mixed_entries_via_auto() -> None:
# Existing C++ attrs remain unaffected; new Python variants have no attrs
yet.
code = Mixed.attr_dict["code"]
- assert code[Mixed.Alpha.value] == 10
- assert code[Mixed.Beta.value] == 20
+ assert code[Mixed.Alpha._value] == 10
+ assert code[Mixed.Beta._value] == 20
def test_cxx_backed_python_entry_accepts_def_attr() -> None:
@@ -672,12 +819,12 @@ def test_default_repr_in_nested_container() -> None:
red = auto()
green = auto()
- by_name_repr = repr(Color.by_name)
- assert f"{key}.red" in by_name_repr
- assert f"{key}.green" in by_name_repr
+ all_entries_repr = repr(list(Color.all_entries()))
+ assert f"{key}.red" in all_entries_repr
+ assert f"{key}.green" in all_entries_repr
- by_value_entries = [repr(v) for v in Color.by_value]
- assert by_value_entries == [f"{key}.red", f"{key}.green"]
+ all_entries = [repr(v) for v in Color.all_entries()]
+ assert all_entries == [f"{key}.red", f"{key}.green"]
def test_default_repr_with_attribute_carrying_variant() -> None:
diff --git a/tests/python/test_dataclass_py_class.py
b/tests/python/test_dataclass_py_class.py
index acfabe5..1b70489 100644
--- a/tests/python/test_dataclass_py_class.py
+++ b/tests/python/test_dataclass_py_class.py
@@ -32,7 +32,7 @@ from tvm_ffi import core
from tvm_ffi._dunder import _install_dataclass_dunders
from tvm_ffi._ffi_api import DeepCopy, RecursiveEq, RecursiveHash, ReprPrint
from tvm_ffi.core import MISSING, Object, TypeInfo, TypeSchema,
_to_py_class_value
-from tvm_ffi.dataclasses import KW_ONLY, Field, field, py_class
+from tvm_ffi.dataclasses import KW_ONLY, Field, IntEnum, StrEnum, entry,
field, py_class
from tvm_ffi.registry import _add_class_attrs
from tvm_ffi.testing import TestObjectBase as _TestObjectBase
from tvm_ffi.testing.testing import requires_py310
@@ -193,6 +193,44 @@ class TestFieldParsing:
obj = BoolFld(x=True)
assert obj.x is True
+ def test_payload_enum_fields_end_to_end(self) -> None:
+ """IntEnum/StrEnum fields on @py_class accept raw payloads and expose
enum objects."""
+
+ class Priority(IntEnum, type_key=_unique_key("Priority")):
+ low = entry(value=10)
+ high = entry(value=20)
+
+ class Opcode(StrEnum, type_key=_unique_key("Opcode")):
+ add = entry(value="+")
+ mul = entry(value="*")
+
+ @py_class(_unique_key("EnumFields"))
+ class Instruction(Object):
+ priority: Priority
+ opcode: Opcode
+
+ from_payload = Instruction(priority=20, opcode="*") # ty:
ignore[invalid-argument-type]
+ assert isinstance(from_payload.priority, Priority)
+ assert isinstance(from_payload.opcode, Opcode)
+ assert from_payload.priority.same_as(Priority.high)
+ assert from_payload.opcode.same_as(Opcode.mul)
+ assert from_payload.priority.value == 20
+ assert from_payload.opcode.value == "*"
+
+ from_enum = Instruction(priority=Priority.low, opcode=Opcode.add)
+ assert from_enum.priority.same_as(Priority.low)
+ assert from_enum.opcode.same_as(Opcode.add)
+
+ from_enum.priority = 20 # ty: ignore[invalid-assignment]
+ from_enum.opcode = "*" # ty: ignore[invalid-assignment]
+ assert from_enum.priority.same_as(Priority.high)
+ assert from_enum.opcode.same_as(Opcode.mul)
+
+ with pytest.raises((TypeError, RuntimeError), match="expected"):
+ from_enum.priority = 99 # ty: ignore[invalid-assignment]
+ with pytest.raises((TypeError, RuntimeError), match="expected"):
+ from_enum.opcode = "/" # ty: ignore[invalid-assignment]
+
@requires_py310
def test_optional_field(self) -> None:
@py_class(_unique_key("OptFld"))
diff --git a/tests/python/test_type_converter.py
b/tests/python/test_type_converter.py
index 28d885a..6adef75 100644
--- a/tests/python/test_type_converter.py
+++ b/tests/python/test_type_converter.py
@@ -20,6 +20,7 @@ from __future__ import annotations
import collections.abc
import ctypes
+import itertools
import os
import sys
import typing
@@ -36,6 +37,7 @@ from tvm_ffi.core import (
_object_type_key_to_index,
_to_py_class_value,
)
+from tvm_ffi.dataclasses import IntEnum, StrEnum, entry
# Python 3.9+ supports list[int], dict[str, int], tuple[int, ...] at runtime.
# On 3.8, these raise TypeError("'type' object is not subscriptable").
@@ -50,15 +52,37 @@ from tvm_ffi.testing import (
)
from tvm_ffi.testing.testing import requires_py39, requires_py310
-
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
+_TYPE_KEY_COUNTER = itertools.count()
+
+
def S(origin: str, *args: TypeSchema) -> TypeSchema:
"""Shorthand constructor for TypeSchema (string-based)."""
return TypeSchema(origin, tuple(args))
+def _unique_type_key(base: str) -> str:
+ return f"testing.type_converter.{base}_{next(_TYPE_KEY_COUNTER)}"
+
+
+def _make_int_enum_type() -> typing.Any:
+ class Colors(IntEnum, type_key=_unique_type_key("IntEnum")):
+ red = entry(value=10)
+ blue = entry(value=20)
+
+ return Colors
+
+
+def _make_str_enum_type() -> typing.Any:
+ class Tokens(StrEnum, type_key=_unique_type_key("StrEnum")):
+ add = entry(value="+")
+ mul = entry(value="*")
+
+ return Tokens
+
+
# Annotation-based constructor — the main subject under test.
A = TypeSchema.from_annotation
@@ -267,7 +291,42 @@ class TestObjectTypes:
# ---------------------------------------------------------------------------
-# Category 6: Optional
+# Category 6: Payload enums
+# ---------------------------------------------------------------------------
+class TestPayloadEnums:
+ def test_int_enum_convert_from_int(self) -> None:
+ """IntEnum accepts its user-visible integer payload."""
+ Colors = _make_int_enum_type()
+ result = _to_py_class_value(A(Colors).convert(20))
+ assert result.same_as(Colors.blue)
+
+ def test_int_enum_passthrough_existing_object(self) -> None:
+ """IntEnum keeps the object passthrough path for existing enum
objects."""
+ Colors = _make_int_enum_type()
+ result = _to_py_class_value(A(Colors).convert(Colors.red))
+ assert result.same_as(Colors.red)
+
+ def test_int_enum_rejects_unknown_payload(self) -> None:
+ """IntEnum still rejects unmatched integer payloads."""
+ Colors = _make_int_enum_type()
+ with pytest.raises(TypeError, match="expected"):
+ A(Colors).check_value(99)
+
+ def test_str_enum_convert_from_str(self) -> None:
+ """StrEnum accepts its user-visible string payload."""
+ Tokens = _make_str_enum_type()
+ result = _to_py_class_value(A(Tokens).convert("*"))
+ assert result.same_as(Tokens.mul)
+
+ def test_str_enum_rejects_unknown_payload(self) -> None:
+ """StrEnum still rejects unmatched string payloads."""
+ Tokens = _make_str_enum_type()
+ with pytest.raises(TypeError, match="expected"):
+ A(Tokens).check_value("/")
+
+
+# ---------------------------------------------------------------------------
+# Category 7: Optional
# ---------------------------------------------------------------------------
class TestOptional:
def test_none_passes(self) -> None:
@@ -291,7 +350,7 @@ class TestOptional:
# ---------------------------------------------------------------------------
-# Category 7: Union / Variant
+# Category 8: Union / Variant
# ---------------------------------------------------------------------------
class TestUnion:
def test_first_alt_passes(self) -> None:
@@ -313,7 +372,7 @@ class TestUnion:
# ---------------------------------------------------------------------------
-# Category 8: Containers
+# Category 9: Containers
# ---------------------------------------------------------------------------
class TestContainers:
@requires_py39