This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 86794e7d91 [Relax][Frontend] Add ParameterList and ParameterDict
containers (#19495)
86794e7d91 is described below
commit 86794e7d91fa6dce66eb4c3995bc30a50948ae07
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Sat May 2 23:55:52 2026 +0900
[Relax][Frontend] Add ParameterList and ParameterDict containers (#19495)
This PR adds first-class `nn.ParameterList` and `nn.ParameterDict`
containers to the Relax frontend.
These containers provide PyTorch-like list/dict registration for raw
`nn.Parameter` objects while preserving Relax frontend semantics: values
must be explicit `nn.Parameter` instances, with no automatic
tensor-to-parameter conversion.
### Changes
- Add public `nn.ParameterList` and `nn.ParameterDict` exports.
- Support stable parameter names in traversal:
- `params.0`, `params.1`
- `params.foo`, `params.bar`
- Integrate the new containers with:
- `named_parameters()`
- `parameters()`
- `state_dict()`
- `load_state_dict()`
- `to(dtype=...)`
- `export_tvm()`
- `nn.Mutator`
- Add focused tests for basic container behavior, type validation,
nested traversal, export parameter names, state loading, dtype
conversion, and mutator naming.
---------
Co-authored-by: gemini-code-assist[bot]
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
---
python/tvm/relax/frontend/nn/__init__.py | 12 +-
python/tvm/relax/frontend/nn/core.py | 119 ++++++++++-
python/tvm/relax/frontend/nn/visitor.py | 62 +++++-
tests/python/relax/test_frontend_nn_mutator.py | 36 ++++
.../relax/test_frontend_nn_parameter_containers.py | 223 +++++++++++++++++++++
5 files changed, 446 insertions(+), 6 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/__init__.py
b/python/tvm/relax/frontend/nn/__init__.py
index 282944af98..1763ca152f 100644
--- a/python/tvm/relax/frontend/nn/__init__.py
+++ b/python/tvm/relax/frontend/nn/__init__.py
@@ -19,7 +19,17 @@
# pylint: disable=redefined-builtin
from . import op, spec
-from .core import Effect, Module, ModuleDict, ModuleList, Object, Parameter,
Tensor
+from .core import (
+ Effect,
+ Module,
+ ModuleDict,
+ ModuleList,
+ Object,
+ Parameter,
+ ParameterDict,
+ ParameterList,
+ Tensor,
+)
from .exporter import add_extern
from .extern import ExternModule, ObjectModule, SourceModule
from .modules import (
diff --git a/python/tvm/relax/frontend/nn/core.py
b/python/tvm/relax/frontend/nn/core.py
index f3886e94cb..3725a84d61 100644
--- a/python/tvm/relax/frontend/nn/core.py
+++ b/python/tvm/relax/frontend/nn/core.py
@@ -625,6 +625,63 @@ class ModuleDict(Module):
module.to(dtype=dtype)
+class ParameterDict(Module):
+ """Holds parameters in a dict."""
+
+ def __init__(
+ self,
+ params: OrderedDict[str, Parameter] | dict[str, Parameter] | None =
None,
+ ):
+ self.params: OrderedDict[str, Parameter] = OrderedDict()
+ if params is not None:
+ self.update(params)
+
+ def __iter__(self) -> Iterator[str]:
+ return iter(self.params)
+
+ def __getitem__(self, key: str) -> Parameter:
+ return self.params[key]
+
+ def __setitem__(self, key: str, param: Parameter) -> None:
+ if not isinstance(key, str):
+ raise TypeError(f"ParameterDict keys must be strings, but got
{type(key).__name__}")
+ if not isinstance(param, Parameter):
+ raise TypeError(f"ParameterDict values must be nn.Parameter, but
got {type(param).__name__}")
+ self.params[key] = param
+
+ def __len__(self) -> int:
+ return len(self.params)
+
+ def keys(self) -> Iterator[str]:
+ return self.params.keys()
+
+ def values(self) -> Iterator[Parameter]:
+ return self.params.values()
+
+ def items(self) -> Iterator[tuple[str, Parameter]]:
+ return self.params.items()
+
+ def get(self, key: str, default: Parameter | None = None) -> Parameter |
None:
+ return self.params.get(key, default)
+
+ def update(self, params: dict[str, Parameter]) -> None:
+ for key, param in params.items():
+ self[key] = param
+
+ def clear(self) -> None:
+ self.params.clear()
+
+ def pop(self, key: str) -> Parameter:
+ return self.params.pop(key)
+
+ def __contains__(self, key: str) -> bool:
+ return key in self.params
+
+ def to(self, dtype: str | None = None) -> None: # pylint:
disable=invalid-name
+ for param in self.params.values():
+ param.to(dtype=dtype)
+
+
class ModuleList(Module):
"""Holds submodules in a list."""
@@ -658,6 +715,44 @@ class ModuleList(Module):
return x
+class ParameterList(Module):
+ """Holds parameters in a list."""
+
+ def __init__(self, params: list[Parameter] | None = None):
+ self.params: list[Parameter] = []
+ if params is not None:
+ self.extend(params)
+
+ def __iter__(self) -> Iterator[Parameter]:
+ return iter(self.params)
+
+ def __getitem__(self, idx: int) -> Parameter:
+ return self.params[idx]
+
+ def __setitem__(self, idx: int, param: Parameter) -> None:
+ if not isinstance(param, Parameter):
+ raise TypeError(f"ParameterList elements must be nn.Parameter, but
got {type(param).__name__}")
+ self.params[idx] = param
+
+ def __len__(self) -> int:
+ return len(self.params)
+
+ def append(self, param: Parameter) -> None:
+ """Add a parameter to the end of the ParameterList"""
+ if not isinstance(param, Parameter):
+ raise TypeError(f"ParameterList elements must be nn.Parameter, but
got {type(param).__name__}")
+ self.params.append(param)
+
+ def extend(self, params: list[Parameter]) -> None:
+ """Add parameters to the end of the ParameterList"""
+ for param in params:
+ self.append(param)
+
+ def to(self, dtype: str | None = None) -> None: # pylint:
disable=invalid-name
+ for param in self.params:
+ param.to(dtype=dtype)
+
+
def wrap_nested(expr: rx.Expr, name: str) -> Tensor | Sequence[Tensor]:
"""Wrap the given relax.Expr, emit it using the current BlockBuilder,
and automatically handle nested cases if the expr represents a Tuple.
@@ -692,7 +787,17 @@ def wrap_nested(expr: rx.Expr, name: str) -> Tensor |
Sequence[Tensor]:
def _attribute_finder(root: Module, prefix: str, condition_yield:
Callable[[Any], bool]):
"""Find attributes that satisfy the condition recursively"""
- if isinstance(root, ModuleList):
+ if isinstance(root, ParameterList):
+ for i, param in enumerate(root):
+ if condition_yield(param):
+ yield prefix + f"{i}", param
+ return
+ elif isinstance(root, ParameterDict):
+ for name, param in root.items():
+ if condition_yield(param):
+ yield prefix + name, param
+ return
+ elif isinstance(root, ModuleList):
for i, subitem in enumerate(root):
yield from _attribute_finder(subitem, prefix + f"{i}.",
condition_yield)
return
@@ -703,6 +808,18 @@ def _attribute_finder(root: Module, prefix: str,
condition_yield: Callable[[Any]
for name, item in root.__dict__.items():
if condition_yield(item):
yield prefix + name, item
+ elif isinstance(item, ParameterList):
+ yield from _attribute_finder(
+ item,
+ prefix + name + ".",
+ condition_yield,
+ )
+ elif isinstance(item, ParameterDict):
+ yield from _attribute_finder(
+ item,
+ prefix + name + ".",
+ condition_yield,
+ )
elif isinstance(item, ModuleList):
yield from _attribute_finder(
item,
diff --git a/python/tvm/relax/frontend/nn/visitor.py
b/python/tvm/relax/frontend/nn/visitor.py
index e3279ceae5..69583eaae8 100644
--- a/python/tvm/relax/frontend/nn/visitor.py
+++ b/python/tvm/relax/frontend/nn/visitor.py
@@ -116,6 +116,42 @@ class Mutator:
"""
return self.visit(name, node)
+ def visit_parameterdict(self, name: str, node: nn.ParameterDict) -> Any:
+ """The base visiting method for mutation of nn.ParameterDict nodes.
+
+ Parameters
+ ----------
+ name : str
+ The name of the current node in parent's attribute.
+
+ node : nn.ParameterDict
+ The current node of nn.ParameterDict to mutate.
+
+ Returns
+ ------
+ ret_node: Any
+ The new node to replace current node.
+ """
+ return self.visit(name, node)
+
+ def visit_parameterlist(self, name: str, node: nn.ParameterList) -> Any:
+ """The base visiting method for mutation of nn.ParameterList nodes.
+
+ Parameters
+ ----------
+ name : str
+ The name of the current node in parent's attribute.
+
+ node : nn.ParameterList
+ The current node of nn.ParameterList to mutate.
+
+ Returns
+ ------
+ ret_node: Any
+ The new node to replace current node.
+ """
+ return self.visit(name, node)
+
def visit(self, name: str, node: Any) -> Any:
"""The base dispatching method for visiting of all nodes.
@@ -141,9 +177,19 @@ class Mutator:
else:
return f"{parent}.{child}"
- if isinstance(node, nn.ModuleList):
+ if isinstance(node, nn.ParameterList):
+ for i in range(len(node)):
+ node[i] = self.visit_param(_get_child_name(name, str(i)),
node[i])
+ elif isinstance(node, nn.ParameterDict):
+ for k, v in node.items():
+ node[k] = self.visit_param(_get_child_name(name, k), v)
+ elif isinstance(node, nn.ModuleList):
for i in range(len(node)):
- if isinstance(node[i], nn.ModuleDict):
+ if isinstance(node[i], nn.ParameterDict):
+ node[i] = self.visit_parameterdict(_get_child_name(name,
str(i)), node[i])
+ elif isinstance(node[i], nn.ParameterList):
+ node[i] = self.visit_parameterlist(_get_child_name(name,
str(i)), node[i])
+ elif isinstance(node[i], nn.ModuleDict):
node[i] = self.visit_moduledict(f"{name}.{i}", node[i])
elif isinstance(node[i], nn.ModuleList):
node[i] = self.visit_modulelist(f"{name}.{i}", node[i])
@@ -155,7 +201,11 @@ class Mutator:
node[i] = self.visit_param(f"{name}.{i}", node[i])
elif isinstance(node, nn.ModuleDict):
for k, v in node.items():
- if isinstance(v, nn.ModuleDict):
+ if isinstance(v, nn.ParameterDict):
+ node[k] = self.visit_parameterdict(_get_child_name(name,
k), v)
+ elif isinstance(v, nn.ParameterList):
+ node[k] = self.visit_parameterlist(_get_child_name(name,
k), v)
+ elif isinstance(v, nn.ModuleDict):
node[k] = self.visit_moduledict(_get_child_name(name, k),
v)
elif isinstance(v, nn.ModuleList):
node[k] = self.visit_modulelist(_get_child_name(name, k),
v)
@@ -167,7 +217,11 @@ class Mutator:
node[k] = self.visit_param(_get_child_name(name, k), v)
else:
for key, value in node.__dict__.items():
- if isinstance(value, nn.ModuleDict):
+ if isinstance(value, nn.ParameterDict):
+ setattr(node, key,
self.visit_parameterdict(_get_child_name(name, key), value))
+ elif isinstance(value, nn.ParameterList):
+ setattr(node, key,
self.visit_parameterlist(_get_child_name(name, key), value))
+ elif isinstance(value, nn.ModuleDict):
setattr(node, key,
self.visit_moduledict(_get_child_name(name, key), value))
elif isinstance(value, nn.ModuleList):
setattr(node, key,
self.visit_modulelist(_get_child_name(name, key), value))
diff --git a/tests/python/relax/test_frontend_nn_mutator.py
b/tests/python/relax/test_frontend_nn_mutator.py
index 253e24a4ed..23c8c9cde6 100644
--- a/tests/python/relax/test_frontend_nn_mutator.py
+++ b/tests/python/relax/test_frontend_nn_mutator.py
@@ -127,6 +127,42 @@ def test_mutator_naming_modulelist():
mutator.visit("mod_list", mod_list)
+def test_mutator_naming_parameter_containers():
+ class Module(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.param_list = nn.ParameterList(
+ [
+ nn.Parameter((32, 128), "float64"),
+ nn.Parameter((32, 128), "float32"),
+ ]
+ )
+ self.param_dict = nn.ParameterDict(
+ {
+ "k0": nn.Parameter((32, 128), "float16"),
+ "k1": nn.Parameter((32, 128), "float8"),
+ }
+ )
+
+ seen = []
+
+ class Mutator(nn.Mutator):
+ def visit_param(self, name: str, node: nn.Parameter) -> Any:
+ seen.append((name, node.dtype))
+ return node
+
+ module = Module()
+ mutator = Mutator()
+ mutator.visit("", module)
+
+ assert seen == [
+ ("param_list.0", "float64"),
+ ("param_list.1", "float32"),
+ ("param_dict.k0", "float16"),
+ ("param_dict.k1", "float8"),
+ ]
+
+
def test_mutator_module():
class SubModule1(nn.Module):
def __init__(self) -> None:
diff --git a/tests/python/relax/test_frontend_nn_parameter_containers.py
b/tests/python/relax/test_frontend_nn_parameter_containers.py
new file mode 100644
index 0000000000..d07a21405a
--- /dev/null
+++ b/tests/python/relax/test_frontend_nn_parameter_containers.py
@@ -0,0 +1,223 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Any
+
+import numpy as np
+import pytest
+
+import tvm
+import tvm.testing
+from tvm.relax.frontend import nn
+
+
+class ParamContainerModule(nn.Module):
+ def __init__(self):
+ self.list_params = nn.ParameterList(
+ [
+ nn.Parameter((4,), "float32"),
+ nn.Parameter((4,), "float32"),
+ ]
+ )
+ self.dict_params = nn.ParameterDict(
+ {
+ "foo": nn.Parameter((4,), "float32"),
+ "bar": nn.Parameter((4,), "float32"),
+ }
+ )
+
+
+def test_parameter_list_basic_behavior():
+ p0 = nn.Parameter((4,), "float32")
+ p1 = nn.Parameter((4,), "float32")
+ params = nn.ParameterList([p0])
+ params.append(p1)
+
+ assert len(params) == 2
+ assert params[0] is p0
+ assert list(params) == [p0, p1]
+
+ p2 = nn.Parameter((4,), "float32")
+ params[1] = p2
+ assert params[1] is p2
+
+ p3 = nn.Parameter((4,), "float32")
+ params.extend([p3])
+ assert list(params) == [p0, p2, p3]
+
+
+def test_parameter_dict_basic_behavior():
+ p0 = nn.Parameter((4,), "float32")
+ p1 = nn.Parameter((4,), "float32")
+ params = nn.ParameterDict({"foo": p0})
+ params["bar"] = p1
+
+ assert len(params) == 2
+ assert params["foo"] is p0
+ assert "bar" in params
+ assert list(params) == ["foo", "bar"]
+ assert list(params.keys()) == ["foo", "bar"]
+ assert list(params.values()) == [p0, p1]
+ assert list(params.items()) == [("foo", p0), ("bar", p1)]
+ assert params.get("foo") is p0
+
+ p2 = nn.Parameter((4,), "float32")
+ params.update({"baz": p2})
+ assert list(params.keys()) == ["foo", "bar", "baz"]
+ assert params.pop("baz") is p2
+ params.clear()
+ assert len(params) == 0
+
+
+def test_type_validation():
+ with pytest.raises(TypeError):
+ nn.ParameterList([object()])
+
+ with pytest.raises(TypeError):
+ nn.ParameterDict({"bad": object()})
+
+ with pytest.raises(TypeError):
+ nn.ParameterDict({1: nn.Parameter((4,), "float32")})
+
+ with pytest.raises(TypeError):
+ nn.ParameterList()[0] = object()
+
+
+def test_named_parameters_parameters_and_state_dict():
+ m = ParamContainerModule()
+
+ expected = [
+ "list_params.0",
+ "list_params.1",
+ "dict_params.foo",
+ "dict_params.bar",
+ ]
+
+ assert list(m.state_dict().keys()) == expected
+ assert [name for name, _ in m.named_parameters()] == expected
+ assert len(list(m.parameters())) == 4
+
+
+def test_nested_traversal_through_module_dict():
+ class Inner(nn.Module):
+ def __init__(self):
+ self.params = nn.ParameterList([nn.Parameter((4,), "float32")])
+
+ class Outer(nn.Module):
+ def __init__(self):
+ self.blocks = nn.ModuleDict({"inner": Inner()})
+
+ m = Outer()
+ assert list(m.state_dict().keys()) == ["blocks.inner.params.0"]
+
+
+def test_nested_traversal_through_module_list():
+ class Inner(nn.Module):
+ def __init__(self):
+ self.params = nn.ParameterList([nn.Parameter((4,), "float32")])
+
+ class Outer(nn.Module):
+ def __init__(self):
+ self.blocks = nn.ModuleList([Inner()])
+
+ m = Outer()
+ assert list(m.state_dict().keys()) == ["blocks.0.params.0"]
+
+
+def test_to_dtype():
+ m = ParamContainerModule()
+ m.to(dtype="float16")
+
+ assert m.list_params[0].dtype == "float16"
+ assert m.list_params[1].dtype == "float16"
+ assert m.dict_params["foo"].dtype == "float16"
+ assert m.dict_params["bar"].dtype == "float16"
+
+
+def test_load_state_dict():
+ m = ParamContainerModule()
+ p0 = nn.Parameter((4,), "float32")
+ p0.data = np.full((4,), 1.0, dtype="float32")
+ p1 = nn.Parameter((4,), "float32")
+ p1.data = np.full((4,), 2.0, dtype="float32")
+ p2 = nn.Parameter((4,), "float32")
+ p2.data = np.full((4,), 3.0, dtype="float32")
+ p3 = nn.Parameter((4,), "float32")
+ p3.data = np.full((4,), 4.0, dtype="float32")
+ state_dict = {
+ "list_params.0": p0,
+ "list_params.1": p1,
+ "dict_params.foo": p2,
+ "dict_params.bar": p3,
+ }
+
+ missing_keys, unexpected_keys = m.load_state_dict(state_dict)
+
+ assert missing_keys == []
+ assert unexpected_keys == []
+ tvm.testing.assert_allclose(m.list_params[0].data.numpy(), np.full((4,),
1.0, "float32"))
+ tvm.testing.assert_allclose(m.list_params[1].data.numpy(), np.full((4,),
2.0, "float32"))
+ tvm.testing.assert_allclose(
+ m.dict_params["foo"].data.numpy(), np.full((4,), 3.0, "float32")
+ )
+ tvm.testing.assert_allclose(
+ m.dict_params["bar"].data.numpy(), np.full((4,), 4.0, "float32")
+ )
+
+
+def test_export_tvm_parameter_names():
+ class M(nn.Module):
+ def __init__(self):
+ self.biases = nn.ParameterList(
+ [
+ nn.Parameter((4,), "float32"),
+ nn.Parameter((4,), "float32"),
+ ]
+ )
+ self.scales = nn.ParameterDict({"main": nn.Parameter((4,),
"float32")})
+
+ def forward(self, x):
+ return x + self.biases[0] + self.biases[1] + self.scales["main"]
+
+ _, params = M().export_tvm(
+ spec={"forward": {"x": nn.spec.Tensor((4,), "float32")}},
+ debug=False,
+ )
+ assert [name for name, _ in params] == ["biases.0", "biases.1",
"scales.main"]
+
+
+def test_mutator_parameter_container_names():
+ seen = []
+
+ class Recorder(nn.Mutator):
+ def visit_param(self, name: str, node: nn.Parameter) -> Any:
+ seen.append(name)
+ return node
+
+ m = ParamContainerModule()
+ Recorder().visit_module("", m)
+
+ assert seen == [
+ "list_params.0",
+ "list_params.1",
+ "dict_params.foo",
+ "dict_params.bar",
+ ]
+
+
+if __name__ == "__main__":
+ tvm.testing.main()