This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d6ef18e771 [Relax][Frontend][TFLite] Add DILATE operator mapping 
(#19481)
d6ef18e771 is described below

commit d6ef18e7711bcc083525d7707fe8fe8530a366ee
Author: as4230 <[email protected]>
AuthorDate: Fri May 1 00:06:42 2026 -0400

    [Relax][Frontend][TFLite] Add DILATE operator mapping (#19481)
    
    This PR adds TFLite frontend support for the DILATE operator which
    extends a tensor by inserting a padding value between existing elements
    per axis according to the dilation strides.
    
    Decomposes into existing Relax primitives instead of registering a new
    op:
    applied per axis:
    - relax.op.reshape adds a size-1 stride-axis and merges it back after
    padding
    - relax.op.full builds a padding tensor with (stride - 1) values along
    that axis
    - relax.op.concat interleaves the padding between input elements
    - relax.op.strided_slice trims the trailing pad to output size
    
    Both static and dynamic dilations are supported.
    
    Frontend tests use hand-rolled .tflite fixtures since DILATE has no
    public TF Python emitter through tf.lite.TFLiteConverter, so the
    standard verify(TestClass, Expected) pattern can't reach it. Extends
    DENSIFY's fixture builders to handle BuiltinOptions2 and non-FLOAT32
    tensors. _finish_tflite_model now writes the TFL3 file identifier so the
    produced buffer is a valid input for tf.lite.Interpreter in the nightly
    E2E path.
    
    Validation:
    python -m pytest tests/python/relax/test_frontend_tflite.py -k dilate -v
    
    Addresses the DILATE item under #19412.
---
 .../tvm/relax/frontend/tflite/tflite_frontend.py   |  62 +++++
 tests/python/relax/test_frontend_tflite.py         | 261 ++++++++++++++++++++-
 2 files changed, 317 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 155b6301f9..0b1097b095 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -137,6 +137,7 @@ class OperatorConverter:
             "DEPTHWISE_CONV_2D": functools.partial(self.convert_conv, 
conv_type="depthwise"),
             "DEQUANTIZE": self.convert_dequantize,
             "DETECTION_POSTPROCESS": self.convert_detection_postprocess,
+            "DILATE": self.convert_dilate,
             "DIV": functools.partial(self._convert_elemwise, 
relax_op=_op.divide),
             "ELU": self.convert_elu,
             "EQUAL": functools.partial(
@@ -3417,6 +3418,67 @@ class OperatorConverter:
 
         return out
 
+    def convert_dilate(self, op):
+        """Convert TFLite DILATE"""
+        input_tensors = self.get_input_tensors(op)
+        output_tensors = self.get_output_tensors(op)
+        assert len(input_tensors) == 3, "input tensors length should be 3"
+        assert len(output_tensors) == 1, "output tensors length should be 1"
+
+        in_expr = self.get_tensor_expr(input_tensors[0])
+        in_shape = to_int_list(self.get_tensor_shape(input_tensors[0]))
+        in_dtype = self.get_tensor_type_str(input_tensors[0].tensor.Type())
+        n_dims = len(in_shape)
+
+        dilations_tensor = input_tensors[1]
+        padding_expr = self.get_tensor_expr(input_tensors[2])
+
+        # Runtime dilations bind tensor values to TIR Vars for symbolic 
+        # per-axis math.
+        if self.has_expr(dilations_tensor.tensor_idx):
+            dilations_expr = self.get_expr(dilations_tensor.tensor_idx)
+            dilations_expr = self.bb.match_cast(
+                dilations_expr, relax.TensorStructInfo([n_dims], "int32")
+            )
+            dilations_int64 = 
self.bb.normalize(relax.op.astype(dilations_expr, "int64"))
+            shape_var = self.bb.emit(relax.op.tensor_to_shape(dilations_int64))
+            stride_vars = [tirx.Var(f"dilate_stride_{i}", "int64") for i in 
range(n_dims)]
+            self.bb.match_cast(shape_var, relax.ShapeStructInfo(stride_vars))
+            strides = stride_vars
+        else:
+            strides = to_int_list(self.get_tensor_value(dilations_tensor))
+
+        # Per axis: reshape to add a size-1 stride-axis, concat (s-1) padding
+        # values along it, reshape to merge axes (length d*s), trim trailing
+        # pad to TFLite's output dim formula (d-1)*s + 1.
+        result = in_expr
+        current_shape = list(in_shape)
+        axes = list(range(n_dims))
+        ones = [1] * n_dims
+        for axis in range(n_dims):
+            d = current_shape[axis]
+            s = strides[axis]
+            expanded_shape = current_shape[: axis + 1] + [1] + 
current_shape[axis + 1 :]
+            expanded = relax.op.reshape(result, expanded_shape)
+            pad_shape = list(expanded_shape)
+            pad_shape[axis + 1] = s - 1
+            pad = relax.op.full(pad_shape, padding_expr, dtype=in_dtype)
+            concatted = relax.op.concat([expanded, pad], axis=axis + 1)
+            merged_shape = list(current_shape)
+            merged_shape[axis] = d * s
+            merged = relax.op.reshape(concatted, merged_shape)
+            # (d - 1) * s + 1 is the output dim along this axis.
+            final_dim = (d - 1) * s + 1
+            end = list(merged_shape)
+            end[axis] = final_dim
+            result = relax.op.strided_slice(
+                merged, axes=axes, begin=[0] * n_dims, end=end, strides=ones
+            )
+            current_shape = list(merged_shape)
+            current_shape[axis] = final_dim
+
+        return result
+
     def convert_detection_postprocess(self, op):
         """Convert TFLite_Detection_PostProcess"""
         flexbuffer = op.CustomOptionsAsNumpy().tobytes()
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index 37211d337a..64e4a6e953 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -2995,6 +2995,7 @@ def _get_tflite_schema_enum(enum_name):
 _tfl_add_options = _get_tflite_schema_module("AddOptions")
 _tfl_buffer = _get_tflite_schema_module("Buffer")
 _tfl_conv2d_options = _get_tflite_schema_module("Conv2DOptions")
+_tfl_dilate_options = _get_tflite_schema_module("DilateOptions")
 _tfl_dimension_metadata = _get_tflite_schema_module("DimensionMetadata")
 _tfl_fully_connected_options = 
_get_tflite_schema_module("FullyConnectedOptions")
 _tfl_int32_vector = _get_tflite_schema_module("Int32Vector")
@@ -3007,6 +3008,7 @@ _tfl_tensor = _get_tflite_schema_module("Tensor")
 
 _tfl_builtin_operator = _get_tflite_schema_enum("BuiltinOperator")
 _tfl_builtin_options = _get_tflite_schema_enum("BuiltinOptions")
+_tfl_builtin_options2 = _get_tflite_schema_enum("BuiltinOptions2")
 _tfl_dimension_type = _get_tflite_schema_enum("DimensionType")
 _tfl_fc_weights_format = 
_get_tflite_schema_enum("FullyConnectedOptionsWeightsFormat")
 _tfl_padding = _get_tflite_schema_enum("Padding")
@@ -3062,8 +3064,10 @@ def _tflite_shape(builder, shape):
     return _tflite_int32_vector(builder, _tfl_tensor.TensorStartShapeVector, 
shape)
 
 
-def _build_tensor(builder, buffer_idx, shape, sparsity=None):
+def _build_tensor(builder, buffer_idx, shape, sparsity=None, tensor_type=None):
     """Helper to build a TFLite tensor."""
+    if tensor_type is None:
+        tensor_type = _tfl_tensor_type.FLOAT32
     shape_vec = _tflite_shape(builder, shape)
     _tfl_tensor.TensorStart(builder)
     _tfl_tensor.TensorAddBuffer(builder, buffer_idx)
@@ -3072,7 +3076,7 @@ def _build_tensor(builder, buffer_idx, shape, 
sparsity=None):
     _tfl_tensor.TensorAddShape(builder, shape_vec)
     if sparsity is not None:
         _tfl_tensor.TensorAddSparsity(builder, sparsity)
-    _tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32)
+    _tfl_tensor.TensorAddType(builder, tensor_type)
     return _tfl_tensor.TensorEnd(builder)
 
 
@@ -3089,7 +3093,14 @@ def _build_buffer(builder, data=None):
 
 
 def _build_operator(
-    builder, opcode_index, inputs, outputs, builtin_options_type, 
builtin_options=None
+    builder,
+    opcode_index,
+    inputs,
+    outputs,
+    builtin_options_type=None,
+    builtin_options=None,
+    builtin_options2_type=None,
+    builtin_options2=None,
 ):
     inputs_vec = _tflite_int32_vector(builder, 
_tfl_operator.OperatorStartInputsVector, inputs)
     outputs_vec = _tflite_int32_vector(
@@ -3099,15 +3110,23 @@ def _build_operator(
     _tfl_operator.OperatorAddOpcodeIndex(builder, opcode_index)
     _tfl_operator.OperatorAddInputs(builder, inputs_vec)
     _tfl_operator.OperatorAddOutputs(builder, outputs_vec)
-    _tfl_operator.OperatorAddBuiltinOptionsType(builder, builtin_options_type)
+    if builtin_options_type is not None:
+        _tfl_operator.OperatorAddBuiltinOptionsType(builder, 
builtin_options_type)
     if builtin_options is not None:
         _tfl_operator.OperatorAddBuiltinOptions(builder, builtin_options)
+    if builtin_options2_type is not None:
+        _tfl_operator.OperatorAddBuiltinOptions2Type(builder, 
builtin_options2_type)
+    if builtin_options2 is not None:
+        _tfl_operator.OperatorAddBuiltinOptions2(builder, builtin_options2)
     return _tfl_operator.OperatorEnd(builder)
 
 
 def _build_operator_code(builder, builtin_op):
+    # deprecated_builtin_code is int8 (max 127). Ops past that write 127 as a
+    # placeholder and use the full builtin_code field.
+    deprecated_code = builtin_op if builtin_op < 127 else 127
     _tfl_operator_code.OperatorCodeStart(builder)
-    _tfl_operator_code.OperatorCodeAddDeprecatedBuiltinCode(builder, 
builtin_op)
+    _tfl_operator_code.OperatorCodeAddDeprecatedBuiltinCode(builder, 
deprecated_code)
     _tfl_operator_code.OperatorCodeAddBuiltinCode(builder, builtin_op)
     _tfl_operator_code.OperatorCodeAddVersion(builder, 1)
     return _tfl_operator_code.OperatorCodeEnd(builder)
@@ -3145,7 +3164,7 @@ def _finish_tflite_model(builder, *, subgraph, 
operator_codes, buffers):
     _tfl_model.ModelAddVersion(builder, 3)
     model = _tfl_model.ModelEnd(builder)
 
-    builder.Finish(model)
+    builder.Finish(model, b"TFL3")
     return bytes(builder.Output())
 
 
@@ -3517,5 +3536,235 @@ def test_densify_with_fully_connected():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def _build_dilate_only_case(
+    builder, *, input_shape, dilations, dilation_value, dynamic_dilations=False
+):
+    input_tensor_idx = 0
+    dilations_tensor_idx = 1
+    padding_value_tensor_idx = 2
+    output_tensor_idx = 3
+
+    output_shape = tuple((input_shape[i] - 1) * dilations[i] + 1 for i in 
range(len(input_shape)))
+
+    input_tensor = _build_tensor(builder, 1, input_shape)
+    dilations_tensor = _build_tensor(
+        builder, 2, [len(dilations)], tensor_type=_tfl_tensor_type.INT32
+    )
+    padding_value_tensor = _build_tensor(builder, 3, [])
+    output_tensor = _build_tensor(builder, 4, output_shape)
+
+    _tfl_dilate_options.DilateOptionsStart(builder)
+    dilate_opts = _tfl_dilate_options.DilateOptionsEnd(builder)
+
+    dilate_op = _build_operator(
+        builder,
+        0,
+        [input_tensor_idx, dilations_tensor_idx, padding_value_tensor_idx],
+        [output_tensor_idx],
+        builtin_options2_type=_tfl_builtin_options2.DilateOptions,
+        builtin_options2=dilate_opts,
+    )
+    sg_inputs = (
+        [input_tensor_idx, dilations_tensor_idx] if dynamic_dilations else 
[input_tensor_idx]
+    )
+    subgraph = _build_subgraph(
+        builder,
+        tensors=[input_tensor, dilations_tensor, padding_value_tensor, 
output_tensor],
+        operators=[dilate_op],
+        inputs=sg_inputs,
+        outputs=[output_tensor_idx],
+    )
+    operator_codes = [_build_operator_code(builder, 
_tfl_builtin_operator.DILATE)]
+    return subgraph, operator_codes
+
+
+def test_dilate():
+    """TFLite DILATE with constant dilations"""
+    builder = flatbuffers.Builder(1024)
+    input_shape = (3, 4)
+    dilations = [2, 2]
+    dilation_value = 0.5
+
+    subgraph, operator_codes = _build_dilate_only_case(
+        builder,
+        input_shape=input_shape,
+        dilations=dilations,
+        dilation_value=dilation_value,
+    )
+
+    buffers = [
+        _build_buffer(builder),
+        _build_buffer(builder),
+        _build_buffer(builder, np.asarray(dilations, 
dtype=np.int32).tobytes()),
+        _build_buffer(
+            builder, np.asarray([dilation_value], dtype=np.float32).tobytes()
+        ),
+        _build_buffer(builder),
+    ]
+
+    buf = _finish_tflite_model(
+        builder, subgraph=subgraph, operator_codes=operator_codes, 
buffers=buffers
+    )
+    if hasattr(tflite.Model, "Model"):
+        tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+    else:
+        tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+    mod = from_tflite(tflite_model)
+    mod["main"] = mod["main"].without_attr("params")
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            tvmgen_tensor_0: R.Tensor((3, 4), dtype="float32"),
+        ) -> R.Tensor((5, 7), dtype="float32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                lv: R.Tensor((3, 1, 4), dtype="float32") = R.reshape(
+                    tvmgen_tensor_0, R.shape([3, 1, 4])
+                )
+                lv1: R.Tensor((3, 1, 4), dtype="float32") = R.full(
+                    R.shape([3, 1, 4]), R.const(0.5, "float32"), 
dtype="float32"
+                )
+                lv2: R.Tensor((3, 2, 4), dtype="float32") = R.concat((lv, 
lv1), axis=1)
+                lv3: R.Tensor((6, 4), dtype="float32") = R.reshape(lv2, 
R.shape([6, 4]))
+                lv4: R.Tensor((5, 4), dtype="float32") = R.strided_slice(
+                    lv3, [0, 1], [0, 0], [5, 4], [1, 1], assume_inbound=False
+                )
+                lv5: R.Tensor((5, 4, 1), dtype="float32") = R.reshape(
+                    lv4, R.shape([5, 4, 1])
+                )
+                lv6: R.Tensor((5, 4, 1), dtype="float32") = R.full(
+                    R.shape([5, 4, 1]), R.const(0.5, "float32"), 
dtype="float32"
+                )
+                lv7: R.Tensor((5, 4, 2), dtype="float32") = R.concat((lv5, 
lv6), axis=2)
+                lv8: R.Tensor((5, 8), dtype="float32") = R.reshape(lv7, 
R.shape([5, 8]))
+                gv: R.Tensor((5, 7), dtype="float32") = R.strided_slice(
+                    lv8, [0, 1], [0, 0], [5, 7], [1, 1], assume_inbound=False
+                )
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_dilate_dynamic_dilations():
+    """DILATE with runtime dilations"""
+    builder = flatbuffers.Builder(1024)
+    input_shape = (3, 4)
+    dilations_for_shape = [2, 2]
+    dilation_value = 0.5
+
+    subgraph, operator_codes = _build_dilate_only_case(
+        builder,
+        input_shape=input_shape,
+        dilations=dilations_for_shape,
+        dilation_value=dilation_value,
+        dynamic_dilations=True,
+    )
+
+    buffers = [
+        _build_buffer(builder),
+        _build_buffer(builder),
+        _build_buffer(builder),  # dilations is a runtime input so empty buffer
+        _build_buffer(
+            builder, np.asarray([dilation_value], dtype=np.float32).tobytes()
+        ),
+        _build_buffer(builder),
+    ]
+
+    buf = _finish_tflite_model(
+        builder, subgraph=subgraph, operator_codes=operator_codes, 
buffers=buffers
+    )
+    if hasattr(tflite.Model, "Model"):
+        tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+    else:
+        tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+    mod = from_tflite(tflite_model)
+    mod["main"] = mod["main"].without_attr("params")
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            tvmgen_tensor_0: R.Tensor((3, 4), dtype="float32"),
+            tvmgen_tensor_1: R.Tensor((2,), dtype="int32"),
+        ) -> R.Tensor(dtype="float32", ndim=2):
+            R.func_attr({"num_input": 2})
+            dilate_stride_0 = T.int64()
+            dilate_stride_1 = T.int64()
+            with R.dataflow():
+                lv: R.Tensor((2,), dtype="int32") = R.match_cast(
+                    tvmgen_tensor_1, R.Tensor((2,), dtype="int32")
+                )
+                lv1: R.Tensor((2,), dtype="int64") = R.astype(lv, 
dtype="int64")
+                lv2: R.Shape(ndim=2) = R.tensor_to_shape(lv1)
+                _lv3: R.Shape([dilate_stride_0, dilate_stride_1]) = 
R.match_cast(
+                    lv2, R.Shape([dilate_stride_0, dilate_stride_1])
+                )
+                lv4: R.Tensor((3, 1, 4), dtype="float32") = R.reshape(
+                    tvmgen_tensor_0, R.shape([3, 1, 4])
+                )
+                lv5: R.Tensor((3, dilate_stride_0 - 1, 4), dtype="float32") = 
R.full(
+                    R.shape([3, dilate_stride_0 - 1, 4]),
+                    R.const(0.5, "float32"),
+                    dtype="float32",
+                )
+                lv6: R.Tensor(
+                    (3, 1 + (dilate_stride_0 - 1), 4), dtype="float32"
+                ) = R.concat((lv4, lv5), axis=1)
+                lv7: R.Tensor((3 * dilate_stride_0, 4), dtype="float32") = 
R.reshape(
+                    lv6, R.shape([3 * dilate_stride_0, 4])
+                )
+                lv8: R.Tensor(
+                    (T.min(dilate_stride_0 * 2 + 1, dilate_stride_0 * 3), 4),
+                    dtype="float32",
+                ) = R.strided_slice(
+                    lv7,
+                    [0, 1],
+                    [0, 0],
+                    [2 * dilate_stride_0 + 1, 4],
+                    [1, 1],
+                    assume_inbound=False,
+                )
+                lv9: R.Tensor(
+                    (2 * dilate_stride_0 + 1, 4, 1), dtype="float32"
+                ) = R.reshape(lv8, R.shape([2 * dilate_stride_0 + 1, 4, 1]))
+                lv10: R.Tensor(
+                    (2 * dilate_stride_0 + 1, 4, dilate_stride_1 - 1), 
dtype="float32"
+                ) = R.full(
+                    R.shape([2 * dilate_stride_0 + 1, 4, dilate_stride_1 - 1]),
+                    R.const(0.5, "float32"),
+                    dtype="float32",
+                )
+                lv11: R.Tensor(
+                    (2 * dilate_stride_0 + 1, 4, 1 + (dilate_stride_1 - 1)),
+                    dtype="float32",
+                ) = R.concat((lv9, lv10), axis=2)
+                lv12: R.Tensor(
+                    (2 * dilate_stride_0 + 1, 4 * dilate_stride_1), 
dtype="float32"
+                ) = R.reshape(
+                    lv11, R.shape([2 * dilate_stride_0 + 1, 4 * 
dilate_stride_1])
+                )
+                gv: R.Tensor(
+                    (
+                        dilate_stride_0 * 2 + 1,
+                        T.min(dilate_stride_1 * 3 + 1, dilate_stride_1 * 4),
+                    ),
+                    dtype="float32",
+                ) = R.strided_slice(
+                    lv12,
+                    [0, 1],
+                    [0, 0],
+                    [2 * dilate_stride_0 + 1, 3 * dilate_stride_1 + 1],
+                    [1, 1],
+                    assume_inbound=False,
+                )
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
     pytest.main(["-s", __file__])

Reply via email to