This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 09e7216 fix(container): align sequence equality and contains (#560)
09e7216 is described below
commit 09e7216697307aaf454c041a9d527bd88f9986ee
Author: Junru Shao <[email protected]>
AuthorDate: Tue Apr 21 09:34:52 2026 -0700
fix(container): align sequence equality and contains (#560)
---
python/tvm_ffi/container.py | 68 ++++++++++++++++++++++++++++++++++--------
tests/python/test_container.py | 19 ++++++++++--
2 files changed, 72 insertions(+), 15 deletions(-)
diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py
index 4d25ce9..c7e1df6 100644
--- a/python/tvm_ffi/container.py
+++ b/python/tvm_ffi/container.py
@@ -79,6 +79,40 @@ __all__ = ["Array", "Dict", "List", "Map"]
T = TypeVar("T")
K = TypeVar("K")
V = TypeVar("V")
+
+
+def _sequence_compare_other(this: object, other: object) -> object:
+ """Normalize plain Python sequences for structural container equality."""
+ if isinstance(other, (str, bytes, Mapping)):
+ return NotImplemented
+ if isinstance(other, Sequence):
+ try:
+ return type(this)(other)
+ except (TypeError, ValueError):
+ return NotImplemented
+ return NotImplemented
+
+
+def _sequence_contains(
+ this: Sequence[Any],
+ value: object,
+ ffi_contains: Callable[[Any, object], bool],
+) -> bool:
+ """Containment with a Python-level structural fallback for nested
sequences."""
+ if ffi_contains(this, value):
+ return True
+ if not isinstance(value, Sequence) or isinstance(value, (str, bytes)):
+ return False
+ try:
+ search_value = type(this)(value) # ty:
ignore[too-many-positional-arguments]
+ except (TypeError, ValueError):
+ return False
+ for item in this:
+ if item == search_value:
+ return True
+ return False
+
+
_DefaultT = TypeVar("_DefaultT")
from .core import MISSING
@@ -199,19 +233,23 @@ class Array(core.CContainerBase, core.Object,
Sequence[T]):
def __contains__(self, value: object) -> bool:
"""Check if the array contains a value."""
- return _ffi_api.ArrayContains(self, value)
+ return _sequence_contains(self, value, _ffi_api.ArrayContains)
def __eq__(self, other: object) -> bool:
"""Structural equality."""
- if not (isinstance(other, type(self)) or isinstance(self,
type(other))):
+ if isinstance(other, type(self)) or isinstance(self, type(other)):
+ return _ffi_api.RecursiveEq(self, other)
+ other = _sequence_compare_other(self, other)
+ if other is NotImplemented:
return NotImplemented
return _ffi_api.RecursiveEq(self, other)
def __ne__(self, other: object) -> bool:
"""Structural inequality."""
- if not (isinstance(other, type(self)) or isinstance(self,
type(other))):
+ result = self.__eq__(other)
+ if result is NotImplemented:
return NotImplemented
- return not _ffi_api.RecursiveEq(self, other)
+ return not result
def __hash__(self) -> int:
"""Structural hash."""
@@ -358,19 +396,23 @@ class List(core.CContainerBase, core.Object,
MutableSequence[T]):
def __contains__(self, value: object) -> bool:
"""Check if the list contains a value."""
- return _ffi_api.ListContains(self, value)
+ return _sequence_contains(self, value, _ffi_api.ListContains)
def __eq__(self, other: object) -> bool:
"""Structural equality."""
- if not (isinstance(other, type(self)) or isinstance(self,
type(other))):
+ if isinstance(other, type(self)) or isinstance(self, type(other)):
+ return _ffi_api.RecursiveEq(self, other)
+ other = _sequence_compare_other(self, other)
+ if other is NotImplemented:
return NotImplemented
return _ffi_api.RecursiveEq(self, other)
def __ne__(self, other: object) -> bool:
"""Structural inequality."""
- if not (isinstance(other, type(self)) or isinstance(self,
type(other))):
+ result = self.__eq__(other)
+ if result is NotImplemented:
return NotImplemented
- return not _ffi_api.RecursiveEq(self, other)
+ return not result
def __hash__(self) -> int:
"""Structural hash."""
@@ -539,9 +581,10 @@ class Map(core.CContainerBase, core.Object, Mapping[K, V]):
def __ne__(self, other: object) -> bool:
"""Structural inequality."""
- if not (isinstance(other, type(self)) or isinstance(self,
type(other))):
+ result = self.__eq__(other)
+ if result is NotImplemented:
return NotImplemented
- return not _ffi_api.RecursiveEq(self, other)
+ return not result
def __hash__(self) -> int:
"""Structural hash."""
@@ -663,9 +706,10 @@ class Dict(core.CContainerBase, core.Object,
MutableMapping[K, V]):
def __ne__(self, other: object) -> bool:
"""Structural inequality."""
- if not (isinstance(other, type(self)) or isinstance(self,
type(other))):
+ result = self.__eq__(other)
+ if result is NotImplemented:
return NotImplemented
- return not _ffi_api.RecursiveEq(self, other)
+ return not result
def __hash__(self) -> int:
"""Structural hash."""
diff --git a/tests/python/test_container.py b/tests/python/test_container.py
index 9b74b0c..82b2cfa 100644
--- a/tests/python/test_container.py
+++ b/tests/python/test_container.py
@@ -235,6 +235,11 @@ def test_array_contains(arr: list[Any], value: Any,
expected: bool) -> None:
assert (value in a) == expected
+def test_array_contains_plain_tuple() -> None:
+ a = tvm_ffi.Array([("BLOCK_SIZE", 128)])
+ assert ("BLOCK_SIZE", 128) in a
+
+
@pytest.mark.parametrize(
"arr, expected",
[
@@ -764,8 +769,9 @@ def test_array_eq_nested() -> None:
def test_array_eq_not_implemented_for_unrelated() -> None:
a = tvm_ffi.Array([1, 2, 3])
- assert a.__eq__([1, 2, 3]) is NotImplemented
- assert a.__ne__([1, 2, 3]) is NotImplemented
+ assert a == [1, 2, 3]
+ assert a == (1, 2, 3)
+ assert not (a != [1, 2, 3])
assert a.__eq__("hello") is NotImplemented
@@ -794,7 +800,14 @@ def test_list_eq_empty() -> None:
def test_list_eq_not_implemented_for_unrelated() -> None:
a = tvm_ffi.List([1, 2, 3])
- assert a.__eq__([1, 2, 3]) is NotImplemented
+ assert a == [1, 2, 3]
+ assert a == (1, 2, 3)
+ assert not (a != [1, 2, 3])
+
+
+def test_list_contains_plain_tuple() -> None:
+ a = tvm_ffi.List([("BLOCK_SIZE", 128)])
+ assert ("BLOCK_SIZE", 128) in a
def test_list_hash() -> None: