This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ca68bef6b5 [Relax][TVMScript] Print ExternFunc struct_info when 
non-default (#19416)
ca68bef6b5 is described below

commit ca68bef6b5ecec9872a7de4fb1c63139dd150048
Author: Neo Chien <[email protected]>
AuthorDate: Sun Apr 19 03:28:22 2026 +0800

    [Relax][TVMScript] Print ExternFunc struct_info when non-default (#19416)
    
    ### Summary
    
    1. Add HasDefaultExternFuncStructInfo helper to detect default
    FuncStructInfo for extern functions.
    
    2. Update relax::ExternFunc printer to:
     - emit global_symbol using the correct AccessPath attribute key,
    - conditionally include struct_info only when it differs from the
    default inferred-by-sinfo-args derive function,
    - use a variadic args array instead of a single positional literal to
    prepare the ExternFunc call.
    
    3. This reduces noisy/redundant output when printing ExternFunc nodes
    while preserving explicit struct_info when it conveys meaningful
    information.
    
    ---------
    
    Co-authored-by: cchung100m <[email protected]>
---
 src/script/printer/relax/function.cc               | 17 +++++++++--
 tests/python/relax/test_tvmscript_printer_relax.py | 35 ++++++++++++++++++++++
 2 files changed, 50 insertions(+), 2 deletions(-)

diff --git a/src/script/printer/relax/function.cc 
b/src/script/printer/relax/function.cc
index 24ae192c73..c759fa80ae 100644
--- a/src/script/printer/relax/function.cc
+++ b/src/script/printer/relax/function.cc
@@ -22,6 +22,15 @@ namespace tvm {
 namespace script {
 namespace printer {
 
+static bool HasDefaultExternFuncStructInfo(const relax::ExternFunc& n) {
+  const auto* sinfo = n->struct_info_.as<relax::FuncStructInfoNode>();
+  if (sinfo == nullptr || sinfo->params.defined() || sinfo->purity ||
+      !sinfo->ret->IsInstance<relax::ObjectStructInfoNode>()) {
+    return false;
+  }
+  return true;
+}
+
 bool AtTopLevelFunction(const IRDocsifier& d) {
   // fewer than 2 frames: not in a function at all
   if (d->frames.size() < 2) {
@@ -128,8 +137,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
 TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
     .set_dispatch<relax::ExternFunc>(  //
         "", [](relax::ExternFunc n, AccessPath n_p, IRDocsifier d) -> Doc {
-          // TODO(@junrushao): print more information out of extern function.
-          return Relax(d, 
"ExternFunc")->Call({LiteralDoc::Str(n->global_symbol, n_p)});
+          ffi::Array<ExprDoc> args;
+          args.push_back(LiteralDoc::Str(n->global_symbol, 
n_p->Attr("global_symbol")));
+          if (!HasDefaultExternFuncStructInfo(n)) {
+            args.push_back(d->AsDoc<ExprDoc>(n->struct_info_, 
n_p->Attr("struct_info_")));
+          }
+          return Relax(d, "ExternFunc")->Call(args);
         });
 
 TVM_SCRIPT_REPR(relax::FunctionNode, ReprPrintRelax);
diff --git a/tests/python/relax/test_tvmscript_printer_relax.py 
b/tests/python/relax/test_tvmscript_printer_relax.py
index c50c1fcb25..cf3e28388e 100644
--- a/tests/python/relax/test_tvmscript_printer_relax.py
+++ b/tests/python/relax/test_tvmscript_printer_relax.py
@@ -98,6 +98,41 @@ class Module:
     )
 
 
+def test_extern_func_with_struct_info():
+    obj = IRModule(
+        {
+            "my_ext": relax.ExternFunc(
+                "my_ext",
+                relax.FuncStructInfo([], 
relax.TensorStructInfo(dtype="float32", ndim=2), purity=True),
+            ),
+        }
+    )
+    _assert_print(
+        obj,
+        """
+# from tvm.script import ir as I
+# from tvm.script import relax as R
+
[email protected]_module
+class Module:
+    my_ext = R.ExternFunc("my_ext", R.Callable((), R.Tensor(dtype="float32", 
ndim=2), True))
+""",
+    )
+
+
+def test_extern_func_with_struct_info_roundtrip():
+    mod = IRModule(
+        {
+            "my_ext": relax.ExternFunc(
+                "my_ext",
+                relax.FuncStructInfo([], 
relax.TensorStructInfo(dtype="float32", ndim=2), purity=True),
+            ),
+        }
+    )
+    roundtrip = tvm.script.from_source(mod.script(verbose_expr=True))
+    tvm.ir.assert_structural_equal(mod, roundtrip)
+
+
 def test_nested_function():
     @I.ir_module
     class NestedFunction:

Reply via email to