This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 09e7216  fix(container): align sequence equality and contains (#560)
09e7216 is described below

commit 09e7216697307aaf454c041a9d527bd88f9986ee
Author: Junru Shao <[email protected]>
AuthorDate: Tue Apr 21 09:34:52 2026 -0700

    fix(container): align sequence equality and contains (#560)
---
 python/tvm_ffi/container.py    | 68 ++++++++++++++++++++++++++++++++++--------
 tests/python/test_container.py | 19 ++++++++++--
 2 files changed, 72 insertions(+), 15 deletions(-)

diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py
index 4d25ce9..c7e1df6 100644
--- a/python/tvm_ffi/container.py
+++ b/python/tvm_ffi/container.py
@@ -79,6 +79,40 @@ __all__ = ["Array", "Dict", "List", "Map"]
 T = TypeVar("T")
 K = TypeVar("K")
 V = TypeVar("V")
+
+
+def _sequence_compare_other(this: object, other: object) -> object:
+    """Normalize plain Python sequences for structural container equality."""
+    if isinstance(other, (str, bytes, Mapping)):
+        return NotImplemented
+    if isinstance(other, Sequence):
+        try:
+            return type(this)(other)
+        except (TypeError, ValueError):
+            return NotImplemented
+    return NotImplemented
+
+
+def _sequence_contains(
+    this: Sequence[Any],
+    value: object,
+    ffi_contains: Callable[[Any, object], bool],
+) -> bool:
+    """Containment with a Python-level structural fallback for nested 
sequences."""
+    if ffi_contains(this, value):
+        return True
+    if not isinstance(value, Sequence) or isinstance(value, (str, bytes)):
+        return False
+    try:
+        search_value = type(this)(value)  # ty: 
ignore[too-many-positional-arguments]
+    except (TypeError, ValueError):
+        return False
+    for item in this:
+        if item == search_value:
+            return True
+    return False
+
+
 _DefaultT = TypeVar("_DefaultT")
 
 from .core import MISSING
@@ -199,19 +233,23 @@ class Array(core.CContainerBase, core.Object, 
Sequence[T]):
 
     def __contains__(self, value: object) -> bool:
         """Check if the array contains a value."""
-        return _ffi_api.ArrayContains(self, value)
+        return _sequence_contains(self, value, _ffi_api.ArrayContains)
 
     def __eq__(self, other: object) -> bool:
         """Structural equality."""
-        if not (isinstance(other, type(self)) or isinstance(self, 
type(other))):
+        if isinstance(other, type(self)) or isinstance(self, type(other)):
+            return _ffi_api.RecursiveEq(self, other)
+        other = _sequence_compare_other(self, other)
+        if other is NotImplemented:
             return NotImplemented
         return _ffi_api.RecursiveEq(self, other)
 
     def __ne__(self, other: object) -> bool:
         """Structural inequality."""
-        if not (isinstance(other, type(self)) or isinstance(self, 
type(other))):
+        result = self.__eq__(other)
+        if result is NotImplemented:
             return NotImplemented
-        return not _ffi_api.RecursiveEq(self, other)
+        return not result
 
     def __hash__(self) -> int:
         """Structural hash."""
@@ -358,19 +396,23 @@ class List(core.CContainerBase, core.Object, 
MutableSequence[T]):
 
     def __contains__(self, value: object) -> bool:
         """Check if the list contains a value."""
-        return _ffi_api.ListContains(self, value)
+        return _sequence_contains(self, value, _ffi_api.ListContains)
 
     def __eq__(self, other: object) -> bool:
         """Structural equality."""
-        if not (isinstance(other, type(self)) or isinstance(self, 
type(other))):
+        if isinstance(other, type(self)) or isinstance(self, type(other)):
+            return _ffi_api.RecursiveEq(self, other)
+        other = _sequence_compare_other(self, other)
+        if other is NotImplemented:
             return NotImplemented
         return _ffi_api.RecursiveEq(self, other)
 
     def __ne__(self, other: object) -> bool:
         """Structural inequality."""
-        if not (isinstance(other, type(self)) or isinstance(self, 
type(other))):
+        result = self.__eq__(other)
+        if result is NotImplemented:
             return NotImplemented
-        return not _ffi_api.RecursiveEq(self, other)
+        return not result
 
     def __hash__(self) -> int:
         """Structural hash."""
@@ -539,9 +581,10 @@ class Map(core.CContainerBase, core.Object, Mapping[K, V]):
 
     def __ne__(self, other: object) -> bool:
         """Structural inequality."""
-        if not (isinstance(other, type(self)) or isinstance(self, 
type(other))):
+        result = self.__eq__(other)
+        if result is NotImplemented:
             return NotImplemented
-        return not _ffi_api.RecursiveEq(self, other)
+        return not result
 
     def __hash__(self) -> int:
         """Structural hash."""
@@ -663,9 +706,10 @@ class Dict(core.CContainerBase, core.Object, 
MutableMapping[K, V]):
 
     def __ne__(self, other: object) -> bool:
         """Structural inequality."""
-        if not (isinstance(other, type(self)) or isinstance(self, 
type(other))):
+        result = self.__eq__(other)
+        if result is NotImplemented:
             return NotImplemented
-        return not _ffi_api.RecursiveEq(self, other)
+        return not result
 
     def __hash__(self) -> int:
         """Structural hash."""
diff --git a/tests/python/test_container.py b/tests/python/test_container.py
index 9b74b0c..82b2cfa 100644
--- a/tests/python/test_container.py
+++ b/tests/python/test_container.py
@@ -235,6 +235,11 @@ def test_array_contains(arr: list[Any], value: Any, 
expected: bool) -> None:
     assert (value in a) == expected
 
 
+def test_array_contains_plain_tuple() -> None:
+    a = tvm_ffi.Array([("BLOCK_SIZE", 128)])
+    assert ("BLOCK_SIZE", 128) in a
+
+
 @pytest.mark.parametrize(
     "arr, expected",
     [
@@ -764,8 +769,9 @@ def test_array_eq_nested() -> None:
 
 def test_array_eq_not_implemented_for_unrelated() -> None:
     a = tvm_ffi.Array([1, 2, 3])
-    assert a.__eq__([1, 2, 3]) is NotImplemented
-    assert a.__ne__([1, 2, 3]) is NotImplemented
+    assert a == [1, 2, 3]
+    assert a == (1, 2, 3)
+    assert not (a != [1, 2, 3])
     assert a.__eq__("hello") is NotImplemented
 
 
@@ -794,7 +800,14 @@ def test_list_eq_empty() -> None:
 
 def test_list_eq_not_implemented_for_unrelated() -> None:
     a = tvm_ffi.List([1, 2, 3])
-    assert a.__eq__([1, 2, 3]) is NotImplemented
+    assert a == [1, 2, 3]
+    assert a == (1, 2, 3)
+    assert not (a != [1, 2, 3])
+
+
+def test_list_contains_plain_tuple() -> None:
+    a = tvm_ffi.List([("BLOCK_SIZE", 128)])
+    assert ("BLOCK_SIZE", 128) in a
 
 
 def test_list_hash() -> None:

Reply via email to