This is an automated email from the ASF dual-hosted git repository.
yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 43adad701f [Relax][PyTorch] Add support for argsort, sort, topk ops
(#17810)
43adad701f is described below
commit 43adad701fd135de7cb3c082e4f87f2e63b16bf3
Author: Shushi Hong <[email protected]>
AuthorDate: Sat Apr 5 05:13:19 2025 +0800
[Relax][PyTorch] Add support for argsort, sort, topk ops (#17810)
* Update fx_translator.py
* Update base_fx_graph_translator.py
* Update test_frontend_from_fx.py
* Update test_frontend_from_fx.py
* fix lint
---
.../frontend/torch/base_fx_graph_translator.py | 28 ++++++++++
python/tvm/relax/frontend/torch/fx_translator.py | 3 ++
tests/python/relax/test_frontend_from_fx.py | 62 ++++++++++++++++++++++
3 files changed, 93 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index d99411bd56..affbd81e1c 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -960,6 +960,12 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
########## Manipulation ##########
+ def _argsort(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim",
-1)
+ descending = node.args[2] if len(node.args) > 2 else
node.kwargs.get("descending", False)
+ return self.block_builder.emit(relax.op.argsort(x, dim, descending))
+
def _cat(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
@@ -1090,6 +1096,12 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
raise Exception("Unexpected args " + str(node.args))
return self.block_builder.emit(relax.op.scatter_elements(x, index,
src, axis=dim))
+ def _sort(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim",
-1)
+ descending = node.args[2] if len(node.args) > 2 else
node.kwargs.get("descending", False)
+ return self.block_builder.emit(relax.op.sort(x, dim, descending))
+
def _split(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
split_size = node.args[1]
@@ -1140,6 +1152,22 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else
args[1:]
return self.block_builder.emit(relax.op.tile(x, dims))
+ def _topk(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ k = args[1] if len(args) > 1 else node.kwargs.get("k", 1)
+ dim = args[2] if len(args) > 2 else node.kwargs.get("dim", -1)
+ largest = args[3] if len(args) > 3 else node.kwargs.get("largest",
True)
+ _sorted = args[4] if len(args) > 4 else node.kwargs.get("_sorted",
True)
+
+ if not _sorted:
+ msg = "Currently supports only sorted output for topk operator."
+ raise AssertionError(msg)
+
+ return self.block_builder.emit(
+ relax.op.topk(x, k=k, axis=dim, largest=largest, ret_type="both",
dtype="int64")
+ )
+
def _transpose(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
full_idx = list(range(len(self.shape_of(args[0]))))
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 0b649eb755..a151a57ae6 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -754,6 +754,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"argmin": self._argmax_argmin(relax.op.argmin),
"where": self._where,
# tensor manipulation
+ "argsort": self._argsort,
"cat": self._cat,
"chunk": self._chunk,
"concat": self._cat,
@@ -772,11 +773,13 @@ class TorchFXImporter(BaseFXGraphImporter):
"scatter": self._scatter,
"select": self._select,
"size": self._size,
+ "sort": self._sort,
"split": self._split,
"squeeze": self._squeeze,
"stack": self._stack,
"take": self._take,
"tile": self._tile,
+ "topk": self._topk,
"transpose": self._transpose,
"unsqueeze": lambda node: self.block_builder.emit(
relax.op.expand_dims(self.env[node.args[0]], node.args[1])
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index d913baf13a..2c5560b577 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -4368,5 +4368,67 @@ def test_where():
)
+def test_argsort():
+ class Argsort(Module):
+ def forward(self, x):
+ return torch.argsort(x, dim=1, descending=True)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((5, 3), dtype="float32"),
+ ) -> R.Tensor((5, 3), dtype="int32"):
+ with R.dataflow():
+ lv: R.Tensor((5, 3), dtype="int32") = R.argsort(inp_0, axis=1,
descending=True)
+ gv: R.Tensor((5, 3), dtype="int32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Argsort(), [([5, 3], "float32")], {}, Expected)
+
+
+def test_sort():
+ class Sort(Module):
+ def forward(self, x):
+ return torch.sort(x, dim=1, descending=True)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((5, 3), dtype="float32"),
+ ) -> R.Tensor((5, 3), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((5, 3), dtype="float32") = R.sort(inp_0, axis=1,
descending=True)
+ gv: R.Tensor((5, 3), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Sort(), [([5, 3], "float32")], {}, Expected)
+
+
+def test_topk():
+ class Topk(Module):
+ def forward(self, x):
+ return torch.topk(x, k=2, dim=1, largest=True, sorted=True)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((5, 3), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2),
dtype="int64")):
+ with R.dataflow():
+ lv: R.Tuple(
+ R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2),
dtype="int64")
+ ) = R.topk(inp_0, k=2, axis=1, ret_type="both", largest=True,
dtype="int64")
+ gv: R.Tuple(R.Tensor((5, 2), dtype="float32"), R.Tensor((5,
2), dtype="int64")) = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Topk(), [([5, 3], "float32")], {}, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()