This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c1415d6e4d [Relax][Frontend][TFLite] Add NON_MAX_SUPPRESSION_V4
converter (#19464)
c1415d6e4d is described below
commit c1415d6e4da1965c617fbfc51fcfd8caa3863b3d
Author: as4230 <[email protected]>
AuthorDate: Wed Apr 29 01:49:02 2026 -0400
[Relax][Frontend][TFLite] Add NON_MAX_SUPPRESSION_V4 converter (#19464)
Adds the missing TFLite NonMaxSuppressionV4 frontend handler. The
underlying relax.op.vision.non_max_suppression already covers V4's
behavior with soft_nms_sigma at the default 0.0 (hard-NMS path). The
handler bridges TFLite's tensor format to the Relax op, following the
same pattern as convert_nms_v5 (#19426) but without its soft-NMS
branching.
Tests cover conversion and IR structural assertions, run with
`pytest tests/python/relax/test_frontend_tflite.py -k nms_v4`. E2E
correctness runs on the nightly gate (CI_ENV_NIGHTLY).
Relates to #19412.
---
.../tvm/relax/frontend/tflite/tflite_frontend.py | 59 +++++++++
tests/python/relax/test_frontend_tflite.py | 147 +++++++++++++++++++++
2 files changed, 206 insertions(+)
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index fe85e4da9e..24554c6fec 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -239,6 +239,7 @@ class OperatorConverter:
# "UNIDIRECTIONAL_SEQUENCE_LSTM":
self.convert_unidirectional_sequence_lstm,
"WHERE": self.convert_select,
"ZEROS_LIKE": self.convert_zeros_like,
+ "NON_MAX_SUPPRESSION_V4": self.convert_nms_v4,
"NON_MAX_SUPPRESSION_V5": self.convert_nms_v5,
}
@@ -3657,6 +3658,64 @@ class OperatorConverter:
num_detections = relax.op.astype(num_detections, "float32")
return relax.Tuple([detection_boxes, detection_classes,
detection_scores, num_detections])
+ def convert_nms_v4(self, op):
+ """Convert TFLite NonMaxSuppressionV4"""
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 5, "input tensor length should be 5"
+
+ boxes = self.get_tensor_expr(input_tensors[0])
+ scores = self.get_tensor_expr(input_tensors[1])
+
+ max_output_size = self.get_tensor_value(input_tensors[2])
+ iou_threshold = self.get_tensor_value(input_tensors[3])
+ score_threshold = self.get_tensor_value(input_tensors[4])
+
+ if isinstance(max_output_size, np.ndarray):
+ assert max_output_size.size == 1, "only one value is expected."
+ max_output_size = int(max_output_size)
+
+ if isinstance(iou_threshold, np.ndarray):
+ assert iou_threshold.size == 1, "only one value is expected."
+ iou_threshold = float(iou_threshold)
+
+ if isinstance(score_threshold, np.ndarray):
+ assert score_threshold.size == 1, "only one value is expected."
+ score_threshold = float(score_threshold)
+
+ scores_expand = relax.op.expand_dims(scores, axis=-1)
+ data = relax.op.concat([scores_expand, boxes], axis=-1)
+ data = relax.op.expand_dims(data, axis=0)
+
+ valid_counts_ret = relax.op.vision.get_valid_counts(
+ data, score_threshold=score_threshold, id_index=-1, score_index=0
+ )
+ count = valid_counts_ret[0]
+ data = valid_counts_ret[1]
+ indices = valid_counts_ret[2]
+
+ nms_ret = relax.op.vision.non_max_suppression(
+ data=data,
+ valid_count=count,
+ indices=indices,
+ max_output_size=max_output_size,
+ iou_threshold=iou_threshold,
+ force_suppress=True,
+ top_k=-1,
+ coord_start=1,
+ score_index=0,
+ id_index=-1,
+ return_indices=True,
+ invalid_to_bottom=False,
+ )
+
+ selected_indices = relax.op.squeeze(nms_ret[0], axis=[0])
+ selected_indices = relax.op.strided_slice(
+ selected_indices, axes=[0], begin=[0], end=[max_output_size]
+ )
+ num_valid = relax.op.reshape(nms_ret[1], [])
+
+ return relax.Tuple([selected_indices, num_valid])
+
def convert_nms_v5(self, op):
"""Convert TFLite NonMaxSuppressionV5"""
input_tensors = self.get_input_tensors(op)
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index 7a166f2fd9..a5506c5984 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -1490,6 +1490,59 @@ def test_batch_matmul_adj():
verify(BatchMatMulAdj, Expected)
+def _verify_nms_v4(mod, tf_func, boxes_np, scores_np):
+ """E2E verify for NMS V4: only run on nightly, compare valid outputs
only."""
+ if "CI_ENV_NIGHTLY" not in os.environ:
+ return
+
+ tf_indices, tf_valid = tf_func(tf.constant(boxes_np),
tf.constant(scores_np))
+ n_valid = int(tf_valid.numpy())
+
+ tgt = tvm.target.Target("llvm")
+ ex = tvm.compile(mod, tgt)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ vm.set_input("main", boxes_np, scores_np)
+ vm.invoke_stateful("main")
+ tvm_indices, tvm_valid = vm.get_outputs("main")
+
+ assert int(tvm_valid.numpy()) == n_valid
+ np.testing.assert_array_equal(
+ tf_indices.numpy()[:n_valid],
+ tvm_indices.numpy()[:n_valid],
+ )
+
+
+def _build_nms_v4_mod(num_boxes, max_output_size, iou_threshold,
score_threshold):
+ """Convert a NonMaxSuppressionV4 TFLite model to a Relax module.
+
+ Scalar params must be Python literals (not tf.constant) so TFLite can
+ statically infer output shapes during conversion.
+ """
+
+ class NMSv4Module(tf.Module):
+ @tf.function(
+ input_signature=[
+ tf.TensorSpec(shape=(num_boxes, 4), dtype=tf.float32),
+ tf.TensorSpec(shape=(num_boxes,), dtype=tf.float32),
+ ]
+ )
+ def func(self, boxes, scores):
+ indices, valid = tf.raw_ops.NonMaxSuppressionV4(
+ boxes=boxes,
+ scores=scores,
+ max_output_size=max_output_size,
+ iou_threshold=iou_threshold,
+ score_threshold=score_threshold,
+ pad_to_max_output_size=True,
+ )
+ return indices, valid
+
+ instance = NMSv4Module()
+ cf = instance.func.get_concrete_function()
+ mod = _get_mod_from_cfunc(cf)
+ return mod, instance.func
+
+
def _verify_nms_v5(mod, tf_func, boxes_np, scores_np, soft_nms_sigma=0.0):
"""E2E verify for NMS: only run on nightly, compare valid outputs only."""
if "CI_ENV_NIGHTLY" not in os.environ:
@@ -1934,6 +1987,100 @@ def test_nms_v5_soft_ir():
assert "R.clip(" in ir
+_NMS_V4_CASES = [
+ pytest.param(
+ 6,
+ 3,
+ 0.5,
+ 0.0,
+ np.array(
+ [
+ [0.0, 0.0, 1.0, 1.0],
+ [0.0, 0.0, 1.0, 1.0],
+ [0.0, 0.1, 1.0, 1.1],
+ [0.0, 0.0, 1.0, 0.9],
+ [0.5, 0.5, 1.5, 1.5],
+ [0.0, 0.0, 0.3, 0.3],
+ ],
+ dtype=np.float32,
+ ),
+ np.array([0.9, 0.75, 0.6, 0.5, 0.4, 0.3], dtype=np.float32),
+ id="basic",
+ ),
+ pytest.param(
+ 8,
+ 4,
+ 0.5,
+ 0.4,
+ _make_valid_boxes(np.random.default_rng(42), 8),
+ np.random.default_rng(42).random(8, dtype=np.float32),
+ id="score_threshold",
+ ),
+ pytest.param(
+ 5,
+ 3,
+ 0.5,
+ 0.99,
+ _make_valid_boxes(np.random.default_rng(0), 5),
+ np.array([0.1, 0.2, 0.3, 0.4, 0.5], dtype=np.float32),
+ id="all_suppressed",
+ ),
+ pytest.param(
+ 4,
+ 10,
+ 0.5,
+ 0.0,
+ np.array(
+ [
+ [0.0, 0.0, 0.3, 0.3],
+ [0.5, 0.5, 0.8, 0.8],
+ [0.1, 0.1, 0.4, 0.4],
+ [0.6, 0.6, 0.9, 0.9],
+ ],
+ dtype=np.float32,
+ ),
+ np.array([0.9, 0.85, 0.7, 0.65], dtype=np.float32),
+ id="max_output_size_larger_than_boxes",
+ ),
+]
+
+
[email protected](
+ "num_boxes,max_output_size,iou_threshold,score_threshold,boxes,scores",
+ _NMS_V4_CASES,
+)
+def test_nms_v4(num_boxes, max_output_size, iou_threshold, score_threshold,
boxes, scores):
+ """NON_MAX_SUPPRESSION_V4: conversion smoke test + E2E correctness
(nightly only)."""
+ mod, tf_func = _build_nms_v4_mod(num_boxes, max_output_size,
iou_threshold, score_threshold)
+ _verify_nms_v4(mod, tf_func, boxes, scores)
+
+
+def test_nms_v4_ir():
+ """Verify the emitted Relax IR has correct structure for
NON_MAX_SUPPRESSION_V4."""
+ num_boxes = 6
+ max_output_size = 3
+ mod, _ = _build_nms_v4_mod(
+ num_boxes=num_boxes,
+ max_output_size=max_output_size,
+ iou_threshold=0.5,
+ score_threshold=0.0,
+ )
+
+ ir = mod.script()
+
+ # Validate correct sorting/id indices are passed to valid_counts
+ assert "score_index=0" in ir
+ assert "id_index=-1" in ir
+ # NMS size limit validation
+ assert f"max_output_size={max_output_size}" in ir
+ # Valid output shape must be () statically
+ assert 'R.Tensor((), dtype="int32")' in ir
+ # Selected indices tensor bounds check
+ assert f"R.Tensor(({max_output_size},)" in ir
+ # V4 must use hard-NMS (soft_nms_sigma left at default 0.0)
+ assert "soft_nms_sigma=0.0" in ir
+
+
_DETECTION_POSTPROCESS_SMOKE_CASES = [
pytest.param(
{