This is an automated email from the ASF dual-hosted git repository.

yzh119 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 96edb9a  feat(enum): add payload enum compatibility behavior (#568)
96edb9a is described below

commit 96edb9a8fd0c474b311234b174719dda20230419
Author: Junru Shao <[email protected]>
AuthorDate: Fri Apr 24 02:20:59 2026 -0700

    feat(enum): add payload enum compatibility behavior (#568)
    
    TL;DR:
    - Make `IntEnum` and `StrEnum` variants behave like their raw payloads
    for common Python enum operations while preserving FFI singleton
    identity through same_as.
    
    Examples:
    - `Priority.low.name == "low"` and `repr(Priority.low) ==
    "Priority.low"`.
    - `Priority.low == 10`, `str(Priority.low) == "10"`, and
    `hash(Priority.low) == hash(10)`.
    - `Priority(10).same_as(Priority.low)` and
    `Priority("low").same_as(Priority.low)`.
    - `Opcode("+").same_as(Opcode.add)` and
    `Opcode("add").same_as(Opcode.add)`.
    
    Details:
    - Install payload-only instance dunders after `py_class`/`c_class`
    generation so payload semantics override generated object defaults
    without clobbering explicit user methods.
    - Gate payload behavior on `__ffi_enum_payload_value_type__` so plain
    Enum does not interpret arbitrary class-body values as public payloads.
    - Keep single-argument member lookup in `_EnumMeta.__call__` for
    existing variants, raw payloads, and names.
    - Document why these methods are dynamically installed and what
    identity/equality contract they preserve.
    - Add enum tests covering `name`, `repr`, `str`, `equality`, `hashing`,
    and construction from payload/name.
---
 python/tvm_ffi/dataclasses/enum.py  | 102 ++++++++++++++++++++++++++++++++++++
 tests/python/test_dataclass_enum.py |  28 ++++++++++
 2 files changed, 130 insertions(+)

diff --git a/python/tvm_ffi/dataclasses/enum.py 
b/python/tvm_ffi/dataclasses/enum.py
index 3182b7c..2ec9ce2 100644
--- a/python/tvm_ffi/dataclasses/enum.py
+++ b/python/tvm_ffi/dataclasses/enum.py
@@ -211,6 +211,74 @@ class _ClassProperty:
         return self._fget(cls)
 
 
+def _install_payload_enum_behaviors(cls: type, *, user_defined_repr: bool) -> 
None:
+    """Install stdlib-like instance behavior for ``IntEnum`` / ``StrEnum``.
+
+    Payload enums have two identities that plain ``Enum`` does not: the FFI
+    singleton identity (``Priority.low.same_as(...)``) and a user-visible raw
+    payload (``Priority.low.value == 10``).  These methods make the Python
+    surface follow the payload where users expect stdlib-like behavior, while
+    keeping the object identity available through ``same_as``.
+
+    This runs after ``py_class`` / ``c_class`` because those decorators may
+    install generated object dunders.  Installing here lets payload enums
+    override generated repr/equality/hash defaults, but still preserves methods
+    explicitly written in the enum subclass body.
+    """
+
+    def _payload_enum_eq(self: Enum, other: object) -> Any:
+        """Compare payload enum variants by payload and accept raw payloads.
+
+        Two variants of the same enum class compare like their ``value`` 
fields,
+        and a variant can compare directly with its raw payload type
+        (``Priority.low == 10``).  Unrelated objects return ``NotImplemented``
+        so Python can try reflected comparison or fall back normally.
+        """
+        if isinstance(other, type(self)):
+            return self.value == other.value  # type: ignore[attr-defined]
+        payload_type = getattr(type(self), "__ffi_enum_payload_value_type__", 
None)
+        if payload_type is not None and isinstance(other, payload_type):
+            return self.value == other  # type: ignore[attr-defined]
+        return NotImplemented
+
+    def _payload_enum_ne(self: Enum, other: object) -> Any:
+        """Negate payload equality while preserving ``NotImplemented`` 
semantics."""
+        if (eq := _payload_enum_eq(self, other)) is not NotImplemented:
+            return not eq
+        return NotImplemented
+
+    def _payload_enum_str(self: Enum) -> str:
+        """Render as the raw payload, matching ``int`` / ``str`` enum 
expectations."""
+        return str(self.value)  # type: ignore[attr-defined]
+
+    def _payload_enum_hash(self: Enum) -> int:
+        """Hash like the raw payload so equality and hashing stay 
consistent."""
+        return hash(self.value)  # type: ignore[attr-defined]
+
+    _PAYLOAD_ENUM_DUNDERS = {
+        "__eq__": _payload_enum_eq,
+        "__ne__": _payload_enum_ne,
+        "__str__": _payload_enum_str,
+        "__hash__": _payload_enum_hash,
+    }
+
+    def _payload_enum_name(self: Enum) -> str:
+        """Expose the declaration name without making plain ``Enum`` grow 
``name``."""
+        return str(self._name)
+
+    def _payload_enum_repr(self: Enum) -> str:
+        """Render as ``Class.member``, matching Python payload enum 
conventions."""
+        return f"{type(self).__name__}.{self.name}"  # ty: 
ignore[unresolved-attribute]
+
+    if "name" not in cls.__dict__:
+        cls.name = property(_payload_enum_name)  # ty: 
ignore[unresolved-attribute]
+    if not user_defined_repr:
+        cls.__repr__ = _payload_enum_repr  # type: ignore[attr-defined]
+    for name, method in _PAYLOAD_ENUM_DUNDERS.items():
+        if name not in cls.__dict__:
+            setattr(cls, name, method)
+
+
 # ---------------------------------------------------------------------------
 # Enum base + EnumAttrMap
 # ---------------------------------------------------------------------------
@@ -230,6 +298,36 @@ class _EnumMeta(type(Object)):
     def __len__(cls) -> int:
         return len(_ordered_entries(cls))  # ty: ignore[unresolved-attribute]
 
+    def __call__(cls, *args: Any, **kwargs: Any) -> Any:
+        """Construct enum variants from existing variants, payloads, or names.
+
+        Normal object construction is still delegated to ``Object`` / dataclass
+        initialization.  The single-argument path is intercepted to match
+        stdlib enum lookup behavior:
+
+        * ``Priority(Priority.low)`` returns the same variant.
+        * ``Priority(10)`` returns the member whose payload value is ``10``.
+        * ``Priority("low")`` returns the member declared with that name.
+
+        Payload lookup is only enabled for classes with
+        ``__ffi_enum_payload_value_type__`` so plain enums do not accidentally
+        interpret arbitrary class-body values as public payloads.
+        """
+        if not kwargs and len(args) == 1:
+            value = args[0]
+            if isinstance(value, cls):
+                return value
+            payload_value_type = getattr(cls, 
"__ffi_enum_payload_value_type__", None)
+            if payload_value_type is not None and isinstance(value, 
payload_value_type):
+                value_entries = _value_entries_dict(cls)
+                if value_entries is not None and value in value_entries:
+                    return value_entries[value]
+            if isinstance(value, str):
+                entries = _entries_dict(cls)
+                if entries is not None and value in entries:
+                    return entries[value]
+        return super().__call__(*args, **kwargs)
+
 
 @dataclass_transform(
     eq_default=False,
@@ -356,6 +454,7 @@ class Enum(Object, metaclass=_EnumMeta):
         super().__init_subclass__(**kwargs)
         if type_key is None:
             return
+        user_defined_repr = "__repr__" in cls.__dict__
 
         binders, python_entries = _collect_entry_declarations(cls)
 
@@ -365,6 +464,9 @@ class Enum(Object, metaclass=_EnumMeta):
         else:
             py_class(type_key, frozen=frozen)(cls)
 
+        if getattr(cls, "__ffi_enum_payload_value_type__", None) is not None:
+            _install_payload_enum_behaviors(cls, 
user_defined_repr=user_defined_repr)
+
         _resolve_entries(cls, binders, python_entries, type_key=type_key, 
cxx_backed=cxx_backed)
 
     @classmethod
diff --git a/tests/python/test_dataclass_enum.py 
b/tests/python/test_dataclass_enum.py
index 05e95de..15c0d2b 100644
--- a/tests/python/test_dataclass_enum.py
+++ b/tests/python/test_dataclass_enum.py
@@ -174,6 +174,34 @@ def test_str_enum_payload_literal_sugar() -> None:
     assert list(Opcode.all_entries()) == [Opcode.add, Opcode.mul]
 
 
+def test_payload_enum_compat_behaviors() -> None:
+    class Priority(IntEnum, type_key=_unique_key("IntEnumCompat")):
+        low = 10
+        high = 20
+
+    class Opcode(StrEnum, type_key=_unique_key("StrEnumCompat")):
+        add = "+"
+        mul = "*"
+
+    assert Priority.low.name == "low"  # ty: ignore[possibly-missing-attribute]
+    assert repr(Priority.low) == "Priority.low"
+    assert Priority.low == 10
+    assert Priority.low != 20
+    assert str(Priority.low) == "10"
+    assert hash(Priority.low) == hash(10)
+    assert Priority(10).same_as(Priority.low)  # ty: ignore[missing-argument]
+    assert Priority("low").same_as(Priority.low)  # ty: 
ignore[missing-argument, invalid-argument-type]
+
+    assert Opcode.add.name == "add"  # ty: ignore[possibly-missing-attribute]
+    assert repr(Opcode.add) == "Opcode.add"
+    assert Opcode.add == "+"
+    assert Opcode.add != "*"
+    assert str(Opcode.add) == "+"
+    assert hash(Opcode.add) == hash("+")
+    assert Opcode("+").same_as(Opcode.add)  # ty: ignore[missing-argument, 
invalid-argument-type]
+    assert Opcode("add").same_as(Opcode.add)  # ty: ignore[missing-argument, 
invalid-argument-type]
+
+
 def test_payload_literal_sugar_preserves_annotated_field_defaults() -> None:
     class Opcode(StrEnum, type_key=_unique_key("StrEnumLiteralDefault")):
         arity: int = 0

Reply via email to