This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0e5c885699 [Relax][Frontend][TFLite] Add BITCAST operator mapping
(#19466)
0e5c885699 is described below
commit 0e5c8856993115cffccb3de11fc027a231aed090
Author: as4230 <[email protected]>
AuthorDate: Wed Apr 29 01:16:40 2026 -0400
[Relax][Frontend][TFLite] Add BITCAST operator mapping (#19466)
This PR adds TFLite frontend support for the BITCAST operator which
reinterprets a tensor's bytes as a different dtype without converting
the underlying data.
The handler lowers BITCAST to relax.op.memory.view which aliases the
input buffer with the new shape and dtype.
Frontend tests cover same-width (float32 -> int32, uint8 -> int8),
width-changing smaller (int32[3] -> int16[3, 2]), and width-changing
larger (int16[5, 2] -> int32[5]).
` python -m pytest tests/python/relax/test_frontend_tflite.py -k bitcast
-v `
Addresses the BITCAST item under #19412.
---------
Co-authored-by: gemini-code-assist[bot]
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
---
.../tvm/relax/frontend/tflite/tflite_frontend.py | 23 ++++++
tests/python/relax/test_frontend_tflite.py | 88 ++++++++++++++++++++++
2 files changed, 111 insertions(+)
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 5536c369db..fe85e4da9e 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -124,6 +124,7 @@ class OperatorConverter:
"AVERAGE_POOL_2D": functools.partial(self.convert_pool2d,
pool_type="average"),
"BATCH_TO_SPACE_ND": self.convert_batch_to_space_nd,
"BATCH_MATMUL": self.convert_batch_matmul,
+ "BITCAST": self.convert_bitcast,
"CAST": self.convert_cast,
"CEIL": functools.partial(self._convert_unary_elemwise,
relax_op=_op.ceil),
"CONCATENATION": self.convert_concatenation,
@@ -2486,6 +2487,28 @@ class OperatorConverter:
return relax.op.reverse_sequence(in_expr, length_expr, seq_axis,
batch_axis)
+ def convert_bitcast(self, op):
+ """Convert TFLite BITCAST"""
+ input_tensors = self.get_input_tensors(op)
+ output_tensors = self.get_output_tensors(op)
+ assert len(input_tensors) == 1, "input tensors length should be 1"
+ assert len(output_tensors) == 1, "output tensors length should be 1"
+
+ in_expr = self.get_tensor_expr(input_tensors[0])
+ input_dtype = self.get_tensor_type_str(input_tensors[0].tensor.Type())
+ output_dtype =
self.get_tensor_type_str(output_tensors[0].tensor.Type())
+ input_shape = to_int_list(self.get_tensor_shape(input_tensors[0]))
+ output_shape = to_int_list(self.get_tensor_shape(output_tensors[0]))
+
+ input_nbytes = int(np.prod(input_shape)) *
np.dtype(input_dtype).itemsize
+ output_nbytes = int(np.prod(output_shape)) *
np.dtype(output_dtype).itemsize
+ assert input_nbytes == output_nbytes, (
+ "TFLite BITCAST requires input.nbytes == output.nbytes, "
+ f"but got input={input_nbytes} bytes, output={output_nbytes} bytes"
+ )
+
+ return relax.op.memory.view(in_expr, shape=output_shape,
dtype=output_dtype)
+
def convert_cast(self, op):
"""Convert TFLite CAST"""
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index 23002c8668..7a166f2fd9 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -276,6 +276,94 @@ def test_cast():
verify(Cast, Expected)
+def test_bitcast_float32_to_int32():
+ """BITCAST same-width: float32 -> int32, shape preserved."""
+
+ class BitcastF32ToI32(tf.Module):
+ @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30),
dtype=tf.float32)])
+ def func(self, x):
+ return tf.bitcast(x, tf.int32)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 30),
dtype="int32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ gv: R.Tensor((1, 30), dtype="int32") = R.memory.view(
+ x, R.shape([1, 30]), R.dtype("int32")
+ )
+ R.output(gv)
+ return gv
+
+ verify(BitcastF32ToI32, Expected)
+
+
+def test_bitcast_uint8_to_int8():
+ """BITCAST same-width 8-bit: uint8 -> int8."""
+
+ class BitcastU8ToI8(tf.Module):
+ @tf.function(input_signature=[tf.TensorSpec(shape=(4,),
dtype=tf.uint8)])
+ def func(self, x):
+ return tf.bitcast(x, tf.int8)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((4,), dtype="uint8")) -> R.Tensor((4,),
dtype="int8"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ gv: R.Tensor((4,), dtype="int8") = R.memory.view(x,
R.shape([4]), R.dtype("int8"))
+ R.output(gv)
+ return gv
+
+ verify(BitcastU8ToI8, Expected)
+
+
+def test_bitcast_int32_to_int16_widens_shape():
+ """BITCAST width-changing (smaller): int32[3] -> int16[3, 2]."""
+
+ class BitcastI32ToI16(tf.Module):
+ @tf.function(input_signature=[tf.TensorSpec(shape=(3,),
dtype=tf.int32)])
+ def func(self, x):
+ return tf.bitcast(x, tf.int16)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((3,), dtype="int32")) -> R.Tensor((3, 2),
dtype="int16"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ gv: R.Tensor((3, 2), dtype="int16") = R.memory.view(
+ x, R.shape([3, 2]), R.dtype("int16")
+ )
+ R.output(gv)
+ return gv
+
+ verify(BitcastI32ToI16, Expected)
+
+
+def test_bitcast_int16_to_int32_collapses_shape():
+ """BITCAST width-changing (larger): int16[5, 2] -> int32[5]."""
+
+ class BitcastI16ToI32(tf.Module):
+ @tf.function(input_signature=[tf.TensorSpec(shape=(5, 2),
dtype=tf.int16)])
+ def func(self, x):
+ return tf.bitcast(x, tf.int32)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((5, 2), dtype="int16")) -> R.Tensor((5,),
dtype="int32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ gv: R.Tensor((5,), dtype="int32") = R.memory.view(x,
R.shape([5]), R.dtype("int32"))
+ R.output(gv)
+ return gv
+
+ verify(BitcastI16ToI32, Expected)
+
+
def test_expand_dims():
class ExpandDims(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=(1, 30),
dtype=tf.float32)])