This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7504e3ed1a [REFACTOR][SCRIPT] TVMScript dialect-friendly refactor:
per-dialect restructure + dialect registry (#19479)
7504e3ed1a is described below
commit 7504e3ed1ab6e211a4102210fe97d63decfab1b6
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Apr 30 07:22:56 2026 -0400
[REFACTOR][SCRIPT] TVMScript dialect-friendly refactor: per-dialect
restructure + dialect registry (#19479)
## Summary
Restructure TVMScript to be dialect-agnostic at the script-core layer
while letting each extension dialect (TIRX, Relax) own its own
per-dialect script subtree. IR is below script in the dependency
stack and is NOT a peer dialect — its script handlers stay in the
shared core.
This PR folds together two coupled refactors that were initially
opened as separate PRs (#19478 and the original #19479); they
share rename / relocation surface so they ship as one cohesive
change.
## What this PR does
### Per-dialect script subtree (originally #19479)
- Moves per-dialect printer + builder from
`src/script/{printer,ir_builder}/{tirx,relax}/` to
`src/{tirx,relax}/script/{printer,builder}/`.
- Tightens `src/script/*.cc` CMake glob to the dialect-free core.
- Refactors `IRBuilder::DeclFunction` to dispatch via FFI registry
(`script.ir_builder.decl_function.<type-key>`); removes
cross-dialect includes from the shared core.
- Adds `tvm.script.register_dialect` API + `__getattr__` + a
`sys.meta_path` finder for Python-side dialect discovery.
In-tree dialects (tirx, relax) registered centrally in
`python/tvm/__init__.py`.
- Drops the obsolete static re-export shims at
`python/tvm/script/{parser,ir_builder}/{tirx,relax}/`.
### Dialect-agnostic printer config (originally #19478)
- Relocates `include/tvm/ir/script_printer.h` →
`include/tvm/script/printer/config.h` next to the rest of the
printer's public surface. The header is not IR-specific.
- Renames `TVM_SCRIPT_REPR` → `TVM_REGISTER_SCRIPT_AS_REPR` for
clarity (the macro registers Script as the kRepr callback +
per-type vtable dispatch). Aligns with the `TVM_REGISTER_*`
family.
- Drops dialect-hardcoded `PrinterConfig` fields (`tir_prefix`,
`relax_prefix`, `show_all_struct_info`, `buffer_dtype`) in favor
of a generic `ffi::Map<String, Any> extra_config` keyed by
`"<dialect>.<knob>"`. Each call site reads via the templated
accessor `config->GetExtraConfig<T>("...", default)`.
- Promotes `std::string` config fields to `ffi::String`.
After this lands, the script-printer core knows nothing specific
about any dialect — new dialects plug in via the registry pattern
with zero core edits. Public Python API surface unchanged.
---
CMakeLists.txt | 15 +-
include/tvm/ir/expr.h | 2 +-
include/tvm/ir/module.h | 2 +-
.../relax => relax/script/builder}/frame.h | 6 +-
.../ir_builder/relax => relax/script/builder}/ir.h | 8 +-
.../script_printer.h => script/printer/config.h} | 114 ++++++------
include/tvm/tirx/buffer.h | 2 +-
include/tvm/tirx/function.h | 2 +-
.../tirx => tirx/script/builder}/frame.h | 6 +-
.../ir_builder/tirx => tirx/script/builder}/ir.h | 8 +-
include/tvm/tirx/stmt.h | 2 +-
python/tvm/__init__.py | 7 +-
python/tvm/relax/__init__.py | 4 +
.../_ffi_api.py => relax/script/__init__.py} | 12 +-
.../script/builder/__init__.py} | 10 +-
.../relax => relax/script/builder}/_ffi_api.py | 0
.../script/builder}/distributed/__init__.py | 0
.../script/builder}/distributed/_ffi_api.py | 0
.../script/builder}/distributed/ir.py | 6 +-
.../relax => relax/script/builder}/frame.py | 2 +-
.../relax => relax/script/builder}/ir.py | 2 +-
.../relax => relax/script/parser}/__init__.py | 5 +-
.../parser/relax => relax/script/parser}/dist.py | 6 +-
.../parser/relax => relax/script/parser}/entry.py | 11 +-
.../parser/relax => relax/script/parser}/parser.py | 10 +-
python/tvm/relax/transform/legalize_ops/grad.py | 2 +-
python/tvm/runtime/script_printer.py | 14 +-
python/tvm/script/__init__.py | 206 ++++++++++++++++++++-
python/tvm/script/ir_builder/__init__.py | 32 +++-
python/tvm/script/ir_builder/relax/__init__.py | 21 ---
python/tvm/script/parser/__init__.py | 33 +++-
python/tvm/script/relax.py | 20 --
python/tvm/script/tirx.py | 20 --
python/tvm/tirx/__init__.py | 4 +
.../_ffi_api.py => tirx/script/__init__.py} | 12 +-
.../tirx => tirx/script/builder}/__init__.py | 0
.../tirx => tirx/script/builder}/_ffi_api.py | 0
.../script/builder}/external_kernel.py | 3 +-
.../tirx => tirx/script/builder}/frame.py | 3 +-
.../ir_builder/tirx => tirx/script/builder}/ir.py | 0
.../tirx => tirx/script/builder}/triton.py | 0
.../tirx => tirx/script/builder}/utils.py | 0
.../parser/tirx => tirx/script/parser}/__init__.py | 5 +-
.../parser/tirx => tirx/script/parser}/entry.py | 7 +-
.../tirx => tirx/script/parser}/operation.py | 3 +-
.../parser/tirx => tirx/script/parser}/parser.py | 11 +-
src/ir/expr.cc | 2 +-
src/ir/script_printer.cc | 50 +++--
src/ir/structural_equal.cc | 2 +-
.../relax => relax/script/builder}/distributed.cc | 2 +-
.../relax => relax/script/builder}/frame.cc | 4 +-
.../relax => relax/script/builder}/ir.cc | 2 +-
.../relax => relax/script/builder}/utils.h | 10 +-
.../relax => relax/script/printer}/binding.cc | 8 +-
.../printer/relax => relax/script/printer}/call.cc | 2 +-
.../relax => relax/script/printer}/distributed.cc | 8 +-
.../printer/relax => relax/script/printer}/expr.cc | 18 +-
.../relax => relax/script/printer}/function.cc | 4 +-
.../relax => relax/script/printer}/region.cc | 6 +-
.../relax => relax/script/printer}/struct_info.cc | 12 +-
.../printer/relax => relax/script/printer}/tir.cc | 2 +-
.../printer/relax => relax/script/printer}/type.cc | 8 +-
.../printer/relax => relax/script/printer}/utils.h | 11 +-
src/script/ir_builder/ir/ir.cc | 48 ++---
.../printer/{ir/distributed.cc => config.cc} | 33 ++--
src/script/printer/ir/distributed.cc | 2 -
src/script/printer/ir/ir.cc | 10 +-
src/script/printer/utils.h | 26 ++-
src/tirx/ir/expr.cc | 2 +-
.../tirx => tirx/script/builder}/frame.cc | 4 +-
.../ir_builder/tirx => tirx/script/builder}/ir.cc | 16 +-
.../tirx => tirx/script/builder}/utils.h | 10 +-
.../printer/tirx => tirx/script/printer}/block.cc | 4 +-
.../printer/tirx => tirx/script/printer}/buffer.cc | 14 +-
.../printer/tirx => tirx/script/printer}/expr.cc | 64 +++----
.../tirx => tirx/script/printer}/for_loop.cc | 2 +-
.../tirx => tirx/script/printer}/function.cc | 2 +-
.../printer/tirx => tirx/script/printer}/ir.cc | 10 +-
.../printer/tirx => tirx/script/printer}/stmt.cc | 23 +--
.../printer/tirx => tirx/script/printer}/utils.h | 8 +-
80 files changed, 668 insertions(+), 399 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index ee6726c1b9..af86e3d846 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -275,7 +275,19 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS
src/topi/*.cc
src/driver/*.cc
src/support/*.cc
- src/script/*.cc
+ # TVMScript shared core (Doc IR + dispatch infrastructure + IR-layer
+ # printer/builder). Per-dialect (relax, tirx, ...) printer/builder pieces
+ # live under src/<dialect>/script/. The list below is intentionally
explicit
+ # (not src/script/*.cc) so that any new file accidentally added under
+ # src/script/{printer,ir_builder}/<dialect>/ for a non-IR dialect is
+ # rejected at link time rather than silently included.
+ src/script/ir_builder/base.cc
+ src/script/ir_builder/ir/*.cc
+ src/script/printer/config.cc
+ src/script/printer/doc.cc
+ src/script/printer/doc_printer/*.cc
+ src/script/printer/ir_docsifier.cc
+ src/script/printer/ir/*.cc
src/relax/ir/*.cc
src/relax/op/*.cc
src/relax/analysis/*.cc
@@ -288,6 +300,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS
src/relax/distributed/*.cc
src/relax/distributed/transform/*.cc
src/relax/op/distributed/*.cc
+ src/relax/script/*.cc
src/relax/testing/*.cc
)
diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index cb73ec7a3a..1ce7a112a3 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -29,9 +29,9 @@
#include <tvm/ir/cast.h>
#include <tvm/ir/cow.h>
#include <tvm/ir/repr.h>
-#include <tvm/ir/script_printer.h>
#include <tvm/ir/source_map.h>
#include <tvm/ir/type.h>
+#include <tvm/script/printer/config.h>
#include <algorithm>
#include <functional>
diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h
index 137909bd47..5f9994c3df 100644
--- a/include/tvm/ir/module.h
+++ b/include/tvm/ir/module.h
@@ -32,9 +32,9 @@
#include <tvm/ir/expr.h>
#include <tvm/ir/function.h>
#include <tvm/ir/global_info.h>
-#include <tvm/ir/script_printer.h>
#include <tvm/ir/source_map.h>
#include <tvm/ir/type.h>
+#include <tvm/script/printer/config.h>
#include <string>
#include <unordered_map>
diff --git a/include/tvm/script/ir_builder/relax/frame.h
b/include/tvm/relax/script/builder/frame.h
similarity index 98%
rename from include/tvm/script/ir_builder/relax/frame.h
rename to include/tvm/relax/script/builder/frame.h
index 6dd8b6e2a1..799e602cd0 100644
--- a/include/tvm/script/ir_builder/relax/frame.h
+++ b/include/tvm/relax/script/builder/frame.h
@@ -16,8 +16,8 @@
* specific language governing permissions and limitations
* under the License.
*/
-#ifndef TVM_SCRIPT_IR_BUILDER_RELAX_FRAME_H_
-#define TVM_SCRIPT_IR_BUILDER_RELAX_FRAME_H_
+#ifndef TVM_RELAX_SCRIPT_BUILDER_FRAME_H_
+#define TVM_RELAX_SCRIPT_BUILDER_FRAME_H_
#include <tvm/ffi/reflection/registry.h>
#include <tvm/relax/block_builder.h>
@@ -329,4 +329,4 @@ class ElseFrame : public SeqExprFrame {
} // namespace script
} // namespace tvm
-#endif // TVM_SCRIPT_IR_BUILDER_RELAX_FRAME_H_
+#endif // TVM_RELAX_SCRIPT_BUILDER_FRAME_H_
diff --git a/include/tvm/script/ir_builder/relax/ir.h
b/include/tvm/relax/script/builder/ir.h
similarity index 96%
rename from include/tvm/script/ir_builder/relax/ir.h
rename to include/tvm/relax/script/builder/ir.h
index ac26ddc036..d0047b2ab1 100644
--- a/include/tvm/script/ir_builder/relax/ir.h
+++ b/include/tvm/relax/script/builder/ir.h
@@ -16,13 +16,13 @@
* specific language governing permissions and limitations
* under the License.
*/
-#ifndef TVM_SCRIPT_IR_BUILDER_RELAX_IR_H_
-#define TVM_SCRIPT_IR_BUILDER_RELAX_IR_H_
+#ifndef TVM_RELAX_SCRIPT_BUILDER_IR_H_
+#define TVM_RELAX_SCRIPT_BUILDER_IR_H_
#include <tvm/relax/expr.h>
+#include <tvm/relax/script/builder/frame.h>
#include <tvm/relax/struct_info.h>
#include <tvm/script/ir_builder/base.h>
-#include <tvm/script/ir_builder/relax/frame.h>
namespace tvm {
namespace script {
@@ -143,4 +143,4 @@ ElseFrame Else();
} // namespace script
} // namespace tvm
-#endif // TVM_SCRIPT_IR_BUILDER_RELAX_IR_H_
+#endif // TVM_RELAX_SCRIPT_BUILDER_IR_H_
diff --git a/include/tvm/ir/script_printer.h
b/include/tvm/script/printer/config.h
similarity index 66%
rename from include/tvm/ir/script_printer.h
rename to include/tvm/script/printer/config.h
index dff4a17a60..5f5486ac57 100644
--- a/include/tvm/ir/script_printer.h
+++ b/include/tvm/script/printer/config.h
@@ -17,11 +17,11 @@
* under the License.
*/
/*!
- * \file tvm/ir/script_printer.h
+ * \file tvm/script/printer/config.h
* \brief Printer class to print repr string of each AST/IR nodes.
*/
-#ifndef TVM_IR_SCRIPT_PRINTER_H_
-#define TVM_IR_SCRIPT_PRINTER_H_
+#ifndef TVM_SCRIPT_PRINTER_CONFIG_H_
+#define TVM_SCRIPT_PRINTER_CONFIG_H_
#include <tvm/ffi/any.h>
#include <tvm/ffi/container/array.h>
@@ -29,10 +29,10 @@
#include <tvm/ffi/reflection/access_path.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/string.h>
+#include <tvm/ir/cast.h>
#include <tvm/ir/node_functor.h>
#include <tvm/runtime/data_type.h>
-#include <iostream>
#include <string>
namespace tvm {
@@ -44,18 +44,12 @@ class PrinterConfigNode : public ffi::Object {
/*! \brief Whether or not to show metadata. */
bool show_meta = false;
/*! \brief The prefix of IR nodes */
- std::string ir_prefix = "I";
- /*! \brief The prefix of TIR nodes */
- std::string tir_prefix = "T";
- /*! \brief The prefix of Relax nodes */
- std::string relax_prefix = "R";
+ ffi::String ir_prefix = "I";
/*!
* \brief The alias of the current module at cross-function call
* \note Directly use module name if it's empty.
*/
- std::string module_alias = "cls";
- /*! \brief Default data type of TIR buffer */
- DataType buffer_dtype = DataType::Float(32);
+ ffi::String module_alias = "cls";
/*! \brief Default data type of integer literals */
DataType int_dtype = DataType::Int(32);
/*!
@@ -77,41 +71,6 @@ class PrinterConfigNode : public ffi::Object {
/*! \brief Whether variable names should include the object's address */
bool show_object_address = false;
- /*! \brief In Relax, whether to show all StructInfo annotations
- *
- * If true (default), all variable bindings will be annotated with
- * the struct info of the variable being bound.
- *
- * If false, the annotations will only be shown when they are
- * required for correct parsing of the Relax function. For example,
- * function parameters must always have struct info annotations, but
- * the struct info for expressions within a function body may be inferred
from their
- * arguments, and are therefore
- *
- * Example:
- *
- * \code{.py}
- * # func.show(show_all_struct_info=True)
- * @R.function
- * def func(
- * A: R.Tensor((10, 20), dtype="float32"),
- * B: R.Tensor((10,20), dtype="float32"),
- * ) -> R.Tensor((10, 20), dtype="float32"):
- * C: R.Tensor((10,20), dtype="float32") = R.add(A, B2)
- * return C
- *
- * # func.show(show_all_struct_info=False)
- * @R.function
- * def func(
- * A: R.Tensor((10, 20), dtype="float32"),
- * B: R.Tensor((10,20), dtype="float32"),
- * ) -> R.Tensor((10, 20), dtype="float32"):
- * C = R.add(A, B2)
- * return C
- * \endcode
- */
- bool show_all_struct_info = true;
-
/* \brief ffi::Object path to be underlined */
ffi::Array<ffi::reflection::AccessPath> path_to_underline;
/*! \brief ffi::Object path to be annotated. */
@@ -121,16 +80,39 @@ class PrinterConfigNode : public ffi::Object {
/*! \brief ffi::Object to be annotated. */
ffi::Map<ffi::ObjectRef, ffi::String> obj_to_annotate =
ffi::Map<ffi::ObjectRef, ffi::String>();
+ /*!
+ * \brief Generic extension map for dialect-specific config knobs.
+ *
+ * Keys are conventionally namespaced as "<dialect>.<knob>", e.g.:
+ * "tirx.prefix" — the TIR prefix (default "T")
+ * "tirx.buffer_dtype" — default buffer dtype (default float32)
+ * "relax.prefix" — the Relax prefix (default "R")
+ * "relax.show_all_struct_info" — whether to show all struct info (default
true)
+ *
+ * Use GetExtraConfig<T>(key, fallback) to read values with a typed fallback.
+ */
+ ffi::Map<ffi::String, ffi::Any> extra_config;
+
+ /*!
+ * \brief Look up a value in extra_config with type cast and fallback.
+ *
+ * Keys are conventionally namespaced as "<dialect>.<knob>"
+ * (e.g. "tirx.prefix", "relax.show_all_struct_info").
+ */
+ template <typename T>
+ T GetExtraConfig(const ffi::String& key, T fallback) const {
+ auto it = extra_config.find(key);
+ if (it == extra_config.end()) return fallback;
+ return Downcast<T>((*it).second);
+ }
+
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<PrinterConfigNode>()
.def_ro("binding_names", &PrinterConfigNode::binding_names)
.def_ro("show_meta", &PrinterConfigNode::show_meta)
.def_ro("ir_prefix", &PrinterConfigNode::ir_prefix)
- .def_ro("tir_prefix", &PrinterConfigNode::tir_prefix)
- .def_ro("relax_prefix", &PrinterConfigNode::relax_prefix)
.def_ro("module_alias", &PrinterConfigNode::module_alias)
- .def_ro("buffer_dtype", &PrinterConfigNode::buffer_dtype)
.def_ro("int_dtype", &PrinterConfigNode::int_dtype)
.def_ro("float_dtype", &PrinterConfigNode::float_dtype)
.def_ro("verbose_expr", &PrinterConfigNode::verbose_expr)
@@ -139,11 +121,11 @@ class PrinterConfigNode : public ffi::Object {
.def_ro("num_context_lines", &PrinterConfigNode::num_context_lines)
.def_ro("syntax_sugar", &PrinterConfigNode::syntax_sugar)
.def_ro("show_object_address", &PrinterConfigNode::show_object_address)
- .def_ro("show_all_struct_info",
&PrinterConfigNode::show_all_struct_info)
.def_ro("path_to_underline", &PrinterConfigNode::path_to_underline)
.def_ro("path_to_annotate", &PrinterConfigNode::path_to_annotate)
.def_ro("obj_to_underline", &PrinterConfigNode::obj_to_underline)
- .def_ro("obj_to_annotate", &PrinterConfigNode::obj_to_annotate);
+ .def_ro("obj_to_annotate", &PrinterConfigNode::obj_to_annotate)
+ .def_ro("extra_config", &PrinterConfigNode::extra_config);
}
ffi::Array<ffi::String> GetBuiltinKeywords();
@@ -177,5 +159,31 @@ class TVMScriptPrinter {
config.value_or(PrinterConfig()));
\
}
+/*!
+ * \brief The fallback body used by TVM_REGISTER_SCRIPT_AS_REPR.
+ *
+ * Tries to format \p obj via TVMScriptPrinter::Script; on error falls back to
+ * a plain address string. Defined in src/script/printer/config.cc so that
+ * <tvm/runtime/logging.h> is not pulled into this public header.
+ */
+TVM_DLL std::string RedirectedReprPrinterMethod(const ffi::ObjectRef& obj);
+
+/*!
+ * \brief Register Script as the kRepr callback for ObjectType and install
+ * the per-type dispatch entry in TVMScriptPrinter::vtable().
+ *
+ * \param ObjectType The concrete object node type (e.g. tirx::VarNode).
+ * \param Method The TVMScriptPrinter vtable dispatch function.
+ */
+#define TVM_REGISTER_SCRIPT_AS_REPR(ObjectType, Method)
\
+ TVM_FFI_STATIC_INIT_BLOCK() {
\
+ namespace refl = tvm::ffi::reflection;
\
+ refl::TypeAttrDef<ObjectType>().def(refl::type_attr::kRepr,
\
+ [](ffi::ObjectRef obj, ffi::Function)
-> ffi::String { \
+ return
RedirectedReprPrinterMethod(obj); \
+ });
\
+ }
\
+ TVM_STATIC_IR_FUNCTOR(TVMScriptPrinter,
vtable).set_dispatch<ObjectType>(Method)
+
} // namespace tvm
-#endif // TVM_IR_SCRIPT_PRINTER_H_
+#endif // TVM_SCRIPT_PRINTER_CONFIG_H_
diff --git a/include/tvm/tirx/buffer.h b/include/tvm/tirx/buffer.h
index 9adbf15f48..72640a80df 100644
--- a/include/tvm/tirx/buffer.h
+++ b/include/tvm/tirx/buffer.h
@@ -29,7 +29,7 @@
#include <tvm/ffi/string.h>
#include <tvm/ir/cow.h>
#include <tvm/ir/expr.h>
-#include <tvm/ir/script_printer.h>
+#include <tvm/script/printer/config.h>
#include <tvm/tirx/var.h>
#include <string>
diff --git a/include/tvm/tirx/function.h b/include/tvm/tirx/function.h
index e413a97652..aec5f30454 100644
--- a/include/tvm/tirx/function.h
+++ b/include/tvm/tirx/function.h
@@ -28,8 +28,8 @@
#include <tvm/ffi/container/variant.h>
#include <tvm/ir/cow.h>
#include <tvm/ir/function.h>
-#include <tvm/ir/script_printer.h>
#include <tvm/runtime/tensor.h>
+#include <tvm/script/printer/config.h>
#include <tvm/tirx/buffer.h>
#include <tvm/tirx/expr.h>
#include <tvm/tirx/stmt.h>
diff --git a/include/tvm/script/ir_builder/tirx/frame.h
b/include/tvm/tirx/script/builder/frame.h
similarity index 99%
rename from include/tvm/script/ir_builder/tirx/frame.h
rename to include/tvm/tirx/script/builder/frame.h
index 0979e9bf12..e90e7e7e74 100644
--- a/include/tvm/script/ir_builder/tirx/frame.h
+++ b/include/tvm/tirx/script/builder/frame.h
@@ -16,8 +16,8 @@
* specific language governing permissions and limitations
* under the License.
*/
-#ifndef TVM_SCRIPT_IR_BUILDER_TIR_FRAME_H_
-#define TVM_SCRIPT_IR_BUILDER_TIR_FRAME_H_
+#ifndef TVM_TIRX_SCRIPT_BUILDER_FRAME_H_
+#define TVM_TIRX_SCRIPT_BUILDER_FRAME_H_
#include <tvm/script/ir_builder/base.h>
#include <tvm/script/ir_builder/ir/frame.h>
@@ -602,4 +602,4 @@ class ElseFrame : public TIRFrame {
} // namespace script
} // namespace tvm
-#endif // TVM_SCRIPT_IR_BUILDER_TIR_FRAME_H_
+#endif // TVM_TIRX_SCRIPT_BUILDER_FRAME_H_
diff --git a/include/tvm/script/ir_builder/tirx/ir.h
b/include/tvm/tirx/script/builder/ir.h
similarity index 99%
rename from include/tvm/script/ir_builder/tirx/ir.h
rename to include/tvm/tirx/script/builder/ir.h
index bfe6261233..31cca16709 100644
--- a/include/tvm/script/ir_builder/tirx/ir.h
+++ b/include/tvm/tirx/script/builder/ir.h
@@ -16,12 +16,12 @@
* specific language governing permissions and limitations
* under the License.
*/
-#ifndef TVM_SCRIPT_IR_BUILDER_TIR_IR_H_
-#define TVM_SCRIPT_IR_BUILDER_TIR_IR_H_
+#ifndef TVM_TIRX_SCRIPT_BUILDER_IR_H_
+#define TVM_TIRX_SCRIPT_BUILDER_IR_H_
#include <tvm/script/ir_builder/base.h>
-#include <tvm/script/ir_builder/tirx/frame.h>
#include <tvm/tirx/op.h>
+#include <tvm/tirx/script/builder/frame.h>
namespace tvm {
namespace script {
@@ -519,4 +519,4 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());
} // namespace script
} // namespace tvm
-#endif // TVM_SCRIPT_IR_BUILDER_TIR_IR_H_
+#endif // TVM_TIRX_SCRIPT_BUILDER_IR_H_
diff --git a/include/tvm/tirx/stmt.h b/include/tvm/tirx/stmt.h
index 13be56eab1..7be0153e76 100644
--- a/include/tvm/tirx/stmt.h
+++ b/include/tvm/tirx/stmt.h
@@ -26,7 +26,7 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/cow.h>
-#include <tvm/ir/script_printer.h>
+#include <tvm/script/printer/config.h>
#include <tvm/tirx/expr.h>
#include <optional>
diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py
index 5798f684dd..72f212a9a1 100644
--- a/python/tvm/__init__.py
+++ b/python/tvm/__init__.py
@@ -44,7 +44,11 @@ from .ir import transform
from .ir import instrument
from . import ir
-# tvm.tirx
+# tvm.script — must be imported before any dialect package so that
+# tvm.script.register_dialect is reachable when dialect __init__.py files run.
+from . import script
+
+# tvm.tirx — registers itself via tvm.script.register_dialect in its __init__
from . import tirx
# tvm.s_tir
@@ -71,6 +75,7 @@ from .contrib import rocm as _rocm, nvcc as _nvcc
# Relax contain modules that are only available in compiler package
# Do not import them if TVM is built with runtime only
if not _RUNTIME_ONLY:
+ # tvm.relax — registers itself via tvm.script.register_dialect in its
__init__
from . import relax
# NOTE: This file should be python2 compatible so we can
diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py
index 65c44db3ac..3784c6a63d 100644
--- a/python/tvm/relax/__init__.py
+++ b/python/tvm/relax/__init__.py
@@ -120,3 +120,7 @@ from . import utils
from .vm_build import build, VMExecutable
from .binding_rewrite import DataflowBlockRewrite
+
+import tvm.script
+
+tvm.script.register_dialect("relax", "tvm.relax.script")
diff --git a/python/tvm/script/ir_builder/relax/distributed/_ffi_api.py
b/python/tvm/relax/script/__init__.py
similarity index 63%
copy from python/tvm/script/ir_builder/relax/distributed/_ffi_api.py
copy to python/tvm/relax/script/__init__.py
index 2ebc4935b8..acbd9d2435 100644
--- a/python/tvm/script/ir_builder/relax/distributed/_ffi_api.py
+++ b/python/tvm/relax/script/__init__.py
@@ -14,8 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""FFI APIs for tvm.script.ir_builder.relax.distributed"""
+"""Relax-layer TVMScript pieces (parser, builder).
-import tvm_ffi
+After the per-dialect TVMScript restructure, the Relax layer owns its own
+``script/{parser,builder}`` subpackages. ``tvm.script.relax`` resolves to
+this module via the dialect registry, so the public parser surface
+(``function``, ``Tensor``, ``match_cast``, etc.) is re-exported here.
+"""
-tvm_ffi.init_ffi_api("script.ir_builder.relax.distributed", __name__) #
pylint: disable=protected-access
+# pylint: disable=redefined-builtin,wildcard-import,unused-wildcard-import
+from .parser import *
+from .parser import dist
diff --git a/python/tvm/script/ir_builder/relax/distributed/_ffi_api.py
b/python/tvm/relax/script/builder/__init__.py
similarity index 73%
copy from python/tvm/script/ir_builder/relax/distributed/_ffi_api.py
copy to python/tvm/relax/script/builder/__init__.py
index 2ebc4935b8..021bcf9b1c 100644
--- a/python/tvm/script/ir_builder/relax/distributed/_ffi_api.py
+++ b/python/tvm/relax/script/builder/__init__.py
@@ -14,8 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""FFI APIs for tvm.script.ir_builder.relax.distributed"""
+"""Package tvm.relax.script.builder.
-import tvm_ffi
+Holds the per-dialect ir_builder API for Relax. The legacy path
+``tvm.script.ir_builder.relax`` resolves here via the dialect registry.
+"""
-tvm_ffi.init_ffi_api("script.ir_builder.relax.distributed", __name__) #
pylint: disable=protected-access
+# pylint: disable=wildcard-import,redefined-builtin
+from . import distributed, frame, ir
+from .ir import *
diff --git a/python/tvm/script/ir_builder/relax/_ffi_api.py
b/python/tvm/relax/script/builder/_ffi_api.py
similarity index 100%
rename from python/tvm/script/ir_builder/relax/_ffi_api.py
rename to python/tvm/relax/script/builder/_ffi_api.py
diff --git a/python/tvm/script/ir_builder/relax/distributed/__init__.py
b/python/tvm/relax/script/builder/distributed/__init__.py
similarity index 100%
rename from python/tvm/script/ir_builder/relax/distributed/__init__.py
rename to python/tvm/relax/script/builder/distributed/__init__.py
diff --git a/python/tvm/script/ir_builder/relax/distributed/_ffi_api.py
b/python/tvm/relax/script/builder/distributed/_ffi_api.py
similarity index 100%
copy from python/tvm/script/ir_builder/relax/distributed/_ffi_api.py
copy to python/tvm/relax/script/builder/distributed/_ffi_api.py
diff --git a/python/tvm/script/ir_builder/relax/distributed/ir.py
b/python/tvm/relax/script/builder/distributed/ir.py
similarity index 97%
rename from python/tvm/script/ir_builder/relax/distributed/ir.py
rename to python/tvm/relax/script/builder/distributed/ir.py
index 485e91e3e5..a74727c995 100644
--- a/python/tvm/script/ir_builder/relax/distributed/ir.py
+++ b/python/tvm/relax/script/builder/distributed/ir.py
@@ -40,12 +40,12 @@ from tvm.relax.op.distributed import (
from tvm.relax.op.distributed import (
redistribute as _redistribute,
)
+from tvm.relax.script.builder.ir import py_str
from tvm.relax.utils import convert_to_expr
from tvm.runtime import _tensor
+from tvm.script.ir_builder import IRBuilder
+from tvm.script.ir_builder.ir import IRModuleFrame
-from ... import IRBuilder
-from ...ir import IRModuleFrame
-from ..ir import py_str
from . import _ffi_api
diff --git a/python/tvm/script/ir_builder/relax/frame.py
b/python/tvm/relax/script/builder/frame.py
similarity index 96%
rename from python/tvm/script/ir_builder/relax/frame.py
rename to python/tvm/relax/script/builder/frame.py
index 028fa56b64..fec8a88af9 100644
--- a/python/tvm/script/ir_builder/relax/frame.py
+++ b/python/tvm/relax/script/builder/frame.py
@@ -18,7 +18,7 @@
from tvm_ffi import register_object as _register_object
-from ..base import IRBuilderFrame
+from tvm.script.ir_builder.base import IRBuilderFrame
@_register_object("script.ir_builder.relax.RelaxFrame")
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/relax/script/builder/ir.py
similarity index 99%
rename from python/tvm/script/ir_builder/relax/ir.py
rename to python/tvm/relax/script/builder/ir.py
index 51aae5350e..f62164dbd7 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/relax/script/builder/ir.py
@@ -214,8 +214,8 @@ from tvm.runtime._tensor import (
vulkan,
webgpu,
)
+from tvm.script.ir_builder.ir import decl_function, lookup_vdevice
-from ..ir import decl_function, lookup_vdevice
from . import _ffi_api, frame
##################### Python Native Function Alias ######################
diff --git a/python/tvm/script/parser/relax/__init__.py
b/python/tvm/relax/script/parser/__init__.py
similarity index 92%
rename from python/tvm/script/parser/relax/__init__.py
rename to python/tvm/relax/script/parser/__init__.py
index ddac5342ee..70703eac11 100644
--- a/python/tvm/script/parser/relax/__init__.py
+++ b/python/tvm/relax/script/parser/__init__.py
@@ -20,8 +20,9 @@
from typing import TYPE_CHECKING
-from ...ir_builder.relax import * # pylint: disable=redefined-builtin
-from ...ir_builder.relax import ir as _relax
+from tvm.relax.script.builder import * # pylint: disable=redefined-builtin
+from tvm.relax.script.builder import ir as _relax
+
from . import parser as _parser
from .entry import Callable, Object, Prim, Shape, Tensor, Tuple, match_cast
diff --git a/python/tvm/script/parser/relax/dist.py
b/python/tvm/relax/script/parser/dist.py
similarity index 98%
rename from python/tvm/script/parser/relax/dist.py
rename to python/tvm/relax/script/parser/dist.py
index cc3da49215..f4e59f95fe 100644
--- a/python/tvm/script/parser/relax/dist.py
+++ b/python/tvm/relax/script/parser/dist.py
@@ -22,9 +22,7 @@ from typing import Any, Optional, Union
from tvm.ir import Range
from tvm.relax import TensorStructInfo
from tvm.relax.distributed import DeviceMesh, DTensorStructInfo, Placement,
device_mesh
-from tvm.script.ir_builder import IRBuilder
-from tvm.script.ir_builder.ir import IRModuleFrame
-from tvm.script.ir_builder.relax.distributed import (
+from tvm.relax.script.builder.distributed import (
annotate_sharding,
call_tir,
call_tir_local_view,
@@ -32,6 +30,8 @@ from tvm.script.ir_builder.relax.distributed import (
redistribute,
redistribute_replica_to_shard,
)
+from tvm.script.ir_builder import IRBuilder
+from tvm.script.ir_builder.ir import IRModuleFrame
from tvm.tirx import PrimExpr
from .entry import StructInfoProxy, TensorProxy
diff --git a/python/tvm/script/parser/relax/entry.py
b/python/tvm/relax/script/parser/entry.py
similarity index 98%
rename from python/tvm/script/parser/relax/entry.py
rename to python/tvm/relax/script/parser/entry.py
index a7e6181412..5a14fe4ecd 100644
--- a/python/tvm/script/parser/relax/entry.py
+++ b/python/tvm/relax/script/parser/entry.py
@@ -34,15 +34,14 @@ from tvm.relax import (
TupleStructInfo,
)
from tvm.relax.expr import Var
+from tvm.relax.script import builder as R
from tvm.runtime import ObjectConvertible
+from tvm.script.ir_builder.ir import lookup_vdevice
+from tvm.script.parser._core import doc, parse, utils
+from tvm.script.parser.core.entry import scan_macro
+from tvm.script.parser.core.parser import Parser, ScriptMacro
from tvm.tirx import PrimExpr
-from ...ir_builder import relax as R
-from .._core import doc, parse, utils
-from ..core.entry import scan_macro
-from ..core.parser import Parser, ScriptMacro
-from ..ir import lookup_vdevice
-
FType = TypeVar("FType", bound=_Callable)
############################## R.function ##############################
diff --git a/python/tvm/script/parser/relax/parser.py
b/python/tvm/relax/script/parser/parser.py
similarity index 98%
rename from python/tvm/script/parser/relax/parser.py
rename to python/tvm/relax/script/parser/parser.py
index 8800ea156a..47daf17d35 100644
--- a/python/tvm/script/parser/relax/parser.py
+++ b/python/tvm/relax/script/parser/parser.py
@@ -23,13 +23,13 @@ from typing import Any
from tvm import relax, tirx
from tvm.ir import GlobalVar, structural_equal
from tvm.relax import Expr, StructInfo
+from tvm.relax.script import builder as R
+from tvm.relax.script.builder.frame import BindingBlockFrame
from tvm.relax.utils import convert_to_expr
-from tvm.script.ir_builder.relax.frame import BindingBlockFrame
+from tvm.script.ir_builder import ir as I
+from tvm.script.ir_builder.base import IRBuilder
+from tvm.script.parser._core import Parser, dispatch, doc
-from ...ir_builder import ir as I
-from ...ir_builder import relax as R
-from ...ir_builder.base import IRBuilder
-from .._core import Parser, dispatch, doc
from .entry import (
MatchCastPair,
StructInfoProxy,
diff --git a/python/tvm/relax/transform/legalize_ops/grad.py
b/python/tvm/relax/transform/legalize_ops/grad.py
index 53222b5d50..cf8e7764d5 100644
--- a/python/tvm/relax/transform/legalize_ops/grad.py
+++ b/python/tvm/relax/transform/legalize_ops/grad.py
@@ -22,7 +22,7 @@ import logging
from tvm import te, tirx, topi
from tvm.script.ir_builder import IRBuilder
from tvm.script.ir_builder import tirx as T
-from tvm.script.ir_builder.tirx.utils import buffer_proxy
+from tvm.tirx.script.builder.utils import buffer_proxy
from ...block_builder import BlockBuilder
from ...expr import Call, Expr
diff --git a/python/tvm/runtime/script_printer.py
b/python/tvm/runtime/script_printer.py
index 6ba9abb032..31f39acac9 100644
--- a/python/tvm/runtime/script_printer.py
+++ b/python/tvm/runtime/script_printer.py
@@ -34,10 +34,7 @@ class PrinterConfig(Object):
binding_names: Sequence[str]
show_meta: bool
ir_prefix: str
- tir_prefix: str
- relax_prefix: str
module_alias: str
- buffer_dtype: str
int_dtype: str
float_dtype: str
verbose_expr: bool
@@ -46,7 +43,7 @@ class PrinterConfig(Object):
num_context_lines: int
syntax_sugar: bool
show_object_address: bool
- show_all_struct_info: bool
+ extra_config: dict
path_to_underline: list[AccessPath] | None
path_to_annotate: dict[AccessPath, str] | None
obj_to_underline: list[AccessPath] | None
@@ -81,10 +78,7 @@ class PrinterConfig(Object):
cfg = {
"show_meta": show_meta,
"ir_prefix": ir_prefix,
- "tir_prefix": tir_prefix,
- "relax_prefix": relax_prefix,
"module_alias": module_alias,
- "buffer_dtype": buffer_dtype,
"int_dtype": int_dtype,
"float_dtype": float_dtype,
"verbose_expr": verbose_expr,
@@ -93,11 +87,15 @@ class PrinterConfig(Object):
"num_context_lines": num_context_lines,
"syntax_sugar": syntax_sugar,
"show_object_address": show_object_address,
- "show_all_struct_info": show_all_struct_info,
"path_to_underline": path_to_underline,
"path_to_annotate": path_to_annotate,
"obj_to_underline": obj_to_underline,
"obj_to_annotate": obj_to_annotate,
+ # Dialect-specific config via dotted keys in extra_config
+ "tirx.prefix": tir_prefix,
+ "tirx.buffer_dtype": buffer_dtype,
+ "relax.prefix": relax_prefix,
+ "relax.show_all_struct_info": show_all_struct_info,
}
if name is not None:
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/__init__.py
index d44f2f7e87..99c7aec2bf 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/__init__.py
@@ -1,4 +1,3 @@
-# isort: skip_file
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -15,7 +14,206 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""TVM Script APIs of TVM Python Package"""
+"""
+TVMScript public namespace.
-from .parser import ir, ir_module
-from .parser import parse as from_source
+Dialect resolution mechanism
+----------------------------
+
+``tvm.script`` is a virtual namespace: dialect names like ``tirx`` and
+``relax`` are not bound as static attributes here. Instead:
+
+- ``register_dialect(name, module_path)`` writes an entry to
+ ``_DIALECT_REGISTRY: dict[str, str]``. Each in-tree dialect's
+ ``__init__.py`` calls this on import (e.g., ``tvm.tirx.__init__.py``
+ calls ``tvm.script.register_dialect("tirx", "tvm.tirx.script")``).
+ Out-of-tree dialects can register themselves the same way.
+
+- ``__getattr__(name)`` (PEP 562) fires on missing attribute access.
+ If ``name`` is in ``_DIALECT_REGISTRY``, the listed module is imported
+ and cached as a normal module attribute. Subsequent accesses
+ skip ``__getattr__`` (cached in ``globals()``).
+
+- Subpackages ``tvm.script.parser``, ``tvm.script.ir_builder``, etc.
+ each define their own ``__getattr__`` that consults the SAME
+ ``_DIALECT_REGISTRY`` and appends their suffix. So
+ ``tvm.script.parser.tirx`` resolves to ``tvm.tirx.script.parser`` via
+ the dialect registry + ``.parser`` suffix.
+
+- For deep statement-form imports like
+ ``from tvm.script.parser.tirx.entry import ObjectProxy``, PEP 562's
+ ``__getattr__`` is not enough — it only handles one-level
+ ``from X import Y``. A ``sys.meta_path`` finder (see
+ ``_DialectRedirectFinder``) intercepts the import machinery to
+ register the real module under the legacy name in ``sys.modules``,
+ so subsequent attribute walks resolve correctly.
+
+Each dialect's ``tvm.<dialect>.script`` package MUST expose ``parser``,
+``ir_builder``, and (where applicable) ``printer`` as submodules. This
+convention is what makes the suffix-append redirect work uniformly.
+IR is foundational (script depends on ir) and is NOT a dialect; its
+script handlers live in the shared core, not via this registry.
+
+Bootstrap order
+---------------
+
+``python/tvm/__init__.py`` imports ``tvm.script`` BEFORE importing any
+dialect package (``tvm.tirx``, ``tvm.relax``, …). This guarantees that
+``tvm.script.register_dialect`` is reachable the moment a dialect's own
+``__init__.py`` runs and calls it. The ``tvm.script`` module itself
+stays dialect-agnostic at load time (no dialect submodules are eagerly
+imported here), so there is no circular dependency.
+"""
+
+import importlib
+import importlib.util
+import sys
+from typing import Any
+
+_DIALECT_REGISTRY: dict[str, str] = {}
+
+# Subpackages of `tvm.script` whose per-dialect children are redirected to a
+# matching subpackage under `tvm.<dialect>.script`. The values are the
+# subpackage name on the dialect side (e.g. `tvm.<dialect>.script.parser`).
+_REDIRECTED_SUBPACKAGES = {
+ "tvm.script.parser": "parser",
+ "tvm.script.ir_builder": "builder",
+}
+
+
+def register_dialect(name: str, module_path: str) -> None:
+ """Register a dialect's script package path.
+
+ Writes ``name -> module_path`` into ``_DIALECT_REGISTRY``. After
+ registration, ``tvm.script.<name>`` resolves to ``module_path`` via
+ ``__getattr__``, and ``tvm.script.parser.<name>`` /
``tvm.script.ir_builder.<name>``
+ resolve to ``module_path + ".parser"`` / ``module_path + ".builder"`` etc.
+ via each subpackage's own ``__getattr__``. Deep statement-form imports
+ (e.g., ``from tvm.script.parser.<name>.entry import X``) are handled
+ by ``_DialectRedirectFinder`` on ``sys.meta_path``.
+
+ This function is idempotent — re-registering the same name with the same
+ path is harmless.
+
+ Each in-tree dialect calls this from its own ``__init__.py``::
+
+ import tvm.script
+ tvm.script.register_dialect("tirx", "tvm.tirx.script")
+
+ Out-of-tree dialects do the same in their own package init without
+ editing any in-tree file.
+
+ Parameters
+ ----------
+ name : str
+ The short name exposed under ``tvm.script.<name>`` (e.g. ``"tirx"``).
+ module_path : str
+ The full dotted module path of the dialect's script package, e.g.
+ ``"tvm.tirx.script"``. That package must expose ``parser`` and
+ ``ir_builder`` as submodules (and ``printer`` where applicable) so
+ that the suffix-append redirect works uniformly.
+ """
+ _DIALECT_REGISTRY[name] = module_path
+
+
+def _redirect_target(fullname: str) -> str | None:
+ """Return the target module path for a redirected ``tvm.script[...]`` name.
+
+ Returns ``None`` if ``fullname`` is not a redirected name.
+ """
+ if fullname.startswith("tvm.script."):
+ # tvm.script.<dialect>[.subpath]
+ rest = fullname[len("tvm.script.") :]
+ head, _, tail = rest.partition(".")
+ if head in _DIALECT_REGISTRY and "." not in head:
+ target = _DIALECT_REGISTRY[head]
+ return f"{target}.{tail}" if tail else target
+ # tvm.script.parser.<dialect>[.subpath] /
tvm.script.ir_builder.<dialect>[.subpath]
+ for prefix, sub in _REDIRECTED_SUBPACKAGES.items():
+ if fullname == prefix or not fullname.startswith(prefix + "."):
+ continue
+ rest = fullname[len(prefix) + 1 :]
+ head, _, tail = rest.partition(".")
+ if head in _DIALECT_REGISTRY:
+ target = f"{_DIALECT_REGISTRY[head]}.{sub}"
+ return f"{target}.{tail}" if tail else target
+ return None
+
+
+class _DialectRedirectFinder:
+ """``sys.meta_path`` finder that redirects ``tvm.script.<dialect>`` import
paths.
+
+ PEP 562 ``__getattr__`` only handles one-level attribute lookups
+ (``from tvm.script import tirx``). It cannot intercept deep
+ statement-form imports such as::
+
+ from tvm.script.parser.tirx.entry import ObjectProxy
+ import tvm.script.ir_builder.relax.ir
+
+ This finder is installed on ``sys.meta_path`` to cover those cases.
+ When the import machinery asks for a module whose full name starts with
+ ``tvm.script.<dialect>`` (or ``tvm.script.parser.<dialect>``, etc.) and
+ that dialect is in ``_DIALECT_REGISTRY``, :meth:`find_spec` imports the
+ real target module (e.g. ``tvm.tirx.script.parser.entry``) and registers
+ it in ``sys.modules`` under the legacy name, so all subsequent imports and
+ attribute walks resolve correctly without going through the redirect again.
+ """
+
+ @classmethod
+ def find_spec(cls, fullname, path, target=None):
+ redirected = _redirect_target(fullname)
+ if redirected is None:
+ return None
+ # Resolve the target module and alias it under the legacy name.
+ module = importlib.import_module(redirected)
+ sys.modules[fullname] = module
+ return importlib.util.spec_from_loader(fullname, _AliasLoader(module))
+
+
+class _AliasLoader:
+ """Loader that returns an already-resolved module for an alias spec."""
+
+ def __init__(self, module):
+ self._module = module
+
+ def create_module(self, spec):
+ return self._module
+
+ def exec_module(self, module):
+ # Module is already populated by the redirect target.
+ return None
+
+
+# Install the redirect finder once. Re-importing tvm.script (e.g. during a
+# pytest reload) must not stack duplicates.
+if not any(isinstance(f, _DialectRedirectFinder) for f in sys.meta_path):
+ sys.meta_path.append(_DialectRedirectFinder())
+
+
+def __getattr__(name: str) -> Any:
+ if name in _DIALECT_REGISTRY:
+ module = importlib.import_module(_DIALECT_REGISTRY[name])
+ globals()[name] = module
+ return module
+ if name == "ir":
+ # IR is foundational — its parser is a real submodule under
+ # tvm.script.parser.ir, exposed here as `tvm.script.ir` for the
+ # legacy `from tvm.script import ir as I` pattern.
+ ir_parser = importlib.import_module("tvm.script.parser.ir")
+ globals()["ir"] = ir_parser
+ return ir_parser
+ if name in ("from_source", "parse"):
+ from .parser._core import parse # pylint:
disable=import-outside-toplevel
+
+ globals()["from_source"] = parse
+ globals()["parse"] = parse
+ return parse
+ if name == "ir_module":
+ # ir_module lives in the IR parser at tvm.script.parser.ir; the IR
+ # layer is foundational, so we resolve it directly rather than via
+ # the dialect registry.
+ ir_parser = importlib.import_module("tvm.script.parser.ir")
+ ir_module_value = ir_parser.ir_module
+ globals()["ir_module"] = ir_module_value
+ return ir_module_value
+ raise AttributeError(f"module 'tvm.script' has no attribute {name!r}")
diff --git a/python/tvm/script/ir_builder/__init__.py
b/python/tvm/script/ir_builder/__init__.py
index 7925a8830d..ee114fd1d8 100644
--- a/python/tvm/script/ir_builder/__init__.py
+++ b/python/tvm/script/ir_builder/__init__.py
@@ -1,4 +1,3 @@
-# isort: skip_file
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -15,6 +14,35 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""tvm.script.ir_builder is a generic IR builder for TVM."""
+"""The ir_builder subpackage of TVMScript.
+
+Per-dialect builder submodules (``tvm.script.ir_builder.tirx``, etc.) are
+resolved lazily via :data:`tvm.script._DIALECT_REGISTRY`. When a dialect
+is accessed (e.g. ``tvm.script.ir_builder.tirx``), this subpackage's
+``__getattr__`` looks up the dialect in ``_DIALECT_REGISTRY`` and imports
+``<dialect_module_path>.builder`` (e.g. ``tvm.tirx.script.builder``),
+caching the result so subsequent accesses skip ``__getattr__``.
+
+The IR layer is foundational and is NOT registered as a dialect — its
+builder lives as a real submodule ``tvm.script.ir_builder.ir``.
+
+See :mod:`tvm.script` for a full description of the dialect resolution
+mechanism, including the ``_DialectRedirectFinder`` that handles
+deep statement-form imports.
+"""
+
+import importlib
+from typing import Any
from .base import IRBuilder
+
+
+def __getattr__(name: str) -> Any:
+ # Lazy import to avoid loading tvm.script during dialect bootstrap.
+ from tvm.script import _DIALECT_REGISTRY # pylint:
disable=import-outside-toplevel
+
+ if name in _DIALECT_REGISTRY:
+ module = importlib.import_module(f"{_DIALECT_REGISTRY[name]}.builder")
+ globals()[name] = module
+ return module
+ raise AttributeError(f"module 'tvm.script.ir_builder' has no attribute
{name!r}")
diff --git a/python/tvm/script/ir_builder/relax/__init__.py
b/python/tvm/script/ir_builder/relax/__init__.py
deleted file mode 100644
index 35291ce7ea..0000000000
--- a/python/tvm/script/ir_builder/relax/__init__.py
+++ /dev/null
@@ -1,21 +0,0 @@
-# isort: skip_file
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Package tvm.script.ir_builder.relax"""
-
-from . import distributed, frame
-from .ir import * # pylint: disable=wildcard-import,redefined-builtin
diff --git a/python/tvm/script/parser/__init__.py
b/python/tvm/script/parser/__init__.py
index 16a6a10133..279b0ec00a 100644
--- a/python/tvm/script/parser/__init__.py
+++ b/python/tvm/script/parser/__init__.py
@@ -1,4 +1,3 @@
-# isort: skip_file
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -15,8 +14,38 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""The parser"""
+"""The parser subpackage of TVMScript.
+
+Per-dialect parser submodules (``tvm.script.parser.tirx``, etc.) are
+resolved lazily via :data:`tvm.script._DIALECT_REGISTRY`. When a dialect
+is accessed (e.g. ``tvm.script.parser.tirx``), this subpackage's
+``__getattr__`` looks up the dialect in ``_DIALECT_REGISTRY`` and imports
+``<dialect_module_path>.parser`` (e.g. ``tvm.tirx.script.parser``),
+caching the result so subsequent accesses skip ``__getattr__``.
+
+The IR layer is foundational and is NOT registered as a dialect — its
+parser lives as a real submodule ``tvm.script.parser.ir``, with
+``ir_module`` re-exported at this level for convenience.
+
+See :mod:`tvm.script` for a full description of the dialect resolution
+mechanism, including the ``_DialectRedirectFinder`` that handles
+deep statement-form imports.
+"""
+
+import importlib
+from typing import Any
from . import _core, ir
from ._core import parse
from .ir import ir_module
+
+
+def __getattr__(name: str) -> Any:
+ # Lazy import to avoid loading tvm.script during dialect bootstrap.
+ from tvm.script import _DIALECT_REGISTRY # pylint:
disable=import-outside-toplevel
+
+ if name in _DIALECT_REGISTRY:
+ module = importlib.import_module(f"{_DIALECT_REGISTRY[name]}.parser")
+ globals()[name] = module
+ return module
+ raise AttributeError(f"module 'tvm.script.parser' has no attribute
{name!r}")
diff --git a/python/tvm/script/relax.py b/python/tvm/script/relax.py
deleted file mode 100644
index 5afeb12449..0000000000
--- a/python/tvm/script/relax.py
+++ /dev/null
@@ -1,20 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# ruff: noqa: F403
-"""TVM Script APIs of TVM Python Package for Relax"""
-
-from .parser.relax import * # pylint:
disable=redefined-builtin,unused-wildcard-import,wildcard-import
diff --git a/python/tvm/script/tirx.py b/python/tvm/script/tirx.py
deleted file mode 100644
index 21a2c6f427..0000000000
--- a/python/tvm/script/tirx.py
+++ /dev/null
@@ -1,20 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# ruff: noqa: F403
-"""TVM Script APIs of TVM Python Package for TIR"""
-
-from .parser.tirx import * # pylint:
disable=redefined-builtin,unused-wildcard-import,wildcard-import
diff --git a/python/tvm/tirx/__init__.py b/python/tvm/tirx/__init__.py
index caa9494980..4d727a812a 100644
--- a/python/tvm/tirx/__init__.py
+++ b/python/tvm/tirx/__init__.py
@@ -121,3 +121,7 @@ from . import stmt_functor
from .build import build
from .pipeline import get_tir_pipeline, get_default_tir_pipeline
from .functor import PyStmtExprVisitor, PyStmtExprMutator
+
+import tvm.script
+
+tvm.script.register_dialect("tirx", "tvm.tirx.script")
diff --git a/python/tvm/script/ir_builder/relax/distributed/_ffi_api.py
b/python/tvm/tirx/script/__init__.py
similarity index 62%
rename from python/tvm/script/ir_builder/relax/distributed/_ffi_api.py
rename to python/tvm/tirx/script/__init__.py
index 2ebc4935b8..25bfe3148d 100644
--- a/python/tvm/script/ir_builder/relax/distributed/_ffi_api.py
+++ b/python/tvm/tirx/script/__init__.py
@@ -14,8 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""FFI APIs for tvm.script.ir_builder.relax.distributed"""
+"""TIRX-layer TVMScript pieces (parser, builder).
-import tvm_ffi
+After the per-dialect TVMScript restructure, the TIRX layer owns its own
+``script/{parser,builder}`` subpackages. ``tvm.script.tirx`` resolves to
+this module via the dialect registry, so the public parser surface
+(``prim_func``, ``Buffer``, ``Ptr``, etc.) is re-exported here.
+"""
-tvm_ffi.init_ffi_api("script.ir_builder.relax.distributed", __name__) #
pylint: disable=protected-access
+# pylint: disable=redefined-builtin,wildcard-import,unused-wildcard-import
+from .parser import *
+from .parser import Buffer, Ptr, macro, prim_func
diff --git a/python/tvm/script/ir_builder/tirx/__init__.py
b/python/tvm/tirx/script/builder/__init__.py
similarity index 100%
rename from python/tvm/script/ir_builder/tirx/__init__.py
rename to python/tvm/tirx/script/builder/__init__.py
diff --git a/python/tvm/script/ir_builder/tirx/_ffi_api.py
b/python/tvm/tirx/script/builder/_ffi_api.py
similarity index 100%
rename from python/tvm/script/ir_builder/tirx/_ffi_api.py
rename to python/tvm/tirx/script/builder/_ffi_api.py
diff --git a/python/tvm/script/ir_builder/tirx/external_kernel.py
b/python/tvm/tirx/script/builder/external_kernel.py
similarity index 98%
rename from python/tvm/script/ir_builder/tirx/external_kernel.py
rename to python/tvm/tirx/script/builder/external_kernel.py
index 1785454a0f..e76854b936 100644
--- a/python/tvm/script/ir_builder/tirx/external_kernel.py
+++ b/python/tvm/tirx/script/builder/external_kernel.py
@@ -206,10 +206,11 @@ def call_kernel(
kwargs : Dict[str, Any]
Additional keyword arguments to pass to the kernel or compilation.
"""
- from ..ir import ( # pylint: disable=import-outside-toplevel
+ from tvm.script.ir_builder.ir import ( # pylint:
disable=import-outside-toplevel
module_get_attr,
module_set_attr,
)
+
from .ir import call_packed # pylint: disable=import-outside-toplevel
kernel_type = f"{type(kernel).__module__}.{type(kernel).__qualname__}"
diff --git a/python/tvm/script/ir_builder/tirx/frame.py
b/python/tvm/tirx/script/builder/frame.py
similarity index 97%
rename from python/tvm/script/ir_builder/tirx/frame.py
rename to python/tvm/tirx/script/builder/frame.py
index aeced570ba..8d0feeb4c5 100644
--- a/python/tvm/script/ir_builder/tirx/frame.py
+++ b/python/tvm/tirx/script/builder/frame.py
@@ -18,10 +18,9 @@
from tvm_ffi import register_object as _register_object
+from tvm.script.ir_builder.base import IRBuilderFrame
from tvm.tirx import Var
-from ..base import IRBuilderFrame
-
@_register_object("script.ir_builder.tirx.TIRFrame")
class TIRFrame(IRBuilderFrame): ...
diff --git a/python/tvm/script/ir_builder/tirx/ir.py
b/python/tvm/tirx/script/builder/ir.py
similarity index 100%
rename from python/tvm/script/ir_builder/tirx/ir.py
rename to python/tvm/tirx/script/builder/ir.py
diff --git a/python/tvm/script/ir_builder/tirx/triton.py
b/python/tvm/tirx/script/builder/triton.py
similarity index 100%
rename from python/tvm/script/ir_builder/tirx/triton.py
rename to python/tvm/tirx/script/builder/triton.py
diff --git a/python/tvm/script/ir_builder/tirx/utils.py
b/python/tvm/tirx/script/builder/utils.py
similarity index 100%
rename from python/tvm/script/ir_builder/tirx/utils.py
rename to python/tvm/tirx/script/builder/utils.py
diff --git a/python/tvm/script/parser/tirx/__init__.py
b/python/tvm/tirx/script/parser/__init__.py
similarity index 91%
rename from python/tvm/script/parser/tirx/__init__.py
rename to python/tvm/tirx/script/parser/__init__.py
index 929cf96635..bfae9d06eb 100644
--- a/python/tvm/script/parser/tirx/__init__.py
+++ b/python/tvm/tirx/script/parser/__init__.py
@@ -19,8 +19,9 @@
from typing import TYPE_CHECKING
-from ...ir_builder.tirx import * # pylint: disable=redefined-builtin
-from ...ir_builder.tirx import ir as _tir
+from tvm.tirx.script.builder import * # pylint: disable=redefined-builtin
+from tvm.tirx.script.builder import ir as _tir
+
from . import operation as _operation
from . import parser as _parser
from .entry import Buffer, Ptr
diff --git a/python/tvm/script/parser/tirx/entry.py
b/python/tvm/tirx/script/parser/entry.py
similarity index 97%
rename from python/tvm/script/parser/tirx/entry.py
rename to python/tvm/tirx/script/parser/entry.py
index fdac3b0db4..4764a10243 100644
--- a/python/tvm/script/parser/tirx/entry.py
+++ b/python/tvm/tirx/script/parser/entry.py
@@ -20,11 +20,10 @@ import inspect
from collections.abc import Callable
from tvm.ir.base import deprecated
+from tvm.script.parser._core import parse, scan_macro, utils
+from tvm.script.parser.core.parser import Parser, ScriptMacro
from tvm.tirx import Buffer, PrimFunc
-
-from ...ir_builder.tirx import block_name_suffix_context, buffer, ptr
-from .._core import parse, scan_macro, utils
-from ..core.parser import Parser, ScriptMacro
+from tvm.tirx.script.builder import block_name_suffix_context, buffer, ptr
def prim_func(
diff --git a/python/tvm/script/parser/tirx/operation.py
b/python/tvm/tirx/script/parser/operation.py
similarity index 98%
rename from python/tvm/script/parser/tirx/operation.py
rename to python/tvm/tirx/script/parser/operation.py
index fd528ba353..dac8f06ebf 100644
--- a/python/tvm/script/parser/tirx/operation.py
+++ b/python/tvm/tirx/script/parser/operation.py
@@ -18,11 +18,10 @@
from tvm import tirx
from tvm.runtime import DataType, DataTypeCode
+from tvm.script.parser._core import OpMethod, doc, register_op
from tvm.tirx import IntImm
from tvm.tirx.expr import FloatImm
-from .._core import OpMethod, doc, register_op
-
def _register_expr_op(ty: type): # pylint: disable=invalid-name
ty._dispatch_type = ty # pylint: disable=protected-access
diff --git a/python/tvm/script/parser/tirx/parser.py
b/python/tvm/tirx/script/parser/parser.py
similarity index 98%
rename from python/tvm/script/parser/tirx/parser.py
rename to python/tvm/tirx/script/parser/parser.py
index 825eda154d..3fc06e8e1a 100644
--- a/python/tvm/script/parser/tirx/parser.py
+++ b/python/tvm/tirx/script/parser/parser.py
@@ -24,13 +24,12 @@ import tvm_ffi
import tvm
from tvm.ir import GlobalVar, PrimType
+from tvm.script.ir_builder import ir as I
+from tvm.script.ir_builder.base import IRBuilder
+from tvm.script.ir_builder.base import IRBuilderFrame as Frame
+from tvm.script.parser._core import Parser, dispatch, doc
from tvm.tirx import Buffer, IterVar, PrimExpr, Var
-
-from ...ir_builder import ir as I
-from ...ir_builder import tirx as T
-from ...ir_builder.base import IRBuilder
-from ...ir_builder.base import IRBuilderFrame as Frame
-from .._core import Parser, dispatch, doc
+from tvm.tirx.script import builder as T
def bind_with_value(self: Parser, node: doc.expr, var_name: str, value: Any)
-> Any:
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index a45c927bcc..f3f55878f8 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -228,7 +228,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
return ss.str();
});
// Note: kRepr for GlobalVarNode is registered in script/printer/ir/ir.cc
- // via TVM_SCRIPT_REPR(GlobalVarNode, ReprPrintIR).
+ // via TVM_REGISTER_SCRIPT_AS_REPR(GlobalVarNode, ReprPrintIR).
}
} // namespace tvm
diff --git a/src/ir/script_printer.cc b/src/ir/script_printer.cc
index 8297a4fe77..ea0c3d031e 100644
--- a/src/ir/script_printer.cc
+++ b/src/ir/script_printer.cc
@@ -22,7 +22,7 @@
#include <tvm/ir/cast.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/repr.h>
-#include <tvm/ir/script_printer.h>
+#include <tvm/script/printer/config.h>
#include <algorithm>
@@ -71,18 +71,9 @@ PrinterConfig::PrinterConfig(ffi::Map<ffi::String, Any>
config_dict) {
if (auto v = config_dict.Get("ir_prefix")) {
n->ir_prefix = Downcast<ffi::String>(v.value());
}
- if (auto v = config_dict.Get("tir_prefix")) {
- n->tir_prefix = Downcast<ffi::String>(v.value());
- }
- if (auto v = config_dict.Get("relax_prefix")) {
- n->relax_prefix = Downcast<ffi::String>(v.value());
- }
if (auto v = config_dict.Get("module_alias")) {
n->module_alias = Downcast<ffi::String>(v.value());
}
- if (auto v = config_dict.Get("buffer_dtype")) {
- n->buffer_dtype =
DataType(ffi::StringToDLDataType(Downcast<ffi::String>(v.value())));
- }
if (auto v = config_dict.Get("int_dtype")) {
n->int_dtype =
DataType(ffi::StringToDLDataType(Downcast<ffi::String>(v.value())));
}
@@ -123,22 +114,47 @@ PrinterConfig::PrinterConfig(ffi::Map<ffi::String, Any>
config_dict) {
if (auto v = config_dict.Get("show_object_address")) {
n->show_object_address = v.value().cast<bool>();
}
- if (auto v = config_dict.Get("show_all_struct_info")) {
- n->show_all_struct_info = v.value().cast<bool>();
+ // Dialect-specific keys are stored in extra_config with dotted-name keys.
+ // String-typed dialect keys passed through directly.
+ for (const char* key : {"tirx.prefix", "relax.prefix"}) {
+ if (auto v = config_dict.Get(key)) {
+ n->extra_config.Set(ffi::String(key), v.value());
+ }
+ }
+ // "tirx.buffer_dtype" is passed as a DLDataType string from Python; convert
to DataType.
+ if (auto v = config_dict.Get("tirx.buffer_dtype")) {
+ DataType dt(ffi::StringToDLDataType(Downcast<ffi::String>(v.value())));
+ n->extra_config.Set(ffi::String("tirx.buffer_dtype"), ffi::Any(dt));
+ }
+ // Boolean dialect keys.
+ if (auto v = config_dict.Get("relax.show_all_struct_info")) {
+ n->extra_config.Set(ffi::String("relax.show_all_struct_info"), v.value());
+ }
+ if (auto v = config_dict.Get("extra_config")) {
+ auto extra = Downcast<ffi::Map<ffi::String, ffi::Any>>(v.value());
+ for (auto kv : extra) {
+ n->extra_config.Set(kv.first, kv.second);
+ }
}
// Checking prefixes if they are valid Python identifiers.
- TVM_FFI_ICHECK(IsIdentifier(n->ir_prefix)) << "Invalid `ir_prefix`: " <<
n->ir_prefix;
- TVM_FFI_ICHECK(IsIdentifier(n->tir_prefix)) << "Invalid `tir_prefix`: " <<
n->tir_prefix;
- TVM_FFI_ICHECK(IsIdentifier(n->relax_prefix)) << "Invalid `relax_prefix`: "
<< n->relax_prefix;
- TVM_FFI_ICHECK(n->module_alias.empty() || IsIdentifier(n->module_alias))
+ TVM_FFI_ICHECK(IsIdentifier(std::string(n->ir_prefix)))
+ << "Invalid `ir_prefix`: " << n->ir_prefix;
+ ffi::String tir_prefix = n->GetExtraConfig<ffi::String>("tirx.prefix", "T");
+ ffi::String relax_prefix = n->GetExtraConfig<ffi::String>("relax.prefix",
"R");
+ TVM_FFI_ICHECK(IsIdentifier(std::string(tir_prefix))) << "Invalid
`tirx.prefix`: " << tir_prefix;
+ TVM_FFI_ICHECK(IsIdentifier(std::string(relax_prefix)))
+ << "Invalid `relax.prefix`: " << relax_prefix;
+ TVM_FFI_ICHECK(n->module_alias.empty() ||
IsIdentifier(std::string(n->module_alias)))
<< "Invalid `module_alias`: " << n->module_alias;
this->data_ = std::move(n);
}
ffi::Array<ffi::String> PrinterConfigNode::GetBuiltinKeywords() {
- ffi::Array<ffi::String> result{this->ir_prefix, this->tir_prefix,
this->relax_prefix};
+ ffi::String tir_prefix = GetExtraConfig<ffi::String>("tirx.prefix", "T");
+ ffi::String relax_prefix = GetExtraConfig<ffi::String>("relax.prefix", "R");
+ ffi::Array<ffi::String> result{this->ir_prefix, tir_prefix, relax_prefix};
if (!this->module_alias.empty()) {
result.push_back(this->module_alias);
}
diff --git a/src/ir/structural_equal.cc b/src/ir/structural_equal.cc
index 1ce037f504..4dcf2a32a6 100644
--- a/src/ir/structural_equal.cc
+++ b/src/ir/structural_equal.cc
@@ -26,7 +26,7 @@
#include <tvm/ir/module.h>
#include <tvm/ir/node_functor.h>
#include <tvm/ir/repr.h>
-#include <tvm/ir/script_printer.h>
+#include <tvm/script/printer/config.h>
#include <optional>
#include <unordered_map>
diff --git a/src/script/ir_builder/relax/distributed.cc
b/src/relax/script/builder/distributed.cc
similarity index 98%
rename from src/script/ir_builder/relax/distributed.cc
rename to src/relax/script/builder/distributed.cc
index be2bf1a22a..0561e582dd 100644
--- a/src/script/ir_builder/relax/distributed.cc
+++ b/src/relax/script/builder/distributed.cc
@@ -20,8 +20,8 @@
#include <tvm/relax/analysis.h>
#include <tvm/relax/attrs/op.h>
#include <tvm/relax/distributed/struct_info.h>
+#include <tvm/relax/script/builder/ir.h>
#include <tvm/relax/struct_info.h>
-#include <tvm/script/ir_builder/relax/ir.h>
#include <tvm/tirx/op.h>
#include "./utils.h"
diff --git a/src/script/ir_builder/relax/frame.cc
b/src/relax/script/builder/frame.cc
similarity index 99%
rename from src/script/ir_builder/relax/frame.cc
rename to src/relax/script/builder/frame.cc
index 1550730e1f..14b658085a 100644
--- a/src/script/ir_builder/relax/frame.cc
+++ b/src/relax/script/builder/frame.cc
@@ -21,8 +21,8 @@
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
-#include <tvm/script/ir_builder/relax/frame.h>
-#include <tvm/script/ir_builder/relax/ir.h>
+#include <tvm/relax/script/builder/frame.h>
+#include <tvm/relax/script/builder/ir.h>
#include "./utils.h"
diff --git a/src/script/ir_builder/relax/ir.cc b/src/relax/script/builder/ir.cc
similarity index 99%
rename from src/script/ir_builder/relax/ir.cc
rename to src/relax/script/builder/ir.cc
index 7c517717ea..46bc4ecfeb 100644
--- a/src/script/ir_builder/relax/ir.cc
+++ b/src/relax/script/builder/ir.cc
@@ -18,8 +18,8 @@
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/relax/analysis.h>
+#include <tvm/relax/script/builder/ir.h>
#include <tvm/relax/struct_info.h>
-#include <tvm/script/ir_builder/relax/ir.h>
#include <tvm/tirx/op.h>
#include "./utils.h"
diff --git a/src/script/ir_builder/relax/utils.h
b/src/relax/script/builder/utils.h
similarity index 96%
rename from src/script/ir_builder/relax/utils.h
rename to src/relax/script/builder/utils.h
index a2204ef54e..14e762064f 100644
--- a/src/script/ir_builder/relax/utils.h
+++ b/src/relax/script/builder/utils.h
@@ -16,13 +16,13 @@
* specific language governing permissions and limitations
* under the License.
*/
-#ifndef TVM_SCRIPT_IR_BUILDER_RELAX_UTILS_H_
-#define TVM_SCRIPT_IR_BUILDER_RELAX_UTILS_H_
+#ifndef TVM_RELAX_SCRIPT_BUILDER_UTILS_H_
+#define TVM_RELAX_SCRIPT_BUILDER_UTILS_H_
+#include <tvm/relax/script/builder/frame.h>
+#include <tvm/relax/script/builder/ir.h>
#include <tvm/relax/struct_info_functor.h>
#include <tvm/relax/utils.h>
-#include <tvm/script/ir_builder/relax/frame.h>
-#include <tvm/script/ir_builder/relax/ir.h>
#include <string>
@@ -138,4 +138,4 @@ inline tvm::relax::SeqExpr GetSeqExprForBranch(const
SeqExprFrame& frame, ffi::S
} // namespace script
} // namespace tvm
-#endif // TVM_SCRIPT_IR_BUILDER_RELAX_UTILS_H_
+#endif // TVM_RELAX_SCRIPT_BUILDER_UTILS_H_
diff --git a/src/script/printer/relax/binding.cc
b/src/relax/script/printer/binding.cc
similarity index 93%
rename from src/script/printer/relax/binding.cc
rename to src/relax/script/printer/binding.cc
index bde2038d91..ec158a0b67 100644
--- a/src/script/printer/relax/binding.cc
+++ b/src/relax/script/printer/binding.cc
@@ -47,7 +47,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
using relax::StructInfo;
using relax::MatchStructInfo;
ffi::Optional<ExprDoc> ann = std::nullopt;
- if (d->cfg->show_all_struct_info) {
+ if (d->cfg->GetExtraConfig<bool>("relax.show_all_struct_info",
true)) {
ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value);
}
ExprDoc rhs = Relax(d, "match_cast")
@@ -88,9 +88,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
return PrintIfExpr(n, n_p, d, std::nullopt, std::nullopt);
});
-TVM_SCRIPT_REPR(relax::MatchCastNode, ReprPrintRelax);
-TVM_SCRIPT_REPR(relax::VarBindingNode, ReprPrintRelax);
-TVM_SCRIPT_REPR(relax::IfNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::MatchCastNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::VarBindingNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::IfNode, ReprPrintRelax);
} // namespace printer
} // namespace script
diff --git a/src/script/printer/relax/call.cc b/src/relax/script/printer/call.cc
similarity index 99%
rename from src/script/printer/relax/call.cc
rename to src/relax/script/printer/call.cc
index 9abe464d24..262be66e92 100644
--- a/src/script/printer/relax/call.cc
+++ b/src/relax/script/printer/call.cc
@@ -333,7 +333,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
return prefix->Call(args, kwargs_keys, kwargs_values);
});
-TVM_SCRIPT_REPR(relax::CallNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::CallNode, ReprPrintRelax);
} // namespace printer
} // namespace script
diff --git a/src/script/printer/relax/distributed.cc
b/src/relax/script/printer/distributed.cc
similarity index 95%
rename from src/script/printer/relax/distributed.cc
rename to src/relax/script/printer/distributed.cc
index bdc15b6ee6..0a67b55af8 100644
--- a/src/script/printer/relax/distributed.cc
+++ b/src/relax/script/printer/distributed.cc
@@ -20,7 +20,7 @@
#include <tvm/ir/expr.h>
#include <tvm/relax/distributed/struct_info.h>
-#include "../ir/utils.h"
+#include "../../../script/printer/ir/utils.h"
#include "./utils.h"
namespace tvm {
@@ -126,9 +126,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
}
});
-TVM_SCRIPT_REPR(relax::distributed::DeviceMeshNode, ReprPrintRelax);
-TVM_SCRIPT_REPR(relax::distributed::PlacementNode, ReprPrintRelax);
-TVM_SCRIPT_REPR(relax::distributed::DTensorStructInfoNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::distributed::DeviceMeshNode,
ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::distributed::PlacementNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::distributed::DTensorStructInfoNode,
ReprPrintRelax);
} // namespace printer
} // namespace script
} // namespace tvm
diff --git a/src/script/printer/relax/expr.cc b/src/relax/script/printer/expr.cc
similarity index 91%
rename from src/script/printer/relax/expr.cc
rename to src/relax/script/printer/expr.cc
index 99d9618639..c8a813b8d5 100644
--- a/src/script/printer/relax/expr.cc
+++ b/src/relax/script/printer/expr.cc
@@ -163,15 +163,15 @@ Doc PrintRelaxVar(relax::Var n, AccessPath p, IRDocsifier
d) {
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch<relax::Var>("",
PrintRelaxVar);
TVM_STATIC_IR_FUNCTOR(IRDocsifier,
vtable).set_dispatch<relax::DataflowVar>("", PrintRelaxVar);
-TVM_SCRIPT_REPR(relax::PrimValueNode, ReprPrintRelax);
-TVM_SCRIPT_REPR(relax::StringImmNode, ReprPrintRelax);
-TVM_SCRIPT_REPR(relax::DataTypeImmNode, ReprPrintRelax);
-TVM_SCRIPT_REPR(relax::TupleNode, ReprPrintRelax);
-TVM_SCRIPT_REPR(relax::TupleGetItemNode, ReprPrintRelax);
-TVM_SCRIPT_REPR(relax::ShapeExprNode, ReprPrintRelax);
-TVM_SCRIPT_REPR(relax::VarNode, ReprPrintRelax);
-TVM_SCRIPT_REPR(relax::DataflowVarNode, ReprPrintRelax);
-TVM_SCRIPT_REPR(relax::ConstantNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::PrimValueNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::StringImmNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::DataTypeImmNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::TupleNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::TupleGetItemNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::ShapeExprNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::VarNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::DataflowVarNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::ConstantNode, ReprPrintRelax);
} // namespace printer
} // namespace script
diff --git a/src/script/printer/relax/function.cc
b/src/relax/script/printer/function.cc
similarity index 97%
rename from src/script/printer/relax/function.cc
rename to src/relax/script/printer/function.cc
index c759fa80ae..e30a2b0bf4 100644
--- a/src/script/printer/relax/function.cc
+++ b/src/relax/script/printer/function.cc
@@ -145,8 +145,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
return Relax(d, "ExternFunc")->Call(args);
});
-TVM_SCRIPT_REPR(relax::FunctionNode, ReprPrintRelax);
-TVM_SCRIPT_REPR(relax::ExternFuncNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::FunctionNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::ExternFuncNode, ReprPrintRelax);
} // namespace printer
} // namespace script
diff --git a/src/script/printer/relax/region.cc
b/src/relax/script/printer/region.cc
similarity index 95%
rename from src/script/printer/relax/region.cc
rename to src/relax/script/printer/region.cc
index 83e527ccc1..f5b50e66c9 100644
--- a/src/script/printer/relax/region.cc
+++ b/src/relax/script/printer/region.cc
@@ -94,9 +94,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
return ScopeDoc(std::nullopt, Relax(d, "dataflow")->Call({}), stmts);
});
-TVM_SCRIPT_REPR(relax::SeqExprNode, ReprPrintRelax);
-TVM_SCRIPT_REPR(relax::BindingBlockNode, ReprPrintRelax);
-TVM_SCRIPT_REPR(relax::DataflowBlockNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::SeqExprNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::BindingBlockNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::DataflowBlockNode, ReprPrintRelax);
} // namespace printer
} // namespace script
diff --git a/src/script/printer/relax/struct_info.cc
b/src/relax/script/printer/struct_info.cc
similarity index 94%
rename from src/script/printer/relax/struct_info.cc
rename to src/relax/script/printer/struct_info.cc
index a480c5e839..1019cfa7e9 100644
--- a/src/script/printer/relax/struct_info.cc
+++ b/src/relax/script/printer/struct_info.cc
@@ -186,12 +186,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
return Relax(d, "Callable")->Call({TupleDoc(params_doc), ret_doc,
purity_doc});
});
-TVM_SCRIPT_REPR(relax::ObjectStructInfoNode, ReprPrintRelax);
-TVM_SCRIPT_REPR(relax::PrimStructInfoNode, ReprPrintRelax);
-TVM_SCRIPT_REPR(relax::ShapeStructInfoNode, ReprPrintRelax);
-TVM_SCRIPT_REPR(relax::TensorStructInfoNode, ReprPrintRelax);
-TVM_SCRIPT_REPR(relax::TupleStructInfoNode, ReprPrintRelax);
-TVM_SCRIPT_REPR(relax::FuncStructInfoNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::ObjectStructInfoNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::PrimStructInfoNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::ShapeStructInfoNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::TensorStructInfoNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::TupleStructInfoNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::FuncStructInfoNode, ReprPrintRelax);
} // namespace printer
} // namespace script
diff --git a/src/script/printer/relax/tir.cc b/src/relax/script/printer/tir.cc
similarity index 99%
rename from src/script/printer/relax/tir.cc
rename to src/relax/script/printer/tir.cc
index 2c1bf77156..e0742f8edd 100644
--- a/src/script/printer/relax/tir.cc
+++ b/src/relax/script/printer/tir.cc
@@ -19,7 +19,7 @@
#include <tvm/ffi/cast.h>
#include <tvm/ir/expr.h>
-#include "../tirx/utils.h"
+#include "../../../tirx/script/printer/utils.h"
#include "./utils.h"
namespace tvm {
diff --git a/src/script/printer/relax/type.cc b/src/relax/script/printer/type.cc
similarity index 92%
rename from src/script/printer/relax/type.cc
rename to src/relax/script/printer/type.cc
index 0322052443..f5cbfcb166 100644
--- a/src/script/printer/relax/type.cc
+++ b/src/relax/script/printer/type.cc
@@ -80,10 +80,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
d->AsDoc<ExprDoc>(n->ret_type, n_p->Attr("ret_type"))});
});
-TVM_SCRIPT_REPR(relax::ShapeTypeNode, ReprPrintRelax);
-TVM_SCRIPT_REPR(relax::ObjectTypeNode, ReprPrintRelax);
-TVM_SCRIPT_REPR(relax::TensorTypeNode, ReprPrintRelax);
-TVM_SCRIPT_REPR(relax::PackedFuncTypeNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::ShapeTypeNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::ObjectTypeNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::TensorTypeNode, ReprPrintRelax);
+TVM_REGISTER_SCRIPT_AS_REPR(relax::PackedFuncTypeNode, ReprPrintRelax);
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("script.printer.ReprPrintRelax", ReprPrintRelax);
diff --git a/src/script/printer/relax/utils.h b/src/relax/script/printer/utils.h
similarity index 95%
rename from src/script/printer/relax/utils.h
rename to src/relax/script/printer/utils.h
index 7901690963..607728cb5b 100644
--- a/src/script/printer/relax/utils.h
+++ b/src/relax/script/printer/utils.h
@@ -16,8 +16,8 @@
* specific language governing permissions and limitations
* under the License.
*/
-#ifndef TVM_SCRIPT_PRINTER_RELAX_UTILS_H_
-#define TVM_SCRIPT_PRINTER_RELAX_UTILS_H_
+#ifndef TVM_RELAX_SCRIPT_PRINTER_UTILS_H_
+#define TVM_RELAX_SCRIPT_PRINTER_UTILS_H_
#include <tvm/ffi/reflection/registry.h>
#include <tvm/relax/analysis.h>
@@ -32,7 +32,7 @@
#include <utility>
#include <vector>
-#include "../utils.h"
+#include "../../../script/printer/utils.h"
namespace tvm {
namespace script {
@@ -85,7 +85,8 @@ inline ffi::Optional<ExprDoc> StructInfoAsAnn(const
relax::Var& v, const AccessP
if (!v->struct_info_.defined()) {
return std::nullopt;
}
- bool attempt_to_hide_struct_info = !d->cfg->show_all_struct_info;
+ bool attempt_to_hide_struct_info =
+ !d->cfg->GetExtraConfig<bool>("relax.show_all_struct_info", true);
if (const auto* call = rhs.as<relax::CallNode>()) {
static const Op& call_tir_op = Op::Get("relax.call_tir");
@@ -156,4 +157,4 @@ inline int FindVDeviceIndexByTargetKind(const VDevice&
vdevice, const IRDocsifie
} // namespace script
} // namespace tvm
-#endif // TVM_SCRIPT_PRINTER_RELAX_UTILS_H_
+#endif // TVM_RELAX_SCRIPT_PRINTER_UTILS_H_
diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc
index e357b9783f..153c42e9d1 100644
--- a/src/script/ir_builder/ir/ir.cc
+++ b/src/script/ir_builder/ir/ir.cc
@@ -19,10 +19,7 @@
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/module.h>
-#include <tvm/relax/analysis.h>
#include <tvm/script/ir_builder/ir/ir.h>
-#include <tvm/tirx/function.h>
-#include <tvm/tirx/op.h>
#include "./utils.h"
@@ -38,15 +35,24 @@ IRModuleFrame IRModule() {
return IRModuleFrame(n);
}
-inline relax::StructInfo GetGlobalVarStructInfo(const BaseFunc& func) {
+// DeclFunction lives at the IR layer because an IRModule may host
+// heterogeneous function kinds (e.g. relax::Function, tirx::PrimFunc).
+// To derive the GlobalVar's struct_info_ without coupling the IR layer to
+// any specific dialect, dispatch is keyed by the function's type-key:
+// each dialect registers its own handler that maps a function of that
+// type to the appropriate struct_info.
+inline ffi::Optional<ffi::ObjectRef> GetGlobalVarStructInfo(const BaseFunc&
func) {
if (func->struct_info_.defined()) {
- return tvm::relax::GetStructInfo(func);
- } else if (const auto* prim_func = func.as<tvm::tirx::PrimFuncNode>()) {
- return tvm::relax::FuncStructInfo::OpaqueFunc(
- tvm::relax::StructInfoFromType(prim_func->ret_type));
- } else {
- TVM_FFI_THROW(InternalError) << "Unsupported function type: " <<
func->GetTypeKey();
+ return func->struct_info_;
+ }
+ // Registry: "script.ir_builder.decl_function.<type-key>" — per-function-kind
+ // handler that derives the GlobalVar struct_info from the function
signature.
+ // Grep hint: grep -rn 'script.ir_builder.decl_function.' src/
+ const std::string key = "script.ir_builder.decl_function." +
func->GetTypeKey();
+ if (auto fn = tvm::ffi::Function::GetGlobal(key)) {
+ return (*fn)(func).cast<ffi::Optional<ffi::ObjectRef>>();
}
+ return std::nullopt;
}
GlobalVar DeclFunction(const ffi::String& func_name, const BaseFunc&
func_signature) {
@@ -54,18 +60,12 @@ GlobalVar DeclFunction(const ffi::String& func_name, const
BaseFunc& func_signat
TVM_FFI_CHECK(!frame->global_var_map.count(func_name), ValueError)
<< "function " << func_name << " already exists";
- auto gvar_type = [&]() -> Type {
- if (auto prim_func = func_signature.as<tirx::PrimFuncNode>()) {
- ffi::Array<Type> arg_types =
- prim_func->params.Map([](const auto& var) { return GetType(var); });
- return FuncType(arg_types, prim_func->ret_type);
- }
-
- return {};
- }();
-
GlobalVar gv = GlobalVar(func_name);
- gv->struct_info_ = GetGlobalVarStructInfo(func_signature);
+ if (auto sinfo = GetGlobalVarStructInfo(func_signature)) {
+ gv->struct_info_ = sinfo.value();
+ } else {
+ TVM_FFI_THROW(InternalError) << "Unsupported function type: " <<
func_signature->GetTypeKey();
+ }
TVM_FFI_CHECK(frame->functions.find(gv) == frame->functions.end(),
ValueError)
<< "function " << func_name << " has already been defined.";
frame->global_var_map.Set(func_name, gv);
@@ -80,7 +80,11 @@ void DefFunction(const ffi::String& func_name, const
BaseFunc& func) {
<< "function " << func_name << " does not exist, please declare it
first.";
const GlobalVar& gv = (*it).second;
frame->functions.Set(gv, func);
- gv->struct_info_ = GetGlobalVarStructInfo(func);
+ if (auto sinfo = GetGlobalVarStructInfo(func)) {
+ gv->struct_info_ = sinfo.value();
+ } else {
+ TVM_FFI_THROW(InternalError) << "Unsupported function type: " <<
func->GetTypeKey();
+ }
}
void ModuleAttrs(ffi::Map<ffi::String, Any> attrs, bool allow_overwrite) {
diff --git a/src/script/printer/ir/distributed.cc b/src/script/printer/config.cc
similarity index 58%
copy from src/script/printer/ir/distributed.cc
copy to src/script/printer/config.cc
index 62d4c3ad61..d68aaff2ce 100644
--- a/src/script/printer/ir/distributed.cc
+++ b/src/script/printer/config.cc
@@ -16,26 +16,23 @@
* specific language governing permissions and limitations
* under the License.
*/
-#include <tvm/ir/expr.h>
-#include <tvm/relax/distributed/global_info.h>
+#include <tvm/runtime/logging.h>
+#include <tvm/script/printer/config.h>
+
+#include <sstream>
-#include "../relax/utils.h"
-#include "./utils.h"
namespace tvm {
-namespace script {
-namespace printer {
-TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
- .set_dispatch<ffi::Shape>("", [](ffi::Shape n, AccessPath n_p, IRDocsifier
d) -> Doc {
- int s = n.size();
- ffi::Array<ExprDoc> results;
- results.reserve(s);
- for (int i = 0; i < s; ++i) {
- results.push_back(d->AsDoc<ExprDoc>(Integer(n[i]), n_p->ArrayItem(i)));
- }
- return TupleDoc(results);
- });
+std::string RedirectedReprPrinterMethod(const ffi::ObjectRef& obj) {
+ try {
+ return TVMScriptPrinter::Script(obj, std::nullopt);
+ } catch (const tvm::ffi::Error& e) {
+ LOG(WARNING) << "TVMScript printer falls back to the basic address printer
with the error:\n"
+ << e.what();
+ std::ostringstream os;
+ os << obj->GetTypeKey() << '(' << obj.get() << ')';
+ return os.str();
+ }
+}
-} // namespace printer
-} // namespace script
} // namespace tvm
diff --git a/src/script/printer/ir/distributed.cc
b/src/script/printer/ir/distributed.cc
index 62d4c3ad61..5abc316154 100644
--- a/src/script/printer/ir/distributed.cc
+++ b/src/script/printer/ir/distributed.cc
@@ -17,9 +17,7 @@
* under the License.
*/
#include <tvm/ir/expr.h>
-#include <tvm/relax/distributed/global_info.h>
-#include "../relax/utils.h"
#include "./utils.h"
namespace tvm {
namespace script {
diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc
index f48fca168f..a9b998d03e 100644
--- a/src/script/printer/ir/ir.cc
+++ b/src/script/printer/ir/ir.cc
@@ -164,11 +164,11 @@ std::string ReprPrintIRModule(const ffi::ObjectRef& mod,
const PrinterConfig& cf
return ReprPrintIR(mod, cfg);
}
-TVM_SCRIPT_REPR(GlobalVarNode, ReprPrintIR);
-TVM_SCRIPT_REPR(DictAttrsNode, ReprPrintIR);
-TVM_SCRIPT_REPR(FuncTypeNode, ReprPrintIR);
-TVM_SCRIPT_REPR(RangeNode, ReprPrintIR);
-TVM_SCRIPT_REPR(IRModuleNode, ReprPrintIRModule);
+TVM_REGISTER_SCRIPT_AS_REPR(GlobalVarNode, ReprPrintIR);
+TVM_REGISTER_SCRIPT_AS_REPR(DictAttrsNode, ReprPrintIR);
+TVM_REGISTER_SCRIPT_AS_REPR(FuncTypeNode, ReprPrintIR);
+TVM_REGISTER_SCRIPT_AS_REPR(RangeNode, ReprPrintIR);
+TVM_REGISTER_SCRIPT_AS_REPR(IRModuleNode, ReprPrintIRModule);
} // namespace printer
} // namespace script
diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h
index 5840c49e8e..84d1854b75 100644
--- a/src/script/printer/utils.h
+++ b/src/script/printer/utils.h
@@ -25,6 +25,7 @@
#include <tvm/ffi/extra/serialization.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/base.h>
+#include <tvm/script/printer/config.h>
#include <tvm/script/printer/ir_docsifier.h>
#include <sstream>
@@ -39,16 +40,11 @@ namespace tvm {
namespace script {
namespace printer {
-#define TVM_SCRIPT_REPR(ObjectType, Method)
\
- TVM_FFI_STATIC_INIT_BLOCK() {
\
- namespace refl = tvm::ffi::reflection;
\
- refl::TypeAttrDef<ObjectType>().def(refl::type_attr::kRepr,
\
- [](ffi::ObjectRef obj, ffi::Function)
-> ffi::String { \
- return
RedirectedReprPrinterMethod(obj); \
- });
\
- }
\
- TVM_STATIC_IR_FUNCTOR(TVMScriptPrinter,
vtable).set_dispatch<ObjectType>(Method)
-
+// Note: the `TVM_SCRIPT_REPR` macro is intentionally duplicated in each
+// dialect-local `src/<dialect>/script/printer/utils.h`. Keeping a single
+// definition here would force the dialect headers to depend on this shared
+// header, which the per-dialect restructure aims to avoid for cross-directory
+// references. See each `<dialect>/script/printer/utils.h` for the macro.
inline std::string RedirectedReprPrinterMethod(const ffi::ObjectRef& obj) {
try {
return TVMScriptPrinter::Script(obj, std::nullopt);
@@ -116,13 +112,13 @@ inline ExprDoc IR(const IRDocsifier& d, const
ffi::String& attr) {
/*! \brief Creates the TIR common prefix, which is by default `T` */
inline ExprDoc TIR(const IRDocsifier& d, const ffi::String& attr) {
d->ir_usage.insert("tirx");
- return IdDoc(d->cfg->tir_prefix)->Attr(attr);
+ return IdDoc(d->cfg->GetExtraConfig<ffi::String>("tirx.prefix",
"T"))->Attr(attr);
}
/*! \brief Creates the Relax common prefix, which is by default `R` */
inline ExprDoc Relax(const IRDocsifier& d, const ffi::String& attr) {
d->ir_usage.insert("relax");
- return IdDoc(d->cfg->relax_prefix)->Attr(attr);
+ return IdDoc(d->cfg->GetExtraConfig<ffi::String>("relax.prefix",
"R"))->Attr(attr);
}
inline std::string DType2Str(const runtime::DataType& dtype) {
@@ -137,10 +133,12 @@ inline Doc HeaderWrapper(const IRDocsifier& d, const Doc&
doc) {
stmts.push_back(CommentDoc("from tvm.script import ir as " +
d->cfg->ir_prefix));
}
if (d->ir_usage.count("tirx")) {
- stmts.push_back(CommentDoc("from tvm.script import tirx as " +
d->cfg->tir_prefix));
+ stmts.push_back(CommentDoc("from tvm.script import tirx as " +
+
d->cfg->GetExtraConfig<ffi::String>("tirx.prefix", "T")));
}
if (d->ir_usage.count("relax")) {
- stmts.push_back(CommentDoc("from tvm.script import relax as " +
d->cfg->relax_prefix));
+ stmts.push_back(CommentDoc("from tvm.script import relax as " +
+
d->cfg->GetExtraConfig<ffi::String>("relax.prefix", "R")));
}
stmts.push_back(CommentDoc(""));
stmts.push_back(Downcast<StmtDoc>(doc));
diff --git a/src/tirx/ir/expr.cc b/src/tirx/ir/expr.cc
index 77c6b12aff..1aa2407d23 100644
--- a/src/tirx/ir/expr.cc
+++ b/src/tirx/ir/expr.cc
@@ -85,7 +85,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tirx.convert",
[](ffi::Variant<PrimExpr, ffi::Array<PrimExpr>> expr)
{ return expr; });
- // Note: kRepr for VarNode/SizeVarNode is registered via TVM_SCRIPT_REPR in
+ // Note: kRepr for VarNode/SizeVarNode is registered via
TVM_REGISTER_SCRIPT_AS_REPR in
// src/script/printer/tirx/expr.cc (-> ReprPrintTIR which delegates to
TVMScriptPrinter).
}
diff --git a/src/script/ir_builder/tirx/frame.cc
b/src/tirx/script/builder/frame.cc
similarity index 98%
rename from src/script/ir_builder/tirx/frame.cc
rename to src/tirx/script/builder/frame.cc
index 659c23bf3b..5defb1b821 100644
--- a/src/script/ir_builder/tirx/frame.cc
+++ b/src/tirx/script/builder/frame.cc
@@ -17,10 +17,10 @@
* under the License.
*/
#include <tvm/script/ir_builder/ir/ir.h>
-#include <tvm/script/ir_builder/tirx/frame.h>
#include <tvm/tirx/function.h>
+#include <tvm/tirx/script/builder/frame.h>
-#include "../../../tirx/ir/script/script_complete.h"
+#include "../../ir/script/script_complete.h"
#include "./utils.h"
namespace tvm {
diff --git a/src/script/ir_builder/tirx/ir.cc b/src/tirx/script/builder/ir.cc
similarity index 97%
rename from src/script/ir_builder/tirx/ir.cc
rename to src/tirx/script/builder/ir.cc
index 9e90a418b4..7044cfe7e3 100644
--- a/src/script/ir_builder/tirx/ir.cc
+++ b/src/tirx/script/builder/ir.cc
@@ -20,7 +20,9 @@
#include <tvm/ffi/cast.h>
#include <tvm/ffi/container/variant.h>
#include <tvm/ffi/reflection/registry.h>
-#include <tvm/script/ir_builder/tirx/ir.h>
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/tirx/script/builder/ir.h>
#include "./utils.h"
@@ -884,6 +886,18 @@ TVM_FFI_STATIC_INIT_BLOCK() {
[](PrimExpr a, PrimExpr b) -> PrimExpr { return tvm::min(a, b); })
.def("script.ir_builder.tirx.max",
[](PrimExpr a, PrimExpr b) -> PrimExpr { return tvm::max(a, b); });
+ // Registry: "script.ir_builder.decl_function.tirx.PrimFunc" — derives the
+ // GlobalVar struct_info for a tirx PrimFunc declared via I.DeclFunction.
+ // The IR layer's DeclFunction looks up this key on the function's type-key
+ // when no pre-existing struct_info_ is set.
+ refl::GlobalDef().def("script.ir_builder.decl_function.tirx.PrimFunc",
+ [](const BaseFunc& func) -> ffi::ObjectRef {
+ const auto* prim_func =
func.as<tvm::tirx::PrimFuncNode>();
+ TVM_FFI_ICHECK(prim_func != nullptr)
+ << "Expected tirx::PrimFunc, got " <<
func->GetTypeKey();
+ return tvm::relax::FuncStructInfo::OpaqueFunc(
+
tvm::relax::StructInfoFromType(prim_func->ret_type));
+ });
}
} // namespace tirx
} // namespace ir_builder
diff --git a/src/script/ir_builder/tirx/utils.h
b/src/tirx/script/builder/utils.h
similarity index 96%
rename from src/script/ir_builder/tirx/utils.h
rename to src/tirx/script/builder/utils.h
index 542a577c77..9504519126 100644
--- a/src/script/ir_builder/tirx/utils.h
+++ b/src/tirx/script/builder/utils.h
@@ -16,13 +16,13 @@
* specific language governing permissions and limitations
* under the License.
*/
-#ifndef TVM_SCRIPT_IR_BUILDER_TIR_UTILS_H_
-#define TVM_SCRIPT_IR_BUILDER_TIR_UTILS_H_
+#ifndef TVM_TIRX_SCRIPT_BUILDER_UTILS_H_
+#define TVM_TIRX_SCRIPT_BUILDER_UTILS_H_
#include <tvm/ffi/cast.h>
-#include <tvm/script/ir_builder/tirx/frame.h>
-#include <tvm/script/ir_builder/tirx/ir.h>
#include <tvm/tirx/op.h>
+#include <tvm/tirx/script/builder/frame.h>
+#include <tvm/tirx/script/builder/ir.h>
#include <tvm/tirx/stmt.h>
namespace tvm {
@@ -138,4 +138,4 @@ inline tvm::tirx::BufferRegion
BufferRegionFromLoad(tvm::tirx::BufferLoad buffer
} // namespace script
} // namespace tvm
-#endif // TVM_SCRIPT_IR_BUILDER_TIR_UTILS_H_
+#endif // TVM_TIRX_SCRIPT_BUILDER_UTILS_H_
diff --git a/src/script/printer/tirx/block.cc b/src/tirx/script/printer/block.cc
similarity index 98%
rename from src/script/printer/tirx/block.cc
rename to src/tirx/script/printer/block.cc
index 00f8fc0df0..6c86d68ff5 100644
--- a/src/script/printer/tirx/block.cc
+++ b/src/tirx/script/printer/block.cc
@@ -231,8 +231,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
return PrintBlock(d, block, p, std::nullopt, std::nullopt);
});
-TVM_SCRIPT_REPR(tirx::SBlockNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::SBlockRealizeNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::SBlockNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::SBlockRealizeNode, ReprPrintTIR);
} // namespace printer
} // namespace script
diff --git a/src/script/printer/tirx/buffer.cc
b/src/tirx/script/printer/buffer.cc
similarity index 96%
rename from src/script/printer/tirx/buffer.cc
rename to src/tirx/script/printer/buffer.cc
index 0e8e1dbb93..eb34153557 100644
--- a/src/script/printer/tirx/buffer.cc
+++ b/src/tirx/script/printer/buffer.cc
@@ -90,7 +90,7 @@ ffi::Map<ffi::String, ExprDoc> BufferAttrs(tirx::Buffer
buffer, const AccessPath
kwargs.Set("shape", TupleDoc(results));
}
// Step 2. Handle `buffer.dtype`
- if (buffer->dtype != d->cfg->buffer_dtype) {
+ if (buffer->dtype != d->cfg->GetExtraConfig<DataType>("tirx.buffer_dtype",
DataType::Float(32))) {
kwargs.Set("dtype", LiteralDoc::DataType(buffer->dtype,
buffer_p->Attr("dtype")));
}
// Step 3. Handle `buffer.data`
@@ -342,12 +342,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
return prefix[BufferIndices(load->indices, p->Attr("indices"), d)];
});
-TVM_SCRIPT_REPR(tirx::BufferRegionNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::BufferLoadNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::BufferStoreNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::BufferNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::MatchBufferRegionNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::ProducerLoadNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::BufferRegionNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::BufferLoadNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::BufferStoreNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::BufferNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::MatchBufferRegionNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::ProducerLoadNode, ReprPrintTIR);
} // namespace printer
} // namespace script
diff --git a/src/script/printer/tirx/expr.cc b/src/tirx/script/printer/expr.cc
similarity index 90%
rename from src/script/printer/tirx/expr.cc
rename to src/tirx/script/printer/expr.cc
index fa2fe05b43..6d2e13cbd4 100644
--- a/src/script/printer/tirx/expr.cc
+++ b/src/tirx/script/printer/expr.cc
@@ -386,38 +386,38 @@ TVM_SCRIPT_PRINTER_DEF_BINARY(Max, "max");
#undef TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR
#undef TVM_SCRIPT_PRINTER_DEF_BINARY
-TVM_SCRIPT_REPR(tirx::VarNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::SizeVarNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::IterVarNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::StringImmNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::CastNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::AddNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::SubNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::MulNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::DivNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::ModNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::FloorDivNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::FloorModNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::MinNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::MaxNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::LTNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::LENode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::EQNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::NENode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::GTNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::GENode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::AndNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::OrNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::NotNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::SelectNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::RampNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::BroadcastNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::LetNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::CallNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::ShuffleNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::CommReducerNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::IndexMapNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::ReduceNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::VarNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::SizeVarNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::IterVarNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::StringImmNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::CastNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::AddNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::SubNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::MulNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::DivNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::ModNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::FloorDivNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::FloorModNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::MinNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::MaxNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::LTNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::LENode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::EQNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::NENode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::GTNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::GENode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::AndNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::OrNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::NotNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::SelectNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::RampNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::BroadcastNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::LetNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::CallNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::ShuffleNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::CommReducerNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::IndexMapNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::ReduceNode, ReprPrintTIR);
} // namespace printer
} // namespace script
diff --git a/src/script/printer/tirx/for_loop.cc
b/src/tirx/script/printer/for_loop.cc
similarity index 98%
rename from src/script/printer/tirx/for_loop.cc
rename to src/tirx/script/printer/for_loop.cc
index a4d9d2c4b0..9897dd2189 100644
--- a/src/script/printer/tirx/for_loop.cc
+++ b/src/tirx/script/printer/for_loop.cc
@@ -131,7 +131,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
return ForDoc(lhs, rhs, (*f)->stmts);
});
-TVM_SCRIPT_REPR(tirx::ForNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::ForNode, ReprPrintTIR);
} // namespace printer
} // namespace script
diff --git a/src/script/printer/tirx/function.cc
b/src/tirx/script/printer/function.cc
similarity index 99%
rename from src/script/printer/tirx/function.cc
rename to src/tirx/script/printer/function.cc
index 27212685ab..a743539c53 100644
--- a/src/script/printer/tirx/function.cc
+++ b/src/tirx/script/printer/function.cc
@@ -204,7 +204,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
/*body=*/(*f)->stmts));
});
-TVM_SCRIPT_REPR(tirx::PrimFuncNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::PrimFuncNode, ReprPrintTIR);
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tvm::GlobalVar>(
//
diff --git a/src/script/printer/tirx/ir.cc b/src/tirx/script/printer/ir.cc
similarity index 93%
rename from src/script/printer/tirx/ir.cc
rename to src/tirx/script/printer/ir.cc
index 4a7517599b..57bec5a561 100644
--- a/src/script/printer/tirx/ir.cc
+++ b/src/tirx/script/printer/ir.cc
@@ -95,11 +95,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
return TIR(d, "target")->Call({d->AsDoc<ExprDoc>(config, p)});
});
-TVM_SCRIPT_REPR(IntImmNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(FloatImmNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(PrimTypeNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(PointerTypeNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(TupleTypeNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(IntImmNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(FloatImmNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(PrimTypeNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(PointerTypeNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(TupleTypeNode, ReprPrintTIR);
} // namespace printer
} // namespace script
diff --git a/src/script/printer/tirx/stmt.cc b/src/tirx/script/printer/stmt.cc
similarity index 94%
rename from src/script/printer/tirx/stmt.cc
rename to src/tirx/script/printer/stmt.cc
index 7794eadd62..3c3ab21f93 100644
--- a/src/script/printer/tirx/stmt.cc
+++ b/src/tirx/script/printer/stmt.cc
@@ -16,7 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
-#include "../../../tirx/transform/ir_utils.h" // For `GetPtrStorageScope`
+#include "../../transform/ir_utils.h" // For `GetPtrStorageScope`
#include "./utils.h"
namespace tvm {
@@ -246,10 +246,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
return DoConciseScoping(lhs, rhs.value(), &(*f)->stmts, concise);
});
-TVM_SCRIPT_REPR(tirx::BindNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::AttrStmtNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::AssertStmtNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::WhileNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::BindNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::AttrStmtNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::AssertStmtNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::WhileNode, ReprPrintTIR);
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tirx::AllocBuffer>( //
"", [](tirx::AllocBuffer stmt, AccessPath p, IRDocsifier d) -> Doc {
@@ -287,7 +287,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
args.push_back(TupleDoc(shape_docs));
}
// dtype (positional, skip if default float32)
- if (buffer->dtype != d->cfg->buffer_dtype) {
+ if (buffer->dtype !=
+ d->cfg->GetExtraConfig<DataType>("tirx.buffer_dtype",
DataType::Float(32))) {
args.push_back(LiteralDoc::DataType(buffer->dtype,
buffer_p->Attr("dtype")));
}
// scope (keyword, skip if "global")
@@ -309,11 +310,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
return AssignDoc(lhs, rhs, std::nullopt);
});
-TVM_SCRIPT_REPR(tirx::AllocBufferNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::DeclBufferNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::SeqStmtNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::IfThenElseNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tirx::EvaluateNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::AllocBufferNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::DeclBufferNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::SeqStmtNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::IfThenElseNode, ReprPrintTIR);
+TVM_REGISTER_SCRIPT_AS_REPR(tirx::EvaluateNode, ReprPrintTIR);
} // namespace printer
} // namespace script
} // namespace tvm
diff --git a/src/script/printer/tirx/utils.h b/src/tirx/script/printer/utils.h
similarity index 98%
rename from src/script/printer/tirx/utils.h
rename to src/tirx/script/printer/utils.h
index f2c4272125..8dc6e703bc 100644
--- a/src/script/printer/tirx/utils.h
+++ b/src/tirx/script/printer/utils.h
@@ -16,8 +16,8 @@
* specific language governing permissions and limitations
* under the License.
*/
-#ifndef TVM_SCRIPT_PRINTER_TIR_UTILS_H_
-#define TVM_SCRIPT_PRINTER_TIR_UTILS_H_
+#ifndef TVM_TIRX_SCRIPT_PRINTER_UTILS_H_
+#define TVM_TIRX_SCRIPT_PRINTER_UTILS_H_
#include <tvm/ffi/cast.h>
#include <tvm/ffi/reflection/registry.h>
@@ -36,7 +36,7 @@
#include <utility>
#include <vector>
-#include "../utils.h"
+#include "../../../script/printer/utils.h"
namespace tvm {
namespace script {
@@ -290,4 +290,4 @@ class OccurrenceCounter : public tirx::StmtExprVisitor {
} // namespace script
} // namespace tvm
-#endif // TVM_SCRIPT_PRINTER_TIR_UTILS_H_
+#endif // TVM_TIRX_SCRIPT_PRINTER_UTILS_H_