This is an automated email from the ASF dual-hosted git repository.
spectrometerHBH pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9edd5bd958 [REFACTOR] Remove tvm.runtime.packed_func and container
shims; route via tvm_ffi (#19442)
9edd5bd958 is described below
commit 9edd5bd958ae81d3f77ddb2f53099ca0b31f5e72
Author: Tianqi Chen <[email protected]>
AuthorDate: Sat Apr 25 11:02:08 2026 -0400
[REFACTOR] Remove tvm.runtime.packed_func and container shims; route via
tvm_ffi (#19442)
## Summary
- Delete the three Python shim modules that re-exported tvm-ffi types
under `tvm.runtime` / `tvm.ir`:
`python/tvm/runtime/packed_func.py`, `python/tvm/runtime/container.py`,
`python/tvm/ir/container.py`.
- Drop the matching re-exports from `tvm.runtime`, `tvm.ir`, and `tvm`
package init files, so
`tvm.runtime.PackedFunc`, `tvm.runtime.ShapeTuple`,
`tvm.runtime.String`, `tvm.ir.Array`,
`tvm.ir.Map`, and `tvm.container.Array` no longer exist.
- Migrate every productive caller, test, and tutorial to the canonical
names: `tvm_ffi.Function`,
`tvm_ffi.Shape`, `tvm_ffi.core.String`, `tvm_ffi.Array`, and
`tvm_ffi.Map`.
## Test plan
- [x] `pytest tests/python/all-platform-minimal-test` (75 passed, 77
skipped)
- [x] `pytest tests/python/runtime/test_runtime_container.py
tests/python/all-platform-minimal-test/test_runtime_packed_func.py` (20
passed)
- [x] `pytest tests/python/ir/test_node_reflection.py
tests/python/ir/test_container_structural_equal.py` (32 passed)
- [x] `pytest tests/python/relax/test_vm_build.py
tests/python/relax/test_vm_execbuilder.py
tests/python/relax/test_vm_codegen_only.py` (125 passed, 2 xfailed)
- [x] `pytest tests/python/relax/test_runtime_builtin.py
tests/python/relax/test_op_misc.py` (19 passed)
- [x] `pytest tests/python/target/test_target_target.py` (37 passed, 3
skipped)
- [x] `pre-commit run` clean on touched files
---
docs/arch/index.rst | 13 ++++++-----
docs/how_to/tutorials/cross_compilation_and_rpc.py | 3 ++-
.../how_to/tutorials/export_and_load_executable.py | 6 ++++--
docs/how_to/tutorials/optimize_llm.py | 17 ++++++++-------
python/tvm/__init__.py | 1 -
python/tvm/contrib/cutlass/gen_tensor_op.py | 2 +-
python/tvm/exec/disco_worker.py | 11 +++++-----
python/tvm/ir/__init__.py | 1 -
python/tvm/ir/container.py | 21 ------------------
python/tvm/ir/supply.py | 3 +--
python/tvm/relax/base_py_module.py | 9 ++++----
python/tvm/relax/block_builder.py | 2 +-
python/tvm/relax/distributed/global_info.py | 14 ++++++------
python/tvm/relax/dpl/pattern.py | 6 +++---
python/tvm/relax/exec_builder.py | 6 +++---
python/tvm/relax/expr.py | 7 +++---
python/tvm/relax/frontend/nn/subroutine.py | 4 +++-
python/tvm/relax/frontend/nn/torch.py | 8 +++----
python/tvm/relax/frontend/torch/dynamo.py | 4 +++-
python/tvm/relax/op/base.py | 10 +++++----
python/tvm/relax/struct_info.py | 3 ++-
python/tvm/relax/training/optimizer.py | 5 +++--
python/tvm/relax/transform/transform.py | 2 +-
python/tvm/relax/utils.py | 5 +++--
python/tvm/runtime/__init__.py | 2 --
python/tvm/runtime/_tensor.py | 12 +++++------
python/tvm/runtime/container.py | 22 -------------------
python/tvm/runtime/disco/process_pool.py | 6 ++----
python/tvm/runtime/disco/session.py | 9 ++++----
python/tvm/runtime/executable.py | 8 ++++---
python/tvm/runtime/packed_func.py | 23 --------------------
python/tvm/runtime/vm.py | 22 +++++++++----------
python/tvm/s_tir/meta_schedule/arg_info.py | 17 ++++++---------
python/tvm/s_tir/meta_schedule/utils.py | 13 ++++++-----
python/tvm/s_tir/schedule/trace.py | 3 ++-
python/tvm/script/ir_builder/tirx/ir.py | 7 ++++--
python/tvm/script/parser/tirx/parser.py | 4 +++-
python/tvm/target/codegen.py | 3 ++-
python/tvm/target/target.py | 4 ++--
python/tvm/te/operation.py | 3 ++-
python/tvm/tirx/op.py | 5 +++--
python/tvm/topi/image/resize.py | 6 +++---
python/tvm/topi/nn/upsampling.py | 4 ++--
.../test_runtime_packed_func.py | 7 +++---
tests/python/arith/test_arith_domain_touched.py | 4 +++-
tests/python/contrib/test_popen_pool.py | 4 +++-
tests/python/disco/test_custom_allreduce.py | 7 +++---
tests/python/disco/test_loader.py | 13 ++++++-----
tests/python/disco/test_nvshmem.py | 6 +++---
tests/python/disco/test_session.py | 7 +++---
tests/python/ir/test_container_structural_equal.py | 5 +++--
tests/python/ir/test_node_reflection.py | 7 +++---
.../test_runtime_builtin_kv_cache_transfer.py | 19 ++++++++--------
...est_runtime_builtin_kv_cache_transfer_kernel.py | 11 +++++-----
tests/python/relax/test_contrib_vllm.py | 3 ++-
tests/python/relax/test_frontend_onnx.py | 12 ++++++++---
tests/python/relax/test_op_gradient_numeric.py | 3 ++-
tests/python/relax/test_op_misc.py | 4 +++-
tests/python/relax/test_pipeline.py | 5 +++--
tests/python/relax/test_relax_operators.py | 25 +++++++++-------------
tests/python/relax/test_runtime_builtin.py | 19 ++++++++--------
...runtime_builtin_paged_attention_kv_cache_cpu.py | 19 ++++++++--------
..._builtin_paged_attention_kv_cache_flashinfer.py | 11 +++++-----
...ltin_paged_attention_kv_cache_mla_flashinfer.py | 11 +++++-----
...ime_builtin_paged_attention_kv_cache_mla_tir.py | 11 +++++-----
...runtime_builtin_paged_attention_kv_cache_tir.py | 19 ++++++++--------
.../python/relax/test_runtime_builtin_rnn_state.py | 10 ++++-----
.../relax/test_training_optimizer_numeric.py | 3 ++-
tests/python/relax/test_vm_build.py | 19 ++++++++--------
tests/python/relax/test_vm_codegen_only.py | 3 ++-
tests/python/relax/test_vm_execbuilder.py | 7 +++---
tests/python/runtime/test_runtime_container.py | 10 ++++-----
tests/python/runtime/test_runtime_rpc.py | 7 +++---
tests/python/target/test_target_target.py | 3 ++-
.../python/tvmscript/test_tvmscript_printer_doc.py | 2 +-
75 files changed, 302 insertions(+), 330 deletions(-)
diff --git a/docs/arch/index.rst b/docs/arch/index.rst
index ba4db56422..381abd1f75 100644
--- a/docs/arch/index.rst
+++ b/docs/arch/index.rst
@@ -151,19 +151,19 @@ The main goal of TVM's runtime is to provide a minimal
API for loading and execu
# Example runtime execution program in python, with type annotated
mod: tvm.runtime.Module = tvm.runtime.load_module("compiled_artifact.so")
arr: tvm.runtime.Tensor = tvm.runtime.tensor([1, 2, 3], device=tvm.cuda(0))
- fun: tvm.runtime.PackedFunc = mod["addone"]
+ fun: tvm_ffi.Function = mod["addone"]
fun(arr)
print(arr.numpy())
-:py:class:`tvm.runtime.Module` encapsulates the result of compilation. A
runtime.Module contains a GetFunction method to obtain PackedFuncs by name.
+:py:class:`tvm.runtime.Module` encapsulates the result of compilation. A
runtime.Module contains a GetFunction method to obtain
:py:class:`tvm_ffi.Function` instances by name.
-:py:class:`tvm.runtime.PackedFunc` is a type-erased function interface for
both the generated functions. A runtime.PackedFunc can take arguments and
return values with the
-following types: POD types(int, float), string, runtime.PackedFunc,
runtime.Module, runtime.Tensor, and other sub-classes of runtime.Object.
+:py:class:`tvm_ffi.Function` is a type-erased function interface for both the
generated functions. A tvm_ffi.Function can take arguments and return values
with the
+following types: POD types(int, float), string, tvm_ffi.Function,
runtime.Module, runtime.Tensor, and other sub-classes of runtime.Object.
-:py:class:`tvm.runtime.Module` and :py:class:`tvm.runtime.PackedFunc` are
powerful mechanisms to modularize the runtime. For example, to get the above
`addone` function on CUDA, we can use LLVM to generate the host-side code to
compute the launching parameters(e.g. size of the thread groups) and then call
into another PackedFunc from a CUDAModule that is backed by the CUDA driver
API. The same mechanism can be used for OpenCL kernels.
+:py:class:`tvm.runtime.Module` and :py:class:`tvm_ffi.Function` are powerful
mechanisms to modularize the runtime. For example, to get the above `addone`
function on CUDA, we can use LLVM to generate the host-side code to compute the
launching parameters(e.g. size of the thread groups) and then call into another
tvm_ffi.Function from a CUDAModule that is backed by the CUDA driver API. The
same mechanism can be used for OpenCL kernels.
-The above example only deals with a simple `addone` function. The code snippet
below gives an example of an end-to-end model execution using the Relax Virtual
Machine, which is built on the same runtime.Module and runtime.PackedFunc
interface:
+The above example only deals with a simple `addone` function. The code snippet
below gives an example of an end-to-end model execution using the Relax Virtual
Machine, which is built on the same runtime.Module and tvm_ffi.Function
interface:
.. code-block:: python
@@ -434,4 +434,3 @@ and then integrate it into the IRModule.
While possible to construct operators directly via TensorIR or tensor
expressions (TE) for each use case, it is tedious to do so.
`topi` (Tensor operator inventory) provides a set of pre-defined operators
defined by numpy and found in common deep learning workloads.
-
diff --git a/docs/how_to/tutorials/cross_compilation_and_rpc.py
b/docs/how_to/tutorials/cross_compilation_and_rpc.py
index f573dfc7ce..7ef45c38b0 100644
--- a/docs/how_to/tutorials/cross_compilation_and_rpc.py
+++ b/docs/how_to/tutorials/cross_compilation_and_rpc.py
@@ -97,6 +97,7 @@ and the Firefly-RK3399 for an OpenCL example.
# Here we will declare a simple kernel on the local machine:
import numpy as np
+import tvm_ffi
import tvm
from tvm import rpc, te
@@ -481,7 +482,7 @@ def run_pytorch_model_via_rpc():
output = vm.get_outputs("main")
# Extract result (handle both tuple and single tensor outputs)
- if isinstance(output, tvm.ir.Array) and len(output) > 0:
+ if isinstance(output, tvm_ffi.Array) and len(output) > 0:
result = output[0]
else:
result = output
diff --git a/docs/how_to/tutorials/export_and_load_executable.py
b/docs/how_to/tutorials/export_and_load_executable.py
index 7378b3c71c..0b206267bb 100644
--- a/docs/how_to/tutorials/export_and_load_executable.py
+++ b/docs/how_to/tutorials/export_and_load_executable.py
@@ -62,6 +62,8 @@ except ImportError: # pragma: no cover
# model is exported to a :py:class:`torch.export.ExportedProgram` and then
# translated into a Relax ``IRModule``.
+import tvm_ffi
+
import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_exported_program
@@ -174,7 +176,7 @@ if RUN_EXAMPLE:
# TVM returns Array objects for tuple outputs, access via indexing.
# For models imported from PyTorch, outputs are typically tuples (even for
single outputs).
# For ONNX models, outputs may be a single Tensor directly.
- if isinstance(tvm_output, tvm.ir.Array) and len(tvm_output) > 0:
+ if isinstance(tvm_output, tvm_ffi.Array) and len(tvm_output) > 0:
result_tensor = tvm_output[0]
else:
result_tensor = tvm_output
@@ -263,7 +265,7 @@ if RUN_EXAMPLE:
#
# # Step 6: Extract result (output may be tuple or single Tensor)
# # PyTorch models typically return tuples, ONNX models may return a single
Tensor
-# if isinstance(output, tvm.ir.Array) and len(output) > 0:
+# if isinstance(output, tvm_ffi.Array) and len(output) > 0:
# result_tensor = output[0]
# else:
# result_tensor = output
diff --git a/docs/how_to/tutorials/optimize_llm.py
b/docs/how_to/tutorials/optimize_llm.py
index 58727923a5..0c20f30e40 100644
--- a/docs/how_to/tutorials/optimize_llm.py
+++ b/docs/how_to/tutorials/optimize_llm.py
@@ -61,13 +61,14 @@ import os
from pathlib import Path
from pprint import pprint
+from tvm_ffi import Shape
+
import tvm
from tvm import relax, te, tirx
from tvm.relax import register_pipeline
from tvm.relax.frontend import nn
from tvm.relax.frontend.nn import Tensor, op
from tvm.relax.frontend.nn.llm.kv_cache import PagedKVCache, TIRPagedKVCache
-from tvm.runtime import ShapeTuple
from tvm.s_tir import dlight
######################################################################
@@ -534,10 +535,10 @@ if not IS_IN_CI:
if not IS_IN_CI:
kv_cache = vm["create_tir_paged_kv_cache"](
- ShapeTuple([1]), # max_batch_size=1
- ShapeTuple([2048]), # max_total_seq_len=2048
- ShapeTuple([2048]), # prefill_chunk_size=2048
- ShapeTuple([16]), # page_size=16
+ Shape([1]), # max_batch_size=1
+ Shape([2048]), # max_total_seq_len=2048
+ Shape([2048]), # prefill_chunk_size=2048
+ Shape([16]), # page_size=16
)
@@ -553,7 +554,7 @@ nd_view_func = tvm.get_global_func("vm.builtin.reshape")
def embed(tokens, params):
_embed = vm["embed"](tokens, params)
# Reshape hidden from [seq_len, hidden_size] to [1, seq_len, hidden_size]
- _embed = nd_view_func(_embed, ShapeTuple([1, _embed.shape[0],
_embed.shape[1]]))
+ _embed = nd_view_func(_embed, Shape([1, _embed.shape[0], _embed.shape[1]]))
return _embed
@@ -575,7 +576,7 @@ if not IS_IN_CI:
seq_id = 0
add_sequence_func(kv_cache, seq_id)
hidden_states = embed(tokens, params)
- begin_forward_func(kv_cache, ShapeTuple([seq_id]), ShapeTuple([input_len]))
+ begin_forward_func(kv_cache, Shape([seq_id]), Shape([input_len]))
logits, kv_cache = vm["prefill"](hidden_states, kv_cache, params)
end_forward_func(kv_cache)
@@ -611,7 +612,7 @@ if not IS_IN_CI:
while last_token != tokenizer.eos_token_id:
tokens = tvm.runtime.tensor(np.array([last_token]).astype("int32"),
device=dev)
hidden_states = embed(tokens, params)
- begin_forward_func(kv_cache, ShapeTuple([seq_id]), ShapeTuple([1]))
+ begin_forward_func(kv_cache, Shape([seq_id]), Shape([1]))
logits, kv_cache = vm["decode"](hidden_states, kv_cache, params)
end_forward_func(kv_cache)
diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py
index dfcc8e20ab..7dca7b36fb 100644
--- a/python/tvm/__init__.py
+++ b/python/tvm/__init__.py
@@ -42,7 +42,6 @@ from . import error
from .ir import IRModule
from .ir import transform
from .ir import instrument
-from .ir import container
from . import ir
# tvm.tirx
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index eab4c46c5b..477c1ee449 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -473,7 +473,7 @@ def instantiate_template(func_name, annotations, func_args):
func_name: str
A string to identify the type of the kernel (dense/matmul,
batched_matmul, or conv2d).
- annotations: container.Map
+ annotations: tvm_ffi.Map
Key and value pairs annotated during kernel selection.
func_args: list
diff --git a/python/tvm/exec/disco_worker.py b/python/tvm/exec/disco_worker.py
index d247a3c2e9..2052ae7b04 100644
--- a/python/tvm/exec/disco_worker.py
+++ b/python/tvm/exec/disco_worker.py
@@ -22,10 +22,11 @@ import os
import sys
from collections.abc import Callable
-from tvm_ffi import get_global_func, register_global_func
+from tvm_ffi import Shape, get_global_func, register_global_func
+from tvm_ffi.core import String
import tvm
-from tvm.runtime import ShapeTuple, String, Tensor, tensor
+from tvm.runtime import Tensor, tensor
@register_global_func("tests.disco.add_one", override=True)
@@ -55,9 +56,9 @@ def _str_obj_func(x: str):
@register_global_func("tests.disco.shape_tuple", override=True)
-def _shape_tuple_func(x: ShapeTuple):
- assert isinstance(x, ShapeTuple)
- return ShapeTuple(list(x) + [4, 5])
+def _shape_tuple_func(x: Shape):
+ assert isinstance(x, Shape)
+ return Shape(list(x) + [4, 5])
@register_global_func("tests.disco.test_callback", override=True)
diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py
index 0b4398944e..a63829ef40 100644
--- a/python/tvm/ir/__init__.py
+++ b/python/tvm/ir/__init__.py
@@ -32,7 +32,6 @@ from .base import (
structural_equal,
structural_hash,
)
-from .container import Array, Map
from .expr import BaseExpr, GlobalVar, PrimExpr, Range, RelaxExpr
from .function import BaseFunc, CallingConv
from .global_info import GlobalInfo, DummyGlobalInfo, VDevice
diff --git a/python/tvm/ir/container.py b/python/tvm/ir/container.py
deleted file mode 100644
index eae012b879..0000000000
--- a/python/tvm/ir/container.py
+++ /dev/null
@@ -1,21 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Additional container data structures used across IR variants."""
-
-from tvm_ffi import Array, Map
-
-__all__ = ["Array", "Map"]
diff --git a/python/tvm/ir/supply.py b/python/tvm/ir/supply.py
index 8e73ec4d12..183e20f257 100644
--- a/python/tvm/ir/supply.py
+++ b/python/tvm/ir/supply.py
@@ -18,7 +18,6 @@
import tvm_ffi
-import tvm
from tvm import IRModule, Object
from . import _ffi_api
@@ -100,7 +99,7 @@ class GlobalVarSupply(Object):
self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_NameSupply,
name_supply)
elif isinstance(value, NameSupply):
self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_NameSupply, value)
- elif isinstance(value, list | tvm.container.Array):
+ elif isinstance(value, list | tvm_ffi.Array):
self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_IRModules, value)
elif isinstance(value, IRModule):
self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_IRModule, value)
diff --git a/python/tvm/relax/base_py_module.py
b/python/tvm/relax/base_py_module.py
index 67b6761633..1834c25c31 100644
--- a/python/tvm/relax/base_py_module.py
+++ b/python/tvm/relax/base_py_module.py
@@ -22,11 +22,12 @@ import os
from typing import Any, Optional, Union
import numpy as np
+from tvm_ffi import Function
import tvm
from tvm import relax, tirx
from tvm.ir import IRModule
-from tvm.runtime import Device, PackedFunc, Tensor
+from tvm.runtime import Device, Tensor
from tvm.target import Target
try:
@@ -100,8 +101,8 @@ class BasePyModule:
self.__getattr__ = _getattr_python_function
- self.compiled_tir_funcs: dict[str, PackedFunc] = {}
- self.extern_funcs: dict[str, PackedFunc] = {}
+ self.compiled_tir_funcs: dict[str, Function] = {}
+ self.extern_funcs: dict[str, Function] = {}
self.tir_func_names: list[str] = []
self.relax_func_names: list[str] = []
self.relax_vm: relax.VirtualMachine | None = None
@@ -450,7 +451,7 @@ class BasePyModule:
numpy_array = tvm_tensor.numpy()
return torch.from_numpy(numpy_array)
- def get_function(self, name: str) -> PackedFunc | None:
+ def get_function(self, name: str) -> Function | None:
"""Get a compiled function by name."""
if name in self.compiled_tir_funcs:
return self.compiled_tir_funcs[name]
diff --git a/python/tvm/relax/block_builder.py
b/python/tvm/relax/block_builder.py
index 13bea3180c..7c1fed673e 100644
--- a/python/tvm/relax/block_builder.py
+++ b/python/tvm/relax/block_builder.py
@@ -297,7 +297,7 @@ class BlockBuilder(Object):
called with python `list` or `tuple` objects. These objects
should be converted to `relax.Tuple` prior to calling an FFI
function, as they would otherwise be converted to
- `tvm.runtime.Array`. In addition, any nested tuple objects
+ `tvm_ffi.Array`. In addition, any nested tuple objects
should be converted.
"""
if isinstance(expr, list | tuple):
diff --git a/python/tvm/relax/distributed/global_info.py
b/python/tvm/relax/distributed/global_info.py
index 576ab0c70e..125e6652f5 100644
--- a/python/tvm/relax/distributed/global_info.py
+++ b/python/tvm/relax/distributed/global_info.py
@@ -18,10 +18,10 @@
"""Global Info Data structures for distributed tensor."""
import tvm_ffi
+from tvm_ffi import Shape
from tvm.ir import Range
from tvm.ir.global_info import GlobalInfo
-from tvm.runtime import ShapeTuple
from . import _ffi_api as ffi
@@ -33,15 +33,15 @@ class DeviceMesh(GlobalInfo):
Parameters
----------
- shape: Union[ShapeTuple, List[int], Tuple[int]]
+ shape: Union[Shape, List[int], Tuple[int]]
Logical shape of device mesh
device_ids: Union[List[int], Range]
Represents the device id in the mesh
"""
- def __init__(self, shape: ShapeTuple | list[int] | tuple[int], device_ids:
list[int] | Range):
- if isinstance(shape, list | tuple):
- shape = ShapeTuple(shape)
+ def __init__(self, shape: Shape | list[int] | tuple[int], device_ids:
list[int] | Range):
+ if not isinstance(shape, Shape):
+ shape = Shape(shape)
device_range = None
if isinstance(device_ids, Range):
device_range = device_ids
@@ -49,11 +49,11 @@ class DeviceMesh(GlobalInfo):
self.__init_handle_by_constructor__(ffi.DeviceMesh, shape, device_ids,
device_range) # type: ignore
-def device_mesh(shape: ShapeTuple, device_ids: list[int] | Range) ->
DeviceMesh:
+def device_mesh(shape: Shape, device_ids: list[int] | Range) -> DeviceMesh:
"""Create a device mesh expression.
Parameters
----------
- shape : ShapeTuple
+ shape : Shape
The shape of the device mesh.
device_ids: Union[List[int], Range]
Represents the device id in the mesh
diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py
index fefff21432..89feac1bce 100644
--- a/python/tvm/relax/dpl/pattern.py
+++ b/python/tvm/relax/dpl/pattern.py
@@ -23,9 +23,9 @@
from typing import Union
import tvm_ffi
+from tvm_ffi import Array
import tvm
-from tvm.ir.container import Array
from tvm.ir.expr import PrimExpr
from tvm.ir.op import Op
@@ -848,7 +848,7 @@ def is_shape(shape: list[tvm.ir.PrimExpr]) ->
"PrimArrPattern":
Raises
------
ValueError
- If the argument shape is not a list/tuple/tvm.ir.Array
+ If the argument shape is not a list/tuple/tvm_ffi.Array
Note
----
@@ -856,7 +856,7 @@ def is_shape(shape: list[tvm.ir.PrimExpr]) ->
"PrimArrPattern":
puts assumptions on the shape of the tensor matched by pattern p. While
is_shape directly matches the shape (an array of PrimExpr).
"""
- if not isinstance(shape, list | tuple | tvm.ir.Array):
+ if not isinstance(shape, list | tuple | Array):
raise ValueError("is_shape takes a list or tuple as input.")
return PrimArrPattern(shape)
diff --git a/python/tvm/relax/exec_builder.py b/python/tvm/relax/exec_builder.py
index 65ca991071..e6cdd8ac81 100644
--- a/python/tvm/relax/exec_builder.py
+++ b/python/tvm/relax/exec_builder.py
@@ -21,9 +21,9 @@
from enum import IntEnum
import tvm_ffi
+from tvm_ffi import Shape
import tvm
-from tvm.runtime.container import ShapeTuple
from . import _ffi_api
from .vm_build import VMExecutable
@@ -121,10 +121,10 @@ class ExecBuilder(tvm_ffi.core.Object):
if args is not None:
for arg in args:
if isinstance(arg, tuple):
- shape_tuple = ShapeTuple(arg)
+ shape_tuple = Shape(arg)
new_arg = self.convert_constant(shape_tuple)
args_.append(new_arg)
- elif isinstance(arg, tvm.runtime.Tensor | tvm.DataType |
ShapeTuple):
+ elif isinstance(arg, tvm.runtime.Tensor | tvm.DataType |
Shape):
new_arg = self.convert_constant(arg)
args_.append(new_arg)
else:
diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py
index 7c7bcc2aea..5a75e43b12 100644
--- a/python/tvm/relax/expr.py
+++ b/python/tvm/relax/expr.py
@@ -24,6 +24,7 @@ from typing import Any, Optional, Union
import numpy as _np # type: ignore
import tvm_ffi
+from tvm_ffi.core import String
import tvm.ir
import tvm.relax
@@ -32,7 +33,7 @@ from tvm import DataType
from tvm.runtime import Object
from ..ir import BaseFunc, Node, Span
-from ..runtime import Scriptable, String
+from ..runtime import Scriptable
from ..tirx import PrimExpr
from . import _ffi_api
@@ -685,7 +686,7 @@ class ShapeExpr(ExprWithOp):
Parameters
----------
- values: Union[List[PrimExpr], typing.Tuple[PrimExpr, ...], tvm.ir.Array]
+ values: Union[List[PrimExpr], typing.Tuple[PrimExpr, ...], tvm_ffi.Array]
The values of the shape expression.
span: Optional[Span]
@@ -697,7 +698,7 @@ class ShapeExpr(ExprWithOp):
def __init__(
self,
- values: list[PrimExpr] | tuple[PrimExpr, ...] | tvm.ir.Array,
+ values: list[PrimExpr] | tuple[PrimExpr, ...] | tvm_ffi.Array,
span: Span | None = None,
) -> None:
self.__init_handle_by_constructor__(_ffi_api.ShapeExpr, values, span)
# type: ignore
diff --git a/python/tvm/relax/frontend/nn/subroutine.py
b/python/tvm/relax/frontend/nn/subroutine.py
index 0716f19a92..abd94b19cb 100644
--- a/python/tvm/relax/frontend/nn/subroutine.py
+++ b/python/tvm/relax/frontend/nn/subroutine.py
@@ -24,6 +24,8 @@ import inspect
import re
import typing
+import tvm_ffi
+
from tvm import ir, relax
from tvm.ir import structural_equal
from tvm.relax.frontend import nn
@@ -58,7 +60,7 @@ def _get_struct_info(arg):
return arg.struct_info_
elif isinstance(arg, nn.Tensor):
return arg._expr.struct_info_
- elif isinstance(arg, tuple | list | ir.Array):
+ elif isinstance(arg, tuple | list | tvm_ffi.Array):
return relax.TupleStructInfo([_get_struct_info(field) for field in
arg])
else:
raise TypeError(f"Cannot find struct info for {arg} of type
{type(arg)}")
diff --git a/python/tvm/relax/frontend/nn/torch.py
b/python/tvm/relax/frontend/nn/torch.py
index 45dc1a74a9..0ef8baa001 100644
--- a/python/tvm/relax/frontend/nn/torch.py
+++ b/python/tvm/relax/frontend/nn/torch.py
@@ -21,9 +21,9 @@ from collections.abc import Callable
from typing import Any
import torch
+from tvm_ffi import Array, Shape
-from tvm.ir import Array
-from tvm.runtime import ShapeTuple, Tensor, _tensor
+from tvm.runtime import Tensor, _tensor
from tvm.runtime.vm import VirtualMachine
from . import core
@@ -91,7 +91,7 @@ def _tvm_to_torch(arg):
return [_tvm_to_torch(i) for i in arg]
if isinstance(arg, _tensor.Tensor):
return torch.utils.dlpack.from_dlpack(arg)
- if isinstance(arg, ShapeTuple):
+ if isinstance(arg, Shape):
return list(arg)
raise TypeError(f"Unsupported argument type: {type(arg)}")
@@ -108,7 +108,7 @@ def _torch_to_tvm(arg_name, arg_spec, arg_torch):
raise TypeError(
f"Expected argument `{arg_name}` to be `int`, but got
{type(arg_torch)}"
)
- return ShapeTuple([arg_torch])
+ return Shape([arg_torch])
if isinstance(arg_spec, _spec.Tuple):
return [
_torch_to_tvm(f"{arg_name}[{i}]", x, arg_torch[i])
diff --git a/python/tvm/relax/frontend/torch/dynamo.py
b/python/tvm/relax/frontend/torch/dynamo.py
index 9a0b6e1c58..a490054aee 100644
--- a/python/tvm/relax/frontend/torch/dynamo.py
+++ b/python/tvm/relax/frontend/torch/dynamo.py
@@ -22,6 +22,8 @@
import functools
+import tvm_ffi
+
import tvm
from tvm.relax import build as relax_build
@@ -58,7 +60,7 @@ def relax_dynamo(pipeline: tvm.transform.Pass | None = None):
"""A helper function to transfer a Tensor to torch.tensor."""
if isinstance(nd_tensor, tvm.runtime.Tensor):
return torch.from_numpy(nd_tensor.numpy())
- elif isinstance(nd_tensor, tvm.ir.Array):
+ elif isinstance(nd_tensor, tvm_ffi.Array):
return tuple(to_torch_tensor(x) for x in nd_tensor)
else:
raise ValueError(f"Unsupported type {type(nd_tensor)}")
diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py
index d257a8c8e6..453ca2a3d6 100644
--- a/python/tvm/relax/op/base.py
+++ b/python/tvm/relax/op/base.py
@@ -19,6 +19,8 @@
from collections.abc import Callable
+import tvm_ffi
+
import tvm
import tvm.runtime
from tvm.runtime import ObjectConvertible
@@ -450,20 +452,20 @@ def render_object(val: tvm.Object) -> str:
"""
if isinstance(val, tvm.runtime.Tensor):
return str(val)
- if isinstance(val, tvm.ir.Array):
+ if isinstance(val, tvm_ffi.Array):
fields = ", ".join([render_object(val[i]) for i in range(len(val))])
return f"({fields})"
return str(val)
@tvm.register_global_func("relax.run.shape_to_tensor")
-def relax_shape_to_tensor(shape_tuple: tvm.runtime.ShapeTuple) ->
tvm.runtime.Tensor:
+def relax_shape_to_tensor(shape_tuple: tvm_ffi.Shape) -> tvm.runtime.Tensor:
"""
- Takes a ShapeTuple and convert it to Tensor.
+ Takes a Shape and convert it to Tensor.
Parameters
----------
- shape_tuple: tvm.runtime.ShapeTuple
+ shape_tuple: tvm_ffi.Shape
Shape tuple that we want to convert to Tensor at runtime
"""
return tvm.runtime.tensor([int(v) for v in shape_tuple])
diff --git a/python/tvm/relax/struct_info.py b/python/tvm/relax/struct_info.py
index 4bff72f17f..e2f9141550 100644
--- a/python/tvm/relax/struct_info.py
+++ b/python/tvm/relax/struct_info.py
@@ -21,9 +21,10 @@
from typing import Optional, Union
import tvm_ffi
+from tvm_ffi import Array
import tvm
-from tvm.ir import Array, EnvFunc, Span, VDevice
+from tvm.ir import EnvFunc, Span, VDevice
from tvm.runtime import DataType
from tvm.tirx import PrimExpr
diff --git a/python/tvm/relax/training/optimizer.py
b/python/tvm/relax/training/optimizer.py
index 505057c736..a341f0a37b 100644
--- a/python/tvm/relax/training/optimizer.py
+++ b/python/tvm/relax/training/optimizer.py
@@ -21,6 +21,7 @@ from decimal import Decimal
from typing import Optional, Union
import numpy as np # type: ignore
+import tvm_ffi
import tvm
@@ -53,7 +54,7 @@ class Optimizer:
param_list : List[Var]
The list of variables to optimize. Will be set in `init()`.
- state : tvm.ir.Array
+ state : tvm_ffi.Array
`state` is an runtime Array representing the state of the optimizer.
Will be set in
`init()`.
@@ -102,7 +103,7 @@ class Optimizer:
dtype: str
name: str
param_list: list[Var]
- state: tvm.ir.Array
+ state: tvm_ffi.Array
def __init__(self, name: str) -> None:
self.name = name
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index cc1c0929c7..fc374a4e9f 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -27,9 +27,9 @@ from typing import Optional, Union
import numpy as np # type: ignore
import tvm_ffi
+from tvm_ffi import Array
import tvm.ir
-from tvm.ir.container import Array
from tvm.relax import Expr, StructInfo, Var
from tvm.relax.dpl import DFPattern
from tvm.runtime import Object, Tensor
diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py
index c0445a1b49..b50a19cae1 100644
--- a/python/tvm/relax/utils.py
+++ b/python/tvm/relax/utils.py
@@ -26,11 +26,12 @@ from collections.abc import Callable
from typing import Any, Optional
import tvm_ffi
+from tvm_ffi import Array, Map
import tvm
from .. import tirx
-from ..ir import Array, Attrs, Map, Type, VDevice
+from ..ir import Attrs, Type, VDevice
from ..te import Tensor as te_Tensor
from ..te import create_prim_func
from ..tirx import PrimExpr
@@ -191,7 +192,7 @@ def gen_call_tir_inputs(
In the common case, the type of te_args is a Relax expression and is
converted
into a TE tensor.
- If te_args is a nested or recursive datatype (i.e list, dict,
tvm.ir.Map, tvm.ir.Array),
+ If te_args is a nested or recursive datatype (i.e list, dict,
tvm_ffi.Map, tvm_ffi.Array),
we recursive and convert any value of type Relax expression into a TE
tensor.
Common values of type int, float, and str are preserved.
diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py
index cdb61a55af..97b9d006eb 100644
--- a/python/tvm/runtime/__init__.py
+++ b/python/tvm/runtime/__init__.py
@@ -21,7 +21,6 @@ from tvm_ffi import convert
from tvm_ffi._dtype import dtype as DataType, DataTypeCode
# class exposures
-from .packed_func import PackedFunc
from .object import Object
from .script_printer import Scriptable
from .object_generic import ObjectConvertible
@@ -35,7 +34,6 @@ from .executable import Executable
from ._tensor import device, cpu, cuda, opencl, vulkan, metal
from ._tensor import vpi, rocm, ext_dev, from_dlpack
from .module import load_module, enabled, system_lib, load_static_library,
num_threads
-from .container import String, ShapeTuple
from .object_generic import const
from .params import (
save_param_dict,
diff --git a/python/tvm/runtime/_tensor.py b/python/tvm/runtime/_tensor.py
index 7564ea8667..1f4da868bb 100644
--- a/python/tvm/runtime/_tensor.py
+++ b/python/tvm/runtime/_tensor.py
@@ -266,7 +266,7 @@ class Tensor(tvm_ffi.core.Tensor):
Parameters
----------
- shape: Union[tvm.runtime.ShapeTuple, Sequence[typing.SupportsInt]]
+ shape: Union[tvm_ffi.Shape, Sequence[typing.SupportsInt]]
The shape of the view.
@@ -288,8 +288,8 @@ class Tensor(tvm_ffi.core.Tensor):
"""
- if not isinstance(shape, tvm.runtime.ShapeTuple):
- shape = tvm.runtime.ShapeTuple([int(dim) for dim in shape])
+ if not isinstance(shape, tvm_ffi.Shape):
+ shape = tvm_ffi.Shape([int(dim) for dim in shape])
if dtype is None:
dtype = self.dtype
@@ -302,7 +302,7 @@ def empty(shape, dtype="float32", device=None,
mem_scope=None):
Parameters
----------
- shape : Union[tvm.runtime.ShapeTuple, Sequence[typing.SupportsInt]]
+ shape : Union[tvm_ffi.Shape, Sequence[typing.SupportsInt]]
The shape of the array.
dtype : type or str
@@ -320,8 +320,8 @@ def empty(shape, dtype="float32", device=None,
mem_scope=None):
The array tvm supported.
"""
device = device or cpu()
- if not isinstance(shape, tvm.runtime.ShapeTuple):
- shape = tvm.runtime.ShapeTuple([int(dim) for dim in shape])
+ if not isinstance(shape, tvm_ffi.Shape):
+ shape = tvm_ffi.Shape([int(dim) for dim in shape])
dtype = tvm_ffi.dtype(dtype)
arr = _ffi_api.TVMTensorAllocWithScope(shape, dtype, device, mem_scope)
return arr
diff --git a/python/tvm/runtime/container.py b/python/tvm/runtime/container.py
deleted file mode 100644
index d0054230da..0000000000
--- a/python/tvm/runtime/container.py
+++ /dev/null
@@ -1,22 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Runtime container structures."""
-
-from tvm_ffi import Shape as ShapeTuple
-from tvm_ffi.core import String
-
-__all__ = ["ShapeTuple", "String"]
diff --git a/python/tvm/runtime/disco/process_pool.py
b/python/tvm/runtime/disco/process_pool.py
index 16975b6c52..93444f94c2 100644
--- a/python/tvm/runtime/disco/process_pool.py
+++ b/python/tvm/runtime/disco/process_pool.py
@@ -21,9 +21,7 @@ import os
import subprocess
import sys
-from tvm_ffi import register_global_func
-
-from tvm.runtime import ShapeTuple
+from tvm_ffi import Shape, register_global_func
class DiscoPopenWorker:
@@ -188,7 +186,7 @@ def _create_process_pool(num_workers: int, num_groups: int,
entrypoint: str):
nonlocal pool
if worker_id != 0:
read_fd, write_fd = pool[worker_id - 1].start()
- return ShapeTuple([read_fd, write_fd])
+ return Shape([read_fd, write_fd])
del pool
return None
diff --git a/python/tvm/runtime/disco/session.py
b/python/tvm/runtime/disco/session.py
index b7fdec4e98..24a3e993bd 100644
--- a/python/tvm/runtime/disco/session.py
+++ b/python/tvm/runtime/disco/session.py
@@ -26,11 +26,10 @@ from collections.abc import Callable, Sequence
from typing import Any, Optional, Union
import numpy as np
-from tvm_ffi import get_global_func, register_global_func, register_object
+from tvm_ffi import Shape, get_global_func, register_global_func,
register_object
from .._tensor import Tensor
from .._tensor import tensor as _as_Tensor
-from ..container import ShapeTuple
from ..device import Device
from ..object import Object
from . import _ffi_api, process_pool # pylint: disable=unused-import
@@ -153,7 +152,7 @@ class Session(Object):
"""
func = self._get_cached_method("runtime.disco.empty")
- return func(ShapeTuple(shape), dtype, device, worker0_only, in_group)
+ return func(Shape(shape), dtype, device, worker0_only, in_group)
def shutdown(self):
"""Shut down the Disco session"""
@@ -326,7 +325,7 @@ class Session(Object):
The device IDs to be used by the underlying communication library.
"""
assert ccl in ("nccl", "rccl"), f"Unsupported CCL backend: {ccl}"
- _ffi_api.SessionInitCCL(self, ccl, ShapeTuple(device_ids)) # type:
ignore # pylint: disable=no-member
+ _ffi_api.SessionInitCCL(self, ccl, Shape(device_ids)) # type: ignore
# pylint: disable=no-member
self._clear_ipc_memory_pool()
def broadcast(
@@ -497,7 +496,7 @@ class Session(Object):
"""
if op not in REDUCE_OPS:
raise ValueError(f"Unsupported reduce op: {op}. Available ops are:
{REDUCE_OPS.keys()}")
- op = ShapeTuple([REDUCE_OPS[op]])
+ op = Shape([REDUCE_OPS[op]])
func = self._get_cached_method("runtime.disco.allreduce")
func(src, op, in_group, dst)
diff --git a/python/tvm/runtime/executable.py b/python/tvm/runtime/executable.py
index 212896ccb2..660756fa2b 100644
--- a/python/tvm/runtime/executable.py
+++ b/python/tvm/runtime/executable.py
@@ -21,10 +21,12 @@
from collections.abc import Callable
from typing import Any
+from tvm_ffi import Function
+
import tvm
from tvm.contrib import utils as _utils
-from . import Module, PackedFunc
+from . import Module
class Executable:
@@ -35,8 +37,8 @@ class Executable:
self.mod: Module = mod
self._jitted_mod: Module | None = None
- def __getitem__(self, name: str) -> PackedFunc:
- """Get the PackedFunc from the jitted module."""
+ def __getitem__(self, name: str) -> Function:
+ """Get the Function from the jitted module."""
return self.jit().get_function(name, query_imports=True)
def __call__(self, *args, **kwargs) -> Any:
diff --git a/python/tvm/runtime/packed_func.py
b/python/tvm/runtime/packed_func.py
deleted file mode 100644
index 8da25b2e6e..0000000000
--- a/python/tvm/runtime/packed_func.py
+++ /dev/null
@@ -1,23 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# pylint: disable=invalid-name, unused-import
-"""Packed Function namespace."""
-
-from tvm_ffi import Function as PackedFunc
-
-__all__ = ["PackedFunc"]
diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py
index 921739203e..982322d601 100644
--- a/python/tvm/runtime/vm.py
+++ b/python/tvm/runtime/vm.py
@@ -24,10 +24,10 @@ from numbers import Integral, Number
from typing import Any
import numpy as np # type: ignore
-from tvm_ffi import register_global_func
+from tvm_ffi import Function, register_global_func
import tvm
-from tvm.runtime import Device, Object, PackedFunc
+from tvm.runtime import Device, Object
from tvm.runtime.profiling import Report
from ..rpc.base import RPC_SESS_MASK
@@ -126,7 +126,7 @@ class VirtualMachine:
init_args.append(alloc_type)
self.module["vm_initialization"](*init_args)
- def __getitem__(self, key: str) -> PackedFunc:
+ def __getitem__(self, key: str) -> Function:
return self.module[key]
def invoke_closure(self, closure: Object, *args: Any) -> Object:
@@ -157,10 +157,10 @@ class VirtualMachine:
) -> None:
"""
Convenience function. Takes a function from the module and saves
- a `PackedFunc` that, when called, will invoke the function with the
given arguments.
- The `PackedFunc` can be accessed from the module using `saved_name`.
+ a `Function` that, when called, will invoke the function with the
given arguments.
+ The `Function` can be accessed from the module using `saved_name`.
This is included to facilitate timing trials:
- Invoking the returned `PackedFunc` will have less overhead from
dictionary lookups
+ Invoking the returned `Function` will have less overhead from
dictionary lookups
than normally running through the VM.
If the saved name is taken, it can be overridden, though it cannot
override
@@ -178,7 +178,7 @@ class VirtualMachine:
The name that the resulting closure should be saved under.
include_return : bool
- Whether the saved PackedFunc should return its output.
+ Whether the saved Function should return its output.
If timing over RPC, it may not be desirable to send output
between machines.
@@ -328,7 +328,7 @@ class VirtualMachine:
return get_output_rec(func_name)
- def set_instrument(self, instrument: tvm.runtime.PackedFunc) -> None:
+ def set_instrument(self, instrument: Function) -> None:
"""Set an instrumentation function.
If instrument is present, the function will be called
@@ -338,7 +338,7 @@ class VirtualMachine:
.. code:: python
def instrument(
- func: Union[VMClosure, PackedFunc],
+ func: Union[VMClosure, Function],
func_symbol: str,
before_run: bool,
ret_value: any,
@@ -359,7 +359,7 @@ class VirtualMachine:
Parameters
----------
- instrument: tvm.runtime.PackedFunc
+ instrument: tvm_ffi.Function
A instrumentation function that get invoked every VM call instr.
See Also
@@ -485,7 +485,7 @@ class VirtualMachine:
func_name : str
The name of the function.
- args: List of Tensor or other objects supported by PackedFunc.
+ args: List of Tensor or other objects supported by Function.
The arguments to the function.
Returns
diff --git a/python/tvm/s_tir/meta_schedule/arg_info.py
b/python/tvm/s_tir/meta_schedule/arg_info.py
index e01b742e99..7aaee06546 100644
--- a/python/tvm/s_tir/meta_schedule/arg_info.py
+++ b/python/tvm/s_tir/meta_schedule/arg_info.py
@@ -18,10 +18,10 @@
from typing import Any
-from tvm_ffi import register_object
+from tvm_ffi import Shape, register_object
from tvm.ir import IRModule
-from tvm.runtime import DataType, Object, ShapeTuple
+from tvm.runtime import DataType, Object
from tvm.tirx import PrimFunc
from . import _ffi_api
@@ -95,17 +95,17 @@ class TensorInfo(ArgInfo):
----------
dtype : DataType
The data type of the tensor.
- shape : ShapeTuple
+ shape : Shape
The shape of the tensor.
"""
dtype: DataType
- shape: ShapeTuple
+ shape: Shape
def __init__(
self,
dtype: DataType,
- shape: ShapeTuple | list[int],
+ shape: Shape | list[int],
) -> None:
"""Constructor
@@ -113,13 +113,10 @@ class TensorInfo(ArgInfo):
----------
dtype : DataType
The data type of the tensor.
- shape : ShapeTuple
+ shape : Shape
The shape of the tensor.
"""
- if isinstance(shape, ShapeTuple):
- shape_tuple = shape
- else:
- shape_tuple = ShapeTuple(shape)
+ shape_tuple = shape if isinstance(shape, Shape) else Shape(shape)
self.__init_handle_by_constructor__(
_ffi_api.TensorInfo, # type: ignore # pylint: disable=no-member
dtype,
diff --git a/python/tvm/s_tir/meta_schedule/utils.py
b/python/tvm/s_tir/meta_schedule/utils.py
index 1344211711..42f52a6c1e 100644
--- a/python/tvm/s_tir/meta_schedule/utils.py
+++ b/python/tvm/s_tir/meta_schedule/utils.py
@@ -24,12 +24,11 @@ from typing import Any
import numpy as np # type: ignore
import psutil # type: ignore
-from tvm_ffi import get_global_func, register_global_func
+from tvm_ffi import Array, Function, Map, get_global_func, register_global_func
from tvm.error import TVMError
-from tvm.ir import Array, IRModule, Map
+from tvm.ir import IRModule
from tvm.rpc import RPCSession
-from tvm.runtime import PackedFunc
from tvm.tirx import FloatImm, IntImm
@@ -310,21 +309,21 @@ def get_global_func_on_rpc_session(
session: RPCSession,
name: str,
extra_error_msg: str | None = None,
-) -> PackedFunc:
- """Get a PackedFunc from the global registry from an RPCSession.
+) -> Function:
+ """Get a Function from the global registry from an RPCSession.
Parameters
----------
session : RPCSession
The RPCSession to be retrieved from
name : str
- The name of the PackedFunc
+ The name of the Function
extra_error_msg : Optional[str]
Extra information to provide in the error message
Returns
-------
- result : PackedFunc
+ result : Function
The result
"""
try:
diff --git a/python/tvm/s_tir/schedule/trace.py
b/python/tvm/s_tir/schedule/trace.py
index b82df6868e..213d3269fc 100644
--- a/python/tvm/s_tir/schedule/trace.py
+++ b/python/tvm/s_tir/schedule/trace.py
@@ -20,13 +20,14 @@ import os
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
+from tvm_ffi import Array, Map
from tvm_ffi import register_object as _register_object
from tvm.runtime import Object
from tvm.tirx.expr import FloatImm, IntImm
from tvm.tirx.function import IndexMap
-from ...ir import Array, Map, save_json
+from ...ir import save_json
from . import _ffi_api
from .instruction import ATTR_TYPE, INPUT_RV_TYPE, Instruction
diff --git a/python/tvm/script/ir_builder/tirx/ir.py
b/python/tvm/script/ir_builder/tirx/ir.py
index ce88c56315..76f0397a8e 100644
--- a/python/tvm/script/ir_builder/tirx/ir.py
+++ b/python/tvm/script/ir_builder/tirx/ir.py
@@ -30,10 +30,13 @@ from typing import Literal
# isort: on
+import tvm_ffi
+from tvm_ffi.core import String
+
from tvm import ir, tirx
from tvm.ir import Type
from tvm.ir.base import deprecated
-from tvm.runtime import String, convert
+from tvm.runtime import convert
from tvm.target import Target
# pylint: disable=unused-import
@@ -1302,7 +1305,7 @@ def buffer_store(
"""
from tvm.arith import Analyzer # pylint: disable=import-outside-toplevel
- if not isinstance(indices, list | tuple | ir.Array):
+ if not isinstance(indices, list | tuple | tvm_ffi.Array):
indices = [indices]
expr_indices = []
diff --git a/python/tvm/script/parser/tirx/parser.py
b/python/tvm/script/parser/tirx/parser.py
index 888113c7d4..825eda154d 100644
--- a/python/tvm/script/parser/tirx/parser.py
+++ b/python/tvm/script/parser/tirx/parser.py
@@ -20,6 +20,8 @@ import contextlib
from functools import partial
from typing import Any
+import tvm_ffi
+
import tvm
from tvm.ir import GlobalVar, PrimType
from tvm.tirx import Buffer, IterVar, PrimExpr, Var
@@ -91,7 +93,7 @@ def bind_for_value(self: Parser, node: doc.expr, var_name:
str, value: Any) -> A
res : Any
The bound value.
"""
- if isinstance(value, list | tuple | tvm.ir.Array):
+ if isinstance(value, list | tuple | tvm_ffi.Array):
for i, v in enumerate(value):
bind_for_value(self, node, f"{var_name}_{i}", v)
return value
diff --git a/python/tvm/target/codegen.py b/python/tvm/target/codegen.py
index 79d5f3dc09..18949912b0 100644
--- a/python/tvm/target/codegen.py
+++ b/python/tvm/target/codegen.py
@@ -16,7 +16,8 @@
# under the License.
"""Code generation related functions."""
-from ..ir.container import Array
+from tvm_ffi import Array
+
from . import _ffi_api
from .target import Target
diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py
index b008df4986..c71ac8cead 100644
--- a/python/tvm/target/target.py
+++ b/python/tvm/target/target.py
@@ -17,10 +17,10 @@
"""Target data structure."""
import tvm_ffi
+from tvm_ffi import Map
+from tvm_ffi.core import String
-from tvm.ir.container import Map
from tvm.runtime import Device, Object, convert
-from tvm.runtime.container import String
from . import _ffi_api
diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py
index fdc6c5f95a..58effec4db 100644
--- a/python/tvm/te/operation.py
+++ b/python/tvm/te/operation.py
@@ -21,10 +21,11 @@ import inspect
# pylint: disable=invalid-name
from numbers import Integral as _Integral
+from tvm_ffi import Array
+
import tvm.arith._ffi_api
import tvm.tirx
import tvm.tirx._ffi_api
-from tvm.ir import Array
from tvm.runtime import convert
from . import _ffi_api
diff --git a/python/tvm/tirx/op.py b/python/tvm/tirx/op.py
index 6b4a636f30..566f5d905b 100644
--- a/python/tvm/tirx/op.py
+++ b/python/tvm/tirx/op.py
@@ -20,10 +20,11 @@
from typing import Any
import tvm_ffi
+from tvm_ffi import Array
import tvm
from tvm import tirx
-from tvm.ir import Array, Op, PrimExpr
+from tvm.ir import Op, PrimExpr
from tvm.ir.base import Span
from tvm.runtime import const
@@ -3456,7 +3457,7 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
if init is not None:
init = [init]
combiner = CommReducer(lhs, rhs, result, id_elem)
- if not isinstance(axis, list | tuple | tvm.ir.Array):
+ if not isinstance(axis, list | tuple | Array):
axis = [axis]
if where is None:
where = tirx.convert(True)
diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py
index a0ff38f0f3..1f4799c8ec 100644
--- a/python/tvm/topi/image/resize.py
+++ b/python/tvm/topi/image/resize.py
@@ -434,7 +434,7 @@ def resize1d(
out_dtype: string, optional
Type to return. If left None will be same as input type.
- output_shape: tvm.tirx.container.Array, optional
+ output_shape: tvm_ffi.Array, optional
Shape to return. If left None will be inferred
(If shape is determined dynamically, pass out_dtype.shape as
output_shape)
@@ -801,7 +801,7 @@ def resize2d(
out_dtype: string, optional
Type to return. If left None will be same as input type.
- output_shape: tvm.tirx.container.Array, optional
+ output_shape: tvm_ffi.Array, optional
Shape to return. If left None will be inferred
(If shape is determined dynamically, pass out_dtype.shape as
output_shape)
@@ -1270,7 +1270,7 @@ def resize3d(
out_dtype: string, optional
Type to return. If left None will be same as input type.
- output_shape: tvm.tirx.container.Array, optional
+ output_shape: tvm_ffi.Array, optional
Shape to return. If left None will be inferred
(If shape is determined dynamically, pass out_dtype.shape as
output_shape)
diff --git a/python/tvm/topi/nn/upsampling.py b/python/tvm/topi/nn/upsampling.py
index 67b1fa5dac..e45cddd515 100644
--- a/python/tvm/topi/nn/upsampling.py
+++ b/python/tvm/topi/nn/upsampling.py
@@ -52,7 +52,7 @@ def upsampling(
method : {"bilinear", "nearest_neighbor", "bicubic"}
Method to be used for upsampling.
- output_shape: tvm.tirx.container.Array, optional
+ output_shape: tvm_ffi.Array, optional
Shape to return. If left None will be inferred
(If shape is determined dynamically, pass out_dtype.shape as
output_shape)
@@ -147,7 +147,7 @@ def upsampling3d(
Refer to the ONNX Resize operator specification for details.
Available options are "half_pixel", "align_corners" and "asymmetric".
- output_shape: tvm.tirx.container.Array, optional
+ output_shape: tvm_ffi.Array, optional
Shape to return. If left None will be inferred
(If shape is determined dynamically, pass out_dtype.shape as
output_shape)
diff --git a/tests/python/all-platform-minimal-test/test_runtime_packed_func.py
b/tests/python/all-platform-minimal-test/test_runtime_packed_func.py
index f2dd4c2b05..51128c8b5b 100644
--- a/tests/python/all-platform-minimal-test/test_runtime_packed_func.py
+++ b/tests/python/all-platform-minimal-test/test_runtime_packed_func.py
@@ -20,6 +20,7 @@
import gc
import numpy as np
+import tvm_ffi
import tvm
import tvm.testing
@@ -37,7 +38,7 @@ def test_get_global():
# get it out from global function table
f = tvm.get_global_func("my_packed_func")
- assert isinstance(f, tvm.runtime.PackedFunc)
+ assert isinstance(f, tvm_ffi.Function)
y = f(*targs)
assert y == 10
@@ -58,7 +59,7 @@ def test_get_callback_with_node():
# get it out from global function table
f = tvm.get_global_func("my_callback_with_node")
- assert isinstance(f, tvm.runtime.PackedFunc)
+ assert isinstance(f, tvm_ffi.Function)
y = f(x, f2)
assert y.value == 10
@@ -83,7 +84,7 @@ def test_convert():
assert tuple(args) == targs
f = tvm.runtime.convert(myfunc)
- assert isinstance(f, tvm.runtime.PackedFunc)
+ assert isinstance(f, tvm_ffi.Function)
def test_byte_array():
diff --git a/tests/python/arith/test_arith_domain_touched.py
b/tests/python/arith/test_arith_domain_touched.py
index c1791ac184..9d04fad54b 100644
--- a/tests/python/arith/test_arith_domain_touched.py
+++ b/tests/python/arith/test_arith_domain_touched.py
@@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import tvm_ffi
+
import tvm
from tvm.script import tirx as T
@@ -63,7 +65,7 @@ def test_domain_touched():
assert b_domain_r[1].extent.name == "m"
b_domain_w = tvm.arith._ffi_api.DomainTouched(ir, b, False, True)
- assert isinstance(b_domain_w, tvm.container.Array)
+ assert isinstance(b_domain_w, tvm_ffi.Array)
assert len(b_domain_w) == 0
diff --git a/tests/python/contrib/test_popen_pool.py
b/tests/python/contrib/test_popen_pool.py
index 25a43b2c1b..355954caf9 100644
--- a/tests/python/contrib/test_popen_pool.py
+++ b/tests/python/contrib/test_popen_pool.py
@@ -84,13 +84,15 @@ def test_popen_worker_recycles():
def test_popen_pool_executor():
+ import tvm_ffi
+
import tvm
pool = PopenPoolExecutor(max_workers=2, timeout=0.01)
value1 = pool.submit(identity_after, 1, 100)
value2 = pool.submit(terminate_self)
value3 = pool.submit(identity_after, 3, 0)
- value4 = pool.submit(tvm.runtime.String, "xyz")
+ value4 = pool.submit(tvm_ffi.core.String, "xyz")
with pytest.raises(TimeoutError):
value1.result()
diff --git a/tests/python/disco/test_custom_allreduce.py
b/tests/python/disco/test_custom_allreduce.py
index 4aed32c052..1c45677e55 100644
--- a/tests/python/disco/test_custom_allreduce.py
+++ b/tests/python/disco/test_custom_allreduce.py
@@ -21,10 +21,11 @@ from itertools import product
import numpy as np
import pytest
+from tvm_ffi import Shape
import tvm
import tvm.testing
-from tvm.runtime import DataType, ShapeTuple, disco
+from tvm.runtime import DataType, disco
from tvm.runtime.disco import Session
@@ -60,8 +61,8 @@ def test_allreduce(shape, ccl, strategy):
falloc_ipc_storage =
sess.get_global_func("runtime.disco.cuda_ipc.alloc_storage")
falloc_tensor = sess.get_global_func("vm.builtin.alloc_tensor")
fallreduce =
sess.get_global_func("runtime.disco.cuda_ipc.custom_allreduce")
- d_storage = sess.call_packed(falloc_ipc_storage, ShapeTuple(shape),
DataType(dtype))
- d_input = sess.call_packed(falloc_tensor, d_storage, 0, ShapeTuple(shape),
DataType(dtype))
+ d_storage = sess.call_packed(falloc_ipc_storage, Shape(shape),
DataType(dtype))
+ d_input = sess.call_packed(falloc_tensor, d_storage, 0, Shape(shape),
DataType(dtype))
array_1 = np.arange(num_elements, dtype="float32").reshape(*shape)
array_2 = np.arange(start=1, stop=-(num_elements - 1), step=-1,
dtype="float32").reshape(*shape)
diff --git a/tests/python/disco/test_loader.py
b/tests/python/disco/test_loader.py
index 2d5bf130f7..b709571219 100644
--- a/tests/python/disco/test_loader.py
+++ b/tests/python/disco/test_loader.py
@@ -22,13 +22,12 @@ import json
import tempfile
import numpy as np
-from tvm_ffi import register_global_func
+from tvm_ffi import Shape, register_global_func
import tvm
import tvm.testing
from tvm import relax as rx
from tvm.contrib import tvmjs
-from tvm.runtime import ShapeTuple
from tvm.runtime import disco as di
from tvm.s_tir import dlight as dl
from tvm.script import ir as I
@@ -150,8 +149,8 @@ def test_load_shard():
sess.init_ccl("nccl", *devices)
loader = _create_loader(sess, path, param_dict, shard_info)
loader_load = sess.get_global_func("runtime.disco.ShardLoaderLoad")
- d_0 = loader_load(loader, ShapeTuple([0]))
- d_1 = loader_load(loader, ShapeTuple([1]))
+ d_0 = loader_load(loader, Shape([0]))
+ d_1 = loader_load(loader, Shape([1]))
np.testing.assert_equal(
param_dict["x_0"][:, 0:64],
d_0.debug_get_from_remote(0).numpy(),
@@ -198,8 +197,8 @@ def test_load_presharded():
loader = _create_presharded_loader(sess, path)
loader_load =
sess.get_global_func("runtime.disco.ShardLoaderLoadPresharded")
- d_0 = loader_load(loader, ShapeTuple([0]))
- d_1 = loader_load(loader, ShapeTuple([1]))
+ d_0 = loader_load(loader, Shape([0]))
+ d_1 = loader_load(loader, Shape([1]))
np.testing.assert_equal(
param_dict["x_0"][:, 0:64],
@@ -446,7 +445,7 @@ def test_load_qkv_proj_shard(): # pylint:
disable=too-many-locals
sess.init_ccl("nccl", *devices)
loader = _create_loader(sess, path, param_dict, shard_info)
loader_load = sess.get_global_func("runtime.disco.ShardLoaderLoad")
- d_0 = loader_load(loader, ShapeTuple([0]))
+ d_0 = loader_load(loader, Shape([0]))
np.testing.assert_equal(
np_qkv[0],
d_0.debug_get_from_remote(0).numpy(),
diff --git a/tests/python/disco/test_nvshmem.py
b/tests/python/disco/test_nvshmem.py
index a5a1b81976..29509b0f72 100644
--- a/tests/python/disco/test_nvshmem.py
+++ b/tests/python/disco/test_nvshmem.py
@@ -30,11 +30,11 @@ from typing import Any
import numpy as np
import pytest
+from tvm_ffi import Shape
import tvm
import tvm.testing
from tvm.exec import disco_worker as _ # pylint: disable=unused-import
-from tvm.runtime import ShapeTuple
from tvm.runtime import disco as di
from tvm.script import ir as I
from tvm.script import relax as R
@@ -133,8 +133,8 @@ def test_nvshmem_empty(session_kind: di.Session,
num_workers: int):
init_dfunc(uid, num_workers, 0)
sess.sync_worker_0()
empty_dfunc = sess.get_global_func("runtime.disco.nvshmem.empty")
- a = empty_dfunc(ShapeTuple((32, 64)), "float32", device)
- b = empty_dfunc(ShapeTuple((64, 32)), "float32", device)
+ a = empty_dfunc(Shape((32, 64)), "float32", device)
+ b = empty_dfunc(Shape((64, 32)), "float32", device)
sess.sync_worker_0()
finalize_dfunc =
sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem")
finalize_dfunc()
diff --git a/tests/python/disco/test_session.py
b/tests/python/disco/test_session.py
index afb25921da..8adb1ceff0 100644
--- a/tests/python/disco/test_session.py
+++ b/tests/python/disco/test_session.py
@@ -25,12 +25,13 @@ import threading
import numpy as np
import pytest
+from tvm_ffi import Shape
+from tvm_ffi.core import String
import tvm
import tvm.testing
from tvm import relax as rx
from tvm.exec import disco_worker as _ # pylint: disable=unused-import
-from tvm.runtime import ShapeTuple, String
from tvm.runtime import disco as di
from tvm.script import ir as I
from tvm.script import relax as R
@@ -185,10 +186,10 @@ def test_shape_tuple(session_kind):
num_workers = 4
sess = session_kind(num_workers=num_workers)
func: di.DPackedFunc = sess.get_global_func("tests.disco.shape_tuple")
- result: di.DRef = func(ShapeTuple([1, 2, 3]))
+ result: di.DRef = func(Shape([1, 2, 3]))
for i in range(num_workers):
value = result.debug_get_from_remote(i)
- assert isinstance(value, ShapeTuple)
+ assert isinstance(value, Shape)
assert list(value) == [1, 2, 3, 4, 5]
diff --git a/tests/python/ir/test_container_structural_equal.py
b/tests/python/ir/test_container_structural_equal.py
index 9717b0f5b6..1d9d575af8 100644
--- a/tests/python/ir/test_container_structural_equal.py
+++ b/tests/python/ir/test_container_structural_equal.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import pytest
+import tvm_ffi
from tvm_ffi.access_path import AccessPath
import tvm
@@ -114,8 +115,8 @@ def test_array_structural_equal_to_self(contents):
],
)
def test_shape_tuple_structural_equal_to_self(contents):
- a = tvm.runtime.ShapeTuple(list(contents))
- b = tvm.runtime.ShapeTuple(list(contents))
+ a = tvm_ffi.Shape(list(contents))
+ b = tvm_ffi.Shape(list(contents))
assert get_first_mismatch_ensure_symmetry(a, b) is None
diff --git a/tests/python/ir/test_node_reflection.py
b/tests/python/ir/test_node_reflection.py
index 4cc0769d58..111efa5696 100644
--- a/tests/python/ir/test_node_reflection.py
+++ b/tests/python/ir/test_node_reflection.py
@@ -19,6 +19,7 @@ import sys
import numpy as np
import pytest
+import tvm_ffi
import tvm
import tvm.testing
@@ -83,7 +84,7 @@ def test_make_node():
assert AA.op == A.op
assert AA.value_index == A.value_index
- y = tvm.ir.make_node("ir.IntImm", dtype=tvm.runtime.String("int32"),
value=10, span=None)
+ y = tvm.ir.make_node("ir.IntImm", dtype=tvm_ffi.core.String("int32"),
value=10, span=None)
def test_make_sum():
@@ -122,12 +123,12 @@ def test_env_func():
def test_string():
# non printable str, need to store by b64
- s1 = tvm.runtime.String("xy\x01z")
+ s1 = tvm_ffi.core.String("xy\x01z")
s2 = tvm.ir.load_json(tvm.ir.save_json(s1))
tvm.ir.assert_structural_equal(s1, s2)
# printable str, need to store by repr_str
- s1 = tvm.runtime.String("xyz")
+ s1 = tvm_ffi.core.String("xyz")
s2 = tvm.ir.load_json(tvm.ir.save_json(s1))
tvm.ir.assert_structural_equal(s1, s2)
diff --git
a/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py
b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py
index fc3d31ae95..473e10753f 100644
--- a/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py
+++ b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py
@@ -23,6 +23,8 @@ import numpy as np
import pytest
import scipy.special
import torch
+import tvm_ffi
+from tvm_ffi import Shape
import tvm
import tvm.testing
@@ -41,7 +43,6 @@ from tvm.relax.frontend.nn.llm.kv_cache import (
tree_attn,
tree_attn_with_paged_kv_cache,
)
-from tvm.runtime import ShapeTuple
from tvm.s_tir import dlight as dl
@@ -196,7 +197,7 @@ def set_global_func(head_dim, dtype):
def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window):
fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create")
cache = fcreate(
- tvm.runtime.ShapeTuple(
+ tvm_ffi.Shape(
[
reserved_nseq,
maximum_total_seq_length,
@@ -205,12 +206,12 @@ def create_kv_cache(head_dim, dtype, rope_mode,
support_sliding_window):
int(support_sliding_window),
]
),
- tvm.runtime.ShapeTuple([0, num_layers]),
+ tvm_ffi.Shape([0, num_layers]),
num_qo_heads,
num_kv_heads,
head_dim,
head_dim, # v_head_dim
- tvm.runtime.ShapeTuple([int(AttnKind.MHA) for _ in range(num_layers)]),
+ tvm_ffi.Shape([int(AttnKind.MHA) for _ in range(num_layers)]),
False, # enable_kv_transfer
rope_mode,
rope_scale,
@@ -373,10 +374,10 @@ def apply_attention(
if not only_update_host:
fbegin_forward(
kv_cache,
- ShapeTuple(seq_ids),
- ShapeTuple(append_lengths),
+ Shape(seq_ids),
+ Shape(append_lengths),
(
- ShapeTuple(flattened_token_tree_parent_ptr)
+ Shape(flattened_token_tree_parent_ptr)
if flattened_token_tree_parent_ptr is not None
else None
),
@@ -569,7 +570,7 @@ def apply_attention(
seq_ids = [seq_id for seq_id, _ in batch]
if not only_update_host:
fcommit_accepted_token_tree_nodes(
- kv_cache, ShapeTuple(seq_ids),
ShapeTuple(accepted_leaf_indices)
+ kv_cache, Shape(seq_ids), Shape(accepted_leaf_indices)
)
for i, (accepted_leaf_idx, (seq_id, append_length)) in enumerate(
zip(accepted_leaf_indices, batch)
@@ -685,7 +686,7 @@ def
test_paged_attention_kv_cache_transfer(kv_cache_and_config):
remote_pos_maps = comm.bcast(remote_pos_maps, root=1)
comm.Barrier()
for seq_id in prefill_len.keys():
- fdisagg_mark_send(kv_cache, seq_id, 0,
ShapeTuple(remote_pos_maps[seq_id]), 1)
+ fdisagg_mark_send(kv_cache, seq_id, 0,
Shape(remote_pos_maps[seq_id]), 1)
for batch in prefill_operation_seq:
apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v,
skip_add_sequence=True)
device.sync()
diff --git
a/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer_kernel.py
b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer_kernel.py
index e55971acfb..0adbf89a94 100644
---
a/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer_kernel.py
+++
b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer_kernel.py
@@ -17,10 +17,11 @@
# ruff: noqa: E501
import numpy as np
import pytest
+from tvm_ffi import Shape
import tvm
import tvm.testing
-from tvm.runtime import Device, ShapeTuple
+from tvm.runtime import Device
from tvm.runtime import disco as di
page_size = 4
@@ -54,7 +55,7 @@ def test_kv_transfer_without_disco():
init_func(uid, 2, rank)
empty_func = tvm.get_global_func("runtime.disco.nvshmem.empty")
pages = empty_func(
- ShapeTuple((num_layers, num_pages, 2, num_kv_heads, page_size,
head_dim)), "float16", dev
+ Shape((num_layers, num_pages, 2, num_kv_heads, page_size, head_dim)),
"float16", dev
)
position_map_array = [0, 1, 2, 3, 4, 5, 10, 11, 12, 15, 16, 17, 18, 19,
25, 27]
np.random.seed(0)
@@ -108,7 +109,7 @@ def test_kv_transfer_page_to_page_without_disco():
init_func(uid, 2, rank)
empty_func = tvm.get_global_func("runtime.disco.nvshmem.empty")
pages = empty_func(
- ShapeTuple((num_layers, num_pages, 2, num_kv_heads, page_size,
head_dim)), "float16", dev
+ Shape((num_layers, num_pages, 2, num_kv_heads, page_size, head_dim)),
"float16", dev
)
rank_1_position_map_array = [0, 1, 2, 3, 4, 5, 10, 11, 12, 15, 16, 17, 18,
19, 25, 27]
rank_0_position_map_array = list(reversed(rank_1_position_map_array))
@@ -174,7 +175,7 @@ def test_kv_transfer_with_disco():
init_func(uid, 4, rank * 2)
empty_func = sess.get_global_func("runtime.disco.nvshmem.empty")
pages = empty_func(
- ShapeTuple((num_layers, num_pages, 2, num_kv_heads, page_size,
head_dim)),
+ Shape((num_layers, num_pages, 2, num_kv_heads, page_size, head_dim)),
"float16",
Device(device_type=0, device_id=0),
)
@@ -199,7 +200,7 @@ def test_kv_transfer_with_disco():
f_view_func = sess.get_global_func("runtime.TVMTensorCreateView")
layer_view = f_view_func(
pages,
- ShapeTuple([num_pages, 2, num_kv_heads, page_size, head_dim]),
+ Shape([num_pages, 2, num_kv_heads, page_size, head_dim]),
"float16",
layer_id * num_pages * 2 * num_kv_heads * page_size * head_dim * 2,
)
diff --git a/tests/python/relax/test_contrib_vllm.py
b/tests/python/relax/test_contrib_vllm.py
index 6ec59072e8..478fbe95c6 100644
--- a/tests/python/relax/test_contrib_vllm.py
+++ b/tests/python/relax/test_contrib_vllm.py
@@ -17,6 +17,7 @@
# ruff: noqa: RUF005
import numpy as np
import pytest
+import tvm_ffi
import tvm
import tvm.testing
@@ -52,7 +53,7 @@ def build_and_run(mod, inputs_np, target, legalize=True):
out = f(*inputs)
- if isinstance(out, tvm.ir.container.Array):
+ if isinstance(out, tvm_ffi.Array):
return [arr.numpy() for arr in out]
return out.numpy()
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 84348c41ca..4e13e906d8 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -28,6 +28,7 @@ import numpy as np
import onnx
import onnxruntime
import pytest
+import tvm_ffi
from onnx import ModelProto, TensorProto, helper
import tvm
@@ -169,7 +170,7 @@ def check_correctness(
return "other"
def _check_output(tvm_out, ort_out):
- if isinstance(tvm_out, tuple) and isinstance(ort_out,
tvm.runtime.ShapeTuple | list):
+ if isinstance(tvm_out, tuple) and isinstance(ort_out, tvm_ffi.Shape |
list):
assert len(tvm_out) == len(ort_out), "Unequal number of outputs"
for tvm_out_i, ort_out_i in zip(tvm_out, ort_out):
_check_output(tvm_out_i, ort_out_i)
@@ -177,7 +178,7 @@ def check_correctness(
if check_dtypes:
assert tvm_out.numpy().dtype == ort_out.dtype
tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, rtol=rtol,
atol=atol)
- elif isinstance(tvm_out, tvm.runtime.ShapeTuple) and
isinstance(ort_out, np.ndarray):
+ elif isinstance(tvm_out, tvm_ffi.Shape) and isinstance(ort_out,
np.ndarray):
shape_out = tvm.runtime.tensor([int(i) for i in tvm_out])
if check_dtypes:
assert _get_numpy_subdtype(shape_out.numpy()) ==
_get_numpy_subdtype(ort_out)
@@ -741,7 +742,9 @@ def test_softmax_family_opset1_legacy_ir_semantics(op_name:
str, expected_core_o
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [2, 3,
4])],
)
model = helper.make_model(
- graph, producer_name="softmax_family_opset1_ir_test",
opset_imports=[helper.make_opsetid("", 1)]
+ graph,
+ producer_name="softmax_family_opset1_ir_test",
+ opset_imports=[helper.make_opsetid("", 1)],
)
tvm_model = from_onnx(model, opset=1, keep_params_in_input=True)
call_ops = collect_relax_call_ops(tvm_model["main"])
@@ -5650,6 +5653,7 @@ def test_split_to_sequence_uneven_last_chunk(axis: int):
model = helper.make_model(graph,
producer_name="test_split_to_sequence_uneven")
check_correctness(model)
+
def test_quantizelinear_singleton_qparams_opset10():
"""QuantizeLinear must treat shape-[1] scale/zp as scalar in opset10."""
node = helper.make_node("QuantizeLinear", ["x", "scale", "zero_point"],
["y"])
@@ -5722,6 +5726,7 @@ def test_dynamicquantizelinear_opset11():
x = rg.standard_normal((2, 3, 4)).astype("float32")
check_correctness(model, inputs={"x": x}, opset=11, atol=1e-5, rtol=1e-5,
check_dtypes=True)
+
def test_quantizelinear_default_axis_opset10():
"""opset10 QuantizeLinear should honor default axis=1 (not hardcode
axis=0)."""
node = helper.make_node("QuantizeLinear", ["x", "scale", "zero_point"],
["y"])
@@ -5759,5 +5764,6 @@ def test_dequantizelinear_default_axis_opset10():
x = rg.integers(low=0, high=255, size=(2, 3, 4), dtype=np.uint8)
check_correctness(model, inputs={"x": x}, opset=10, check_dtypes=True)
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_op_gradient_numeric.py
b/tests/python/relax/test_op_gradient_numeric.py
index b525e9a744..3c402f1f85 100644
--- a/tests/python/relax/test_op_gradient_numeric.py
+++ b/tests/python/relax/test_op_gradient_numeric.py
@@ -20,6 +20,7 @@ from typing import Union
import numpy as np
import pytest
+import tvm_ffi
import tvm
import tvm.testing
@@ -90,7 +91,7 @@ def relax_check_gradients(
return tvm.runtime.tensor(data)
def _tvm_to_numpy(data, ignore_idx=[]):
- if isinstance(data, tvm.ir.Array):
+ if isinstance(data, tvm_ffi.Array):
return [_tvm_to_numpy(d) for i, d in enumerate(data) if i not in
ignore_idx]
if isinstance(data, tvm.runtime.Tensor):
return data.numpy()
diff --git a/tests/python/relax/test_op_misc.py
b/tests/python/relax/test_op_misc.py
index 42a055ce2e..5f7f0a79d0 100644
--- a/tests/python/relax/test_op_misc.py
+++ b/tests/python/relax/test_op_misc.py
@@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
# ruff: noqa: F841
+import tvm_ffi
+
import tvm
import tvm.testing
from tvm import relax as rx
@@ -58,7 +60,7 @@ def test_call_tir_with_grad():
te_grad_kwargs={"k": 1.0},
)
assert v2.attrs.te_grad_name == "identity_k_grad"
- assert isinstance(v2.attrs.te_grad_kwargs, tvm.ir.container.Map)
+ assert isinstance(v2.attrs.te_grad_kwargs, tvm_ffi.Map)
val = next(iter(v2.attrs.te_grad_kwargs.items()))
assert val[0] == "k" and float(val[1]) == 1.0
diff --git a/tests/python/relax/test_pipeline.py
b/tests/python/relax/test_pipeline.py
index 91a9228e8c..a85cbc4356 100644
--- a/tests/python/relax/test_pipeline.py
+++ b/tests/python/relax/test_pipeline.py
@@ -16,6 +16,7 @@
# under the License.
import numpy as np
import pytest
+import tvm_ffi
import tvm
import tvm.testing
@@ -102,7 +103,7 @@ def test_pipeline_with_kv_cache():
cache_np = np.empty((num_steps, 4), dtype="float32")
vm = relax.VirtualMachine(ex, tvm.cpu())
- kv_cache = vm["create_kv_cache"](tvm.runtime.ShapeTuple([1]))
+ kv_cache = vm["create_kv_cache"](tvm_ffi.Shape([1]))
for i in range(num_steps):
x_np = np.random.rand(1, 4).astype(np.float32)
@@ -110,7 +111,7 @@ def test_pipeline_with_kv_cache():
x = tvm.runtime.tensor(x_np)
y = tvm.runtime.tensor(y_np)
np_shape = (i + 1, 4)
- kv, kv_cache = vm["main"](x, y, tvm.runtime.ShapeTuple(np_shape),
kv_cache)
+ kv, kv_cache = vm["main"](x, y, tvm_ffi.Shape(np_shape), kv_cache)
cache_np[i, :] = x_np + y_np
tvm.testing.assert_allclose(kv.numpy(), cache_np[: np_shape[0], :],
rtol=1e-7, atol=1e-7)
diff --git a/tests/python/relax/test_relax_operators.py
b/tests/python/relax/test_relax_operators.py
index 056833bdc2..8a2eac04d1 100644
--- a/tests/python/relax/test_relax_operators.py
+++ b/tests/python/relax/test_relax_operators.py
@@ -21,6 +21,7 @@ import tempfile
import numpy as np
import pytest
+import tvm_ffi
import tvm
import tvm.testing
@@ -235,10 +236,10 @@ class ShapeOfTest:
def test_op_shape_of(exec_mode):
unit_shape = run_cpu(ShapeOfTest, "get_scalar_shape", exec_mode=exec_mode)
- assert unit_shape == tvm.runtime.ShapeTuple([])
+ assert unit_shape == tvm_ffi.Shape([])
const_shape = run_cpu(ShapeOfTest, "get_constant_shape",
exec_mode=exec_mode)
- assert const_shape == tvm.runtime.ShapeTuple([2, 2])
+ assert const_shape == tvm_ffi.Shape([2, 2])
scalar_shape = run_cpu(
ShapeOfTest,
@@ -246,7 +247,7 @@ def test_op_shape_of(exec_mode):
tvm.runtime.tensor(np.array(1, dtype="int32")),
exec_mode=exec_mode,
)
- assert scalar_shape == tvm.runtime.ShapeTuple([])
+ assert scalar_shape == tvm_ffi.Shape([])
tensor_shape = run_cpu(
ShapeOfTest,
@@ -254,7 +255,7 @@ def test_op_shape_of(exec_mode):
tvm.runtime.tensor(np.zeros((1, 2, 3)).astype("int32")),
exec_mode=exec_mode,
)
- assert tensor_shape == tvm.runtime.ShapeTuple([1, 2, 3])
+ assert tensor_shape == tvm_ffi.Shape([1, 2, 3])
constrained_shape = run_cpu(
ShapeOfTest,
@@ -262,7 +263,7 @@ def test_op_shape_of(exec_mode):
tvm.runtime.tensor(np.zeros((1,)).astype("int32")),
exec_mode=exec_mode,
)
- assert constrained_shape == tvm.runtime.ShapeTuple([1])
+ assert constrained_shape == tvm_ffi.Shape([1])
@tvm.script.ir_module
@@ -286,27 +287,21 @@ def test_op_shape_to_tensor(exec_mode):
assert ShapeToTensorTest["symbolic_shape"].body.struct_info.ndim == 1
# Check its functionality
- out2d = run_cpu(
- ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 2]),
exec_mode=exec_mode
- )
+ out2d = run_cpu(ShapeToTensorTest, "const_shape", tvm_ffi.Shape([3, 2]),
exec_mode=exec_mode)
assert isinstance(out2d, tvm.runtime.Tensor)
assert np.array_equal(out2d.numpy(), np.array([3, 2]))
- out3d = run_cpu(
- ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 3, 2]),
exec_mode=exec_mode
- )
+ out3d = run_cpu(ShapeToTensorTest, "const_shape", tvm_ffi.Shape([3, 3,
2]), exec_mode=exec_mode)
assert isinstance(out3d, tvm.runtime.Tensor)
assert np.array_equal(out3d.numpy(), np.array([3, 3, 2]))
out4d = run_cpu(
- ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 3, 2,
2]), exec_mode=exec_mode
+ ShapeToTensorTest, "const_shape", tvm_ffi.Shape([3, 3, 2, 2]),
exec_mode=exec_mode
)
assert isinstance(out4d, tvm.runtime.Tensor)
assert np.array_equal(out4d.numpy(), np.array([3, 3, 2, 2]))
- outs = run_cpu(
- ShapeToTensorTest, "symbolic_shape", tvm.runtime.ShapeTuple([3, 2]),
exec_mode=exec_mode
- )
+ outs = run_cpu(ShapeToTensorTest, "symbolic_shape", tvm_ffi.Shape([3, 2]),
exec_mode=exec_mode)
assert isinstance(outs, tvm.runtime.Tensor)
assert np.array_equal(outs.numpy(), np.array([3, 2]))
diff --git a/tests/python/relax/test_runtime_builtin.py
b/tests/python/relax/test_runtime_builtin.py
index 65642c4cce..3eb06fc400 100644
--- a/tests/python/relax/test_runtime_builtin.py
+++ b/tests/python/relax/test_runtime_builtin.py
@@ -17,6 +17,7 @@
# ruff: noqa: F401
import numpy as np
import pytest
+import tvm_ffi
import tvm
import tvm.testing
@@ -31,7 +32,7 @@ def test_make_shape():
heap = tvm.runtime.tensor(np.arange(10).astype("int64"))
s = make_shape(heap, 3, MK.USE_IMM, 10, MK.LOAD_SHAPE, 0, MK.LOAD_SHAPE, 2)
- assert s == tvm.runtime.container.ShapeTuple([10, 0, 2])
+ assert s == tvm_ffi.Shape([10, 0, 2])
def test_match_shape():
@@ -41,7 +42,7 @@ def test_match_shape():
assert heap.numpy()[2] == 0
- s = tvm.runtime.container.ShapeTuple([1, 2, 3])
+ s = tvm_ffi.Shape([1, 2, 3])
x = tvm.runtime.tensor(np.zeros([1, 2, 3]))
match_shape(s, heap, 3, MS.ASSERT_EQUAL_TO_IMM, 1, MS.STORE_TO_HEAP, 2,
MS.NO_OP, 0, "")
@@ -70,7 +71,7 @@ def test_match_shape():
def test_check_shape_info():
check_shape_info = tvm.get_global_func("vm.builtin.check_shape_info")
- s = tvm.runtime.container.ShapeTuple([1, 2, 3])
+ s = tvm_ffi.Shape([1, 2, 3])
check_shape_info(s, 3, "")
check_shape_info(s, -1, "")
@@ -157,12 +158,12 @@ def test_attention_kv_cache():
fappend = tvm.get_global_func("vm.builtin.attention_kv_cache_append")
fview = tvm.get_global_func("vm.builtin.attention_kv_cache_view")
- cache = fcreate(tvm.runtime.empty((1, 2), dtype="int32"),
tvm.runtime.ShapeTuple([2, 2]), 0)
+ cache = fcreate(tvm.runtime.empty((1, 2), dtype="int32"),
tvm_ffi.Shape([2, 2]), 0)
num_steps = 2
for i in range(num_steps):
cache = fappend(cache, tvm.runtime.tensor(i * np.ones((1,
2)).astype("int32")))
- res = fview(cache, tvm.runtime.ShapeTuple((num_steps, 2))).numpy()
+ res = fview(cache, tvm_ffi.Shape((num_steps, 2))).numpy()
for i in range(num_steps):
assert res[i][0] == i
assert res[i][1] == i
@@ -221,7 +222,7 @@ def test_attention_kv_cache_window_override():
current_pos = 4
cache = fcreate(
tvm.runtime.tensor(np.full((16, 2), -1).astype("int32")),
- tvm.runtime.ShapeTuple([16, 2]),
+ tvm_ffi.Shape([16, 2]),
current_pos,
)
np_all_arrays = np.zeros((0, 2)).astype("int32")
@@ -233,7 +234,7 @@ def test_attention_kv_cache_window_override():
cache = foverride(cache, tvm.runtime.tensor(np_array), 16)
current_pos = (current_pos + i) % 16
- res = fview(cache, tvm.runtime.ShapeTuple((16, 2))).numpy()
+ res = fview(cache, tvm_ffi.Shape((16, 2))).numpy()
# unrotate cache and assert cache matches last 16 elements
assert (
@@ -253,7 +254,7 @@ def test_attention_kv_cache_window_override_with_sinks():
cache = fcreate(
tvm.runtime.tensor(np.full((16, 2), -1).astype("int32")),
- tvm.runtime.ShapeTuple([16, 2]),
+ tvm_ffi.Shape([16, 2]),
current_pos,
)
np_all_arrays = np.zeros((0, 2)).astype("int32")
@@ -270,7 +271,7 @@ def test_attention_kv_cache_window_override_with_sinks():
current_pos += 1
has_sink = current_pos >= num_attention_sinks
- res = fview(cache, tvm.runtime.ShapeTuple((16, 2))).numpy()
+ res = fview(cache, tvm_ffi.Shape((16, 2))).numpy()
# unrotate cache and assert cache matches last 16 elements
assert (
diff --git
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py
index 6658c52581..c4d99afeea 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py
@@ -21,6 +21,8 @@ import itertools
import numpy as np
import pytest
import scipy.special
+import tvm_ffi
+from tvm_ffi import Shape
import tvm
import tvm.testing
@@ -38,7 +40,6 @@ from tvm.relax.frontend.nn.llm.kv_cache import (
tree_attn_cpu,
tree_attn_with_paged_kv_cache_cpu,
)
-from tvm.runtime import ShapeTuple
from tvm.s_tir import dlight as dl
reserved_nseq = 32
@@ -165,7 +166,7 @@ def set_global_func(head_dim, dtype):
def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window):
fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create")
cache = fcreate(
- tvm.runtime.ShapeTuple(
+ tvm_ffi.Shape(
[
reserved_nseq,
maximum_total_seq_length,
@@ -174,12 +175,12 @@ def create_kv_cache(head_dim, dtype, rope_mode,
support_sliding_window):
int(support_sliding_window),
]
),
- tvm.runtime.ShapeTuple([0, num_layers]),
+ tvm_ffi.Shape([0, num_layers]),
num_qo_heads,
num_kv_heads,
head_dim,
head_dim, # v_head_dim
- tvm.runtime.ShapeTuple([int(AttnKind.MHA) for _ in range(num_layers)]),
+ tvm_ffi.Shape([int(AttnKind.MHA) for _ in range(num_layers)]),
False, # enable_kv_transfer
rope_mode,
rope_scale,
@@ -337,10 +338,10 @@ def apply_attention(
fbegin_forward(
kv_cache,
- ShapeTuple(seq_ids),
- ShapeTuple(append_lengths),
+ Shape(seq_ids),
+ Shape(append_lengths),
(
- ShapeTuple(flattened_token_tree_parent_ptr)
+ Shape(flattened_token_tree_parent_ptr)
if flattened_token_tree_parent_ptr is not None
else None
),
@@ -490,9 +491,7 @@ def apply_attention(
if accepted_leaf_indices is not None:
seq_ids = [seq_id for seq_id, _ in batch]
- fcommit_accepted_token_tree_nodes(
- kv_cache, ShapeTuple(seq_ids), ShapeTuple(accepted_leaf_indices)
- )
+ fcommit_accepted_token_tree_nodes(kv_cache, Shape(seq_ids),
Shape(accepted_leaf_indices))
for i, (accepted_leaf_idx, (seq_id, append_length)) in enumerate(
zip(accepted_leaf_indices, batch)
):
diff --git
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
index da9071c1e4..ef541b1e35 100644
---
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
+++
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
@@ -18,6 +18,8 @@
import pytest
import torch
+import tvm_ffi
+from tvm_ffi import Shape
import tvm
import tvm.testing
@@ -32,7 +34,6 @@ from tvm.relax.frontend.nn.llm.kv_cache import (
_merge_state_inplace,
llama_rope_with_position_map,
)
-from tvm.runtime import ShapeTuple
from tvm.s_tir import dlight as dl
reserved_nseq = 32
@@ -153,7 +154,7 @@ def create_kv_cache(rope_mode):
fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create")
support_sliding_window = 0
cache = fcreate(
- tvm.runtime.ShapeTuple(
+ tvm_ffi.Shape(
[
reserved_nseq,
maximum_total_seq_length,
@@ -162,12 +163,12 @@ def create_kv_cache(rope_mode):
support_sliding_window,
]
),
- tvm.runtime.ShapeTuple([0, num_layers]),
+ tvm_ffi.Shape([0, num_layers]),
num_qo_heads,
num_kv_heads,
head_dim,
head_dim, # v_head_dim
- tvm.runtime.ShapeTuple([int(AttnKind.MHA) for _ in range(num_layers)]),
+ tvm_ffi.Shape([int(AttnKind.MHA) for _ in range(num_layers)]),
False, # enable_kv_transfer
rope_mode,
rope_scale,
@@ -275,7 +276,7 @@ def apply_attention(
(num_layers, 0, num_kv_heads, head_dim), dtype=dtype_torch,
device=device_torch
)
- fbegin_forward(kv_cache, ShapeTuple(seq_ids), ShapeTuple(append_lengths))
+ fbegin_forward(kv_cache, Shape(seq_ids), Shape(append_lengths))
global_new_q = torch.zeros(
(num_layers, 0, num_qo_heads, head_dim), dtype=dtype_torch,
device=device_torch
diff --git
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
index 4683f6fdde..ef2aa35ecd 100644
---
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
+++
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
@@ -19,6 +19,8 @@ import itertools
import numpy as np
import pytest
import torch
+import tvm_ffi
+from tvm_ffi import Shape
import tvm
import tvm.testing
@@ -31,7 +33,6 @@ from tvm.relax.frontend.nn.llm.kv_cache import (
_kv_cache_transpose_append_mla,
_merge_state_inplace,
)
-from tvm.runtime import ShapeTuple
from tvm.s_tir import dlight as dl
np.random.seed(0)
@@ -176,7 +177,7 @@ def create_kv_cache(dtype):
fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create")
fdumb = tvm.get_global_func("test.dumb_function")
cache = fcreate(
- tvm.runtime.ShapeTuple(
+ tvm_ffi.Shape(
[
reserved_nseq,
maximum_total_seq_length,
@@ -185,12 +186,12 @@ def create_kv_cache(dtype):
0,
]
),
- tvm.runtime.ShapeTuple([0, num_layers]),
+ tvm_ffi.Shape([0, num_layers]),
num_attention_heads,
1, # num_kv_heads
kv_lora_rank + qk_rope_head_dim,
kv_lora_rank,
- tvm.runtime.ShapeTuple([int(AttnKind.MLA) for _ in range(num_layers)]),
+ tvm_ffi.Shape([int(AttnKind.MLA) for _ in range(num_layers)]),
False, # enable_kv_transfer
RopeMode.NONE,
1,
@@ -273,7 +274,7 @@ def apply_attention(
device=device_torch,
)
- fbegin_forward(kv_cache, ShapeTuple(seq_ids), ShapeTuple(append_lengths),
None)
+ fbegin_forward(kv_cache, Shape(seq_ids), Shape(append_lengths), None)
global_new_q = torch.zeros(
(num_layers, 0, num_attention_heads, qk_nope_head_dim +
qk_rope_head_dim),
diff --git
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py
index 7b13112cac..ea63c2d21e 100644
---
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py
+++
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py
@@ -19,6 +19,8 @@ import itertools
import numpy as np
import pytest
import torch
+import tvm_ffi
+from tvm_ffi import Shape
import tvm
import tvm.testing
@@ -32,7 +34,6 @@ from tvm.relax.frontend.nn.llm.kv_cache import (
_kv_cache_transpose_append_mla,
_merge_state_inplace,
)
-from tvm.runtime import ShapeTuple
from tvm.s_tir import dlight as dl
reserved_nseq = 32
@@ -164,7 +165,7 @@ def create_kv_cache(dtype):
fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create")
fdumb = tvm.get_global_func("test.dumb_function")
cache = fcreate(
- tvm.runtime.ShapeTuple(
+ tvm_ffi.Shape(
[
reserved_nseq,
maximum_total_seq_length,
@@ -173,12 +174,12 @@ def create_kv_cache(dtype):
0,
]
),
- tvm.runtime.ShapeTuple([0, num_layers]),
+ tvm_ffi.Shape([0, num_layers]),
num_attention_heads,
1, # num_kv_heads
kv_lora_rank + qk_rope_head_dim,
kv_lora_rank,
- tvm.runtime.ShapeTuple([int(AttnKind.MLA) for _ in range(num_layers)]),
+ tvm_ffi.Shape([int(AttnKind.MLA) for _ in range(num_layers)]),
False, # enable_kv_transfer
RopeMode.NONE,
1,
@@ -255,7 +256,7 @@ def apply_attention(
device=device_torch,
)
- fbegin_forward(kv_cache, ShapeTuple(seq_ids), ShapeTuple(append_lengths),
None)
+ fbegin_forward(kv_cache, Shape(seq_ids), Shape(append_lengths), None)
global_new_q = torch.zeros(
(num_layers, 0, num_attention_heads, qk_nope_head_dim +
qk_rope_head_dim),
diff --git
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
index 4d04f01fed..aa679c649b 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
@@ -19,6 +19,8 @@ import itertools
import pytest
import torch
+import tvm_ffi
+from tvm_ffi import Shape
import tvm
import tvm.testing
@@ -37,7 +39,6 @@ from tvm.relax.frontend.nn.llm.kv_cache import (
tree_attn,
tree_attn_with_paged_kv_cache,
)
-from tvm.runtime import ShapeTuple
from tvm.s_tir import dlight as dl
reserved_nseq = 32
@@ -167,7 +168,7 @@ def set_global_func(head_dim, dtype):
def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window):
fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create")
cache = fcreate(
- tvm.runtime.ShapeTuple(
+ tvm_ffi.Shape(
[
reserved_nseq,
maximum_total_seq_length,
@@ -176,12 +177,12 @@ def create_kv_cache(head_dim, dtype, rope_mode,
support_sliding_window):
int(support_sliding_window),
]
),
- tvm.runtime.ShapeTuple([0, num_layers]),
+ tvm_ffi.Shape([0, num_layers]),
num_qo_heads,
num_kv_heads,
head_dim,
head_dim, # v_head_dim
- tvm.runtime.ShapeTuple([int(AttnKind.MHA) for _ in range(num_layers)]),
+ tvm_ffi.Shape([int(AttnKind.MHA) for _ in range(num_layers)]),
False, # enable_kv_transfer
rope_mode,
rope_scale,
@@ -340,10 +341,10 @@ def apply_attention(
fbegin_forward(
kv_cache,
- ShapeTuple(seq_ids),
- ShapeTuple(append_lengths),
+ Shape(seq_ids),
+ Shape(append_lengths),
(
- ShapeTuple(flattened_token_tree_parent_ptr)
+ Shape(flattened_token_tree_parent_ptr)
if flattened_token_tree_parent_ptr is not None
else None
),
@@ -531,9 +532,7 @@ def apply_attention(
if accepted_leaf_indices is not None:
seq_ids = [seq_id for seq_id, _ in batch]
- fcommit_accepted_token_tree_nodes(
- kv_cache, ShapeTuple(seq_ids), ShapeTuple(accepted_leaf_indices)
- )
+ fcommit_accepted_token_tree_nodes(kv_cache, Shape(seq_ids),
Shape(accepted_leaf_indices))
for i, (accepted_leaf_idx, (seq_id, append_length)) in enumerate(
zip(accepted_leaf_indices, batch)
):
diff --git a/tests/python/relax/test_runtime_builtin_rnn_state.py
b/tests/python/relax/test_runtime_builtin_rnn_state.py
index 18cb9c15c6..35b560c89c 100644
--- a/tests/python/relax/test_runtime_builtin_rnn_state.py
+++ b/tests/python/relax/test_runtime_builtin_rnn_state.py
@@ -19,11 +19,11 @@ from collections.abc import Sequence
import numpy as np
import pytest
+from tvm_ffi import Shape
import tvm
import tvm.testing
from tvm import tirx
-from tvm.runtime import ShapeTuple
from tvm.s_tir import dlight as dl
from tvm.script import tirx as T
@@ -121,7 +121,7 @@ def test_rnn_state_get(rnn_state): # pylint:
disable=redefined-outer-name
state = rnn_state
f_clear(state)
f_add_sequence(state, 0)
- f_begin_forward(state, ShapeTuple([0]), ShapeTuple([1]))
+ f_begin_forward(state, Shape([0]), Shape([1]))
tvm_nd_0 = tvm.runtime.tensor(np.empty((1, 16, 16), "float16"),
device=device)
tvm_nd_1 = tvm.runtime.tensor(np.empty((1, 32, 32), "float32"),
device=device)
f_get(state, 0, 0, tvm_nd_0)
@@ -137,7 +137,7 @@ def test_rnn_state_set(rnn_state): # pylint:
disable=redefined-outer-name
f_clear(state)
for seq_id in range(3):
f_add_sequence(state, seq_id)
- f_begin_forward(state, ShapeTuple([0, 2]), ShapeTuple([1, 1]))
+ f_begin_forward(state, Shape([0, 2]), Shape([1, 1]))
f_set(state, 0, 0, tvm.runtime.tensor(np.full((2, 16, 16), 2.0,
"float16"), device=device))
f_set(state, 0, 1, tvm.runtime.tensor(np.full((2, 32, 32), 3.0,
"float32"), device=device))
@@ -153,7 +153,7 @@ def test_rnn_state_popn(rnn_state): # pylint:
disable=redefined-outer-name
f_clear(state)
f_add_sequence(state, 0)
- f_begin_forward(state, ShapeTuple([0]), ShapeTuple([1]))
+ f_begin_forward(state, Shape([0]), Shape([1]))
f_set(state, 0, 0, tvm.runtime.tensor(np_two.reshape(1, 16, 16),
device=device))
f_set(state, 0, 1, tvm.runtime.tensor(np_three.reshape(1, 32, 32),
device=device))
f_end_forward(state)
@@ -171,7 +171,7 @@ def test_rnn_state_fork_sequence(rnn_state): # pylint:
disable=redefined-outer-
f_clear(state)
f_add_sequence(state, 0)
- f_begin_forward(state, ShapeTuple([0]), ShapeTuple([1]))
+ f_begin_forward(state, Shape([0]), Shape([1]))
f_set(state, 0, 0, tvm.runtime.tensor(np_two.reshape(1, 16, 16),
device=device))
f_set(state, 0, 1, tvm.runtime.tensor(np_three.reshape(1, 32, 32),
device=device))
f_end_forward(state)
diff --git a/tests/python/relax/test_training_optimizer_numeric.py
b/tests/python/relax/test_training_optimizer_numeric.py
index c9ca6b2097..01060ffa2c 100644
--- a/tests/python/relax/test_training_optimizer_numeric.py
+++ b/tests/python/relax/test_training_optimizer_numeric.py
@@ -19,6 +19,7 @@
from collections.abc import Callable
import numpy as np
+import tvm_ffi
import tvm
import tvm.testing
@@ -42,7 +43,7 @@ def _numpy_to_tvm(data):
def _tvm_to_numpy(data):
- if isinstance(data, list | tuple | tvm.ir.Array):
+ if isinstance(data, list | tuple | tvm_ffi.Array):
return [_tvm_to_numpy(_data) for _data in data]
return data.numpy()
diff --git a/tests/python/relax/test_vm_build.py
b/tests/python/relax/test_vm_build.py
index ad716a1e71..fa92842abe 100644
--- a/tests/python/relax/test_vm_build.py
+++ b/tests/python/relax/test_vm_build.py
@@ -21,6 +21,8 @@ from collections.abc import Callable
import numpy as np
import pytest
+import tvm_ffi
+from tvm_ffi import Shape
import tvm
import tvm.script
@@ -29,7 +31,6 @@ from tvm import relax, rpc, te, tirx, topi
from tvm.contrib import cc, popen_pool, utils
from tvm.relax.testing import nn
from tvm.relax.testing.vm import check_saved_func
-from tvm.runtime import ShapeTuple
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tirx as T
@@ -544,10 +545,10 @@ def test_vm_relax_symbolic_shape_tuple(exec_mode):
func = vm["main"]
- assert func(ShapeTuple([2, 3])) == (4, 9)
+ assert func(Shape([2, 3])) == (4, 9)
with pytest.raises(ValueError):
- func(ShapeTuple([2, 3, 4]))
+ func(Shape([2, 3, 4]))
with pytest.raises(TypeError):
func(R.prim_value(2))
@@ -570,7 +571,7 @@ def test_vm_relax_symbolic_prim_value(exec_mode):
assert func(2) == 4
with pytest.raises(TypeError):
- func(ShapeTuple([2]))
+ func(Shape([2]))
def test_vm_relax_multiple_symbolic_prim_value(exec_mode):
@@ -598,13 +599,13 @@ def test_vm_relax_multiple_symbolic_prim_value(exec_mode):
func = vm["main"]
- assert func(2, ShapeTuple([4, 12]), 6) == (4, 7)
+ assert func(2, Shape([4, 12]), 6) == (4, 7)
with pytest.raises(RuntimeError):
- func(2, ShapeTuple([4, 12]), 1)
+ func(2, Shape([4, 12]), 1)
with pytest.raises(tvm.TVMError):
- func(ShapeTuple([2]))
+ func(Shape([2]))
@pytest.mark.xfail(reason="Current support for R.Prim with known value is
primarily for int64")
@@ -1044,8 +1045,8 @@ def test_multi_systemlib(exec_mode):
vmA = relax.VirtualMachine(tvm.runtime.system_lib("libA_"), tvm.cpu())
vmB = relax.VirtualMachine(tvm.runtime.system_lib("libB_"), tvm.cpu())
- retA = vmA["main"](tvm.runtime.ShapeTuple([1]))
- retB = vmB["main"](tvm.runtime.ShapeTuple([2]))
+ retA = vmA["main"](tvm_ffi.Shape([1]))
+ retB = vmB["main"](tvm_ffi.Shape([2]))
np.testing.assert_equal(retA.numpy(), np.array([0,
0]).astype("float32"))
np.testing.assert_equal(retB.numpy(), np.array([1,
1]).astype("float32"))
diff --git a/tests/python/relax/test_vm_codegen_only.py
b/tests/python/relax/test_vm_codegen_only.py
index 8108471934..66ed247f15 100644
--- a/tests/python/relax/test_vm_codegen_only.py
+++ b/tests/python/relax/test_vm_codegen_only.py
@@ -22,6 +22,7 @@ Restrictions: all shape lowered, explicit allocation.
import numpy as np
import pytest
+import tvm_ffi
import tvm
import tvm.testing
@@ -267,7 +268,7 @@ def test_shape_check_builtin(exec_mode):
vm = relax.VirtualMachine(ex, tvm.cpu())
x = tvm.runtime.tensor(np.zeros((1, 2)).astype("float32"))
res = vm["main"](x)
- assert res == tvm.runtime.container.ShapeTuple([2, 1, 2])
+ assert res == tvm_ffi.Shape([2, 1, 2])
# wrong input type
with pytest.raises(TypeError):
diff --git a/tests/python/relax/test_vm_execbuilder.py
b/tests/python/relax/test_vm_execbuilder.py
index 82bda90cbe..9d8d19d747 100644
--- a/tests/python/relax/test_vm_execbuilder.py
+++ b/tests/python/relax/test_vm_execbuilder.py
@@ -18,6 +18,7 @@
import numpy as np
import pytest
+import tvm_ffi
import tvm
from tvm import TVMError, relax
@@ -104,9 +105,9 @@ def test_emit_cache():
x1 = ib.convert_constant("str0")
# cache constant str
assert x0 == x1
- s0 = ib.convert_constant(tvm.runtime.container.ShapeTuple([1, 2]))
- s1 = ib.convert_constant(tvm.runtime.container.ShapeTuple([1, 2]))
- s2 = ib.convert_constant(tvm.runtime.container.ShapeTuple([1, 3]))
+ s0 = ib.convert_constant(tvm_ffi.Shape([1, 2]))
+ s1 = ib.convert_constant(tvm_ffi.Shape([1, 2]))
+ s2 = ib.convert_constant(tvm_ffi.Shape([1, 3]))
assert s0 == s1
assert s1 != s2
y0 = ib.convert_constant(tvm.runtime.tensor(np.array([1, 2,
3]).astype("int32")))
diff --git a/tests/python/runtime/test_runtime_container.py
b/tests/python/runtime/test_runtime_container.py
index d603b25f72..43809fb00f 100644
--- a/tests/python/runtime/test_runtime_container.py
+++ b/tests/python/runtime/test_runtime_container.py
@@ -20,15 +20,15 @@ import pickle
import random
import numpy as np
+import tvm_ffi
import tvm
import tvm.runtime
import tvm.testing
-from tvm.runtime import container as _container
def test_string():
- s = tvm.runtime.String("xyz")
+ s = tvm_ffi.core.String("xyz")
assert isinstance(s, str)
assert s.startswith("xy")
@@ -47,18 +47,18 @@ def test_string():
def test_shape_tuple():
shape = [random.randint(-10, 10) for _ in range(5)]
- stuple = _container.ShapeTuple(shape)
+ stuple = tvm_ffi.Shape(shape)
len(stuple) == len(shape)
for a, b in zip(stuple, shape):
assert a == b
# ShapleTuple vs. tuple
assert stuple == tuple(shape)
# ShapleTuple vs. ShapeTuple
- assert stuple == _container.ShapeTuple(shape)
+ assert stuple == tvm_ffi.Shape(shape)
# test pickle
z = pickle.loads(pickle.dumps(stuple))
- assert isinstance(z, tvm.runtime.ShapeTuple)
+ assert isinstance(z, tvm_ffi.Shape)
assert stuple == z
diff --git a/tests/python/runtime/test_runtime_rpc.py
b/tests/python/runtime/test_runtime_rpc.py
index ec112e1023..f75a48cbae 100644
--- a/tests/python/runtime/test_runtime_rpc.py
+++ b/tests/python/runtime/test_runtime_rpc.py
@@ -16,16 +16,17 @@
# under the License.
# ruff: noqa: E712, F841
+import gc
import multiprocessing
import os
import stat
import sys
import tempfile
import time
-import gc
import numpy as np
import pytest
+import tvm_ffi
import tvm
import tvm.testing
@@ -119,8 +120,8 @@ def test_rpc_runtime_string():
def check_remote():
func = client.get_function("rpc.test.runtime_str_concat")
- x = tvm.runtime.container.String("abc")
- y = tvm.runtime.container.String("def")
+ x = tvm_ffi.core.String("abc")
+ y = tvm_ffi.core.String("def")
assert str(func(x, y)) == "abcdef"
check_remote()
diff --git a/tests/python/target/test_target_target.py
b/tests/python/target/test_target_target.py
index 43d55a27fc..1b2246adb0 100644
--- a/tests/python/target/test_target_target.py
+++ b/tests/python/target/test_target_target.py
@@ -17,6 +17,7 @@
import json
import pytest
+import tvm_ffi
import tvm
import tvm.testing
@@ -265,7 +266,7 @@ def test_target_host_merge_2():
def test_target_tvm_object():
"""Test creating Target by using TVM Objects"""
- String = tvm.runtime.container.String
+ String = tvm_ffi.core.String
tgt = tvm.target.Target(target={"kind": "cuda", "host": {"kind": "llvm"}})
assert tgt.kind.name == "cuda"
assert tgt.host.kind.name == "llvm"
diff --git a/tests/python/tvmscript/test_tvmscript_printer_doc.py
b/tests/python/tvmscript/test_tvmscript_printer_doc.py
index 8e9c7f0a74..18c3cec267 100644
--- a/tests/python/tvmscript/test_tvmscript_printer_doc.py
+++ b/tests/python/tvmscript/test_tvmscript_printer_doc.py
@@ -551,7 +551,7 @@ def test_doc_source_paths():
source_paths = [AccessPath.root(), AccessPath.root().attr("x")]
doc.source_paths = source_paths
- # This should triggers the __getattr__ and gets a tvm.ir.container.Array
+ # This should triggers the __getattr__ and gets a tvm_ffi.Array
assert not isinstance(doc.source_paths, list)
assert list(doc.source_paths) == source_paths