This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0082836d2d [REFACTOR][RELAX] Phase out Relax PrimType (#19858)
0082836d2d is described below
commit 0082836d2d2bf55892d10644dfd2913274e95b88
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Jun 22 11:27:20 2026 -0400
[REFACTOR][RELAX] Phase out Relax PrimType (#19858)
Summary:
- Remove the Relax-specific PrimType node/API and use canonical
ir.PrimType for dtype-only scalar types.
- Update parser, printer, analysis, op inference/legalization, and tests
to avoid value-bearing PrimType semantics.
- Preserve scalar values where needed by reading PrimValue expressions
directly instead of storing values in the type.
---
include/tvm/relax/type.h | 35 -----------
python/tvm/relax/__init__.py | 1 -
python/tvm/relax/backend/metal/coreml.py | 3 +-
python/tvm/relax/expr.py | 4 +-
python/tvm/relax/script/parser/entry.py | 35 +++++------
python/tvm/relax/testing/ast_printer.py | 2 +-
.../tvm/relax/transform/lazy_transform_params.py | 2 +-
python/tvm/relax/transform/legalize_ops/index.py | 15 +++--
python/tvm/relax/type.py | 73 +---------------------
python/tvm/relax/utils.py | 31 ++++-----
src/relax/analysis/type_analysis.cc | 56 ++---------------
src/relax/backend/vm/vm_shape_lower.cc | 7 ---
src/relax/ir/block_builder.cc | 9 ---
src/relax/ir/dataflow_expr_rewriter.cc | 2 +-
src/relax/ir/dependent_type.cc | 27 --------
src/relax/ir/expr.cc | 2 +-
src/relax/ir/type_functor.cc | 19 +-----
src/relax/op/memory/view.cc | 4 +-
src/relax/op/op.cc | 2 +-
src/relax/op/tensor/index.cc | 53 +++++++++-------
src/relax/op/tensor/inspect.cc | 50 +++++++--------
src/relax/script/printer/dependent_type.cc | 17 -----
src/relax/transform/fuse_tir.cc | 18 +++---
src/relax/transform/remove_unused_parameters.cc | 2 +-
src/relax/utils.cc | 13 ----
src/tirx/ir/function.cc | 4 +-
tests/cpp/nested_msg_test.cc | 10 +--
tests/python/relax/test_analysis_type_analysis.py | 23 +++----
tests/python/relax/test_ast_printer.py | 2 +-
.../relax/test_backend_transform_shape_lower.py | 3 +
tests/python/relax/test_bind_symbolic_vars.py | 1 +
tests/python/relax/test_blockbuilder_core.py | 4 +-
tests/python/relax/test_blockbuilder_emit_te.py | 14 ++---
tests/python/relax/test_dataflow_rewriter.py | 1 +
tests/python/relax/test_expr.py | 8 +--
tests/python/relax/test_op_binary.py | 14 ++---
tests/python/relax/test_op_manipulate.py | 4 +-
.../relax/test_transform_compute_prim_value.py | 3 +
.../relax/test_transform_lazy_transform_params.py | 3 +
.../test_transform_remove_unused_parameters.py | 3 +
.../test_transform_rewrite_dataflow_reshape.py | 4 +-
tests/python/relax/test_tvmscript_parser.py | 64 +------------------
tests/python/relax/test_tvmscript_printer_relax.py | 15 +++--
tests/python/relax/test_type.py | 26 ++------
tests/python/relax/test_utils.py | 6 +-
tests/python/relax/test_vm_build.py | 2 +
tests/python/tirx-base/test_tir_specialize.py | 4 +-
.../python/tvmscript/test_tvmscript_parser_tir.py | 10 +--
48 files changed, 206 insertions(+), 504 deletions(-)
diff --git a/include/tvm/relax/type.h b/include/tvm/relax/type.h
index fcaddc3ab5..9174e66bdd 100644
--- a/include/tvm/relax/type.h
+++ b/include/tvm/relax/type.h
@@ -114,41 +114,6 @@ class ObjectType : public Type {
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ObjectType, Type,
ObjectTypeNode);
};
-/*!
- * \brief Primitive value.
- */
-class PrimTypeNode : public TypeNode {
- public:
- /*! \brief Underlying primitive value, if known */
- ffi::Optional<PrimExpr> value;
-
- /*! \brief Underlying data type of the primitive value */
- DataType dtype;
-
- static void RegisterReflection() {
- namespace refl = tvm::ffi::reflection;
- refl::ObjectDef<PrimTypeNode>()
- .def_ro("value", &PrimTypeNode::value)
- .def_ro("dtype", &PrimTypeNode::dtype);
- }
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.PrimType", PrimTypeNode, TypeNode);
-};
-
-/*!
- * \brief Managed reference to PrimTypeNode.
- * \sa PrimTypeNode
- */
-class PrimType : public Type {
- public:
- /* Construct a PrimType with a known dtype, but unknown value */
- TVM_DLL PrimType(DataType dtype, Span span = Span());
-
- /* Construct a PrimType with a known value */
- TVM_DLL PrimType(PrimExpr value, Span span = Span());
-
- TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PrimType, Type, PrimTypeNode);
-};
-
/*!
* \brief Type of shape value.
*/
diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py
index 3eea8b0b02..b0dca99248 100644
--- a/python/tvm/relax/__init__.py
+++ b/python/tvm/relax/__init__.py
@@ -54,7 +54,6 @@ from .expr import const, extern, get_shape_of
from .type import (
Type,
ObjectType,
- PrimType,
ShapeType,
TensorType,
TupleType,
diff --git a/python/tvm/relax/backend/metal/coreml.py
b/python/tvm/relax/backend/metal/coreml.py
index d0b7ea3fc8..0152e965b9 100644
--- a/python/tvm/relax/backend/metal/coreml.py
+++ b/python/tvm/relax/backend/metal/coreml.py
@@ -24,6 +24,7 @@ import tvm_ffi
import tvm
from tvm.contrib import coreml_runtime
+from tvm.ir import PrimType
from tvm.relax import transform
from tvm.relax.dpl.pattern import is_op, wildcard
from tvm.relax.expr import (
@@ -37,7 +38,7 @@ from tvm.relax.expr import (
VarBinding,
)
from tvm.relax.transform import PatternCheckContext
-from tvm.relax.type import PrimType, TensorType
+from tvm.relax.type import TensorType
from tvm.support.xcode import compile_coreml
from ...expr_functor import PyExprVisitor, visitor
diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py
index e02bbb51ca..ac10fb45ae 100644
--- a/python/tvm/relax/expr.py
+++ b/python/tvm/relax/expr.py
@@ -419,7 +419,7 @@ class _DLTensorShapeProxy(tvm.runtime.ObjectConvertible):
if not isinstance(axis, tvm.relax.Expr):
axis = tvm.relax.PrimValue(axis)
- if axis.ty is not None and not isinstance(axis.ty, tvm.relax.PrimType):
+ if axis.ty is not None and not isinstance(axis.ty, tvm.ir.PrimType):
raise TypeError(
f"The index used to access {self.tensor}.shape "
f'must have type R.Prim("int64"), '
@@ -487,7 +487,7 @@ class _DLTensorStrideProxy(tvm.runtime.ObjectConvertible):
if not isinstance(axis, tvm.relax.Expr):
axis = tvm.relax.PrimValue(axis)
- if axis.ty is not None and not isinstance(axis.ty, tvm.relax.PrimType):
+ if axis.ty is not None and not isinstance(axis.ty, tvm.ir.PrimType):
raise TypeError(
f"The index used to access {self.tensor}.strides "
f'must have type R.Prim("int64"), '
diff --git a/python/tvm/relax/script/parser/entry.py
b/python/tvm/relax/script/parser/entry.py
index cbcc07ce1b..7301a73a1a 100644
--- a/python/tvm/relax/script/parser/entry.py
+++ b/python/tvm/relax/script/parser/entry.py
@@ -20,12 +20,12 @@ from collections.abc import Callable as _Callable
from typing import Any, TypeVar
import tvm
+from tvm.ir import PrimType
from tvm.relax import (
Expr,
Function,
FuncType,
ObjectType,
- PrimType,
SeqExpr,
ShapeExpr,
ShapeType,
@@ -444,7 +444,6 @@ def Shape(values: list[PrimExpr] | None = None, ndim: int =
-1) -> ShapeProxy:
class PrimProxy(TypeProxy):
dtype: str | None
- value: int | float | str | PrimExpr | None
"""The type of TIR-representable values.
@@ -453,8 +452,6 @@ class PrimProxy(TypeProxy):
dtype : Optional[str]
The data type.
- value: Optional[Union[int, float, str, PrimExpr]]
- The known value
"""
def __init__(
@@ -462,26 +459,23 @@ class PrimProxy(TypeProxy):
dtype: str | None = None,
value: int | float | str | PrimExpr | None = None,
) -> None:
- if dtype is None and value is None:
- raise TypeError(
- "R.Prim missing required argument. Must provide either
'dtype' or 'value'"
- )
+ if dtype is None:
+ if isinstance(value, PrimExpr):
+ dtype = value.dtype
+ elif isinstance(value, float):
+ dtype = "float32"
+ elif value is not None:
+ dtype = "int64"
+ else:
+ raise TypeError("R.Prim missing required argument 'dtype'")
self.dtype = dtype
- self.value = value
def get_symbolic_vars(self) -> set[str]:
- if isinstance(self.value, str) and self.value.isidentifier():
- return {self.value}
- else:
- return set()
+ return set()
def as_ty(self, dict_globals: dict[str, Any] | None = None) -> PrimType:
- if self.value is None:
- return PrimType(dtype=self.dtype)
- else:
- value = _eval_shape(self.value, dict_globals)
- return PrimType(dtype=self.dtype, value=value)
+ return PrimType(self.dtype)
def Prim(
@@ -515,7 +509,10 @@ def _normalize_ty_proxy(annotation) -> TypeProxy:
if annotation is None:
return TupleProxy([])
elif callable(annotation):
- return annotation()
+ annotation = annotation()
+ if isinstance(annotation, PrimExpr):
+ return PrimProxy(annotation.dtype)
+ return annotation
elif isinstance(annotation, TypeProxy):
return annotation
else:
diff --git a/python/tvm/relax/testing/ast_printer.py
b/python/tvm/relax/testing/ast_printer.py
index eb20d8ab5d..220d113a23 100644
--- a/python/tvm/relax/testing/ast_printer.py
+++ b/python/tvm/relax/testing/ast_printer.py
@@ -281,7 +281,7 @@ class ASTPrinter(ExprFunctor):
return self.build_ast_node("ShapeType", **fields)
elif isinstance(ty_node, relax.ObjectType):
return self.build_ast_node("ObjectType")
- elif isinstance(ty_node, relax.PrimType):
+ elif isinstance(ty_node, tvm.ir.PrimType):
return self.build_ast_node("PrimType", dtype=ty_node.dtype)
elif isinstance(ty_node, relax.TensorType):
fields = {}
diff --git a/python/tvm/relax/transform/lazy_transform_params.py
b/python/tvm/relax/transform/lazy_transform_params.py
index 432426bf74..e49ad1c948 100644
--- a/python/tvm/relax/transform/lazy_transform_params.py
+++ b/python/tvm/relax/transform/lazy_transform_params.py
@@ -216,7 +216,7 @@ class LazyTransformParamsFuncCreator:
# direct iterate over the type annotation
for param in func.params[num_input:]:
for ty in unpack_ty(param.ty):
- if isinstance(ty, relax.PrimType | relax.ShapeType):
+ if isinstance(ty, tvm.ir.PrimType | relax.ShapeType):
params.append(relax.Var("symbolic_var_holder", ty))
return relax.Function(
diff --git a/python/tvm/relax/transform/legalize_ops/index.py
b/python/tvm/relax/transform/legalize_ops/index.py
index b71d8958e0..5c7fdca1f4 100644
--- a/python/tvm/relax/transform/legalize_ops/index.py
+++ b/python/tvm/relax/transform/legalize_ops/index.py
@@ -18,11 +18,12 @@
"""Default legalization function for index operators."""
from tvm import te, tirx, topi
+from tvm.ir import PrimType
from ...block_builder import BlockBuilder
-from ...expr import Call, Expr
+from ...expr import Call, Expr, PrimValue, Tuple
from ...op import tensor_to_shape
-from ...type import PrimType, ShapeType
+from ...type import ShapeType
from .common import register_legalize
@@ -36,11 +37,17 @@ def _take(bb: BlockBuilder, call: Call) -> Expr:
@register_legalize("relax.strided_slice")
def _strided_slice(bb: BlockBuilder, call: Call) -> Expr:
def _relax_tuple_to_tir(relax_tuple):
+ if isinstance(relax_tuple, Tuple):
+ output = []
+ for field in relax_tuple.fields:
+ assert isinstance(field, PrimValue)
+ output.append(field.value)
+ return output
+
output = []
for field in relax_tuple.ty.fields:
assert isinstance(field, PrimType)
- assert field.value is not None
- output.append(field.value)
+ return None
return output
if len(call.args) == 4:
diff --git a/python/tvm/relax/type.py b/python/tvm/relax/type.py
index 3ba5a86b90..ad8f469826 100644
--- a/python/tvm/relax/type.py
+++ b/python/tvm/relax/type.py
@@ -21,10 +21,7 @@
import tvm_ffi
from tvm_ffi import Array
-import tvm
-from tvm.ir import EnvFunc, Span, TupleType, VDevice
-from tvm.runtime import DataType
-from tvm.tirx import PrimExpr
+from tvm.ir import EnvFunc, PrimExpr, Span, TupleType, VDevice
from . import _ffi_api
from .expr import Expr, ShapeExpr, Type
@@ -39,72 +36,6 @@ class ObjectType(Type):
self.__init_handle_by_constructor__(_ffi_api.ObjectType, span) #
type: ignore
-@tvm_ffi.register_object("relax.PrimType")
-class PrimType(Type):
- """Type of a primitive POD value.
-
- Parameters
- ----------
- dtype_or_expr : Union[str, DataType, PrimExpr]
-
- The data type of the prim value, or a known expression for the prim
- value.
- """
-
- value: PrimExpr | None
- dtype: str
-
- def __init__(
- self,
- dtype: str | DataType | None = None,
- value: int | float | PrimExpr | None = None,
- span: Span = None,
- ) -> None:
- # Guard against incorrect usage. For backwards compatibility,
- # the dtype and value are in the opposite order from most
- # usages. While PrimType could take a single positional
- # argument and check the type, this would require an API
- # difference from TVMScript's PrimProxy, which cannot.
- # (PrimProxy uses string arguments for datatype, and also for
- # inline variable definitions when used in a function
- # signature, and requires separate arguments to distinguish
- # the two cases.)
- if isinstance(dtype, PrimExpr | int | float):
- raise TypeError(
- f"The first positional argument of PrimType must be the
datatype, "
- f", but received {type(dtype)}. "
- f"The value can be specified as a keyword argument "
- f"without needing specifying the dtype: "
- f"PrimType(value=arg)."
- )
-
- if dtype is None and value is None:
- raise TypeError(
- "PrimType.__init__ missing required argument. "
- "Must provide either 'dtype' or 'value'"
- )
-
- if dtype is not None:
- if isinstance(value, PrimExpr):
- assert value.dtype == dtype, (
- "When providing both 'value' and 'dtype' to
PrimType.__init__, "
- "they must be consistent with each other. "
- "However, the value {value} has dtype {value.dtype}, "
- "but the specified dtype was {dtype}."
- )
- elif isinstance(value, int | float):
- value = tvm.tirx.const(value, dtype)
-
- # Use relax's default integer type if not otherwise specified.
- if isinstance(value, int):
- value = tvm.tirx.IntImm("int64", value)
-
- if value is None:
- self.__init_handle_by_constructor__(_ffi_api.PrimTypeFromDtype,
dtype, span) # type: ignore
- else:
- self.__init_handle_by_constructor__(_ffi_api.PrimTypeFromValue,
value, span) # type: ignore
-
-
@tvm_ffi.register_object("relax.ShapeType")
class ShapeType(Type):
"""Type of a shape value.
@@ -261,5 +192,5 @@ class FuncType(Type):
"""
if isinstance(derive_func, str):
- derive_func = tvm.ir.EnvFunc.get("tvm.relax.type.infer_view_ty")
+ derive_func = EnvFunc.get("tvm.relax.type.infer_view_ty")
return _ffi_api.FuncTypeOpaqueFunc(ret, derive_func, purity, span) #
type: ignore
diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py
index 89c9ac82c1..d143b43eaf 100644
--- a/python/tvm/relax/utils.py
+++ b/python/tvm/relax/utils.py
@@ -29,6 +29,7 @@ import tvm_ffi
from tvm_ffi import Array, Map
import tvm
+from tvm.ir import PrimType
from .. import tirx
from ..ir import Attrs, Type, VDevice
@@ -38,7 +39,7 @@ from ..tirx import PrimExpr
from . import _ffi_api
from .expr import Expr, Function, PrimValue, ShapeExpr, StringImm, te_tensor
from .expr import Tuple as rx_Tuple
-from .type import PrimType, ShapeType, TensorType
+from .type import ShapeType, TensorType
def metadata_partitioner(rx_txt: str) -> list[str]:
@@ -250,23 +251,23 @@ def gen_call_tir_inputs(
return [_convert_te_arg_helper(val) for val in arg.values]
if isinstance(arg.ty, PrimType):
- if arg.ty.value is None:
- n_args = len(create_primfunc_args)
- if isinstance(arg, tvm.relax.Var):
- name = arg.name_hint
- elif n_args < len(string.ascii_lowercase):
- name = string.ascii_lowercase[n_args]
- else:
- name = f"scalar_input_{n_args}"
+ if isinstance(arg, PrimValue):
+ return _convert_te_arg_helper(arg.value)
- tir_param = tirx.Var(name, arg.ty.dtype)
+ n_args = len(create_primfunc_args)
+ if isinstance(arg, tvm.relax.Var):
+ name = arg.name_hint
+ elif n_args < len(string.ascii_lowercase):
+ name = string.ascii_lowercase[n_args]
+ else:
+ name = f"scalar_input_{n_args}"
- call_tir_args.append(arg)
- create_primfunc_args.append(tir_param)
+ tir_param = tirx.Var(name, arg.ty.dtype)
- return tir_param
- else:
- return _convert_te_arg_helper(arg.ty.value)
+ call_tir_args.append(arg)
+ create_primfunc_args.append(tir_param)
+
+ return tir_param
elif isinstance(arg, list | Array):
return [_convert_te_arg_helper(x) for x in arg]
diff --git a/src/relax/analysis/type_analysis.cc
b/src/relax/analysis/type_analysis.cc
index b6c272a827..33070051ae 100644
--- a/src/relax/analysis/type_analysis.cc
+++ b/src/relax/analysis/type_analysis.cc
@@ -87,8 +87,6 @@ Type TypeFromStaticType(const Type& type) {
return ObjectType(type->span);
} else if (const PrimTypeNode* prim_type = type.as<PrimTypeNode>()) {
return PrimType(prim_type->dtype, prim_type->span);
- } else if (const tvm::PrimTypeNode* prim_type =
type.as<tvm::PrimTypeNode>()) {
- return PrimType(prim_type->dtype, prim_type->span);
} else if (const ShapeTypeNode* shape_type = type.as<ShapeTypeNode>()) {
return ShapeType(shape_type->ndim, type->span);
} else if (const TensorTypeNode* tensor_type = type.as<TensorTypeNode>()) {
@@ -127,27 +125,7 @@ class WellDefinedEraser : public TypeMutator, public
ExprMutatorBase, public tir
arith::AnalyzerObj* ana)
: f_shape_var_map_(f_shape_var_map), f_var_map_(f_var_map), ana_(ana) {}
- Type VisitType_(const PrimTypeNode* op) final {
- bool has_undefined = false;
- ffi::Optional<PrimExpr> value;
-
- if (op->value.defined()) {
- std::swap(has_undefined_, has_undefined);
- value = VisitPrimExpr(op->value.value());
- std::swap(has_undefined_, has_undefined);
- }
-
- // erase symbolic shape if we have undefined.
- if (!has_undefined) {
- if (value.same_as(op->value)) {
- return ffi::GetRef<Type>(op);
- } else {
- return PrimType(value.value(), op->span);
- }
- } else {
- return PrimType(op->dtype, op->span);
- }
- }
+ Type VisitType_(const PrimTypeNode* op) final { return
ffi::GetRef<Type>(op); }
Type VisitType_(const ShapeTypeNode* op) final {
bool has_undefined = false;
@@ -341,10 +319,7 @@ class TypeBaseChecker : public
TypeFunctor<BaseCheckResult(const Type&, const Ty
return BaseCheckResult::kFailL0;
}
- if (!lhs->value.defined()) return BaseCheckResult::kPass;
- if (!rhs->value.defined()) return BaseCheckResult::kFailL2;
-
- return PrimValueMatchCheck(lhs->value.value(), rhs->value.value());
+ return BaseCheckResult::kPass;
}
BaseCheckResult VisitType_(const ShapeTypeNode* lhs, const Type& other)
final {
@@ -662,13 +637,7 @@ class TypeBasePreconditionCollector : public
TypeFunctor<PrimExpr(const Type&, c
return IntImm::Bool(false);
}
- if (lhs->value.defined() && rhs->value.defined()) {
- return lhs->value.value() == rhs->value.value();
- } else if (lhs->value.defined() && !rhs->value.defined()) {
- return IntImm::Bool(false);
- } else {
- return IntImm::Bool(true);
- }
+ return IntImm::Bool(true);
}
PrimExpr VisitType_(const ShapeTypeNode* lhs, const Type& other) final {
@@ -1019,19 +988,6 @@ class TypeLCAFinder : public TypeFunctor<Type(const
Type&, const Type&)> {
// as a result we can unify to object.
return ObjectType(lhs->span);
}
- if (!lhs->value.defined() || !rhs->value.defined() ||
- !analyzer_->CanProveEqual(lhs->value.value(), rhs->value.value())) {
- // The two values are known to contain the same dtype, but may
- // contain different values.
- if (!lhs->value.defined()) {
- // If the mismatch was due to extra information in the RHS,
- // prefer to avoid constructing a new object.
- return ffi::GetRef<Type>(lhs);
- } else {
- return PrimType(lhs->dtype, lhs->span);
- }
- }
-
return ffi::GetRef<Type>(lhs);
}
@@ -1234,11 +1190,7 @@ class TIRVarsDetector : public TypeVisitor {
}
}
- void VisitType_(const PrimTypeNode* prim_ty) final {
- if (prim_ty->value.defined()) {
- VisitPrimExpr(prim_ty->value.value());
- }
- }
+ void VisitType_(const PrimTypeNode* prim_ty) final {}
void VisitType_(const ShapeTypeNode* shape_ty) final {
if (shape_ty->values.defined()) {
diff --git a/src/relax/backend/vm/vm_shape_lower.cc
b/src/relax/backend/vm/vm_shape_lower.cc
index 9ba34d6945..8cac4a12f7 100644
--- a/src/relax/backend/vm/vm_shape_lower.cc
+++ b/src/relax/backend/vm/vm_shape_lower.cc
@@ -648,13 +648,6 @@ class VMShapeLowerMutator
{value, DataTypeImm(op->dtype), GetErrContext(err_ctx)},
Attrs(), {void_ty_});
builder_->Emit(call, "_");
}
- if (op->value.defined()) {
- MatchShapeTodoItem item;
- item.input = value;
- item.pattern = {op->value.value()};
- item.err_ctx = err_ctx;
- match_todos->push_back(item);
- }
}
void VisitType_(const ShapeTypeNode* op, Expr value, bool always_check, bool
dynamic_only,
diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc
index aab0fdf8b4..7a45902068 100644
--- a/src/relax/ir/block_builder.cc
+++ b/src/relax/ir/block_builder.cc
@@ -498,15 +498,6 @@ class BlockBuilderImpl : public BlockBuilderNode {
}
}
- void VisitType_(const PrimTypeNode* op) final {
- // Only collect single var defined shape. Ignore something like
`R.Prim(value=m + 1)`
- if (op->value.defined()) {
- if (auto var = op->value.as<tirx::Var>()) {
- shape_var_map_.Set(var.value(), op->value.value());
- }
- }
- }
-
private:
ffi::Map<tirx::Var, PrimExpr> shape_var_map_;
};
diff --git a/src/relax/ir/dataflow_expr_rewriter.cc
b/src/relax/ir/dataflow_expr_rewriter.cc
index cb59f566e6..d8100c2563 100644
--- a/src/relax/ir/dataflow_expr_rewriter.cc
+++ b/src/relax/ir/dataflow_expr_rewriter.cc
@@ -736,7 +736,7 @@ PatternMatchingRewriter
PatternMatchingRewriter::FromModule(IRModule mod) {
return ExternFuncPattern(func->global_symbol);
} else if (auto prim = expr.as<PrimValueNode>()) {
- return TypePattern(WildcardPattern(), PrimType(prim->value));
+ return TypePattern(WildcardPattern(), PrimType(prim->value.dtype()));
} else {
TVM_FFI_THROW(TypeError) << "Cannot convert Relax expression of type "
<< expr->GetTypeKey()
diff --git a/src/relax/ir/dependent_type.cc b/src/relax/ir/dependent_type.cc
index c0ee21646d..6a2034ccc2 100644
--- a/src/relax/ir/dependent_type.cc
+++ b/src/relax/ir/dependent_type.cc
@@ -32,7 +32,6 @@ namespace relax {
TVM_FFI_STATIC_INIT_BLOCK() {
ObjectTypeNode::RegisterReflection();
- PrimTypeNode::RegisterReflection();
ShapeTypeNode::RegisterReflection();
TensorTypeNode::RegisterReflection();
FuncTypeNode::RegisterReflection();
@@ -49,32 +48,6 @@ TVM_FFI_STATIC_INIT_BLOCK() {
refl::GlobalDef().def("relax.ObjectType", [](Span span) { return
ObjectType(span); });
}
-// Prim
-PrimType::PrimType(PrimExpr value, Span span) {
- ffi::ObjectPtr<PrimTypeNode> n = ffi::make_object<PrimTypeNode>();
- n->dtype = value->dtype;
- n->value = std::move(value);
- n->span = span;
- data_ = std::move(n);
-}
-
-PrimType::PrimType(DataType dtype, Span span) {
- ffi::ObjectPtr<PrimTypeNode> n = ffi::make_object<PrimTypeNode>();
- n->dtype = dtype;
- n->value = std::nullopt;
- n->span = span;
- data_ = std::move(n);
-}
-
-TVM_FFI_STATIC_INIT_BLOCK() {
- namespace refl = tvm::ffi::reflection;
- refl::GlobalDef()
- .def("relax.PrimTypeFromDtype",
- [](DataType dtype, Span span) { return PrimType(dtype, span); })
- .def("relax.PrimTypeFromValue",
- [](PrimExpr value, Span span) { return PrimType(value, span); });
-}
-
// Shape
ShapeType::ShapeType(ffi::Array<PrimExpr> values, Span span) {
ffi::ObjectPtr<ShapeTypeNode> n = ffi::make_object<ShapeTypeNode>();
diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc
index d84ddf2bae..ab9b0e92f0 100644
--- a/src/relax/ir/expr.cc
+++ b/src/relax/ir/expr.cc
@@ -363,7 +363,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
PrimValue::PrimValue(PrimExpr value, Span span) {
ffi::ObjectPtr<PrimValueNode> n = ffi::make_object<PrimValueNode>();
- n->ty = PrimType(value);
+ n->ty = PrimType(value.dtype());
n->value = std::move(value);
n->span = std::move(span);
data_ = std::move(n);
diff --git a/src/relax/ir/type_functor.cc b/src/relax/ir/type_functor.cc
index d578db704a..e173e10c03 100644
--- a/src/relax/ir/type_functor.cc
+++ b/src/relax/ir/type_functor.cc
@@ -29,11 +29,7 @@ namespace relax {
void TypeVisitor::VisitType_(const ObjectTypeNode* op) {}
-void TypeVisitor::VisitType_(const PrimTypeNode* op) {
- if (op->value.defined()) {
- this->VisitTypeExprField(op->value.value());
- }
-}
+void TypeVisitor::VisitType_(const PrimTypeNode* op) {}
void TypeVisitor::VisitType_(const ShapeTypeNode* op) {
if (op->values.defined()) {
@@ -70,18 +66,7 @@ void TypeVisitor::VisitType_(const FuncTypeNode* op) {
Type TypeMutator::VisitType_(const ObjectTypeNode* op) { return
ffi::GetRef<Type>(op); }
-Type TypeMutator::VisitType_(const PrimTypeNode* op) {
- if (!op->value.defined()) {
- return ffi::GetRef<Type>(op);
- }
-
- auto new_expr = VisitTypeExprField(op->value.value());
- if (new_expr.same_as(op->value)) {
- return ffi::GetRef<Type>(op);
- } else {
- return PrimType(new_expr);
- }
-}
+Type TypeMutator::VisitType_(const PrimTypeNode* op) { return
ffi::GetRef<Type>(op); }
Type TypeMutator::VisitType_(const ShapeTypeNode* op) {
if (!op->values.defined()) {
diff --git a/src/relax/op/memory/view.cc b/src/relax/op/memory/view.cc
index 1b21432b8d..25ad9aa66d 100644
--- a/src/relax/op/memory/view.cc
+++ b/src/relax/op/memory/view.cc
@@ -135,10 +135,10 @@ Type InferTypeView(const Call& call, const BlockBuilder&
ctx) {
<< "Operator " << call->op
<< " expects the relative_byte_offset to be a 64-bit integer, but
received "
<< arg_relative_byte_offset << ", which has type " << ty;
- if (prim_ty->value.defined()) {
+ if (const auto* prim_value =
arg_relative_byte_offset.as<PrimValueNode>()) {
// An offset of known value is applied. The known value may
// be dynamic.
- return prim_ty->value.value();
+ return prim_value->value;
} else {
// An offset of unknown value is applied.
return std::nullopt;
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index 2e1fa02591..739517de43 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -409,7 +409,7 @@ static ffi::Optional<Type>
InferCallTIROutputTypeFromArguments(
TVM_FFI_ICHECK(packed_tuple_ty);
PrimType dummy_arg_ty = [&]() {
if (packed_tuple_ty->values) {
- return PrimType(packed_tuple_ty->values.value()[i]);
+ return PrimType(packed_tuple_ty->values.value()[i].dtype());
} else {
return PrimType(DataType::Int(64));
}
diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc
index 0f82327994..665ea24e73 100644
--- a/src/relax/op/tensor/index.cc
+++ b/src/relax/op/tensor/index.cc
@@ -184,10 +184,9 @@ TVM_FFI_STATIC_INIT_BLOCK() {
*
* A `relax::Tuple` may be provided to an operator as an in-line
* expression, as a variable bound to known tuple within the current
- * function, as a function argument, etc. The Type of the tuple
- * tracks the known values of any `PrimValue` elements, but it can be
- * tedious to extract. This utility extracts the `PrimExpr` contents
- * of a `relax::Tuple`.
+ * function, as a function argument, etc. This overload validates that
+ * the Type could contain a tuple of `PrimValue` elements. Without a
+ * concrete tuple expression, the values are not statically known.
*
* If the Type cannot contain a tuple of the type specified,
* this function will throw an exception. (e.g. Attempting to extract
@@ -198,7 +197,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
*
* \param ty The Type to inspect
*
- * \returns An array of the `PrimType`, if it can be extracted.
+ * \returns An empty array for an empty tuple, if it can be extracted.
* Otherwise, `std::nullopt`.
*/
template <typename PrimType = PrimExpr,
@@ -227,12 +226,7 @@ ffi::Optional<ffi::Array<PrimType>>
UnpackTupleOfPrimValue(ffi::Optional<Type> t
<< "The type " << ty << " cannot contain a tuple whose elements are "
<< PrimType::ContainerType::_type_key << ", because element " << i <<
" has type " << field;
- if (!prim_ty->value.defined()) return std::nullopt;
-
- ffi::Optional<PrimType> element = prim_ty->value.as<PrimType>();
- if (!element) return std::nullopt;
-
- output.push_back(element.value());
+ return std::nullopt;
}
return output;
}
@@ -241,10 +235,9 @@ ffi::Optional<ffi::Array<PrimType>>
UnpackTupleOfPrimValue(ffi::Optional<Type> t
*
* A `relax::Tuple` may be provided to an operator as an in-line
* expression, as a variable bound to known tuple within the current
- * function, as a function argument, etc. The Type of the tuple
- * tracks the known values of any `PrimValue` elements, but it can be
- * tedious to extract. This utility extracts the `PrimExpr` contents
- * of a `relax::Tuple`.
+ * function, as a function argument, etc. This utility extracts
+ * `PrimValue` contents only when the concrete tuple expression is
+ * available.
*
* If the Type cannot contain a tuple of the type specified,
* this function will throw an exception. (e.g. Attempting to extract
@@ -261,11 +254,29 @@ ffi::Optional<ffi::Array<PrimType>>
UnpackTupleOfPrimValue(ffi::Optional<Type> t
template <typename PrimType = PrimExpr,
typename = std::enable_if_t<std::is_base_of_v<PrimExpr, PrimType>>>
ffi::Optional<ffi::Array<PrimType>> UnpackTupleOfPrimValue(ffi::Optional<Expr>
expr) {
- if (expr) {
- return UnpackTupleOfPrimValue<PrimType>(GetType(expr.value()));
- } else {
- return std::nullopt;
+ if (!expr) return std::nullopt;
+
+ const Expr& value = expr.value();
+ if (const auto* tuple = value.as<TupleNode>()) {
+ ffi::Array<PrimType> output;
+ for (size_t i = 0; i < tuple->fields.size(); i++) {
+ const Expr& field = tuple->fields[i];
+ auto prim_value = field.as<PrimValueNode>();
+ TVM_FFI_CHECK(prim_value, TypeError)
+ << "The expression " << value << " cannot contain a tuple whose
elements are "
+ << PrimType::ContainerType::_type_key << ", because element " << i
<< " is " << field;
+
+ TVM_FFI_CHECK(prim_value->value.template as<typename
PrimType::ContainerType>(), TypeError)
+ << "The expression " << value << " cannot contain a tuple whose
elements are "
+ << PrimType::ContainerType::_type_key << ", because element " << i
<< " has value "
+ << prim_value->value;
+
+ output.push_back(Downcast<PrimType>(prim_value->value));
+ }
+ return output;
}
+
+ return UnpackTupleOfPrimValue<PrimType>(GetType(value));
}
Type InferTypeStridedSlice(const Call& call, const BlockBuilder& ctx) {
@@ -315,7 +326,7 @@ Type InferTypeStridedSlice(const Call& call, const
BlockBuilder& ctx) {
if (!tuple) return false;
return std::all_of(tuple->fields.begin(), tuple->fields.end(), [](const
Type& field) {
- return IsBaseOf(relax::PrimType(DataType::Int(64)), field);
+ return IsBaseOf(tvm::PrimType(DataType::Int(64)), field);
});
};
auto check_tuple = [&](const char* name, Expr expr) {
@@ -454,7 +465,7 @@ InferLayoutOutput InferLayoutStridedSlice(
existing_layout = LayoutDecision(InitialLayout(tensor_ty->ndim));
}
- auto opt_axes_tuple = UnpackTupleOfPrimValue<IntImm>(GetType(call->args[1]));
+ auto opt_axes_tuple = UnpackTupleOfPrimValue<IntImm>(call->args[1]);
TVM_FFI_ICHECK(opt_axes_tuple) << "Layout inference of " << call->op
<< " requires slices to be along static axes.
"
<< "However, expression " << call
diff --git a/src/relax/op/tensor/inspect.cc b/src/relax/op/tensor/inspect.cc
index ebfbccf11e..1494d407b9 100644
--- a/src/relax/op/tensor/inspect.cc
+++ b/src/relax/op/tensor/inspect.cc
@@ -51,7 +51,7 @@ TensorType GetTensorArgInfo(const Call& call) {
return tensor_ty.value();
}
-std::tuple<TensorType, PrimType> GetTensorArgInfoWithIndex(const Call& call) {
+std::tuple<TensorType, ffi::Optional<int64_t>> GetTensorArgInfoWithIndex(const
Call& call) {
TVM_FFI_CHECK_EQ(call->args.size(), 2, TypeError)
<< "Operator " << call->op << " expects two arguments, "
<< "but received " << call->args.size() << " arguments: " << call->args;
@@ -68,19 +68,24 @@ std::tuple<TensorType, PrimType>
GetTensorArgInfoWithIndex(const Call& call) {
<< "Operator " << call->op << " expects arguments (tensor, axis), "
<< "but the second argument " << arg << " in expression " << call << "
has type " << axis->ty;
- auto int_imm_axis = axis_ty->value.as<IntImmNode>();
+ ffi::Optional<int64_t> int_imm_axis = std::nullopt;
+ if (const auto* prim_value = axis.as<PrimValueNode>()) {
+ if (const auto* int_imm = prim_value->value.as<IntImmNode>()) {
+ int_imm_axis = int_imm->value;
+ }
+ }
if (int_imm_axis) {
- TVM_FFI_ICHECK_GE(int_imm_axis->value, 0);
+ TVM_FFI_ICHECK_GE(int_imm_axis.value(), 0);
}
if (int_imm_axis && !tensor_ty->IsUnknownNdim()) {
- TVM_FFI_CHECK_LT(int_imm_axis->value, tensor_ty->ndim, ValueError)
+ TVM_FFI_CHECK_LT(int_imm_axis.value(), tensor_ty->ndim, ValueError)
<< "Expression " << call << " attempts to access " << arg << ".shape["
- << int_imm_axis->value << "]"
+ << int_imm_axis.value() << "]"
<< ", but " << arg << ".shape only has " << tensor_ty->ndim << "
elements";
}
- return {ffi::GetRef<TensorType>(tensor_ty), ffi::GetRef<PrimType>(axis_ty)};
+ return {ffi::GetRef<TensorType>(tensor_ty), int_imm_axis};
}
DataType GetTensorDataType(const Call& call) { return
GetTensorArgInfo(call)->dtype; }
@@ -106,14 +111,7 @@ tirx::PrimFunc
GetDLTensorField(tirx::builtin::TVMStructFieldKind field, DataTyp
return func;
}
-Expr NormalizeToKnownPrimValue(const BlockBuilder&, Call call) {
- if (auto prim_ty = call->ty.as<PrimTypeNode>()) {
- if (prim_ty->value.defined()) {
- return PrimValue(prim_ty->value.value());
- }
- }
- return call;
-}
+Expr NormalizeToKnownPrimValue(const BlockBuilder&, Call call) { return call; }
//// relax.tensor_dtype_code
@@ -129,7 +127,7 @@ Type InferTypeTensorDtypeCode(const Call& call, const
BlockBuilder&) {
if (dtype.is_void()) {
return PrimType(dlpack_type);
} else {
- return PrimType(IntImm(dlpack_type, dtype.code()));
+ return PrimType(dlpack_type);
}
}
@@ -167,7 +165,7 @@ Type InferTypeTensorDtypeBits(const Call& call, const
BlockBuilder&) {
if (dtype.is_void()) {
return PrimType(dlpack_type);
} else {
- return PrimType(IntImm(dlpack_type, dtype.bits()));
+ return PrimType(dlpack_type);
}
}
@@ -205,7 +203,7 @@ Type InferTypeTensorDtypeLanes(const Call& call, const
BlockBuilder&) {
if (dtype.is_void()) {
return PrimType(dlpack_type);
} else {
- return PrimType(IntImm(dlpack_type, dtype.lanes()));
+ return PrimType(dlpack_type);
}
}
@@ -243,7 +241,7 @@ Type InferTypeTensorNDim(const Call& call, const
BlockBuilder&) {
if (ty->IsUnknownNdim()) {
return PrimType(dlpack_type);
} else {
- return PrimType(IntImm(dlpack_type, ty->ndim));
+ return PrimType(dlpack_type);
}
}
@@ -277,13 +275,12 @@ Expr tensor_shape_i(Expr expr) {
Type InferTypeTensorShape(const Call& call, const BlockBuilder&) {
auto dlpack_type = DataType::Int(64);
- auto [tensor_ty, axis_ty] = GetTensorArgInfoWithIndex(call);
+ auto [tensor_ty, int_imm_axis] = GetTensorArgInfoWithIndex(call);
auto tensor_shape = tensor_ty->GetShape();
- auto int_imm_axis = axis_ty->value.as<IntImmNode>();
if (int_imm_axis && tensor_shape.defined()) {
- return PrimType(tensor_shape.value()[int_imm_axis->value]);
+ return PrimType(tensor_shape.value()[int_imm_axis.value()].dtype());
} else {
return PrimType(dlpack_type);
}
@@ -354,10 +351,9 @@ Expr tensor_stride_i(Expr expr) {
Type InferTypeTensorStride(const Call& call, const BlockBuilder&) {
auto dlpack_type = DataType::Int(64);
- auto [tensor_ty, axis_ty] = GetTensorArgInfoWithIndex(call);
+ auto [tensor_ty, int_imm_axis] = GetTensorArgInfoWithIndex(call);
auto opt_tensor_shape = tensor_ty->GetShape();
- auto int_imm_axis = axis_ty->value.as<IntImmNode>();
if (int_imm_axis && opt_tensor_shape.defined()) {
// As of 2024-03-14, Relax does not have an explicit
@@ -374,10 +370,10 @@ Type InferTypeTensorStride(const Call& call, const
BlockBuilder&) {
// for any legalizable Tensor.
auto tensor_shape = opt_tensor_shape.value();
PrimExpr stride = IntImm::Int64(1);
- for (size_t axis = int_imm_axis->value + 1; axis < tensor_shape.size();
axis++) {
+ for (size_t axis = int_imm_axis.value() + 1; axis < tensor_shape.size();
axis++) {
stride = stride * tensor_shape[axis];
}
- return PrimType(stride);
+ return PrimType(stride.dtype());
} else {
return PrimType(dlpack_type);
}
@@ -409,7 +405,7 @@ Type InferTypeTensorByteOffset(const Call& call, const
BlockBuilder&) {
// Relax implicitly requires that the byte offset is zero for any
// legalizable tensor. See InferTypeTensorStride for full
// explanation.
- return PrimType(IntImm(dlpack_type, 0));
+ return PrimType(dlpack_type);
} else {
return PrimType(dlpack_type);
}
@@ -440,7 +436,7 @@ Type InferTypeTensorElemOffset(const Call& call, const
BlockBuilder&) {
// Relax implicitly requires that the element offset is zero for
// any legalizable tensor. See InferTypeTensorStride for
// full explanation.
- return PrimType(IntImm(dlpack_type, 0));
+ return PrimType(dlpack_type);
} else {
return PrimType(dlpack_type);
}
diff --git a/src/relax/script/printer/dependent_type.cc
b/src/relax/script/printer/dependent_type.cc
index ee3aa6663c..a37c21406f 100644
--- a/src/relax/script/printer/dependent_type.cc
+++ b/src/relax/script/printer/dependent_type.cc
@@ -61,22 +61,6 @@ ExprDoc PrintShapeVar(const PrimExpr& e, const AccessPath&
e_p, const IRDocsifie
return expr_doc;
}
-TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
- .set_dispatch<relax::PrimType>("", [](relax::PrimType n, AccessPath n_p,
IRDocsifier d) -> Doc {
- ffi::Array<ExprDoc, void> args;
- ffi::Array<ffi::String> kwargs_keys;
- ffi::Array<ExprDoc, void> kwargs_values;
-
- if (n->value.defined()) {
- kwargs_keys.push_back("value");
- kwargs_values.push_back(PrintShapeVar(n->value.value(),
n_p->Attr("value"), d));
- } else {
- args.push_back(LiteralDoc::DataType(n->dtype, n_p->Attr("dtype")));
- }
-
- return Relax(d, "Prim")->Call(args, kwargs_keys, kwargs_values);
- });
-
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::ShapeType>(
"", [](relax::ShapeType n, AccessPath n_p, IRDocsifier d) -> Doc {
@@ -172,7 +156,6 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
});
TVM_REGISTER_SCRIPT_AS_REPR(relax::ObjectTypeNode, ReprPrintRelax);
-TVM_REGISTER_SCRIPT_AS_REPR(relax::PrimTypeNode, ReprPrintRelax);
TVM_REGISTER_SCRIPT_AS_REPR(relax::ShapeTypeNode, ReprPrintRelax);
TVM_REGISTER_SCRIPT_AS_REPR(relax::TensorTypeNode, ReprPrintRelax);
TVM_REGISTER_SCRIPT_AS_REPR(relax::FuncTypeNode, ReprPrintRelax);
diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index 308c9da4c3..d54e8faf6e 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -980,8 +980,7 @@ class FusedTIRConstructor : public ExprVisitor {
} else if (const auto* prim_value = ty.as<PrimTypeNode>()) {
// Case 2. The relax param is a scalar, we directly create a tirx var
- TVM_FFI_ICHECK(prim_value->value->IsInstance<tirx::VarNode>());
- out->push_back(Downcast<tirx::Var>(prim_value->value));
+ out->push_back(tirx::Var(name_hint, prim_value->dtype));
} else if (const auto* shape_expr = ty.as<ShapeTypeNode>()) {
// Case 3. The relax param is a tuple of scalars, each represented as a
tirx var
@@ -1255,13 +1254,14 @@ class TIRFuseMutator : public ExprMutator {
tir_vars.push_back(prim_value);
}
} else if (const auto* prim_value = ty.as<PrimTypeNode>()) {
- TVM_FFI_ICHECK(prim_value->value.defined())
- << "FuseTIR requires all R.Prim arguments to have a known value.";
- PrimExpr expr = prim_value->value.value();
- TVM_FFI_ICHECK(expr->IsInstance<tirx::VarNode>())
- << "FuseTIR currently requires all R.Prim "
- "arguments to provide a single tirx::Var.";
- tir_vars.push_back(expr);
+ if (const auto* literal = arg.as<PrimValueNode>()) {
+ tir_vars.push_back(literal->value);
+ } else if (const auto* var = arg.as<VarNode>()) {
+ tir_vars.push_back(tirx::Var(var->name_hint(), prim_value->dtype));
+ } else {
+ TVM_FFI_THROW(TypeError) << "FuseTIR expects scalar arguments to be
PrimValue or Var, "
+ << "but received " << arg;
+ }
} else {
arg_list.push_back(arg);
diff --git a/src/relax/transform/remove_unused_parameters.cc
b/src/relax/transform/remove_unused_parameters.cc
index 598478c9c2..218a808565 100644
--- a/src/relax/transform/remove_unused_parameters.cc
+++ b/src/relax/transform/remove_unused_parameters.cc
@@ -100,7 +100,7 @@ std::optional<CalleeAnalysis> AnalyzeCallee(Function func) {
}
for (const auto& tir_var : free_tir_vars) {
- Var relax_var("param_" + tir_var->name_hint, PrimType(tir_var));
+ Var relax_var("param_" + tir_var->name_hint, PrimType(tir_var.dtype()));
params.push_back(relax_var);
}
diff --git a/src/relax/utils.cc b/src/relax/utils.cc
index d35b32ac58..370947e4b0 100644
--- a/src/relax/utils.cc
+++ b/src/relax/utils.cc
@@ -119,18 +119,6 @@ tvm::ffi::Map<tirx::Var, PrimExpr> InferSymbolicVarMap(
}
};
- auto bind_from_prim_value = [&bind_from_prim_expr](const Type& var, const
Type& expr) {
- auto var_ty = var.as<PrimTypeNode>();
- if (!var_ty) return;
-
- auto expr_ty = expr.as<PrimTypeNode>();
- if (!expr_ty) return;
-
- if (!var_ty->value.defined() || !expr_ty->value.defined()) return;
-
- bind_from_prim_expr(var_ty->value.value(), expr_ty->value.value());
- };
-
auto bind_from_shape = [&bind_from_prim_expr](const Type& var, const Type&
expr) {
auto var_shape = var.as<ShapeTypeNode>();
if (!var_shape) return;
@@ -178,7 +166,6 @@ tvm::ffi::Map<tirx::Var, PrimExpr> InferSymbolicVarMap(
bind_from_ty = [&](const Type& var, const Type& expr) {
bind_from_tensor(var, expr);
bind_from_shape(var, expr);
- bind_from_prim_value(var, expr);
bind_from_tuple(var, expr);
};
diff --git a/src/tirx/ir/function.cc b/src/tirx/ir/function.cc
index c44c279980..d6b171481e 100644
--- a/src/tirx/ir/function.cc
+++ b/src/tirx/ir/function.cc
@@ -54,14 +54,14 @@ tvm::Type InferType(const PrimFunc& prim_func) {
return relax::ObjectType();
}
- return relax::PrimType(param->dtype);
+ return PrimType(param->dtype);
}();
params.push_back(param_ty);
}
tvm::Type ret = [&]() -> tvm::Type {
if (const auto* prim = prim_func->ret_type.as<PrimTypeNode>()) {
- return relax::PrimType(prim->dtype);
+ return PrimType(prim->dtype);
} else if (IsVoidType(prim_func->ret_type)) {
return relax::TupleType(ffi::Array<tvm::Type>{});
} else {
diff --git a/tests/cpp/nested_msg_test.cc b/tests/cpp/nested_msg_test.cc
index 7b624304e5..07d9995bbd 100644
--- a/tests/cpp/nested_msg_test.cc
+++ b/tests/cpp/nested_msg_test.cc
@@ -145,9 +145,9 @@ TEST(NestedMsg, Equal) {
}
TEST(NestedMsg, MapAndDecompose) {
- relax::Var x("x", relax::PrimType(runtime::DataType::Int(16)));
- relax::Var y("y", relax::PrimType(runtime::DataType::Int(32)));
- relax::Var z("z", relax::PrimType(runtime::DataType::Int(64)));
+ relax::Var x("x", PrimType(runtime::DataType::Int(16)));
+ relax::Var y("y", PrimType(runtime::DataType::Int(32)));
+ relax::Var z("z", PrimType(runtime::DataType::Int(64)));
BlockBuilder bb = BlockBuilder::Create(std::nullopt);
relax::Expr t0 = bb->Normalize(Tuple({x, y}));
@@ -169,7 +169,7 @@ TEST(NestedMsg, MapAndDecompose) {
[](IntImm lhs, IntImm rhs) -> bool { return lhs->value ==
rhs->value; }));
auto output2 = MapToNestedMsg<IntImm>(GetType(t1), [&](Type ty) ->
NestedMsg<IntImm> {
- const auto* prim_ty = ty.as<relax::PrimTypeNode>();
+ const auto* prim_ty = ty.as<PrimTypeNode>();
if (prim_ty == nullptr) return std::nullopt;
int bits = prim_ty->dtype.bits();
if (bits == 16) return c0;
@@ -306,7 +306,7 @@ TEST(NestedMsg, TransformTupleLeaf) {
NInt msg1 = {c0, {c0, c1}, c2, {c0, {c1, c2}}};
NInt msg2 = {c1, {c2, c0}, c2, {c1, {c2, c0}}};
- relax::PrimType s = relax::PrimType(runtime::DataType::Int(32));
+ PrimType s = PrimType(runtime::DataType::Int(32));
relax::Var x("x", s), y("y", s), z("z", s);
BlockBuilder bb = BlockBuilder::Create(std::nullopt);
Expr expr = bb->Normalize(Tuple({x, Tuple({x, x}), x, Tuple({x, Tuple({x,
x})})}));
diff --git a/tests/python/relax/test_analysis_type_analysis.py
b/tests/python/relax/test_analysis_type_analysis.py
index 20ccb4a0e7..c91c504958 100644
--- a/tests/python/relax/test_analysis_type_analysis.py
+++ b/tests/python/relax/test_analysis_type_analysis.py
@@ -35,8 +35,8 @@ def test_get_static_type_basic():
tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s0),
rx.ObjectType())
# prim
- s1 = rx.PrimType("float32")
- tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s1),
rx.PrimType("float32"))
+ s1 = tvm.ir.PrimType("float32")
+ tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s1),
tvm.ir.PrimType("float32"))
def test_get_static_type_shape():
@@ -105,7 +105,7 @@ def test_erase_to_well_defined_basic():
tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s0), s0)
# prim
- s1 = rx.PrimType("float32")
+ s1 = tvm.ir.PrimType("float32")
tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s1), s1)
@@ -208,8 +208,8 @@ def test_base_check():
n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64")
obj0 = rx.ObjectType()
- prim0 = rx.PrimType("int32")
- prim1 = rx.PrimType("float32")
+ prim0 = tvm.ir.PrimType("int32")
+ prim1 = tvm.ir.PrimType("float32")
shape0 = rx.ShapeType(ndim=-1)
shape1 = rx.ShapeType(ndim=2)
@@ -362,7 +362,7 @@ def _check_derive(ctx, finfo, args_ty, ret):
def test_derive_call_ret_type():
obj0 = rx.ObjectType()
- prim0 = rx.PrimType("float32")
+ prim0 = tvm.ir.PrimType("float32")
n, m = tirx.Var("n0", "int64"), tirx.Var("m0", "int64")
bb = rx.BlockBuilder()
@@ -517,8 +517,8 @@ def _check_lca(lhs, rhs, target):
def test_type_lca():
n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64")
obj0 = rx.ObjectType()
- prim0 = rx.PrimType("int32")
- prim1 = rx.PrimType("float32")
+ prim0 = tvm.ir.PrimType("int32")
+ prim1 = tvm.ir.PrimType("float32")
vdevice0 = ir.VDevice("llvm")
vdevice1 = ir.VDevice("cuda", 0)
@@ -764,7 +764,7 @@ def test_collect_symbolic_var_from_tensor_shape():
assert free_vars == {n, p, q}
-param_type = tvm.testing.parameter("shape_expr", "prim_value")
+param_type = tvm.testing.parameter("shape_expr")
param_order = tvm.testing.parameter("definition_first", "usage_first")
@@ -779,11 +779,6 @@ def
test_collect_symbolic_var_from_non_tensor_params(param_type, param_order):
extra_params = [
rx.Var("shape_expr", rx.ShapeType([tir_n, tir_m])),
]
- elif param_type == "prim_value":
- extra_params = [
- rx.Var("n", rx.PrimType(value=tir_n)),
- rx.Var("m", rx.PrimType(value=tir_m)),
- ]
else:
raise ValueError(f"Unknown param_type: {param_type}")
diff --git a/tests/python/relax/test_ast_printer.py
b/tests/python/relax/test_ast_printer.py
index 49c52ef888..710daf55dc 100644
--- a/tests/python/relax/test_ast_printer.py
+++ b/tests/python/relax/test_ast_printer.py
@@ -289,7 +289,7 @@ def test_ty():
assert printer.visit_ty_(rx.ObjectType()) == "ObjectType()"
- assert printer.visit_ty_(rx.PrimType("int32")) == "PrimType(dtype=int32)"
+ assert printer.visit_ty_(tvm.ir.PrimType("int32")) ==
"PrimType(dtype=int32)"
# empty shape
empty_ssi = rx.ShapeType()
diff --git a/tests/python/relax/test_backend_transform_shape_lower.py
b/tests/python/relax/test_backend_transform_shape_lower.py
index 045468e57a..f2bc877694 100644
--- a/tests/python/relax/test_backend_transform_shape_lower.py
+++ b/tests/python/relax/test_backend_transform_shape_lower.py
@@ -16,6 +16,8 @@
# under the License.
# ruff: noqa: F841
+import pytest
+
import tvm.script
import tvm.testing
from tvm import relax
@@ -816,6 +818,7 @@ def test_check_weights_with_dynamic_shape():
assert_structural_equal(after, expected)
[email protected](reason="value-bearing R.Prim annotations were removed")
def test_update_symbolic_vars_in_match_cast_rhs():
"""Symbolic variables may be used on the RHS of match_cast"""
diff --git a/tests/python/relax/test_bind_symbolic_vars.py
b/tests/python/relax/test_bind_symbolic_vars.py
index 90fe4864c7..fdc514696f 100644
--- a/tests/python/relax/test_bind_symbolic_vars.py
+++ b/tests/python/relax/test_bind_symbolic_vars.py
@@ -204,6 +204,7 @@ def test_bind_symbolic_vars_in_shape_expr():
tvm.ir.assert_structural_equal(expected, after)
[email protected](reason="value-bearing R.Prim annotations were removed")
def test_bind_defining_of_symbolic_vars_in_prim_value():
"""R.Prim may define symbolic variables
diff --git a/tests/python/relax/test_blockbuilder_core.py
b/tests/python/relax/test_blockbuilder_core.py
index 2d2eb95ec8..34be64df5d 100644
--- a/tests/python/relax/test_blockbuilder_core.py
+++ b/tests/python/relax/test_blockbuilder_core.py
@@ -643,8 +643,8 @@ def test_emit_nested_tuple(emit_nested_tuple):
n_sym = tirx.Var("n", "int64")
m_sym = tirx.Var("m", "int64")
- n = rx.Var("n", rx.PrimType(value=n_sym))
- m = rx.Var("m", rx.PrimType(value=m_sym))
+ n = rx.Var("n", tvm.ir.PrimType("int64"))
+ m = rx.Var("m", tvm.ir.PrimType("int64"))
x = rx.Var("x", rx.TensorType([n_sym, m_sym], "float32"))
y = rx.Var("y", rx.TensorType([m_sym, n_sym], "float32"))
diff --git a/tests/python/relax/test_blockbuilder_emit_te.py
b/tests/python/relax/test_blockbuilder_emit_te.py
index 0ca90e5a8b..7643d96980 100644
--- a/tests/python/relax/test_blockbuilder_emit_te.py
+++ b/tests/python/relax/test_blockbuilder_emit_te.py
@@ -75,7 +75,7 @@ def test_emit_te_with_symbolic_arg():
def test_symbolic_shape_in_prim_value():
- """Symbolic vars may be provided to TE in R.Prim"""
+ """Scalar Relax vars may be provided to TE as PrimFunc parameters."""
def te_slice(tensor, i):
return tvm.te.compute([tensor.shape[1]], lambda j: tensor[i, j],
name="slice")
@@ -83,8 +83,7 @@ def test_symbolic_shape_in_prim_value():
def from_builder():
bb = rx.BlockBuilder()
A = rx.Var("A", R.Tensor([16, 16], "float32"))
- tir_i = tvm.tirx.Var("tir_i", "int64")
- relax_i = rx.Var("relax_i", R.Prim(value=tir_i))
+ relax_i = rx.Var("relax_i", tvm.ir.PrimType("int64"))
with bb.function("main", params=[A, relax_i]):
A_sliced = bb.emit_te(te_slice, A, relax_i)
@@ -97,8 +96,8 @@ def test_symbolic_shape_in_prim_value():
@T.prim_func(private=True, s_tir=True)
def te_slice(
A: T.Buffer([T.int64(16), T.int64(16)], "float32"),
- Output: T.Buffer(T.int64(16), "float32"),
row_index: T.int64,
+ Output: T.Buffer(T.int64(16), "float32"),
):
T.func_attr({"tirx.noalias": True})
@@ -110,16 +109,13 @@ def test_symbolic_shape_in_prim_value():
@R.function
def main(
A: R.Tensor([16, 16], "float32"),
- arg_row_index: R.Prim(value="row_index"),
+ arg_row_index: R.Prim("int64"),
):
cls = Expected
- row_index = T.int64()
-
gv = R.call_tir(
cls.te_slice,
- A,
- tir_vars=[row_index],
+ (A, arg_row_index),
out_ty=R.Tensor([16], "float32"),
)
return gv
diff --git a/tests/python/relax/test_dataflow_rewriter.py
b/tests/python/relax/test_dataflow_rewriter.py
index 5264913909..d9a46ba3fd 100644
--- a/tests/python/relax/test_dataflow_rewriter.py
+++ b/tests/python/relax/test_dataflow_rewriter.py
@@ -366,6 +366,7 @@ def test_recursive_rewrite_rules():
tvm.ir.assert_structural_equal(expected, after)
[email protected](reason="value-bearing R.Prim match-cast semantics were
removed")
def test_rewrite_of_arbitrary_dtype():
"""A pattern-match may apply to a tensor with unknown dtype
diff --git a/tests/python/relax/test_expr.py b/tests/python/relax/test_expr.py
index 855361712e..b64804b4ac 100644
--- a/tests/python/relax/test_expr.py
+++ b/tests/python/relax/test_expr.py
@@ -271,7 +271,7 @@ def test_prim_value_with_var():
n = tirx.Var("n", "int64")
pv = rx.PrimValue(n)
assert pv.value.same_as(n)
- tvm.ir.assert_structural_equal(pv.ty, rx.PrimType(value=n))
+ tvm.ir.assert_structural_equal(pv.ty, tvm.ir.PrimType("int64"))
_check_equal(pv, rx.PrimValue(n))
_check_json_roundtrip(pv)
@@ -279,7 +279,7 @@ def test_prim_value_with_var():
def test_prim_value_with_expr():
n = tirx.Var("n", "int64")
pv = rx.PrimValue(n + 1)
- tvm.ir.assert_structural_equal(pv.ty, rx.PrimType(value=n + 1))
+ tvm.ir.assert_structural_equal(pv.ty, tvm.ir.PrimType("int64"))
_check_equal(pv, rx.PrimValue(n + 1))
_check_json_roundtrip(pv)
@@ -301,7 +301,7 @@ def test_datatype_imm():
def test_call():
- dtype = rx.PrimType("int32")
+ dtype = tvm.ir.PrimType("int32")
func = rx.Var("func", rx.FuncType([dtype], dtype))
arg = rx.Var("arg", dtype)
call = rx.Call(func, [arg])
@@ -312,7 +312,7 @@ def test_call():
def test_call_raises_error_for_invalid_function():
"""relax::Call requires the function to have FuncType"""
- dtype = rx.PrimType("int32")
+ dtype = tvm.ir.PrimType("int32")
func = rx.Var("func", dtype)
arg = rx.Var("arg", dtype)
diff --git a/tests/python/relax/test_op_binary.py
b/tests/python/relax/test_op_binary.py
index 953e744fb7..f5d12bbe67 100644
--- a/tests/python/relax/test_op_binary.py
+++ b/tests/python/relax/test_op_binary.py
@@ -141,7 +141,7 @@ def
test_infer_ty_binary_arith_prim_value_with_prim_value(binary_arith_op: Calla
x = relax.Var("x", R.Prim("float32"))
y = relax.Var("y", R.Prim("float32"))
- _check_inference(bb, binary_arith_op(x, y), relax.PrimType("float32"))
+ _check_inference(bb, binary_arith_op(x, y), tvm.ir.PrimType("float32"))
@pytest.mark.parametrize("binary_arith_op,tir_arith_op", binary_arith_ops)
@@ -157,8 +157,8 @@ def
test_infer_ty_binary_arith_known_prim_value_with_prim_value(
x = relax.Var("x", R.Prim(value=tir_x))
y = relax.Var("y", R.Prim(value=tir_y))
- _check_inference(bb, binary_arith_op(x, y), relax.PrimType(value=tir_x +
tir_y))
- _check_inference(bb, binary_arith_op(y, x), relax.PrimType(value=tir_y +
tir_x))
+ _check_inference(bb, binary_arith_op(x, y), tvm.ir.PrimType("float32"))
+ _check_inference(bb, binary_arith_op(y, x), tvm.ir.PrimType("float32"))
binary_cmp_ops = [
@@ -202,8 +202,8 @@ def
test_infer_ty_binary_cmp_prim_value_to_prim_value(binary_cmp_op: Callable):
bb = relax.BlockBuilder()
x = relax.Var("x", R.Prim("float32"))
y = relax.Var("y", R.Prim("float32"))
- _check_inference(bb, binary_cmp_op(x, y), relax.PrimType("bool"))
- _check_inference(bb, binary_cmp_op(y, x), relax.PrimType("bool"))
+ _check_inference(bb, binary_cmp_op(x, y), tvm.ir.PrimType("bool"))
+ _check_inference(bb, binary_cmp_op(y, x), tvm.ir.PrimType("bool"))
@pytest.mark.parametrize("binary_cmp_op,tir_cmp_op", binary_cmp_ops)
@@ -217,8 +217,8 @@ def
test_infer_ty_binary_cmp_known_prim_value_to_prim_value(binary_cmp_op: Calla
x = relax.Var("x", R.Prim(value=tir_x))
y = relax.Var("y", R.Prim(value=tir_y))
- _check_inference(bb, binary_cmp_op(x, y),
relax.PrimType(value=tir_cmp_op(tir_x, tir_y)))
- _check_inference(bb, binary_cmp_op(y, x),
relax.PrimType(value=tir_cmp_op(tir_y, tir_x)))
+ _check_inference(bb, binary_cmp_op(x, y), tvm.ir.PrimType("bool"))
+ _check_inference(bb, binary_cmp_op(y, x), tvm.ir.PrimType("bool"))
@pytest.mark.parametrize("binary_arith_op", [row[0] for row in
binary_arith_ops])
diff --git a/tests/python/relax/test_op_manipulate.py
b/tests/python/relax/test_op_manipulate.py
index c09a04893f..9a938b647d 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -289,7 +289,7 @@ def test_reshape_infer_ty_wrong_input_type():
x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3, 4, 5), "float32")))
x2 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
ns = relax.Var("ns", relax.TensorType((120,), "float32"))
- pv = relax.Var("pv", relax.PrimType("int64"))
+ pv = relax.Var("pv", tvm.ir.PrimType("int64"))
with pytest.raises(TypeError):
bb.normalize(relax.op.reshape(x0, (2, 3, 4, 5)))
@@ -2222,7 +2222,7 @@ def test_split_infer_ty_axis_out_of_range():
def test_split_infer_invalid_ty_indices():
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
- v = relax.Var("v", relax.PrimType("int64"))
+ v = relax.Var("v", tvm.ir.PrimType("int64"))
with pytest.raises(TypeError):
bb.normalize(relax.op.split(x0, [v], axis=1))
diff --git a/tests/python/relax/test_transform_compute_prim_value.py
b/tests/python/relax/test_transform_compute_prim_value.py
index 733dbc295a..a7b89f0654 100644
--- a/tests/python/relax/test_transform_compute_prim_value.py
+++ b/tests/python/relax/test_transform_compute_prim_value.py
@@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
+import pytest
+
import tvm
import tvm.testing
from tvm.script import ir as I
@@ -82,6 +84,7 @@ def test_prim_value_in_branch_condition():
tvm.ir.assert_structural_equal(After, Expected)
[email protected](reason="value-bearing R.Prim annotations were removed")
def test_prim_value_in_pure_function():
@I.ir_module
class Before:
diff --git a/tests/python/relax/test_transform_lazy_transform_params.py
b/tests/python/relax/test_transform_lazy_transform_params.py
index 5642f72a09..35482e3bc0 100644
--- a/tests/python/relax/test_transform_lazy_transform_params.py
+++ b/tests/python/relax/test_transform_lazy_transform_params.py
@@ -16,6 +16,7 @@
# under the License.
# ruff: noqa: F841
import numpy as np
+import pytest
import tvm
import tvm.testing
@@ -751,6 +752,7 @@ def test_params_without_tuple():
tvm.ir.assert_structural_equal(After, Expected)
[email protected](reason="value-bearing R.Prim annotations were removed")
def test_retain_before_num_input():
"""Only lazily load parameters after num_input"""
@@ -844,6 +846,7 @@ def test_get_item_callback():
tvm.ir.assert_structural_equal(After, Expected)
[email protected](reason="value-bearing R.Prim annotations were removed")
def test_get_item_callback_num_attrs():
@I.ir_module(s_tir=True)
class Before:
diff --git a/tests/python/relax/test_transform_remove_unused_parameters.py
b/tests/python/relax/test_transform_remove_unused_parameters.py
index 1be2a5ac2f..4c05cbdb29 100644
--- a/tests/python/relax/test_transform_remove_unused_parameters.py
+++ b/tests/python/relax/test_transform_remove_unused_parameters.py
@@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
+import pytest
+
import tvm
import tvm.testing
from tvm.script import ir as I
@@ -54,6 +56,7 @@ def test_remove_unused_relax_parameter():
tvm.ir.assert_structural_equal(After, Expected)
[email protected](reason="value-bearing R.Prim annotations were removed")
def test_replace_symbolic_variables():
"""If a parameter is only required for its symbolic variables, provide
them directly
diff --git a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
index df774656e1..ef9ef115b1 100644
--- a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
+++ b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
@@ -705,7 +705,7 @@ def test_rewrite_dynamic_reshape():
@I.ir_module(s_tir=True)
class Before:
@R.function
- def main(x: R.Tensor(["N*16"], dtype="float32"), _: R.Prim(value="N")):
+ def main(x: R.Tensor(["N", 16], dtype="float32")):
N = T.int64()
with R.dataflow():
y = R.reshape(x, [N * 4, T.int64(4)])
@@ -716,7 +716,7 @@ def test_rewrite_dynamic_reshape():
@I.ir_module(s_tir=True)
class Expected:
@R.function
- def main(x: R.Tensor(["N*16"], dtype="float32"), _: R.Prim(value="N")):
+ def main(x: R.Tensor(["N", 16], dtype="float32")):
N = T.int64()
cls = Expected
diff --git a/tests/python/relax/test_tvmscript_parser.py
b/tests/python/relax/test_tvmscript_parser.py
index 0f251940dd..902c142b8e 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -1337,7 +1337,7 @@ def test_computed_prim_value_as_branch_condition():
if_else = func.body.blocks[0].bindings[0].value
assert isinstance(if_else.cond, relax.PrimValue)
tvm.ir.assert_structural_equal(N % 16 == 0, if_else.cond.value)
- tvm.ir.assert_structural_equal(if_else.cond.ty, R.Prim(value=N % 16 == 0))
+ tvm.ir.assert_structural_equal(if_else.cond.ty, R.Prim("bool"))
def test_tir_expr_as_branch_condition():
@@ -1409,7 +1409,7 @@ def test_computed_prim_value_as_assert_condition():
condition = assert_op.args[0]
assert isinstance(condition, relax.PrimValue)
tvm.ir.assert_structural_equal(N % 16 == 0, condition.value)
- tvm.ir.assert_structural_equal(condition.ty, R.Prim(value=N % 16 == 0))
+ tvm.ir.assert_structural_equal(condition.ty, R.Prim("bool"))
def test_tir_expr_as_assert_condition():
@@ -1472,19 +1472,6 @@ def
test_erase_to_well_defined_keeps_variants_exposed_by_shape_expr():
_check(foo)
-def test_erase_to_well_defined_keeps_variants_exposed_by_prim_value():
- @R.function
- def foo(x: R.Tensor, _m: R.Prim(value="m"), _n: R.Prim(value="n")):
- q = x
- m, n = T.int64(), T.int64()
- z = R.match_cast(q, R.Tensor((m, n)))
- w = z
- return w
-
- assert foo.ret_ty.shape is not None
- _check(foo)
-
-
def test_erase_to_well_defined_infers_from_shape_expr():
@I.ir_module(s_tir=True)
class Module:
@@ -1510,33 +1497,6 @@ def test_erase_to_well_defined_infers_from_shape_expr():
_check(Module)
-def test_erase_to_well_defined_infers_from_prim_value():
- @I.ir_module(s_tir=True)
- class Module:
- # The subroutine's symbolic variables are only in-scope for the
subroutine.
- @R.function
- def subroutine(x: R.Tensor, _m: R.Prim(value="m"), _n:
R.Prim(value="n")) -> R.Tensor(
- ["m", "n"]
- ):
- q = x
- m, n = T.int64(), T.int64()
- z = R.match_cast(q, R.Tensor((m, n)))
- w = z
- return w
-
- # However, struct inference can make the symbolic variables in
- # the main function to the symbolic variables in the
- # subroutine. Therefore, the shape of the tensor returned
- # from main can have a well-defined shape.
- @R.function
- def main(x: R.Tensor, relax_m: R.Prim(value="m"), relax_n:
R.Prim(value="n")):
- output = Module.subroutine(x, relax_m, relax_n)
- return output
-
- assert Module["main"].ret_ty.shape is not None
- _check(Module)
-
-
def test_empty_tuple():
@R.function
def foo(x: R.Tuple()):
@@ -1617,26 +1577,6 @@ def test_symbolic_vars_in_shape():
_check(baz, bb.get()["baz"])
-def test_symbolic_vars_in_prim_value():
- """Symbolic variable may be defined in R.Prim"""
-
- @R.function
- def baz(x: R.Prim(value="m"), y: R.Tensor(("m * 2",), "float32")):
- m = T.int64()
- z = R.call_dps_packed("test_intrin", y, R.Tensor((m * 2,),
dtype="float32"))
- return z
-
- m = tirx.Var("m", "int64")
- x = relax.Var("x", relax.PrimType(value=m))
- y = relax.Var("y", relax.TensorType([m * 2], "float32"))
- bb = relax.BlockBuilder()
- with bb.function("baz", (x, y)):
- z = bb.emit(relax.call_dps_packed("test_intrin", (y), R.Tensor((m *
2,), dtype="float32")))
- bb.emit_func_output(z)
-
- _check(baz, bb.get()["baz"])
-
-
def test_undefined_symbolic_var_raises_error():
"""An undefined symbolic variable in an error
diff --git a/tests/python/relax/test_tvmscript_printer_relax.py
b/tests/python/relax/test_tvmscript_printer_relax.py
index 012aac8c55..9b4fa20678 100644
--- a/tests/python/relax/test_tvmscript_printer_relax.py
+++ b/tests/python/relax/test_tvmscript_printer_relax.py
@@ -177,8 +177,8 @@ def test_object_ty():
def test_prim_ty():
- obj = relax.PrimType("float32")
- _assert_print(obj, 'R.Prim("float32")')
+ obj = tvm.ir.PrimType("float32")
+ _assert_print(obj, "T.float32")
def test_shape_ty_0():
@@ -223,7 +223,7 @@ def test_tuple_ty_empty():
def test_tuple_ty():
obj = relax.TupleType(
[
- relax.PrimType("float32"),
+ tvm.ir.PrimType("float32"),
relax.ObjectType(),
relax.ShapeType([1, tirx.Var("a", "int64"), 3]),
]
@@ -231,7 +231,7 @@ def test_tuple_ty():
_assert_print(
obj._relax_script(), # pylint: disable=protected-access
"""
-R.Tuple(R.Prim("float32"), R.Object, R.Shape([1, a, 3]))
+R.Tuple(T.float32, R.Object, R.Shape([1, a, 3]))
""",
)
@@ -239,10 +239,10 @@ R.Tuple(R.Prim("float32"), R.Object, R.Shape([1, a, 3]))
def test_func_ty():
obj = relax.FuncType(
params=[
- relax.PrimType("float32"),
+ tvm.ir.PrimType("float32"),
relax.ObjectType(),
relax.ShapeType([1, tirx.Var("a", "int64"), 3]),
- relax.PrimType(value=tirx.Var("b", "int64")),
+ tvm.ir.PrimType("int64"),
],
ret=relax.TensorType(
shape=relax.ShapeExpr([1, 2, 3]),
@@ -252,8 +252,7 @@ def test_func_ty():
_assert_print(
obj,
"a = T.int64()\n"
- "b = T.int64()\n"
- 'R.Callable((R.Prim("float32"), R.Object, R.Shape([1, a, 3]),
R.Prim(value=b)), '
+ "R.Callable((T.float32, R.Object, R.Shape([1, a, 3]), T.int64), "
'R.Tensor((1, 2, 3), dtype="float32"), True)',
)
diff --git a/tests/python/relax/test_type.py b/tests/python/relax/test_type.py
index 1679048df4..8490b1fbbd 100644
--- a/tests/python/relax/test_type.py
+++ b/tests/python/relax/test_type.py
@@ -67,9 +67,9 @@ def test_dyn_tensor_type():
def test_prim_ty():
- s0 = rx.PrimType("float32")
- s1 = rx.PrimType("float32")
- s2 = rx.PrimType("int32")
+ s0 = tvm.ir.PrimType("float32")
+ s1 = tvm.ir.PrimType("float32")
+ s2 = tvm.ir.PrimType("int32")
_check_equal(s0, s1)
@@ -79,7 +79,7 @@ def test_prim_ty():
assert s0 == s1
assert s0 != s2
- assert isinstance(s0, rx.PrimType)
+ assert isinstance(s0, tvm.ir.PrimType)
_check_json_roundtrip(s0)
_check_json_roundtrip(s1)
@@ -88,23 +88,7 @@ def test_prim_ty():
# wrong API constructors
with pytest.raises((RuntimeError, TypeError)):
- rx.PrimType([1])
-
-
-def test_prim_ty_with_expr():
- n = tirx.Var("n", "int64")
- ty = rx.PrimType(value=n + 1)
-
- _check_equal(ty, rx.PrimType(value=n + 1))
- assert not tvm_ffi.structural_equal(ty, rx.PrimType(dtype=n.dtype))
-
- # can turn into str
- str(ty)
-
- assert isinstance(ty, rx.PrimType)
- _check_json_roundtrip(ty)
-
- assert ty.dtype == "int64"
+ tvm.ir.PrimType([1])
def test_shape_ty():
diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py
index c2afcaf21b..1a15484ded 100644
--- a/tests/python/relax/test_utils.py
+++ b/tests/python/relax/test_utils.py
@@ -171,6 +171,7 @@ def test_structural_equal_of_call_nodes():
tvm.ir.assert_structural_equal(uses_same_object_twice,
uses_two_different_objects)
[email protected](reason="value-bearing R.Prim annotations were removed")
def test_structural_equal_with_recursive_lambda_function():
"""A recursive lambda function may be checked for structural equality
@@ -263,10 +264,9 @@ def
test_structural_equal_with_distinct_recursive_lambda_function():
"blocks[0]",
"bindings[0]",
"value",
- "true_branch",
- "body",
- "value",
+ "cond",
"value",
+ "a",
]
with pytest.raises(ValueError, match=re.escape(".".join(mismatch_path))):
diff --git a/tests/python/relax/test_vm_build.py
b/tests/python/relax/test_vm_build.py
index d555f0d5a9..d04f59379f 100644
--- a/tests/python/relax/test_vm_build.py
+++ b/tests/python/relax/test_vm_build.py
@@ -556,6 +556,7 @@ def test_vm_relax_symbolic_shape_tuple(exec_mode):
func(R.prim_value(2))
[email protected](reason="value-bearing R.Prim annotations are erased to
dtype-only PrimType")
def test_vm_relax_symbolic_prim_value(exec_mode):
@I.ir_module(s_tir=True)
class mod:
@@ -576,6 +577,7 @@ def test_vm_relax_symbolic_prim_value(exec_mode):
func(Shape([2]))
[email protected](reason="value-bearing R.Prim annotations are erased to
dtype-only PrimType")
def test_vm_relax_multiple_symbolic_prim_value(exec_mode):
"""Like test_vm_relax_symbolic_prim_value, but with multiple variables"""
diff --git a/tests/python/tirx-base/test_tir_specialize.py
b/tests/python/tirx-base/test_tir_specialize.py
index cecaf07ab8..0529bd90a4 100644
--- a/tests/python/tirx-base/test_tir_specialize.py
+++ b/tests/python/tirx-base/test_tir_specialize.py
@@ -347,10 +347,10 @@ def test_specialization_updates_ty():
def expected() -> T.int32:
T.ret(50)
- ty_before = tvm.relax.FuncType([tvm.relax.PrimType("int32")],
tvm.relax.PrimType("int32"))
+ ty_before = tvm.relax.FuncType([tvm.ir.PrimType("int32")],
tvm.ir.PrimType("int32"))
tvm.ir.assert_structural_equal(before.ty, ty_before)
- ty_expected = tvm.relax.FuncType([], tvm.relax.PrimType("int32"))
+ ty_expected = tvm.relax.FuncType([], tvm.ir.PrimType("int32"))
tvm.ir.assert_structural_equal(expected.ty, ty_expected)
n = before.params[0]
diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py
b/tests/python/tvmscript/test_tvmscript_parser_tir.py
index 5972decaa5..9c1e26459d 100644
--- a/tests/python/tvmscript/test_tvmscript_parser_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py
@@ -413,10 +413,10 @@ def test_inferred_ty_with_prim_args():
expected = tvm.relax.FuncType(
[
- tvm.relax.PrimType("int32"),
- tvm.relax.PrimType("int32"),
+ tvm.ir.PrimType("int32"),
+ tvm.ir.PrimType("int32"),
],
- tvm.relax.PrimType("int32"),
+ tvm.ir.PrimType("int32"),
purity=True,
)
tvm.ir.assert_structural_equal(func.ty, expected)
@@ -434,7 +434,7 @@ def test_inferred_ty_with_buffer_args():
tvm.relax.TensorType([16, 16], "float32"),
tvm.relax.TensorType([256], "int32"),
],
- tvm.relax.PrimType("float32"),
+ tvm.ir.PrimType("float32"),
purity=True,
)
tvm.ir.assert_structural_equal(func.ty, expected)
@@ -460,7 +460,7 @@ def test_inferred_ty_with_internal_allocation():
[
tvm.relax.TensorType([16, 16], "float32"),
],
- tvm.relax.PrimType("float32"),
+ tvm.ir.PrimType("float32"),
purity=True,
)
tvm.ir.assert_structural_equal(func.ty, expected)