This is an automated email from the ASF dual-hosted git repository.
yzh119 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 96edb9a feat(enum): add payload enum compatibility behavior (#568)
96edb9a is described below
commit 96edb9a8fd0c474b311234b174719dda20230419
Author: Junru Shao <[email protected]>
AuthorDate: Fri Apr 24 02:20:59 2026 -0700
feat(enum): add payload enum compatibility behavior (#568)
TL;DR:
- Make `IntEnum` and `StrEnum` variants behave like their raw payloads
for common Python enum operations while preserving FFI singleton
identity through same_as.
Examples:
- `Priority.low.name == "low"` and `repr(Priority.low) ==
"Priority.low"`.
- `Priority.low == 10`, `str(Priority.low) == "10"`, and
`hash(Priority.low) == hash(10)`.
- `Priority(10).same_as(Priority.low)` and
`Priority("low").same_as(Priority.low)`.
- `Opcode("+").same_as(Opcode.add)` and
`Opcode("add").same_as(Opcode.add)`.
Details:
- Install payload-only instance dunders after `py_class`/`c_class`
generation so payload semantics override generated object defaults
without clobbering explicit user methods.
- Gate payload behavior on `__ffi_enum_payload_value_type__` so plain
Enum does not interpret arbitrary class-body values as public payloads.
- Keep single-argument member lookup in `_EnumMeta.__call__` for
existing variants, raw payloads, and names.
- Document why these methods are dynamically installed and what
identity/equality contract they preserve.
- Add enum tests covering `name`, `repr`, `str`, `equality`, `hashing`,
and construction from payload/name.
---
python/tvm_ffi/dataclasses/enum.py | 102 ++++++++++++++++++++++++++++++++++++
tests/python/test_dataclass_enum.py | 28 ++++++++++
2 files changed, 130 insertions(+)
diff --git a/python/tvm_ffi/dataclasses/enum.py
b/python/tvm_ffi/dataclasses/enum.py
index 3182b7c..2ec9ce2 100644
--- a/python/tvm_ffi/dataclasses/enum.py
+++ b/python/tvm_ffi/dataclasses/enum.py
@@ -211,6 +211,74 @@ class _ClassProperty:
return self._fget(cls)
+def _install_payload_enum_behaviors(cls: type, *, user_defined_repr: bool) ->
None:
+ """Install stdlib-like instance behavior for ``IntEnum`` / ``StrEnum``.
+
+ Payload enums have two identities that plain ``Enum`` does not: the FFI
+ singleton identity (``Priority.low.same_as(...)``) and a user-visible raw
+ payload (``Priority.low.value == 10``). These methods make the Python
+ surface follow the payload where users expect stdlib-like behavior, while
+ keeping the object identity available through ``same_as``.
+
+ This runs after ``py_class`` / ``c_class`` because those decorators may
+ install generated object dunders. Installing here lets payload enums
+ override generated repr/equality/hash defaults, but still preserves methods
+ explicitly written in the enum subclass body.
+ """
+
+ def _payload_enum_eq(self: Enum, other: object) -> Any:
+ """Compare payload enum variants by payload and accept raw payloads.
+
+ Two variants of the same enum class compare like their ``value``
fields,
+ and a variant can compare directly with its raw payload type
+ (``Priority.low == 10``). Unrelated objects return ``NotImplemented``
+ so Python can try reflected comparison or fall back normally.
+ """
+ if isinstance(other, type(self)):
+ return self.value == other.value # type: ignore[attr-defined]
+ payload_type = getattr(type(self), "__ffi_enum_payload_value_type__",
None)
+ if payload_type is not None and isinstance(other, payload_type):
+ return self.value == other # type: ignore[attr-defined]
+ return NotImplemented
+
+ def _payload_enum_ne(self: Enum, other: object) -> Any:
+ """Negate payload equality while preserving ``NotImplemented``
semantics."""
+ if (eq := _payload_enum_eq(self, other)) is not NotImplemented:
+ return not eq
+ return NotImplemented
+
+ def _payload_enum_str(self: Enum) -> str:
+ """Render as the raw payload, matching ``int`` / ``str`` enum
expectations."""
+ return str(self.value) # type: ignore[attr-defined]
+
+ def _payload_enum_hash(self: Enum) -> int:
+ """Hash like the raw payload so equality and hashing stay
consistent."""
+ return hash(self.value) # type: ignore[attr-defined]
+
+ _PAYLOAD_ENUM_DUNDERS = {
+ "__eq__": _payload_enum_eq,
+ "__ne__": _payload_enum_ne,
+ "__str__": _payload_enum_str,
+ "__hash__": _payload_enum_hash,
+ }
+
+ def _payload_enum_name(self: Enum) -> str:
+ """Expose the declaration name without making plain ``Enum`` grow
``name``."""
+ return str(self._name)
+
+ def _payload_enum_repr(self: Enum) -> str:
+ """Render as ``Class.member``, matching Python payload enum
conventions."""
+ return f"{type(self).__name__}.{self.name}" # ty:
ignore[unresolved-attribute]
+
+ if "name" not in cls.__dict__:
+ cls.name = property(_payload_enum_name) # ty:
ignore[unresolved-attribute]
+ if not user_defined_repr:
+ cls.__repr__ = _payload_enum_repr # type: ignore[attr-defined]
+ for name, method in _PAYLOAD_ENUM_DUNDERS.items():
+ if name not in cls.__dict__:
+ setattr(cls, name, method)
+
+
# ---------------------------------------------------------------------------
# Enum base + EnumAttrMap
# ---------------------------------------------------------------------------
@@ -230,6 +298,36 @@ class _EnumMeta(type(Object)):
def __len__(cls) -> int:
return len(_ordered_entries(cls)) # ty: ignore[unresolved-attribute]
+ def __call__(cls, *args: Any, **kwargs: Any) -> Any:
+ """Construct enum variants from existing variants, payloads, or names.
+
+ Normal object construction is still delegated to ``Object`` / dataclass
+ initialization. The single-argument path is intercepted to match
+ stdlib enum lookup behavior:
+
+ * ``Priority(Priority.low)`` returns the same variant.
+ * ``Priority(10)`` returns the member whose payload value is ``10``.
+ * ``Priority("low")`` returns the member declared with that name.
+
+ Payload lookup is only enabled for classes with
+ ``__ffi_enum_payload_value_type__`` so plain enums do not accidentally
+ interpret arbitrary class-body values as public payloads.
+ """
+ if not kwargs and len(args) == 1:
+ value = args[0]
+ if isinstance(value, cls):
+ return value
+ payload_value_type = getattr(cls,
"__ffi_enum_payload_value_type__", None)
+ if payload_value_type is not None and isinstance(value,
payload_value_type):
+ value_entries = _value_entries_dict(cls)
+ if value_entries is not None and value in value_entries:
+ return value_entries[value]
+ if isinstance(value, str):
+ entries = _entries_dict(cls)
+ if entries is not None and value in entries:
+ return entries[value]
+ return super().__call__(*args, **kwargs)
+
@dataclass_transform(
eq_default=False,
@@ -356,6 +454,7 @@ class Enum(Object, metaclass=_EnumMeta):
super().__init_subclass__(**kwargs)
if type_key is None:
return
+ user_defined_repr = "__repr__" in cls.__dict__
binders, python_entries = _collect_entry_declarations(cls)
@@ -365,6 +464,9 @@ class Enum(Object, metaclass=_EnumMeta):
else:
py_class(type_key, frozen=frozen)(cls)
+ if getattr(cls, "__ffi_enum_payload_value_type__", None) is not None:
+ _install_payload_enum_behaviors(cls,
user_defined_repr=user_defined_repr)
+
_resolve_entries(cls, binders, python_entries, type_key=type_key,
cxx_backed=cxx_backed)
@classmethod
diff --git a/tests/python/test_dataclass_enum.py
b/tests/python/test_dataclass_enum.py
index 05e95de..15c0d2b 100644
--- a/tests/python/test_dataclass_enum.py
+++ b/tests/python/test_dataclass_enum.py
@@ -174,6 +174,34 @@ def test_str_enum_payload_literal_sugar() -> None:
assert list(Opcode.all_entries()) == [Opcode.add, Opcode.mul]
+def test_payload_enum_compat_behaviors() -> None:
+ class Priority(IntEnum, type_key=_unique_key("IntEnumCompat")):
+ low = 10
+ high = 20
+
+ class Opcode(StrEnum, type_key=_unique_key("StrEnumCompat")):
+ add = "+"
+ mul = "*"
+
+ assert Priority.low.name == "low" # ty: ignore[possibly-missing-attribute]
+ assert repr(Priority.low) == "Priority.low"
+ assert Priority.low == 10
+ assert Priority.low != 20
+ assert str(Priority.low) == "10"
+ assert hash(Priority.low) == hash(10)
+ assert Priority(10).same_as(Priority.low) # ty: ignore[missing-argument]
+ assert Priority("low").same_as(Priority.low) # ty:
ignore[missing-argument, invalid-argument-type]
+
+ assert Opcode.add.name == "add" # ty: ignore[possibly-missing-attribute]
+ assert repr(Opcode.add) == "Opcode.add"
+ assert Opcode.add == "+"
+ assert Opcode.add != "*"
+ assert str(Opcode.add) == "+"
+ assert hash(Opcode.add) == hash("+")
+ assert Opcode("+").same_as(Opcode.add) # ty: ignore[missing-argument,
invalid-argument-type]
+ assert Opcode("add").same_as(Opcode.add) # ty: ignore[missing-argument,
invalid-argument-type]
+
+
def test_payload_literal_sugar_preserves_annotated_field_defaults() -> None:
class Opcode(StrEnum, type_key=_unique_key("StrEnumLiteralDefault")):
arity: int = 0