This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 86794e7d91 [Relax][Frontend] Add ParameterList and ParameterDict 
containers (#19495)
86794e7d91 is described below

commit 86794e7d91fa6dce66eb4c3995bc30a50948ae07
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Sat May 2 23:55:52 2026 +0900

    [Relax][Frontend] Add ParameterList and ParameterDict containers (#19495)
    
    This PR adds first-class `nn.ParameterList` and `nn.ParameterDict`
    containers to the Relax frontend.
    
    These containers provide PyTorch-like list/dict registration for raw
    `nn.Parameter` objects while preserving Relax frontend semantics: values
    must be explicit `nn.Parameter` instances, with no automatic
    tensor-to-parameter conversion.
    
    ### Changes
    
    - Add public `nn.ParameterList` and `nn.ParameterDict` exports.
    - Support stable parameter names in traversal:
      - `params.0`, `params.1`
      - `params.foo`, `params.bar`
    - Integrate the new containers with:
      - `named_parameters()`
      - `parameters()`
      - `state_dict()`
      - `load_state_dict()`
      - `to(dtype=...)`
      - `export_tvm()`
      - `nn.Mutator`
    - Add focused tests for basic container behavior, type validation,
    nested traversal, export parameter names, state loading, dtype
    conversion, and mutator naming.
    
    ---------
    
    Co-authored-by: gemini-code-assist[bot] 
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
---
 python/tvm/relax/frontend/nn/__init__.py           |  12 +-
 python/tvm/relax/frontend/nn/core.py               | 119 ++++++++++-
 python/tvm/relax/frontend/nn/visitor.py            |  62 +++++-
 tests/python/relax/test_frontend_nn_mutator.py     |  36 ++++
 .../relax/test_frontend_nn_parameter_containers.py | 223 +++++++++++++++++++++
 5 files changed, 446 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/__init__.py 
b/python/tvm/relax/frontend/nn/__init__.py
index 282944af98..1763ca152f 100644
--- a/python/tvm/relax/frontend/nn/__init__.py
+++ b/python/tvm/relax/frontend/nn/__init__.py
@@ -19,7 +19,17 @@
 
 # pylint: disable=redefined-builtin
 from . import op, spec
-from .core import Effect, Module, ModuleDict, ModuleList, Object, Parameter, 
Tensor
+from .core import (
+    Effect,
+    Module,
+    ModuleDict,
+    ModuleList,
+    Object,
+    Parameter,
+    ParameterDict,
+    ParameterList,
+    Tensor,
+)
 from .exporter import add_extern
 from .extern import ExternModule, ObjectModule, SourceModule
 from .modules import (
diff --git a/python/tvm/relax/frontend/nn/core.py 
b/python/tvm/relax/frontend/nn/core.py
index f3886e94cb..3725a84d61 100644
--- a/python/tvm/relax/frontend/nn/core.py
+++ b/python/tvm/relax/frontend/nn/core.py
@@ -625,6 +625,63 @@ class ModuleDict(Module):
             module.to(dtype=dtype)
 
 
+class ParameterDict(Module):
+    """Holds parameters in a dict."""
+
+    def __init__(
+        self,
+        params: OrderedDict[str, Parameter] | dict[str, Parameter] | None = 
None,
+    ):
+        self.params: OrderedDict[str, Parameter] = OrderedDict()
+        if params is not None:
+            self.update(params)
+
+    def __iter__(self) -> Iterator[str]:
+        return iter(self.params)
+
+    def __getitem__(self, key: str) -> Parameter:
+        return self.params[key]
+
+    def __setitem__(self, key: str, param: Parameter) -> None:
+        if not isinstance(key, str):
+            raise TypeError(f"ParameterDict keys must be strings, but got 
{type(key).__name__}")
+        if not isinstance(param, Parameter):
+            raise TypeError(f"ParameterDict values must be nn.Parameter, but 
got {type(param).__name__}")
+        self.params[key] = param
+
+    def __len__(self) -> int:
+        return len(self.params)
+
+    def keys(self) -> Iterator[str]:
+        return self.params.keys()
+
+    def values(self) -> Iterator[Parameter]:
+        return self.params.values()
+
+    def items(self) -> Iterator[tuple[str, Parameter]]:
+        return self.params.items()
+
+    def get(self, key: str, default: Parameter | None = None) -> Parameter | 
None:
+        return self.params.get(key, default)
+
+    def update(self, params: dict[str, Parameter]) -> None:
+        for key, param in params.items():
+            self[key] = param
+
+    def clear(self) -> None:
+        self.params.clear()
+
+    def pop(self, key: str) -> Parameter:
+        return self.params.pop(key)
+
+    def __contains__(self, key: str) -> bool:
+        return key in self.params
+
+    def to(self, dtype: str | None = None) -> None:  # pylint: 
disable=invalid-name
+        for param in self.params.values():
+            param.to(dtype=dtype)
+
+
 class ModuleList(Module):
     """Holds submodules in a list."""
 
@@ -658,6 +715,44 @@ class ModuleList(Module):
         return x
 
 
+class ParameterList(Module):
+    """Holds parameters in a list."""
+
+    def __init__(self, params: list[Parameter] | None = None):
+        self.params: list[Parameter] = []
+        if params is not None:
+            self.extend(params)
+
+    def __iter__(self) -> Iterator[Parameter]:
+        return iter(self.params)
+
+    def __getitem__(self, idx: int) -> Parameter:
+        return self.params[idx]
+
+    def __setitem__(self, idx: int, param: Parameter) -> None:
+        if not isinstance(param, Parameter):
+            raise TypeError(f"ParameterList elements must be nn.Parameter, but 
got {type(param).__name__}")
+        self.params[idx] = param
+
+    def __len__(self) -> int:
+        return len(self.params)
+
+    def append(self, param: Parameter) -> None:
+        """Add a parameter to the end of the ParameterList"""
+        if not isinstance(param, Parameter):  
+            raise TypeError(f"ParameterList elements must be nn.Parameter, but 
got {type(param).__name__}")         
+        self.params.append(param)
+
+    def extend(self, params: list[Parameter]) -> None:
+        """Add parameters to the end of the ParameterList"""
+        for param in params:
+            self.append(param)
+
+    def to(self, dtype: str | None = None) -> None:  # pylint: 
disable=invalid-name
+        for param in self.params:
+            param.to(dtype=dtype)
+
+
 def wrap_nested(expr: rx.Expr, name: str) -> Tensor | Sequence[Tensor]:
     """Wrap the given relax.Expr, emit it using the current BlockBuilder,
     and automatically handle nested cases if the expr represents a Tuple.
@@ -692,7 +787,17 @@ def wrap_nested(expr: rx.Expr, name: str) -> Tensor | 
Sequence[Tensor]:
 
 def _attribute_finder(root: Module, prefix: str, condition_yield: 
Callable[[Any], bool]):
     """Find attributes that satisfy the condition recursively"""
-    if isinstance(root, ModuleList):
+    if isinstance(root, ParameterList):
+        for i, param in enumerate(root):
+            if condition_yield(param):
+                yield prefix + f"{i}", param
+        return
+    elif isinstance(root, ParameterDict):
+        for name, param in root.items():
+            if condition_yield(param):
+                yield prefix + name, param
+        return
+    elif isinstance(root, ModuleList):
         for i, subitem in enumerate(root):
             yield from _attribute_finder(subitem, prefix + f"{i}.", 
condition_yield)
         return
@@ -703,6 +808,18 @@ def _attribute_finder(root: Module, prefix: str, 
condition_yield: Callable[[Any]
     for name, item in root.__dict__.items():
         if condition_yield(item):
             yield prefix + name, item
+        elif isinstance(item, ParameterList):
+            yield from _attribute_finder(
+                item,
+                prefix + name + ".",
+                condition_yield,
+            )
+        elif isinstance(item, ParameterDict):
+            yield from _attribute_finder(
+                item,
+                prefix + name + ".",
+                condition_yield,
+            )
         elif isinstance(item, ModuleList):
             yield from _attribute_finder(
                 item,
diff --git a/python/tvm/relax/frontend/nn/visitor.py 
b/python/tvm/relax/frontend/nn/visitor.py
index e3279ceae5..69583eaae8 100644
--- a/python/tvm/relax/frontend/nn/visitor.py
+++ b/python/tvm/relax/frontend/nn/visitor.py
@@ -116,6 +116,42 @@ class Mutator:
         """
         return self.visit(name, node)
 
+    def visit_parameterdict(self, name: str, node: nn.ParameterDict) -> Any:
+        """The base visiting method for mutation of nn.ParameterDict nodes.
+
+        Parameters
+        ----------
+        name : str
+            The name of the current node in parent's attribute.
+
+        node : nn.ParameterDict
+            The current node of nn.ParameterDict to mutate.
+
+        Returns
+        ------
+        ret_node: Any
+            The new node to replace current node.
+        """
+        return self.visit(name, node)
+
+    def visit_parameterlist(self, name: str, node: nn.ParameterList) -> Any:
+        """The base visiting method for mutation of nn.ParameterList nodes.
+
+        Parameters
+        ----------
+        name : str
+            The name of the current node in parent's attribute.
+
+        node : nn.ParameterList
+            The current node of nn.ParameterList to mutate.
+
+        Returns
+        ------
+        ret_node: Any
+            The new node to replace current node.
+        """
+        return self.visit(name, node)
+
     def visit(self, name: str, node: Any) -> Any:
         """The base dispatching method for visiting of all nodes.
 
@@ -141,9 +177,19 @@ class Mutator:
             else:
                 return f"{parent}.{child}"
 
-        if isinstance(node, nn.ModuleList):
+        if isinstance(node, nn.ParameterList):
+            for i in range(len(node)):
+                node[i] = self.visit_param(_get_child_name(name, str(i)), 
node[i])
+        elif isinstance(node, nn.ParameterDict):
+            for k, v in node.items():
+                node[k] = self.visit_param(_get_child_name(name, k), v)
+        elif isinstance(node, nn.ModuleList):
             for i in range(len(node)):
-                if isinstance(node[i], nn.ModuleDict):
+                if isinstance(node[i], nn.ParameterDict):
+                    node[i] = self.visit_parameterdict(_get_child_name(name, 
str(i)), node[i])
+                elif isinstance(node[i], nn.ParameterList):
+                    node[i] = self.visit_parameterlist(_get_child_name(name, 
str(i)), node[i])
+                elif isinstance(node[i], nn.ModuleDict):
                     node[i] = self.visit_moduledict(f"{name}.{i}", node[i])
                 elif isinstance(node[i], nn.ModuleList):
                     node[i] = self.visit_modulelist(f"{name}.{i}", node[i])
@@ -155,7 +201,11 @@ class Mutator:
                     node[i] = self.visit_param(f"{name}.{i}", node[i])
         elif isinstance(node, nn.ModuleDict):
             for k, v in node.items():
-                if isinstance(v, nn.ModuleDict):
+                if isinstance(v, nn.ParameterDict):
+                    node[k] = self.visit_parameterdict(_get_child_name(name, 
k), v)
+                elif isinstance(v, nn.ParameterList):
+                    node[k] = self.visit_parameterlist(_get_child_name(name, 
k), v)
+                elif isinstance(v, nn.ModuleDict):
                     node[k] = self.visit_moduledict(_get_child_name(name, k), 
v)
                 elif isinstance(v, nn.ModuleList):
                     node[k] = self.visit_modulelist(_get_child_name(name, k), 
v)
@@ -167,7 +217,11 @@ class Mutator:
                     node[k] = self.visit_param(_get_child_name(name, k), v)
         else:
             for key, value in node.__dict__.items():
-                if isinstance(value, nn.ModuleDict):
+                if isinstance(value, nn.ParameterDict):
+                    setattr(node, key, 
self.visit_parameterdict(_get_child_name(name, key), value))
+                elif isinstance(value, nn.ParameterList):
+                    setattr(node, key, 
self.visit_parameterlist(_get_child_name(name, key), value))
+                elif isinstance(value, nn.ModuleDict):
                     setattr(node, key, 
self.visit_moduledict(_get_child_name(name, key), value))
                 elif isinstance(value, nn.ModuleList):
                     setattr(node, key, 
self.visit_modulelist(_get_child_name(name, key), value))
diff --git a/tests/python/relax/test_frontend_nn_mutator.py 
b/tests/python/relax/test_frontend_nn_mutator.py
index 253e24a4ed..23c8c9cde6 100644
--- a/tests/python/relax/test_frontend_nn_mutator.py
+++ b/tests/python/relax/test_frontend_nn_mutator.py
@@ -127,6 +127,42 @@ def test_mutator_naming_modulelist():
     mutator.visit("mod_list", mod_list)
 
 
+def test_mutator_naming_parameter_containers():
+    class Module(nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.param_list = nn.ParameterList(
+                [
+                    nn.Parameter((32, 128), "float64"),
+                    nn.Parameter((32, 128), "float32"),
+                ]
+            )
+            self.param_dict = nn.ParameterDict(
+                {
+                    "k0": nn.Parameter((32, 128), "float16"),
+                    "k1": nn.Parameter((32, 128), "float8"),
+                }
+            )
+
+    seen = []
+
+    class Mutator(nn.Mutator):
+        def visit_param(self, name: str, node: nn.Parameter) -> Any:
+            seen.append((name, node.dtype))
+            return node
+
+    module = Module()
+    mutator = Mutator()
+    mutator.visit("", module)
+
+    assert seen == [
+        ("param_list.0", "float64"),
+        ("param_list.1", "float32"),
+        ("param_dict.k0", "float16"),
+        ("param_dict.k1", "float8"),
+    ]
+
+
 def test_mutator_module():
     class SubModule1(nn.Module):
         def __init__(self) -> None:
diff --git a/tests/python/relax/test_frontend_nn_parameter_containers.py 
b/tests/python/relax/test_frontend_nn_parameter_containers.py
new file mode 100644
index 0000000000..d07a21405a
--- /dev/null
+++ b/tests/python/relax/test_frontend_nn_parameter_containers.py
@@ -0,0 +1,223 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Any
+
+import numpy as np
+import pytest
+
+import tvm
+import tvm.testing
+from tvm.relax.frontend import nn
+
+
+class ParamContainerModule(nn.Module):
+    def __init__(self):
+        self.list_params = nn.ParameterList(
+            [
+                nn.Parameter((4,), "float32"),
+                nn.Parameter((4,), "float32"),
+            ]
+        )
+        self.dict_params = nn.ParameterDict(
+            {
+                "foo": nn.Parameter((4,), "float32"),
+                "bar": nn.Parameter((4,), "float32"),
+            }
+        )
+
+
+def test_parameter_list_basic_behavior():
+    p0 = nn.Parameter((4,), "float32")
+    p1 = nn.Parameter((4,), "float32")
+    params = nn.ParameterList([p0])
+    params.append(p1)
+
+    assert len(params) == 2
+    assert params[0] is p0
+    assert list(params) == [p0, p1]
+
+    p2 = nn.Parameter((4,), "float32")
+    params[1] = p2
+    assert params[1] is p2
+
+    p3 = nn.Parameter((4,), "float32")
+    params.extend([p3])
+    assert list(params) == [p0, p2, p3]
+
+
+def test_parameter_dict_basic_behavior():
+    p0 = nn.Parameter((4,), "float32")
+    p1 = nn.Parameter((4,), "float32")
+    params = nn.ParameterDict({"foo": p0})
+    params["bar"] = p1
+
+    assert len(params) == 2
+    assert params["foo"] is p0
+    assert "bar" in params
+    assert list(params) == ["foo", "bar"]
+    assert list(params.keys()) == ["foo", "bar"]
+    assert list(params.values()) == [p0, p1]
+    assert list(params.items()) == [("foo", p0), ("bar", p1)]
+    assert params.get("foo") is p0
+
+    p2 = nn.Parameter((4,), "float32")
+    params.update({"baz": p2})
+    assert list(params.keys()) == ["foo", "bar", "baz"]
+    assert params.pop("baz") is p2
+    params.clear()
+    assert len(params) == 0
+
+
+def test_type_validation():
+    with pytest.raises(TypeError):
+        nn.ParameterList([object()])
+
+    with pytest.raises(TypeError):
+        nn.ParameterDict({"bad": object()})
+
+    with pytest.raises(TypeError):
+        nn.ParameterDict({1: nn.Parameter((4,), "float32")})
+
+    with pytest.raises(TypeError):
+        nn.ParameterList()[0] = object()
+
+
+def test_named_parameters_parameters_and_state_dict():
+    m = ParamContainerModule()
+
+    expected = [
+        "list_params.0",
+        "list_params.1",
+        "dict_params.foo",
+        "dict_params.bar",
+    ]
+
+    assert list(m.state_dict().keys()) == expected
+    assert [name for name, _ in m.named_parameters()] == expected
+    assert len(list(m.parameters())) == 4
+
+
+def test_nested_traversal_through_module_dict():
+    class Inner(nn.Module):
+        def __init__(self):
+            self.params = nn.ParameterList([nn.Parameter((4,), "float32")])
+
+    class Outer(nn.Module):
+        def __init__(self):
+            self.blocks = nn.ModuleDict({"inner": Inner()})
+
+    m = Outer()
+    assert list(m.state_dict().keys()) == ["blocks.inner.params.0"]
+
+
+def test_nested_traversal_through_module_list():
+    class Inner(nn.Module):
+        def __init__(self):
+            self.params = nn.ParameterList([nn.Parameter((4,), "float32")])
+
+    class Outer(nn.Module):
+        def __init__(self):
+            self.blocks = nn.ModuleList([Inner()])
+
+    m = Outer()
+    assert list(m.state_dict().keys()) == ["blocks.0.params.0"]
+
+
+def test_to_dtype():
+    m = ParamContainerModule()
+    m.to(dtype="float16")
+
+    assert m.list_params[0].dtype == "float16"
+    assert m.list_params[1].dtype == "float16"
+    assert m.dict_params["foo"].dtype == "float16"
+    assert m.dict_params["bar"].dtype == "float16"
+
+
+def test_load_state_dict():
+    m = ParamContainerModule()
+    p0 = nn.Parameter((4,), "float32")
+    p0.data = np.full((4,), 1.0, dtype="float32")
+    p1 = nn.Parameter((4,), "float32")
+    p1.data = np.full((4,), 2.0, dtype="float32")
+    p2 = nn.Parameter((4,), "float32")
+    p2.data = np.full((4,), 3.0, dtype="float32")
+    p3 = nn.Parameter((4,), "float32")
+    p3.data = np.full((4,), 4.0, dtype="float32")
+    state_dict = {
+        "list_params.0": p0,
+        "list_params.1": p1,
+        "dict_params.foo": p2,
+        "dict_params.bar": p3,
+    }
+
+    missing_keys, unexpected_keys = m.load_state_dict(state_dict)
+
+    assert missing_keys == []
+    assert unexpected_keys == []
+    tvm.testing.assert_allclose(m.list_params[0].data.numpy(), np.full((4,), 
1.0, "float32"))
+    tvm.testing.assert_allclose(m.list_params[1].data.numpy(), np.full((4,), 
2.0, "float32"))
+    tvm.testing.assert_allclose(
+        m.dict_params["foo"].data.numpy(), np.full((4,), 3.0, "float32")
+    )
+    tvm.testing.assert_allclose(
+        m.dict_params["bar"].data.numpy(), np.full((4,), 4.0, "float32")
+    )
+
+
+def test_export_tvm_parameter_names():
+    class M(nn.Module):
+        def __init__(self):
+            self.biases = nn.ParameterList(
+                [
+                    nn.Parameter((4,), "float32"),
+                    nn.Parameter((4,), "float32"),
+                ]
+            )
+            self.scales = nn.ParameterDict({"main": nn.Parameter((4,), 
"float32")})
+
+        def forward(self, x):
+            return x + self.biases[0] + self.biases[1] + self.scales["main"]
+
+    _, params = M().export_tvm(
+        spec={"forward": {"x": nn.spec.Tensor((4,), "float32")}},
+        debug=False,
+    )
+    assert [name for name, _ in params] == ["biases.0", "biases.1", 
"scales.main"]
+
+
+def test_mutator_parameter_container_names():
+    seen = []
+
+    class Recorder(nn.Mutator):
+        def visit_param(self, name: str, node: nn.Parameter) -> Any:
+            seen.append(name)
+            return node
+
+    m = ParamContainerModule()
+    Recorder().visit_module("", m)
+
+    assert seen == [
+        "list_params.0",
+        "list_params.1",
+        "dict_params.foo",
+        "dict_params.bar",
+    ]
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to