This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new cf859b927a [Relax][Frontend][TFLite] Support STABLEHLO_CUSTOM_CALL
(#19649)
cf859b927a is described below
commit cf859b927af9610a5e34e18209e6ce04658427c9
Author: HoYi <[email protected]>
AuthorDate: Mon Jun 1 00:53:40 2026 +0800
[Relax][Frontend][TFLite] Support STABLEHLO_CUSTOM_CALL (#19649)
## Summary
This PR adds conservative Relax TFLite frontend support for the TFLite
builtin
`STABLEHLO_CUSTOM_CALL` operator.
TFLite marks `STABLEHLO_CUSTOM_CALL` as having no runtime kernel.
Importing
general custom calls as executable Relax operators would therefore give
them
semantics that TFLite itself does not provide. This PR only supports the
metadata-only `Sharding` custom call target, which TensorFlow's
StableHLO
pipeline treats as an annotation that can be erased.
## Design
### Sharding Annotation Lowering
`STABLEHLO_CUSTOM_CALL` now parses `StablehloCustomCallOptions` from
`BuiltinOptions2` and reads the `call_target_name`.
For `call_target_name == "Sharding"`, the frontend lowers the op to
identity:
the output tensor is bound to the input expression. This mirrors
TensorFlow's
handling of Sharding custom calls as metadata annotations. The sharding
spec in
`backend_config` is intentionally dropped for single-device import.
The supported subset is guarded:
- exactly one input and one output
- input and output shape/dtype metadata must match
- `has_side_effect` must be false
- `called_computations` must be empty
All other custom-call targets raise `OpNotImplemented` with the target
name in
the diagnostic.
## Operator Support
| Operator | TFLite options | Relax lowering | Supported subset |
|---|---|---|---|
| `STABLEHLO_CUSTOM_CALL` | `StablehloCustomCallOptions` from
`BuiltinOptions2` | identity for `Sharding`; otherwise unsupported |
metadata-only `Sharding` annotations with unchanged tensor metadata |
## Tests
The tests manually build minimal StableHLO custom-call TFLite
flatbuffers and
compare the supported identity path with
`tvm.ir.assert_structural_equal`.
Unsupported patterns use `pytest.raises`.
| Test | Coverage |
|---|---|
| `test_stablehlo_custom_call_sharding` | `Sharding` annotation lowers
to identity |
| `test_stablehlo_custom_call_unsupported_target` | unknown external
target guard |
| `test_stablehlo_custom_call_sharding_side_effect_unsupported` |
side-effecting `Sharding` guard |
| `test_stablehlo_custom_call_sharding_metadata_mismatch_unsupported` |
input/output metadata guard |
Local validation:
```bash
python -m py_compile \
python/tvm/relax/frontend/tflite/tflite_frontend.py \
tests/python/relax/test_frontend_tflite.py
python -m ruff check \
python/tvm/relax/frontend/tflite/tflite_frontend.py \
tests/python/relax/test_frontend_tflite.py
python -m pytest \
tests/python/relax/test_frontend_tflite.py \
-k stablehlo_custom_call -q
python -m pytest \
tests/python/relax/test_frontend_tflite.py \
-k stablehlo -q
```
Result:
```text
py_compile: passed
ruff check: All checks passed
stablehlo_custom_call tests: 4 passed
stablehlo tests: 81 passed
```
## References
- Issue #19519 item I: remaining StableHLO operators in TFLite
- TensorFlow Lite schema marks `STABLEHLO_CUSTOM_CALL` as no runtime
support
- TensorFlow StableHLO pipeline erases `Sharding` custom calls as
metadata annotations
---
.../tvm/relax/frontend/tflite/tflite_frontend.py | 44 ++++++++
tests/python/relax/test_frontend_tflite.py | 114 +++++++++++++++++++++
2 files changed, 158 insertions(+)
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 45cd41ce5b..2a4455eb30 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -338,6 +338,7 @@ class OperatorConverter:
"STABLEHLO_CONVOLUTION": self._convert_stablehlo_convolution,
"STABLEHLO_CONVERT": self._convert_stablehlo_convert,
"STABLEHLO_COSINE":
functools.partial(self._convert_stablehlo_unary, relax_op=_op.cos),
+ "STABLEHLO_CUSTOM_CALL": self._convert_stablehlo_custom_call,
"STABLEHLO_DIVIDE": functools.partial(
self._convert_stablehlo_binary, relax_op=_op.divide
),
@@ -1743,6 +1744,13 @@ class OperatorConverter:
from tflite.BuiltinOptions2 import BuiltinOptions2
op_options = op.BuiltinOptions2()
+ if op_options is None:
+ # A malformed flatbuffer may declare a BuiltinOptions2 type without
+ # carrying the actual options table. Fail cleanly instead of
raising
+ # an opaque AttributeError when accessing the missing payload.
+ raise tvm.error.OpNotImplemented(
+ f"{options_cls.__name__} is required but missing from the
operator"
+ )
# Look up the expected BuiltinOptions2 enum value by matching the class
# name to an enum member (e.g. StablehloConcatenateOptions → 1).
options_type = getattr(BuiltinOptions2, options_cls.__name__, None)
@@ -2162,6 +2170,42 @@ class OperatorConverter:
relax.op.sort(data, axis=int(opts.Dimension()),
descending=descending)
)
+ def _convert_stablehlo_custom_call(self, op):
+ """Convert supported annotation-only STABLEHLO_CUSTOM_CALL targets."""
+ from tflite.StablehloCustomCallOptions import
StablehloCustomCallOptions
+
+ input_tensors = self.get_input_tensors(op)
+ output_tensors = self.get_output_tensors(op)
+ opts = self._get_stablehlo_options(op, StablehloCustomCallOptions)
+ call_target_name = self._decode_tflite_string(opts.CallTargetName())
+
+ if call_target_name == "Sharding":
+ # TensorFlow treats Sharding custom calls as metadata annotations
+ # and may erase them by replacing the op with its input. Mirror
+ # that identity semantics for the safe single-input/single-output
+ # subset. The sharding spec in backend_config is intentionally
+ # dropped for single-device import. TFLite has no runtime kernel
+ # for general STABLEHLO_CUSTOM_CALL targets.
+ if opts.HasSideEffect():
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_CUSTOM_CALL Sharding with side effects is not
supported"
+ )
+ if opts.CalledComputationsLength() != 0:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_CUSTOM_CALL Sharding with called computations
is not supported"
+ )
+ if len(input_tensors) != 1 or len(output_tensors) != 1:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_CUSTOM_CALL Sharding requires one input and one
output"
+ )
+ self._check_tensor_metadata_match(
+ input_tensors[0], output_tensors[0], "STABLEHLO_CUSTOM_CALL",
"Sharding"
+ )
+ return self.get_tensor_expr(input_tensors[0])
+
+ target = call_target_name or "<empty>"
+ raise tvm.error.OpNotImplemented(f"STABLEHLO_CUSTOM_CALL target
{target} is not supported")
+
def _convert_stablehlo_while(self, op):
"""Convert STABLEHLO_WHILE to a recursive Relax private function."""
from tflite.StablehloWhileOptions import StablehloWhileOptions
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index cc3a84e2fd..7c3e526d99 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -3683,6 +3683,7 @@ _tfl_stablehlo_concat_opts =
_get_tflite_schema_module("StablehloConcatenateOpti
_tfl_stablehlo_bcast_opts =
_get_tflite_schema_module("StablehloBroadcastInDimOptions")
_tfl_stablehlo_composite_opts =
_get_tflite_schema_module("StableHLOCompositeOptions")
_tfl_stablehlo_conv_opts =
_get_tflite_schema_module("StablehloConvolutionOptions")
+_tfl_stablehlo_custom_call_opts =
_get_tflite_schema_module("StablehloCustomCallOptions")
_tfl_stablehlo_dot_opts =
_get_tflite_schema_module("StablehloDotGeneralOptions")
_tfl_stablehlo_iota_opts = _get_tflite_schema_module("StablehloIotaOptions")
_tfl_stablehlo_compare_opts =
_get_tflite_schema_module("StablehloCompareOptions")
@@ -6308,6 +6309,68 @@ def
_build_stablehlo_scatter_model(reducer_name="STABLEHLO_ADD", update_window_d
)
+def _build_stablehlo_custom_call_model(
+ call_target_name="Sharding",
+ has_side_effect=False,
+ output_tensor_type=_tfl_tensor_type.FLOAT32,
+ include_options=True,
+):
+ """Build a single-input STABLEHLO_CUSTOM_CALL model.
+
+ When ``include_options`` is False the operator declares the
+ StablehloCustomCallOptions type but omits the options table, emulating a
+ malformed flatbuffer with a missing BuiltinOptions2 payload.
+ """
+ builder = flatbuffers.Builder(1024)
+
+ custom_call_opts = None
+ if include_options:
+ call_target_name_offset = builder.CreateString(call_target_name)
+ backend_config_offset = builder.CreateString("")
+
_tfl_stablehlo_custom_call_opts.StablehloCustomCallOptionsStart(builder)
+
_tfl_stablehlo_custom_call_opts.StablehloCustomCallOptionsAddCallTargetName(
+ builder, call_target_name_offset
+ )
+
_tfl_stablehlo_custom_call_opts.StablehloCustomCallOptionsAddHasSideEffect(
+ builder, has_side_effect
+ )
+
_tfl_stablehlo_custom_call_opts.StablehloCustomCallOptionsAddBackendConfig(
+ builder, backend_config_offset
+ )
+ custom_call_opts =
_tfl_stablehlo_custom_call_opts.StablehloCustomCallOptionsEnd(builder)
+
+ custom_call_builtin =
_get_stablehlo_builtin_operator("STABLEHLO_CUSTOM_CALL")
+ custom_call_code = _build_operator_code(builder, custom_call_builtin)
+
+ main_tensors = [
+ _build_tensor(builder, 0, [2, 2]),
+ _build_tensor(builder, 1, [2, 2], tensor_type=output_tensor_type),
+ ]
+ custom_call_op = _build_operator(
+ builder,
+ 0,
+ [0],
+ [1],
+ builtin_options2_type=_tfl_builtin_options2.StablehloCustomCallOptions,
+ builtin_options2=custom_call_opts,
+ )
+ main_subgraph = _build_subgraph(
+ builder,
+ tensors=main_tensors,
+ operators=[custom_call_op],
+ inputs=[0],
+ outputs=[1],
+ )
+
+ buffers = [_build_buffer(builder) for _ in range(2)]
+ return _finish_tflite_model(
+ builder,
+ subgraph=main_subgraph,
+ operator_codes=[custom_call_code],
+ buffers=buffers,
+ )
+
+
def _build_stablehlo_while_model(
cond_subgraph_index=1,
body_subgraph_index=2,
@@ -6812,6 +6875,57 @@ def test_stablehlo_scatter_update_window_unsupported():
from_tflite(tflite_model)
+def test_stablehlo_custom_call_sharding():
+ """TFLite StableHLO CUSTOM_CALL Sharding annotation lowers to identity."""
+ mod = _load_model_from_buffer(_build_stablehlo_custom_call_model())
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2),
dtype="float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ gv: R.Tensor((2, 2), dtype="float32") = x
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_stablehlo_custom_call_unsupported_target():
+ """TFLite StableHLO CUSTOM_CALL rejects unknown external call targets."""
+ buf = _build_stablehlo_custom_call_model(call_target_name="custom_backend")
+ with pytest.raises(
+ tvm.error.OpNotImplemented,
+ match="STABLEHLO_CUSTOM_CALL target custom_backend is not supported",
+ ):
+ _load_model_from_buffer(buf)
+
+
+def test_stablehlo_custom_call_sharding_side_effect_unsupported():
+ """TFLite StableHLO CUSTOM_CALL rejects side-effecting Sharding calls."""
+ buf = _build_stablehlo_custom_call_model(has_side_effect=True)
+ with pytest.raises(tvm.error.OpNotImplemented, match="side effects"):
+ _load_model_from_buffer(buf)
+
+
+def test_stablehlo_custom_call_sharding_metadata_mismatch_unsupported():
+ """TFLite StableHLO CUSTOM_CALL rejects Sharding calls that change tensor
metadata."""
+ buf =
_build_stablehlo_custom_call_model(output_tensor_type=_tfl_tensor_type.INT32)
+ with pytest.raises(tvm.error.OpNotImplemented, match="Sharding tensor
metadata mismatch"):
+ _load_model_from_buffer(buf)
+
+
+def test_stablehlo_options_missing_payload_unsupported():
+ """A StableHLO op that declares an options type but omits the payload
fails cleanly."""
+ buf = _build_stablehlo_custom_call_model(include_options=False)
+ with pytest.raises(
+ tvm.error.OpNotImplemented,
+ match="StablehloCustomCallOptions is required but missing from the
operator",
+ ):
+ _load_model_from_buffer(buf)
+
+
def test_stablehlo_while():
"""TFLite STABLEHLO_WHILE lowers to a recursive Relax private function."""
mod = _load_model_from_buffer(_build_stablehlo_while_model())