This is an automated email from the ASF dual-hosted git repository.

yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 43adad701f [Relax][PyTorch] Add support for argsort, sort, topk ops 
(#17810)
43adad701f is described below

commit 43adad701fd135de7cb3c082e4f87f2e63b16bf3
Author: Shushi Hong <[email protected]>
AuthorDate: Sat Apr 5 05:13:19 2025 +0800

    [Relax][PyTorch] Add support for argsort, sort, topk ops (#17810)
    
    * Update fx_translator.py
    
    * Update base_fx_graph_translator.py
    
    * Update test_frontend_from_fx.py
    
    * Update test_frontend_from_fx.py
    
    * fix lint
---
 .../frontend/torch/base_fx_graph_translator.py     | 28 ++++++++++
 python/tvm/relax/frontend/torch/fx_translator.py   |  3 ++
 tests/python/relax/test_frontend_from_fx.py        | 62 ++++++++++++++++++++++
 3 files changed, 93 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index d99411bd56..affbd81e1c 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -960,6 +960,12 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
 
     ########## Manipulation ##########
 
+    def _argsort(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 
-1)
+        descending = node.args[2] if len(node.args) > 2 else 
node.kwargs.get("descending", False)
+        return self.block_builder.emit(relax.op.argsort(x, dim, descending))
+
     def _cat(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
@@ -1090,6 +1096,12 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
             raise Exception("Unexpected args " + str(node.args))
         return self.block_builder.emit(relax.op.scatter_elements(x, index, 
src, axis=dim))
 
+    def _sort(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 
-1)
+        descending = node.args[2] if len(node.args) > 2 else 
node.kwargs.get("descending", False)
+        return self.block_builder.emit(relax.op.sort(x, dim, descending))
+
     def _split(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         split_size = node.args[1]
@@ -1140,6 +1152,22 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else 
args[1:]
         return self.block_builder.emit(relax.op.tile(x, dims))
 
+    def _topk(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        k = args[1] if len(args) > 1 else node.kwargs.get("k", 1)
+        dim = args[2] if len(args) > 2 else node.kwargs.get("dim", -1)
+        largest = args[3] if len(args) > 3 else node.kwargs.get("largest", 
True)
+        _sorted = args[4] if len(args) > 4 else node.kwargs.get("_sorted", 
True)
+
+        if not _sorted:
+            msg = "Currently supports only sorted output for topk operator."
+            raise AssertionError(msg)
+
+        return self.block_builder.emit(
+            relax.op.topk(x, k=k, axis=dim, largest=largest, ret_type="both", 
dtype="int64")
+        )
+
     def _transpose(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         full_idx = list(range(len(self.shape_of(args[0]))))
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 0b649eb755..a151a57ae6 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -754,6 +754,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "argmin": self._argmax_argmin(relax.op.argmin),
             "where": self._where,
             # tensor manipulation
+            "argsort": self._argsort,
             "cat": self._cat,
             "chunk": self._chunk,
             "concat": self._cat,
@@ -772,11 +773,13 @@ class TorchFXImporter(BaseFXGraphImporter):
             "scatter": self._scatter,
             "select": self._select,
             "size": self._size,
+            "sort": self._sort,
             "split": self._split,
             "squeeze": self._squeeze,
             "stack": self._stack,
             "take": self._take,
             "tile": self._tile,
+            "topk": self._topk,
             "transpose": self._transpose,
             "unsqueeze": lambda node: self.block_builder.emit(
                 relax.op.expand_dims(self.env[node.args[0]], node.args[1])
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index d913baf13a..2c5560b577 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -4368,5 +4368,67 @@ def test_where():
     )
 
 
+def test_argsort():
+    class Argsort(Module):
+        def forward(self, x):
+            return torch.argsort(x, dim=1, descending=True)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((5, 3), dtype="float32"),
+        ) -> R.Tensor((5, 3), dtype="int32"):
+            with R.dataflow():
+                lv: R.Tensor((5, 3), dtype="int32") = R.argsort(inp_0, axis=1, 
descending=True)
+                gv: R.Tensor((5, 3), dtype="int32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Argsort(), [([5, 3], "float32")], {}, Expected)
+
+
+def test_sort():
+    class Sort(Module):
+        def forward(self, x):
+            return torch.sort(x, dim=1, descending=True)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((5, 3), dtype="float32"),
+        ) -> R.Tensor((5, 3), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((5, 3), dtype="float32") = R.sort(inp_0, axis=1, 
descending=True)
+                gv: R.Tensor((5, 3), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Sort(), [([5, 3], "float32")], {}, Expected)
+
+
+def test_topk():
+    class Topk(Module):
+        def forward(self, x):
+            return torch.topk(x, k=2, dim=1, largest=True, sorted=True)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((5, 3), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2), 
dtype="int64")):
+            with R.dataflow():
+                lv: R.Tuple(
+                    R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2), 
dtype="int64")
+                ) = R.topk(inp_0, k=2, axis=1, ret_type="both", largest=True, 
dtype="int64")
+                gv: R.Tuple(R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 
2), dtype="int64")) = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Topk(), [([5, 3], "float32")], {}, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to