This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new cf859b927a [Relax][Frontend][TFLite] Support STABLEHLO_CUSTOM_CALL 
(#19649)
cf859b927a is described below

commit cf859b927af9610a5e34e18209e6ce04658427c9
Author: HoYi <[email protected]>
AuthorDate: Mon Jun 1 00:53:40 2026 +0800

    [Relax][Frontend][TFLite] Support STABLEHLO_CUSTOM_CALL (#19649)
    
    ## Summary
    
    This PR adds conservative Relax TFLite frontend support for the TFLite
    builtin
    `STABLEHLO_CUSTOM_CALL` operator.
    
    TFLite marks `STABLEHLO_CUSTOM_CALL` as having no runtime kernel.
    Importing
    general custom calls as executable Relax operators would therefore give
    them
    semantics that TFLite itself does not provide. This PR only supports the
    metadata-only `Sharding` custom call target, which TensorFlow's
    StableHLO
    pipeline treats as an annotation that can be erased.
    
    ## Design
    
    ### Sharding Annotation Lowering
    
    `STABLEHLO_CUSTOM_CALL` now parses `StablehloCustomCallOptions` from
    `BuiltinOptions2` and reads the `call_target_name`.
    
    For `call_target_name == "Sharding"`, the frontend lowers the op to
    identity:
    the output tensor is bound to the input expression. This mirrors
    TensorFlow's
    handling of Sharding custom calls as metadata annotations. The sharding
    spec in
    `backend_config` is intentionally dropped for single-device import.
    
    The supported subset is guarded:
    
    - exactly one input and one output
    - input and output shape/dtype metadata must match
    - `has_side_effect` must be false
    - `called_computations` must be empty
    
    All other custom-call targets raise `OpNotImplemented` with the target
    name in
    the diagnostic.
    
    ## Operator Support
    
    | Operator | TFLite options | Relax lowering | Supported subset |
    |---|---|---|---|
    | `STABLEHLO_CUSTOM_CALL` | `StablehloCustomCallOptions` from
    `BuiltinOptions2` | identity for `Sharding`; otherwise unsupported |
    metadata-only `Sharding` annotations with unchanged tensor metadata |
    
    ## Tests
    
    The tests manually build minimal StableHLO custom-call TFLite
    flatbuffers and
    compare the supported identity path with
    `tvm.ir.assert_structural_equal`.
    Unsupported patterns use `pytest.raises`.
    
    | Test | Coverage |
    |---|---|
    | `test_stablehlo_custom_call_sharding` | `Sharding` annotation lowers
    to identity |
    | `test_stablehlo_custom_call_unsupported_target` | unknown external
    target guard |
    | `test_stablehlo_custom_call_sharding_side_effect_unsupported` |
    side-effecting `Sharding` guard |
    | `test_stablehlo_custom_call_sharding_metadata_mismatch_unsupported` |
    input/output metadata guard |
    
    Local validation:
    
    ```bash
    python -m py_compile \
      python/tvm/relax/frontend/tflite/tflite_frontend.py \
      tests/python/relax/test_frontend_tflite.py
    
    python -m ruff check \
      python/tvm/relax/frontend/tflite/tflite_frontend.py \
      tests/python/relax/test_frontend_tflite.py
    
    python -m pytest \
      tests/python/relax/test_frontend_tflite.py \
      -k stablehlo_custom_call -q
    
    python -m pytest \
      tests/python/relax/test_frontend_tflite.py \
      -k stablehlo -q
    ```
    
    Result:
    
    ```text
    py_compile: passed
    ruff check: All checks passed
    stablehlo_custom_call tests: 4 passed
    stablehlo tests: 81 passed
    ```
    
    ## References
    
    - Issue #19519 item I: remaining StableHLO operators in TFLite
    - TensorFlow Lite schema marks `STABLEHLO_CUSTOM_CALL` as no runtime
    support
    - TensorFlow StableHLO pipeline erases `Sharding` custom calls as
    metadata annotations
---
 .../tvm/relax/frontend/tflite/tflite_frontend.py   |  44 ++++++++
 tests/python/relax/test_frontend_tflite.py         | 114 +++++++++++++++++++++
 2 files changed, 158 insertions(+)

diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 45cd41ce5b..2a4455eb30 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -338,6 +338,7 @@ class OperatorConverter:
             "STABLEHLO_CONVOLUTION": self._convert_stablehlo_convolution,
             "STABLEHLO_CONVERT": self._convert_stablehlo_convert,
             "STABLEHLO_COSINE": 
functools.partial(self._convert_stablehlo_unary, relax_op=_op.cos),
+            "STABLEHLO_CUSTOM_CALL": self._convert_stablehlo_custom_call,
             "STABLEHLO_DIVIDE": functools.partial(
                 self._convert_stablehlo_binary, relax_op=_op.divide
             ),
@@ -1743,6 +1744,13 @@ class OperatorConverter:
         from tflite.BuiltinOptions2 import BuiltinOptions2
 
         op_options = op.BuiltinOptions2()
+        if op_options is None:
+            # A malformed flatbuffer may declare a BuiltinOptions2 type without
+            # carrying the actual options table. Fail cleanly instead of 
raising
+            # an opaque AttributeError when accessing the missing payload.
+            raise tvm.error.OpNotImplemented(
+                f"{options_cls.__name__} is required but missing from the 
operator"
+            )
         # Look up the expected BuiltinOptions2 enum value by matching the class
         # name to an enum member (e.g. StablehloConcatenateOptions → 1).
         options_type = getattr(BuiltinOptions2, options_cls.__name__, None)
@@ -2162,6 +2170,42 @@ class OperatorConverter:
             relax.op.sort(data, axis=int(opts.Dimension()), 
descending=descending)
         )
 
+    def _convert_stablehlo_custom_call(self, op):
+        """Convert supported annotation-only STABLEHLO_CUSTOM_CALL targets."""
+        from tflite.StablehloCustomCallOptions import 
StablehloCustomCallOptions
+
+        input_tensors = self.get_input_tensors(op)
+        output_tensors = self.get_output_tensors(op)
+        opts = self._get_stablehlo_options(op, StablehloCustomCallOptions)
+        call_target_name = self._decode_tflite_string(opts.CallTargetName())
+
+        if call_target_name == "Sharding":
+            # TensorFlow treats Sharding custom calls as metadata annotations
+            # and may erase them by replacing the op with its input. Mirror
+            # that identity semantics for the safe single-input/single-output
+            # subset. The sharding spec in backend_config is intentionally
+            # dropped for single-device import. TFLite has no runtime kernel
+            # for general STABLEHLO_CUSTOM_CALL targets.
+            if opts.HasSideEffect():
+                raise tvm.error.OpNotImplemented(
+                    "STABLEHLO_CUSTOM_CALL Sharding with side effects is not 
supported"
+                )
+            if opts.CalledComputationsLength() != 0:
+                raise tvm.error.OpNotImplemented(
+                    "STABLEHLO_CUSTOM_CALL Sharding with called computations 
is not supported"
+                )
+            if len(input_tensors) != 1 or len(output_tensors) != 1:
+                raise tvm.error.OpNotImplemented(
+                    "STABLEHLO_CUSTOM_CALL Sharding requires one input and one 
output"
+                )
+            self._check_tensor_metadata_match(
+                input_tensors[0], output_tensors[0], "STABLEHLO_CUSTOM_CALL", 
"Sharding"
+            )
+            return self.get_tensor_expr(input_tensors[0])
+
+        target = call_target_name or "<empty>"
+        raise tvm.error.OpNotImplemented(f"STABLEHLO_CUSTOM_CALL target 
{target} is not supported")
+
     def _convert_stablehlo_while(self, op):
         """Convert STABLEHLO_WHILE to a recursive Relax private function."""
         from tflite.StablehloWhileOptions import StablehloWhileOptions
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index cc3a84e2fd..7c3e526d99 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -3683,6 +3683,7 @@ _tfl_stablehlo_concat_opts = 
_get_tflite_schema_module("StablehloConcatenateOpti
 _tfl_stablehlo_bcast_opts = 
_get_tflite_schema_module("StablehloBroadcastInDimOptions")
 _tfl_stablehlo_composite_opts = 
_get_tflite_schema_module("StableHLOCompositeOptions")
 _tfl_stablehlo_conv_opts = 
_get_tflite_schema_module("StablehloConvolutionOptions")
+_tfl_stablehlo_custom_call_opts = 
_get_tflite_schema_module("StablehloCustomCallOptions")
 _tfl_stablehlo_dot_opts = 
_get_tflite_schema_module("StablehloDotGeneralOptions")
 _tfl_stablehlo_iota_opts = _get_tflite_schema_module("StablehloIotaOptions")
 _tfl_stablehlo_compare_opts = 
_get_tflite_schema_module("StablehloCompareOptions")
@@ -6308,6 +6309,68 @@ def 
_build_stablehlo_scatter_model(reducer_name="STABLEHLO_ADD", update_window_d
     )
 
 
+def _build_stablehlo_custom_call_model(
+    call_target_name="Sharding",
+    has_side_effect=False,
+    output_tensor_type=_tfl_tensor_type.FLOAT32,
+    include_options=True,
+):
+    """Build a single-input STABLEHLO_CUSTOM_CALL model.
+
+    When ``include_options`` is False the operator declares the
+    StablehloCustomCallOptions type but omits the options table, emulating a
+    malformed flatbuffer with a missing BuiltinOptions2 payload.
+    """
+    builder = flatbuffers.Builder(1024)
+
+    custom_call_opts = None
+    if include_options:
+        call_target_name_offset = builder.CreateString(call_target_name)
+        backend_config_offset = builder.CreateString("")
+        
_tfl_stablehlo_custom_call_opts.StablehloCustomCallOptionsStart(builder)
+        
_tfl_stablehlo_custom_call_opts.StablehloCustomCallOptionsAddCallTargetName(
+            builder, call_target_name_offset
+        )
+        
_tfl_stablehlo_custom_call_opts.StablehloCustomCallOptionsAddHasSideEffect(
+            builder, has_side_effect
+        )
+        
_tfl_stablehlo_custom_call_opts.StablehloCustomCallOptionsAddBackendConfig(
+            builder, backend_config_offset
+        )
+        custom_call_opts = 
_tfl_stablehlo_custom_call_opts.StablehloCustomCallOptionsEnd(builder)
+
+    custom_call_builtin = 
_get_stablehlo_builtin_operator("STABLEHLO_CUSTOM_CALL")
+    custom_call_code = _build_operator_code(builder, custom_call_builtin)
+
+    main_tensors = [
+        _build_tensor(builder, 0, [2, 2]),
+        _build_tensor(builder, 1, [2, 2], tensor_type=output_tensor_type),
+    ]
+    custom_call_op = _build_operator(
+        builder,
+        0,
+        [0],
+        [1],
+        builtin_options2_type=_tfl_builtin_options2.StablehloCustomCallOptions,
+        builtin_options2=custom_call_opts,
+    )
+    main_subgraph = _build_subgraph(
+        builder,
+        tensors=main_tensors,
+        operators=[custom_call_op],
+        inputs=[0],
+        outputs=[1],
+    )
+
+    buffers = [_build_buffer(builder) for _ in range(2)]
+    return _finish_tflite_model(
+        builder,
+        subgraph=main_subgraph,
+        operator_codes=[custom_call_code],
+        buffers=buffers,
+    )
+
+
 def _build_stablehlo_while_model(
     cond_subgraph_index=1,
     body_subgraph_index=2,
@@ -6812,6 +6875,57 @@ def test_stablehlo_scatter_update_window_unsupported():
         from_tflite(tflite_model)
 
 
+def test_stablehlo_custom_call_sharding():
+    """TFLite StableHLO CUSTOM_CALL Sharding annotation lowers to identity."""
+    mod = _load_model_from_buffer(_build_stablehlo_custom_call_model())
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2), 
dtype="float32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                gv: R.Tensor((2, 2), dtype="float32") = x
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_stablehlo_custom_call_unsupported_target():
+    """TFLite StableHLO CUSTOM_CALL rejects unknown external call targets."""
+    buf = _build_stablehlo_custom_call_model(call_target_name="custom_backend")
+    with pytest.raises(
+        tvm.error.OpNotImplemented,
+        match="STABLEHLO_CUSTOM_CALL target custom_backend is not supported",
+    ):
+        _load_model_from_buffer(buf)
+
+
+def test_stablehlo_custom_call_sharding_side_effect_unsupported():
+    """TFLite StableHLO CUSTOM_CALL rejects side-effecting Sharding calls."""
+    buf = _build_stablehlo_custom_call_model(has_side_effect=True)
+    with pytest.raises(tvm.error.OpNotImplemented, match="side effects"):
+        _load_model_from_buffer(buf)
+
+
+def test_stablehlo_custom_call_sharding_metadata_mismatch_unsupported():
+    """TFLite StableHLO CUSTOM_CALL rejects Sharding calls that change tensor 
metadata."""
+    buf = 
_build_stablehlo_custom_call_model(output_tensor_type=_tfl_tensor_type.INT32)
+    with pytest.raises(tvm.error.OpNotImplemented, match="Sharding tensor 
metadata mismatch"):
+        _load_model_from_buffer(buf)
+
+
+def test_stablehlo_options_missing_payload_unsupported():
+    """A StableHLO op that declares an options type but omits the payload 
fails cleanly."""
+    buf = _build_stablehlo_custom_call_model(include_options=False)
+    with pytest.raises(
+        tvm.error.OpNotImplemented,
+        match="StablehloCustomCallOptions is required but missing from the 
operator",
+    ):
+        _load_model_from_buffer(buf)
+
+
 def test_stablehlo_while():
     """TFLite STABLEHLO_WHILE lowers to a recursive Relax private function."""
     mod = _load_model_from_buffer(_build_stablehlo_while_model())

Reply via email to