This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d26ea6ff51 [REFACTOR][SCRIPT] tvmscript streamline: lift printer.h,
restore one-way dep, migrate dialect config to extra_config (#19631)
d26ea6ff51 is described below
commit d26ea6ff5113e78148324447681f284bc96e2dc2
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu May 28 14:17:32 2026 -0400
[REFACTOR][SCRIPT] tvmscript streamline: lift printer.h, restore one-way
dep, migrate dialect config to extra_config (#19631)
## Background
The `tvm::ir` layer previously had a reverse dependency on
`tvm::script`, injected via the `TVM_OBJECT_ENABLE_SCRIPT_PRINTER()`
macro that added a `Script()` member method to IR node types (IRModule,
PrimExpr, Buffer, PrimFunc, Stmt). This violated the intended one-way
dependency: `script` should depend on `ir`, never the other way around.
Additionally, `PrinterConfigNode` accumulated dialect-specific fields
(`tir_prefix`, `tir_import_module`, `tirx_prefix`, `relax_prefix`) that
created leakage between the generic printer infrastructure and dialect
internals.
## Changes
This PR restores the clean dependency direction and encapsulates dialect
config properly, in 5 commits:
1. **Lift TVMScript entry point into `script/printer/printer.h`**: New
header `include/tvm/script/printer/printer.h` introduces:
- `tvm::Script()` free function replacing `TVMScriptPrinter::Script()`
static method
- `TVMScriptPrinter` class with vtable (`NodeFunctor<std::string(...)>`)
- `TVM_REGISTER_SCRIPT_AS_REPR` macro for registering per-type repr
callbacks
2. **Drop `TVM_OBJECT_ENABLE_SCRIPT_PRINTER` macro**: Remove the macro
from all IR headers (`ir/expr.h`, `ir/module.h`, `tirx/buffer.h`,
`tirx/function.h`, `tirx/stmt.h`), eliminating the reverse `ir` →
`script` dependency. All call sites of `.Script()` member methods
updated to use `tvm::Script()`.
3. **Move dialect-specific `PrinterConfig` fields to `extra_config`**:
Remove `tir_prefix`, `tir_import_module`, `tirx_prefix`, `relax_prefix`
from `PrinterConfigNode`. Dialect internals now read their config via
`GetExtraConfig<T>(key, fallback)` with dotted keys (e.g.,
`"tirx.prefix"`). `buffer_dtype` is kept as a top-level field alongside
`int_dtype`/`float_dtype` since it is a shared scalar-literal default,
not a dialect-specific knob.
4. **Python: drop dialect kwargs, expose `extra_config`**: Update
`PrinterConfig`, `Scriptable.script()`, `Scriptable.show()`,
`Scriptable._relax_script()`, and `BasePyModule.script()` to use
`extra_config: dict | None = None` instead of individual dialect kwargs.
The tirx auto-switch logic is preserved.
5. **Fix transitive include breakage**: Explicitly add direct includes
for `config.h` and `node_functor.h` where headers previously relied on
transitive paths through `expr.h`/`module.h`.
## Testing
- C++ unit tests: 118/118 pass
- TVMScript printer tests: 771 passed, 1 skipped, 1 xfailed
- TIR namespace tests
(`tests/python/tirx/test_printer_tir_namespaces.py`): 13/13 pass
- Relax AST printer tests: 24/24 pass
- Minimal platform tests: 37/37 pass
- Pre-commit (ASF headers, ruff, clang-format): all clean
---
include/tvm/ir/expr.h | 3 -
include/tvm/ir/module.h | 3 -
include/tvm/script/ir_builder/base.h | 1 +
include/tvm/script/printer/config.h | 65 +++--------------
include/tvm/script/printer/doc.h | 1 +
include/tvm/script/printer/ir_docsifier.h | 1 +
include/tvm/script/printer/printer.h | 71 ++++++++++++++++++
include/tvm/tirx/buffer.h | 2 -
include/tvm/tirx/function.h | 2 -
include/tvm/tirx/stmt.h | 3 -
python/tvm/relax/base_py_module.py | 8 +--
python/tvm/runtime/script_printer.py | 83 +++++++---------------
src/s_tir/meta_schedule/database/json_database.cc | 14 ++--
src/s_tir/schedule/error.cc | 4 +-
src/script/printer/config.cc | 4 +-
src/script/printer/script_printer.cc | 24 ++-----
src/script/printer/utils.h | 13 +---
src/tirx/script/printer/buffer.cc | 7 +-
tests/cpp/tir_scalable_datatype.cc | 3 +-
tests/python/tirx/test_printer_tir_namespaces.py | 2 +-
.../tirx/transform/test_transform_lower_tirx.py | 10 +--
21 files changed, 146 insertions(+), 178 deletions(-)
diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index fcd267163c..c351dd83d8 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -31,7 +31,6 @@
#include <tvm/ir/cow.h>
#include <tvm/ir/source_map.h>
#include <tvm/ir/type.h>
-#include <tvm/script/printer/config.h>
#include <algorithm>
#include <functional>
@@ -113,8 +112,6 @@ class PrimExprNode : public BaseExprNode {
refl::ObjectDef<PrimExprNode>().def_ro("dtype", &PrimExprNode::dtype);
}
- TVM_OBJECT_ENABLE_SCRIPT_PRINTER();
-
static constexpr const uint32_t _type_child_slots = 40;
TVM_FFI_DECLARE_OBJECT_INFO("ir.PrimExpr", PrimExprNode, BaseExprNode);
};
diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h
index 6a5f41ca8d..34a451be08 100644
--- a/include/tvm/ir/module.h
+++ b/include/tvm/ir/module.h
@@ -34,7 +34,6 @@
#include <tvm/ir/global_info.h>
#include <tvm/ir/source_map.h>
#include <tvm/ir/type.h>
-#include <tvm/script/printer/config.h>
#include <string>
#include <unordered_map>
@@ -241,8 +240,6 @@ class IRModuleNode : public ffi::Object {
*/
TVM_DLL std::unordered_set<ffi::String> Imports() const;
- TVM_OBJECT_ENABLE_SCRIPT_PRINTER();
-
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.IRModule", IRModuleNode, ffi::Object);
diff --git a/include/tvm/script/ir_builder/base.h
b/include/tvm/script/ir_builder/base.h
index a459df6ee6..0d9c8ccc4f 100644
--- a/include/tvm/script/ir_builder/base.h
+++ b/include/tvm/script/ir_builder/base.h
@@ -23,6 +23,7 @@
#include <tvm/ir/cast.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/function.h>
+#include <tvm/ir/node_functor.h>
#include <vector>
diff --git a/include/tvm/script/printer/config.h
b/include/tvm/script/printer/config.h
index 19510e76a8..541d66f635 100644
--- a/include/tvm/script/printer/config.h
+++ b/include/tvm/script/printer/config.h
@@ -18,7 +18,11 @@
*/
/*!
* \file tvm/script/printer/config.h
- * \brief Printer class to print repr string of each AST/IR nodes.
+ * \brief Configuration object for the TVMScript printer.
+ *
+ * Contains PrinterConfig / PrinterConfigNode, GetBuiltinKeywords,
GetExtraConfig,
+ * and RedirectedReprPrinterMethod. The entry-point free function
tvm::Script()
+ * and the dispatch vtable TVMScriptPrinter live in printer.h.
*/
#ifndef TVM_SCRIPT_PRINTER_CONFIG_H_
#define TVM_SCRIPT_PRINTER_CONFIG_H_
@@ -30,7 +34,6 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/string.h>
#include <tvm/ir/cast.h>
-#include <tvm/ir/node_functor.h>
#include <tvm/runtime/data_type.h>
#include <string>
@@ -45,25 +48,13 @@ class PrinterConfigNode : public ffi::Object {
bool show_meta = false;
/*! \brief The prefix of IR nodes */
ffi::String ir_prefix = "I";
- /*! \brief The prefix of TIR nodes */
- ffi::String tir_prefix = "T";
- /*!
- * \brief The TIR module name used in the printed import (e.g. "tir" or
"tirx").
- * Used in the header comment: "from tvm.script import <tir_import_module>
as <tir_prefix>".
- * When tir_prefix is "Tx", set to "tirx" so the printed script uses "import
tirx as Tx".
- */
- ffi::String tir_import_module = "tir";
- /*! \brief The prefix of TIRX nodes */
- ffi::String tirx_prefix = "Tx";
- /*! \brief Default buffer dtype */
- DataType buffer_dtype = DataType::Float(32);
- /*! \brief The prefix of Relax nodes */
- ffi::String relax_prefix = "R";
/*!
* \brief The alias of the current module at cross-function call
* \note Directly use module name if it's empty.
*/
ffi::String module_alias = "cls";
+ /*! \brief Default buffer dtype */
+ DataType buffer_dtype = DataType::Float(32);
/*! \brief Default data type of integer literals */
DataType int_dtype = DataType::Int(32);
/*!
@@ -99,7 +90,6 @@ class PrinterConfigNode : public ffi::Object {
*
* Keys are conventionally namespaced as "<dialect>.<knob>", e.g.:
* "tirx.prefix" — the TIR prefix (default "T")
- * "tirx.buffer_dtype" — default buffer dtype (default float32)
* "relax.prefix" — the Relax prefix (default "R")
* "relax.show_all_struct_info" — whether to show all struct info (default
true)
*
@@ -127,6 +117,7 @@ class PrinterConfigNode : public ffi::Object {
.def_ro("show_meta", &PrinterConfigNode::show_meta)
.def_ro("ir_prefix", &PrinterConfigNode::ir_prefix)
.def_ro("module_alias", &PrinterConfigNode::module_alias)
+ .def_ro("buffer_dtype", &PrinterConfigNode::buffer_dtype)
.def_ro("int_dtype", &PrinterConfigNode::int_dtype)
.def_ro("float_dtype", &PrinterConfigNode::float_dtype)
.def_ro("verbose_expr", &PrinterConfigNode::verbose_expr)
@@ -156,48 +147,14 @@ class TVM_DLL PrinterConfig : public ffi::ObjectRef {
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PrinterConfig, ffi::ObjectRef,
PrinterConfigNode);
};
-/*! \brief TVMScript-based printer for IR nodes. */
-class TVMScriptPrinter {
- public:
- /* Convert the object to TVMScript format */
- TVM_DLL static std::string Script(const ffi::ObjectRef& node,
- const ffi::Optional<PrinterConfig>& cfg);
- // Allow registration to be printer.
- using FType = NodeFunctor<std::string(const ffi::ObjectRef&, const
PrinterConfig&)>;
- TVM_DLL static FType& vtable();
-};
-
-#define TVM_OBJECT_ENABLE_SCRIPT_PRINTER()
\
- std::string Script(const ffi::Optional<PrinterConfig>& config =
std::nullopt) const { \
- return TVMScriptPrinter::Script(ffi::GetRef<ffi::ObjectRef>(this),
\
- config.value_or(PrinterConfig()));
\
- }
-
/*!
- * \brief The fallback body used by TVM_REGISTER_SCRIPT_AS_REPR.
+ * \brief The fallback body used by TVM_REGISTER_SCRIPT_AS_REPR (defined in
printer.h).
*
- * Tries to format \p obj via TVMScriptPrinter::Script; on error falls back to
- * a plain address string. Defined in src/script/printer/config.cc so that
+ * Tries to format \p obj via tvm::Script; on error falls back to a plain
+ * address string. Defined in src/script/printer/config.cc so that
* <tvm/runtime/logging.h> is not pulled into this public header.
*/
TVM_DLL std::string RedirectedReprPrinterMethod(const ffi::ObjectRef& obj);
-/*!
- * \brief Register Script as the kRepr callback for ObjectType and install
- * the per-type dispatch entry in TVMScriptPrinter::vtable().
- *
- * \param ObjectType The concrete object node type (e.g. tirx::VarNode).
- * \param Method The TVMScriptPrinter vtable dispatch function.
- */
-#define TVM_REGISTER_SCRIPT_AS_REPR(ObjectType, Method)
\
- TVM_FFI_STATIC_INIT_BLOCK() {
\
- namespace refl = tvm::ffi::reflection;
\
- refl::TypeAttrDef<ObjectType>().def(refl::type_attr::kRepr,
\
- [](ffi::ObjectRef obj, ffi::Function)
-> ffi::String { \
- return
RedirectedReprPrinterMethod(obj); \
- });
\
- }
\
- TVM_STATIC_IR_FUNCTOR(TVMScriptPrinter,
vtable).set_dispatch<ObjectType>(Method)
-
} // namespace tvm
#endif // TVM_SCRIPT_PRINTER_CONFIG_H_
diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h
index c602fc80a4..d63942ac71 100644
--- a/include/tvm/script/printer/doc.h
+++ b/include/tvm/script/printer/doc.h
@@ -24,6 +24,7 @@
#include <tvm/ir/expr.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/device_api.h>
+#include <tvm/script/printer/config.h>
#include <string>
diff --git a/include/tvm/script/printer/ir_docsifier.h
b/include/tvm/script/printer/ir_docsifier.h
index e49d4f8a1c..32f2281828 100644
--- a/include/tvm/script/printer/ir_docsifier.h
+++ b/include/tvm/script/printer/ir_docsifier.h
@@ -23,6 +23,7 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/cast.h>
#include <tvm/ir/module.h>
+#include <tvm/script/printer/config.h>
#include <tvm/script/printer/doc.h>
#include <tvm/script/printer/ir_docsifier_functor.h>
diff --git a/include/tvm/script/printer/printer.h
b/include/tvm/script/printer/printer.h
new file mode 100644
index 0000000000..6ace9b8420
--- /dev/null
+++ b/include/tvm/script/printer/printer.h
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/script/printer/printer.h
+ * \brief Entry-point header for TVMScript printing.
+ *
+ * Declares the free function `tvm::Script(node, optional_config)` and the
+ * dispatch vtable `TVMScriptPrinter::vtable()` used by per-dialect printers.
+ * `PrinterConfig` and its dataclass helpers live in config.h; this header is
+ * what callers include to invoke printing.
+ */
+#ifndef TVM_SCRIPT_PRINTER_PRINTER_H_
+#define TVM_SCRIPT_PRINTER_PRINTER_H_
+
+#include <tvm/ir/node_functor.h>
+#include <tvm/script/printer/config.h>
+
+namespace tvm {
+
+/*! \brief Print \p node as TVMScript with the given \p config.
+ *
+ * Falls back to ffi::ReprPrint for types not registered with
TVMScriptPrinter.
+ */
+TVM_DLL std::string Script(const ffi::ObjectRef& node,
+ const ffi::Optional<PrinterConfig>& config =
std::nullopt);
+
+/*! \brief Dispatch vtable used by per-dialect printers to register their
+ * object-type printing functions. Internal, but exposed here because
+ * TVM_REGISTER_SCRIPT_AS_REPR refers to it.
+ */
+class TVMScriptPrinter {
+ public:
+ using FType = NodeFunctor<std::string(const ffi::ObjectRef&, const
PrinterConfig&)>;
+ TVM_DLL static FType& vtable();
+};
+
+/*!
+ * \brief Register Script as the kRepr callback for ObjectType and install
+ * the per-type dispatch entry in TVMScriptPrinter::vtable().
+ *
+ * \param ObjectType The concrete object node type (e.g. tirx::VarNode).
+ * \param Method The TVMScriptPrinter vtable dispatch function.
+ */
+#define TVM_REGISTER_SCRIPT_AS_REPR(ObjectType, Method)
\
+ TVM_FFI_STATIC_INIT_BLOCK() {
\
+ namespace refl = tvm::ffi::reflection;
\
+ refl::TypeAttrDef<ObjectType>().def(refl::type_attr::kRepr,
\
+ [](ffi::ObjectRef obj, ffi::Function)
-> ffi::String { \
+ return
RedirectedReprPrinterMethod(obj); \
+ });
\
+ }
\
+ TVM_STATIC_IR_FUNCTOR(TVMScriptPrinter,
vtable).set_dispatch<ObjectType>(Method)
+
+} // namespace tvm
+#endif // TVM_SCRIPT_PRINTER_PRINTER_H_
diff --git a/include/tvm/tirx/buffer.h b/include/tvm/tirx/buffer.h
index b32b06b755..a5146600f4 100644
--- a/include/tvm/tirx/buffer.h
+++ b/include/tvm/tirx/buffer.h
@@ -28,7 +28,6 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/string.h>
#include <tvm/ir/expr.h>
-#include <tvm/script/printer/config.h>
#include <tvm/tirx/layout.h>
#include <tvm/tirx/var.h>
@@ -166,7 +165,6 @@ class BufferNode : public ffi::Object {
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Buffer", BufferNode, ffi::Object);
- TVM_OBJECT_ENABLE_SCRIPT_PRINTER();
};
/*!
diff --git a/include/tvm/tirx/function.h b/include/tvm/tirx/function.h
index 45a8600a6e..0fae5bb961 100644
--- a/include/tvm/tirx/function.h
+++ b/include/tvm/tirx/function.h
@@ -29,7 +29,6 @@
#include <tvm/ir/cow.h>
#include <tvm/ir/function.h>
#include <tvm/runtime/tensor.h>
-#include <tvm/script/printer/config.h>
#include <tvm/tirx/buffer.h>
#include <tvm/tirx/expr.h>
#include <tvm/tirx/stmt.h>
@@ -120,7 +119,6 @@ class PrimFuncNode : public BaseFuncNode {
*/
TVM_DLL FuncType func_type_annotation() const;
- TVM_OBJECT_ENABLE_SCRIPT_PRINTER();
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.PrimFunc", PrimFuncNode,
BaseFuncNode);
};
diff --git a/include/tvm/tirx/stmt.h b/include/tvm/tirx/stmt.h
index 39cfbac0cd..2e336d1292 100644
--- a/include/tvm/tirx/stmt.h
+++ b/include/tvm/tirx/stmt.h
@@ -25,7 +25,6 @@
#define TVM_TIRX_STMT_H_
#include <tvm/ffi/reflection/registry.h>
-#include <tvm/script/printer/config.h>
#include <tvm/tirx/exec_scope.h>
#include <tvm/tirx/expr.h>
#include <tvm/tirx/layout.h>
@@ -55,8 +54,6 @@ class StmtNode : public ffi::Object {
refl::ObjectDef<StmtNode>().def_ro("span", &StmtNode::span);
}
- TVM_OBJECT_ENABLE_SCRIPT_PRINTER();
-
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const uint32_t _type_child_slots = 15;
diff --git a/python/tvm/relax/base_py_module.py
b/python/tvm/relax/base_py_module.py
index 1834c25c31..5dd8a107ae 100644
--- a/python/tvm/relax/base_py_module.py
+++ b/python/tvm/relax/base_py_module.py
@@ -501,10 +501,7 @@ class BasePyModule:
name: str | None = None,
show_meta: bool = False,
ir_prefix: str = "I",
- tir_prefix: str = "T",
- relax_prefix: str = "R",
module_alias: str = "cls",
- buffer_dtype: str = "float32",
int_dtype: str = "int32",
float_dtype: str = "void",
verbose_expr: bool = False,
@@ -514,6 +511,7 @@ class BasePyModule:
syntax_sugar: bool = True,
show_object_address: bool = False,
show_all_struct_info: bool = True,
+ extra_config: dict | None = None,
) -> str:
"""Print TVM IR into TVMScript text format with Python function
support.
@@ -525,10 +523,7 @@ class BasePyModule:
name=name,
show_meta=show_meta,
ir_prefix=ir_prefix,
- tir_prefix=tir_prefix,
- relax_prefix=relax_prefix,
module_alias=module_alias,
- buffer_dtype=buffer_dtype,
int_dtype=int_dtype,
float_dtype=float_dtype,
verbose_expr=verbose_expr,
@@ -538,6 +533,7 @@ class BasePyModule:
syntax_sugar=syntax_sugar,
show_object_address=show_object_address,
show_all_struct_info=show_all_struct_info,
+ extra_config=extra_config,
)
# If there are no Python functions, return the base script
diff --git a/python/tvm/runtime/script_printer.py
b/python/tvm/runtime/script_printer.py
index e67d950a4c..209efe77a0 100644
--- a/python/tvm/runtime/script_printer.py
+++ b/python/tvm/runtime/script_printer.py
@@ -34,10 +34,8 @@ class PrinterConfig(Object):
binding_names: Sequence[str]
show_meta: bool
ir_prefix: str
- tir_prefix: str
- tir_import_module: str
- relax_prefix: str
module_alias: str
+ buffer_dtype: str
int_dtype: str
float_dtype: str
verbose_expr: bool
@@ -58,9 +56,6 @@ class PrinterConfig(Object):
name: str | None = None,
show_meta: bool = False,
ir_prefix: str = "I",
- tir_prefix: str = "T",
- tir_import_module: str = "tir",
- relax_prefix: str = "R",
module_alias: str = "cls",
buffer_dtype: str = "float32",
int_dtype: str = "int32",
@@ -72,6 +67,7 @@ class PrinterConfig(Object):
syntax_sugar: bool = True,
show_object_address: bool = False,
show_all_struct_info: bool = True,
+ extra_config: dict | None = None,
path_to_underline: list[AccessPath] | None = None,
path_to_annotate: dict[AccessPath, str] | None = None,
obj_to_underline: list[Object] | None = None,
@@ -79,13 +75,11 @@ class PrinterConfig(Object):
) -> None:
if num_context_lines is None:
num_context_lines = -1
- cfg = {
+ cfg: dict = {
"show_meta": show_meta,
"ir_prefix": ir_prefix,
- "tir_prefix": tir_prefix,
- "tir_import_module": tir_import_module,
- "relax_prefix": relax_prefix,
"module_alias": module_alias,
+ "buffer_dtype": buffer_dtype,
"int_dtype": int_dtype,
"float_dtype": float_dtype,
"verbose_expr": verbose_expr,
@@ -99,14 +93,13 @@ class PrinterConfig(Object):
"obj_to_underline": obj_to_underline,
"obj_to_annotate": obj_to_annotate,
# Dialect-specific config via dotted keys in extra_config
- "tirx.prefix": tir_prefix,
- "tirx.buffer_dtype": buffer_dtype,
- "relax.prefix": relax_prefix,
"relax.show_all_struct_info": show_all_struct_info,
}
if name is not None:
cfg["name"] = name
+ if extra_config is not None:
+ cfg["extra_config"] = extra_config
self.__init_handle_by_constructor__(
_ffi_node_api.PrinterConfig,
cfg, # type: ignore # pylint: disable=no-member
@@ -131,11 +124,7 @@ class Scriptable:
name: str | None = None,
show_meta: bool = False,
ir_prefix: str = "I",
- tir_prefix: str = "T",
- tir_import_module: str = "tir",
- relax_prefix: str = "R",
module_alias: str = "cls",
- buffer_dtype: str = "float32",
int_dtype: str = "int32",
float_dtype: str = "void",
verbose_expr: bool = False,
@@ -145,6 +134,7 @@ class Scriptable:
syntax_sugar: bool = True,
show_object_address: bool = False,
show_all_struct_info: bool = True,
+ extra_config: dict | None = None,
path_to_underline: list[AccessPath] | None = None,
path_to_annotate: dict[AccessPath, str] | None = None,
obj_to_underline: list[Object] | None = None,
@@ -160,18 +150,9 @@ class Scriptable:
Whether to print the meta data of the object
ir_prefix : str = "I"
The prefix of AST nodes from tvm.ir
- tir_prefix : str = "T"
- The prefix of AST nodes from tvm.tir
- tir_import_module : str = "tir"
- The module name in the printed import (e.g. \"tir\" or \"tirx\").
- Use tir_import_module=\"tirx\" with tir_prefix=\"Tx\" for all-Tx
output.
- relax_prefix : str = "R"
- The prefix of AST nodes from tvm.relax
module_alias : str = "cls"
The alias of the current module at cross-function call,
Directly use module name if it's empty.
- buffer_dtype : str = "float32"
- The default data type of buffer
int_dtype : str = "int32"
The default data type of integer
float_dtype : str = "void"
@@ -192,6 +173,10 @@ class Scriptable:
If True (default), annotate all variable bindings with the struct
info of that variable. If False, only add annotations where
required for unambiguous round-trip of Relax -> TVMScript -> Relax.
+ extra_config : Optional[dict] = None
+ Dialect-specific configuration passed through to
PrinterConfig.extra_config.
+ Keys are conventionally namespaced as "<dialect>.<knob>", e.g.
+ ``{"tirx.prefix": "Tx"}``.
path_to_underline : Optional[List[AccessPath]] = None
Object path to be underlined
path_to_annotate : Optional[Dict[AccessPath, str]] = None
@@ -211,9 +196,12 @@ class Scriptable:
# printing a PrimFunc / IRModule that has no s_tir-tagged content.
# Free objects (Buffer, BufferRegion, ...) keep the default `T`/`tir`
# flavor — they have no enclosing function to indicate tirx vs s_tir.
- tir_prefix_val = tir_prefix
- tir_import_module_val = tir_import_module
- if tir_prefix == "T" and tir_import_module == "tir":
+ merged_extra: dict = {}
+ if extra_config is not None:
+ merged_extra.update(extra_config)
+
+ # Only auto-switch if the caller has not already set a tirx.prefix
override.
+ if "tirx.prefix" not in merged_extra:
from tvm.ir import IRModule # pylint:
disable=import-outside-toplevel
from tvm.tirx import PrimFunc # pylint:
disable=import-outside-toplevel
@@ -236,19 +224,15 @@ class Scriptable:
if any_prim and not any_s_tir:
switch_to_tirx = True
if switch_to_tirx:
- tir_prefix_val = "Tx"
- tir_import_module_val = "tirx"
+ merged_extra["tirx.prefix"] = "Tx"
+
return _script(
self,
PrinterConfig(
name=name,
show_meta=show_meta,
ir_prefix=ir_prefix,
- tir_prefix=tir_prefix_val,
- tir_import_module=tir_import_module_val,
- relax_prefix=relax_prefix,
module_alias=module_alias,
- buffer_dtype=buffer_dtype,
int_dtype=int_dtype,
float_dtype=float_dtype,
verbose_expr=verbose_expr,
@@ -258,6 +242,7 @@ class Scriptable:
syntax_sugar=syntax_sugar,
show_object_address=show_object_address,
show_all_struct_info=show_all_struct_info,
+ extra_config=merged_extra if merged_extra else None,
path_to_underline=path_to_underline,
path_to_annotate=path_to_annotate,
obj_to_underline=obj_to_underline,
@@ -271,11 +256,7 @@ class Scriptable:
name: str | None = None,
show_meta: bool = False,
ir_prefix: str = "I",
- tir_prefix: str = "T",
- tir_import_module: str = "tir",
- relax_prefix: str = "R",
module_alias: str = "cls",
- buffer_dtype: str = "float32",
int_dtype: str = "int32",
float_dtype: str = "void",
verbose_expr: bool = False,
@@ -284,6 +265,7 @@ class Scriptable:
num_context_lines: int = -1,
syntax_sugar: bool = True,
show_object_address: bool = False,
+ extra_config: dict | None = None,
path_to_underline: list[AccessPath] | None = None,
path_to_annotate: dict[AccessPath, str] | None = None,
obj_to_underline: list[Object] | None = None,
@@ -295,11 +277,7 @@ class Scriptable:
name=name,
show_meta=show_meta,
ir_prefix=ir_prefix,
- tir_prefix=tir_prefix,
- tir_import_module=tir_import_module,
- relax_prefix=relax_prefix,
module_alias=module_alias,
- buffer_dtype=buffer_dtype,
int_dtype=int_dtype,
float_dtype=float_dtype,
verbose_expr=verbose_expr,
@@ -308,6 +286,7 @@ class Scriptable:
num_context_lines=num_context_lines,
syntax_sugar=syntax_sugar,
show_object_address=show_object_address,
+ extra_config=extra_config,
path_to_underline=path_to_underline,
path_to_annotate=path_to_annotate,
obj_to_underline=obj_to_underline,
@@ -323,11 +302,7 @@ class Scriptable:
name: str | None = None,
show_meta: bool = False,
ir_prefix: str = "I",
- tir_prefix: str = "T",
- tir_import_module: str = "tir",
- relax_prefix: str = "R",
module_alias: str = "cls",
- buffer_dtype: str = "float32",
int_dtype: str = "int32",
float_dtype: str = "void",
verbose_expr: bool = False,
@@ -337,6 +312,7 @@ class Scriptable:
syntax_sugar: bool = True,
show_object_address: bool = False,
show_all_struct_info: bool = True,
+ extra_config: dict | None = None,
path_to_underline: list[AccessPath] | None = None,
path_to_annotate: dict[AccessPath, str] | None = None,
obj_to_underline: list[Object] | None = None,
@@ -375,15 +351,9 @@ class Scriptable:
Whether to print the meta data of the object
ir_prefix : str = "I"
The prefix of AST nodes from tvm.ir
- tir_prefix : str = "T"
- The prefix of AST nodes from tvm.tirx
- relax_prefix : str = "R"
- The prefix of AST nodes from tvm.relax
module_alias : str = "cls"
The alias of the current module at cross-function call,
Directly use module name if it's empty.
- buffer_dtype : str = "float32"
- The default data type of buffer
int_dtype : str = "int32"
The default data type of integer
float_dtype : str = "void"
@@ -404,6 +374,8 @@ class Scriptable:
If True (default), annotate all variable bindings with the struct
info of that variable. If False, only add annotations where
required for unambiguous round-trip of Relax -> TVMScript -> Relax.
+ extra_config : Optional[dict] = None
+ Dialect-specific configuration passed through to
PrinterConfig.extra_config.
path_to_underline : Optional[List[AccessPath]] = None
Object path to be underlined
path_to_annotate : Optional[Dict[AccessPath, str]] = None
@@ -425,11 +397,7 @@ class Scriptable:
name=name,
show_meta=show_meta,
ir_prefix=ir_prefix,
- tir_prefix=tir_prefix,
- tir_import_module=tir_import_module,
- relax_prefix=relax_prefix,
module_alias=module_alias,
- buffer_dtype=buffer_dtype,
int_dtype=int_dtype,
float_dtype=float_dtype,
verbose_expr=verbose_expr,
@@ -439,6 +407,7 @@ class Scriptable:
syntax_sugar=syntax_sugar,
show_object_address=show_object_address,
show_all_struct_info=show_all_struct_info,
+ extra_config=extra_config,
path_to_underline=path_to_underline,
path_to_annotate=path_to_annotate,
obj_to_underline=obj_to_underline,
diff --git a/src/s_tir/meta_schedule/database/json_database.cc
b/src/s_tir/meta_schedule/database/json_database.cc
index 8705412fa2..9722dc39b4 100644
--- a/src/s_tir/meta_schedule/database/json_database.cc
+++ b/src/s_tir/meta_schedule/database/json_database.cc
@@ -17,6 +17,7 @@
* under the License.
*/
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/script/printer/printer.h>
#include <set>
#include <thread>
@@ -199,12 +200,13 @@ Database Database::JSONDatabase(ffi::String
path_workload, ffi::String path_tuni
workload = workloads[workload_index];
records[task_id] =
TuningRecord::FromJSON(arr->at(1).cast<ffi::ObjectRef>(), workload);
} catch (std::runtime_error& e) {
- TVM_FFI_THROW(ValueError) << "Unable to parse TuningRecord, on
line " << (task_id + 1)
- << " of file " << path_tuning_record <<
". The workload is:\n"
- << (workload.defined() ?
workload->mod->Script() : "(null)")
- << "\nThe JSONObject of TuningRecord
is:\n"
- << json_obj << "\nThe error message
is:\n"
- << e.what();
+ TVM_FFI_THROW(ValueError)
+ << "Unable to parse TuningRecord, on line " << (task_id + 1)
<< " of file "
+ << path_tuning_record << ". The workload is:\n"
+ << (workload.defined() ? tvm::Script(workload->mod) : "(null)")
+ << "\nThe JSONObject of TuningRecord is:\n"
+ << json_obj << "\nThe error message is:\n"
+ << e.what();
}
});
for (const TuningRecord& record : records) {
diff --git a/src/s_tir/schedule/error.cc b/src/s_tir/schedule/error.cc
index 422352ad88..73a29a59d5 100644
--- a/src/s_tir/schedule/error.cc
+++ b/src/s_tir/schedule/error.cc
@@ -16,6 +16,8 @@
* specific language governing permissions and limitations
* under the License.
*/
+#include <tvm/script/printer/printer.h>
+
#include "./utils.h"
namespace tvm {
@@ -47,7 +49,7 @@ ffi::String ScheduleError::RenderReport(const ffi::String&
primitive) const {
}
os << "ScheduleError: An error occurred in the schedule primitive '" <<
primitive
<< "'.\n\nThe IR with diagnostic is:\n"
- << TVMScriptPrinter::Script(mod, cfg) << std::endl;
+ << tvm::Script(mod, cfg) << std::endl;
// print error message
os << "Error message: " << msg;
diff --git a/src/script/printer/config.cc b/src/script/printer/config.cc
index d68aaff2ce..87ca87979f 100644
--- a/src/script/printer/config.cc
+++ b/src/script/printer/config.cc
@@ -17,7 +17,7 @@
* under the License.
*/
#include <tvm/runtime/logging.h>
-#include <tvm/script/printer/config.h>
+#include <tvm/script/printer/printer.h>
#include <sstream>
@@ -25,7 +25,7 @@ namespace tvm {
std::string RedirectedReprPrinterMethod(const ffi::ObjectRef& obj) {
try {
- return TVMScriptPrinter::Script(obj, std::nullopt);
+ return tvm::Script(obj, std::nullopt);
} catch (const tvm::ffi::Error& e) {
LOG(WARNING) << "TVMScript printer falls back to the basic address printer
with the error:\n"
<< e.what();
diff --git a/src/script/printer/script_printer.cc
b/src/script/printer/script_printer.cc
index f3fc27cf42..d595898c91 100644
--- a/src/script/printer/script_printer.cc
+++ b/src/script/printer/script_printer.cc
@@ -21,7 +21,7 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/cast.h>
#include <tvm/ir/expr.h>
-#include <tvm/script/printer/config.h>
+#include <tvm/script/printer/printer.h>
#include <algorithm>
@@ -34,8 +34,7 @@ TVMScriptPrinter::FType& TVMScriptPrinter::vtable() {
return inst;
}
-std::string TVMScriptPrinter::Script(const ffi::ObjectRef& node,
- const ffi::Optional<PrinterConfig>& cfg) {
+std::string Script(const ffi::ObjectRef& node, const
ffi::Optional<PrinterConfig>& cfg) {
if (!TVMScriptPrinter::vtable().can_dispatch(node)) {
// Fall back to ffi::ReprPrint for types not registered with
TVMScriptPrinter.
return std::string(ffi::ReprPrint(ffi::Any(node)));
@@ -68,18 +67,12 @@ PrinterConfig::PrinterConfig(ffi::Map<ffi::String, Any>
config_dict) {
if (auto v = config_dict.Get("ir_prefix")) {
n->ir_prefix = Downcast<ffi::String>(v.value());
}
- if (auto v = config_dict.Get("tir_prefix")) {
- n->tir_prefix = Downcast<ffi::String>(v.value());
- }
- if (auto v = config_dict.Get("tir_import_module")) {
- n->tir_import_module = Downcast<ffi::String>(v.value());
- }
- if (auto v = config_dict.Get("relax_prefix")) {
- n->relax_prefix = Downcast<ffi::String>(v.value());
- }
if (auto v = config_dict.Get("module_alias")) {
n->module_alias = Downcast<ffi::String>(v.value());
}
+ if (auto v = config_dict.Get("buffer_dtype")) {
+ n->buffer_dtype =
DataType(ffi::StringToDLDataType(Downcast<ffi::String>(v.value())));
+ }
if (auto v = config_dict.Get("int_dtype")) {
n->int_dtype =
DataType(ffi::StringToDLDataType(Downcast<ffi::String>(v.value())));
}
@@ -129,11 +122,6 @@ PrinterConfig::PrinterConfig(ffi::Map<ffi::String, Any>
config_dict) {
n->extra_config.Set(ffi::String(key), v.value());
}
}
- // "tirx.buffer_dtype" is passed as a DLDataType string from Python; convert
to DataType.
- if (auto v = config_dict.Get("tirx.buffer_dtype")) {
- DataType dt(ffi::StringToDLDataType(Downcast<ffi::String>(v.value())));
- n->extra_config.Set(ffi::String("tirx.buffer_dtype"), ffi::Any(dt));
- }
// Boolean dialect keys.
if (auto v = config_dict.Get("relax.show_all_struct_info")) {
n->extra_config.Set(ffi::String("relax.show_all_struct_info"), v.value());
@@ -174,7 +162,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
refl::GlobalDef()
.def("node.PrinterConfig",
[](ffi::Map<ffi::String, Any> config_dict) { return
PrinterConfig(config_dict); })
- .def("node.TVMScriptPrinterScript", TVMScriptPrinter::Script);
+ .def("node.TVMScriptPrinterScript", tvm::Script);
}
} // namespace tvm
diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h
index 67fbf8e155..e1b59aa0c7 100644
--- a/src/script/printer/utils.h
+++ b/src/script/printer/utils.h
@@ -26,8 +26,8 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/base.h>
#include <tvm/runtime/logging.h>
-#include <tvm/script/printer/config.h>
#include <tvm/script/printer/ir_docsifier.h>
+#include <tvm/script/printer/printer.h>
#include <sstream>
#include <string>
@@ -46,17 +46,6 @@ namespace printer {
// definition here would force the dialect headers to depend on this shared
// header, which the per-dialect restructure aims to avoid for cross-directory
// references. See each `<dialect>/script/printer/utils.h` for the macro.
-inline std::string RedirectedReprPrinterMethod(const ffi::ObjectRef& obj) {
- try {
- return TVMScriptPrinter::Script(obj, std::nullopt);
- } catch (const tvm::ffi::Error& e) {
- LOG(WARNING) << "TVMScript printer falls back to the basic address printer
with the error:\n"
- << e.what();
- std::ostringstream os;
- os << obj->GetTypeKey() << '(' << obj.get() << ')';
- return os.str();
- }
-}
inline std::string Docsify(const ffi::ObjectRef& obj, const IRDocsifier& d,
const Frame& f,
const PrinterConfig& cfg) {
diff --git a/src/tirx/script/printer/buffer.cc
b/src/tirx/script/printer/buffer.cc
index 32d50a8f8d..2333eb8900 100644
--- a/src/tirx/script/printer/buffer.cc
+++ b/src/tirx/script/printer/buffer.cc
@@ -92,8 +92,11 @@ ffi::Map<ffi::String, ExprDoc> BufferAttrs(tirx::Buffer
buffer, const AccessPath
kwargs.Set("shape", TupleDoc(results));
}
// Step 2. Handle `buffer.dtype`
- if (buffer->dtype != d->cfg->buffer_dtype) {
- kwargs.Set("dtype", LiteralDoc::DataType(buffer->dtype,
buffer_p->Attr("dtype")));
+ {
+ DataType default_buf_dtype = d->cfg->buffer_dtype;
+ if (buffer->dtype != default_buf_dtype) {
+ kwargs.Set("dtype", LiteralDoc::DataType(buffer->dtype,
buffer_p->Attr("dtype")));
+ }
}
// Step 3. Handle `buffer.data`
// For tmem scope, DeclBuffer does not accept `data` (it auto-creates the
data var).
diff --git a/tests/cpp/tir_scalable_datatype.cc
b/tests/cpp/tir_scalable_datatype.cc
index fd9f76eee3..5ead9c7d40 100644
--- a/tests/cpp/tir_scalable_datatype.cc
+++ b/tests/cpp/tir_scalable_datatype.cc
@@ -20,6 +20,7 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <tvm/runtime/data_type.h>
+#include <tvm/script/printer/printer.h>
#include <tvm/tirx/builtin.h>
#include <tvm/tirx/expr.h>
@@ -195,7 +196,7 @@ TEST(ScalableDataType, TestScalableIntrinCall) {
::llvm::Intrinsic::experimental_stepvector)});
#endif
ASSERT_EQ(call->dtype, scalable_type);
- ASSERT_EQ(call->Script(),
+ ASSERT_EQ(tvm::Script(call),
#if TVM_LLVM_VERSION >= 200
"T.call_llvm_intrin(\"int32xvscalex4\", \"llvm.stepvector\")");
#else
diff --git a/tests/python/tirx/test_printer_tir_namespaces.py
b/tests/python/tirx/test_printer_tir_namespaces.py
index 79d37ea571..50fdd4eea9 100644
--- a/tests/python/tirx/test_printer_tir_namespaces.py
+++ b/tests/python/tirx/test_printer_tir_namespaces.py
@@ -21,7 +21,7 @@ from tvm import tirx as tir
def _assert_print(obj, expected):
# Use Tx prefix so standalone TIR nodes (non-PrimFunc) print as Tx to
match tirx namespace
- out = obj.script(verbose_expr=True, tir_prefix="Tx",
tir_import_module="tirx").strip()
+ out = obj.script(verbose_expr=True, extra_config={"tirx.prefix":
"Tx"}).strip()
assert out == expected.strip()
diff --git a/tests/python/tirx/transform/test_transform_lower_tirx.py
b/tests/python/tirx/transform/test_transform_lower_tirx.py
index c8434f5055..3e20d61f80 100644
--- a/tests/python/tirx/transform/test_transform_lower_tirx.py
+++ b/tests/python/tirx/transform/test_transform_lower_tirx.py
@@ -953,7 +953,7 @@ def
test_lower_exec_context_keeps_plain_predicate_condition():
with tvm.target.Target("cuda"):
lowered = LowerTIRx()(tvm.IRModule({"main": before}))
- script = lowered.script(tir_prefix="Tx", tir_import_module="tirx")
+ script = lowered.script(extra_config={"tirx.prefix": "Tx"})
assert "if wg_id == 0:" in script
assert "0 <= wg_id" not in script
assert "wg_id < 1" not in script
@@ -977,7 +977,7 @@ def
test_lower_exec_context_keeps_plain_scope_predicate_condition():
with tvm.target.Target("cuda"):
lowered = LowerTIRx()(tvm.IRModule({"main": before}))
- script = lowered.script(tir_prefix="Tx", tir_import_module="tirx")
+ script = lowered.script(extra_config={"tirx.prefix": "Tx"})
assert "if wg_id == 0:" in script
assert "0 <= wg_id" not in script
assert "wg_id < 1" not in script
@@ -1002,7 +1002,7 @@ def
test_simplify_uses_floor_div_scope_predicate_as_context_fact():
lowered = LowerTIRx()(tvm.IRModule({"main": before}))
simplified = Simplify()(lowered)
- script = simplified.script(tir_prefix="Tx", tir_import_module="tirx")
+ script = simplified.script(extra_config={"tirx.prefix": "Tx"})
assert "if warp_id_in_cta // 4 == 0:" in script
assert "if 0 <= warp_id_in_cta" not in script
assert "A_1[warp_id_in_cta] = Tx.Cast" in script
@@ -1018,7 +1018,7 @@ def
test_lower_exec_context_selector_filter_for_elect_sync():
@register_dispatch("copy", "cuda", variant=variant, priority=10_000)
def _probe(op_call, sctx):
- seen.append(sctx.inter["laneid"][1].script(tir_prefix="Tx",
tir_import_module="tirx"))
+
seen.append(sctx.inter["laneid"][1].script(extra_config={"tirx.prefix": "Tx"}))
@Tx.prim_func(private=True)
def impl():
@@ -1088,7 +1088,7 @@ def
test_lower_exec_context_scope_guard_mixes_structural_and_selector():
assert _int_pair(seen[0]["inter"], "warpid") == (1, 0)
assert int(seen[0]["inter"]["laneid"][0]) == 1
assert (
- seen[0]["inter"]["laneid"][1].script(tir_prefix="Tx",
tir_import_module="tirx")
+ seen[0]["inter"]["laneid"][1].script(extra_config={"tirx.prefix":
"Tx"})
== "Tx.selector(lane_id, Tx.ptx.elect_sync())"
)
assert len(seen[0]["intra"]) == 0