This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b915cac0bf [Relax][Frontend][TFLite] Add soft-NMS support for TFLite 
NON_MAX_SUPPRESSION_V5 (#19426)
b915cac0bf is described below

commit b915cac0bf5071fd23e8ef1643c7932826e2e05e
Author: HoYi <[email protected]>
AuthorDate: Tue Apr 28 04:42:34 2026 +0800

    [Relax][Frontend][TFLite] Add soft-NMS support for TFLite 
NON_MAX_SUPPRESSION_V5 (#19426)
    
    ## Summary
    
    This PR completes the TFLite `NON_MAX_SUPPRESSION_V5` implementation in
    Relax by adding support for `soft_nms_sigma != 0`.
    
    It extends `relax.vision.non_max_suppression` with soft-NMS attributes,
    updates the TFLite frontend to consume the soft-NMS outputs correctly,
    and aligns the TOPI implementation with LiteRT's reference behavior.
    
    Relates to #19412.
    
    ## Changes
    
    1. **Relax / TOPI soft-NMS support**
    - Extend `NonMaximumSuppressionAttrs` with `soft_nms_sigma` and
    `score_threshold`.
    - Thread the new attributes through Relax op registration, Python
    wrapper, and legalization.
    - Add soft-NMS handling to TOPI classic NMS so
    `relax.vision.non_max_suppression` can represent the
    `NON_MAX_SUPPRESSION_V5` behavior.
    
    2. **TFLite frontend support for `NON_MAX_SUPPRESSION_V5`**
    - Remove the previous `soft_nms_sigma != 0` unsupported-path guard in
    the TFLite frontend.
    - Forward `soft_nms_sigma` and `score_threshold` into
    `relax.vision.non_max_suppression`.
    - Handle the soft-NMS return path explicitly so the frontend reads
    decayed scores from the processed NMS output instead of re-reading the
    original score tensor.
    
    3. **Soft-NMS correctness fixes**
    - Fix the soft-NMS path so boxes whose scores fall below the threshold
    after decay are invalidated consistently.
    - Keep returned indices and decayed scores aligned in both the TOPI TIR
    implementation and the NumPy reference implementation.
    - Update the soft-NMS candidate selection logic to re-pick the current
    best candidate after each decay step, matching LiteRT's
         reference behavior.
       - Align the Gaussian decay formula with LiteRT.
    
    4. **Test coverage**
    - Add Relax tests for soft-NMS struct-info inference and legalization.
    - Add Relax E2E tests covering reordered outputs after score decay and
    other soft-NMS follow-up cases.
    - Add TFLite frontend tests for `NON_MAX_SUPPRESSION_V5` with
    `soft_nms_sigma != 0`.
    - Add IR checks to verify that `soft_nms_sigma` and `score_threshold`
    are forwarded correctly.
    
    ## Testing
    
    ```bash
    python -m pytest -n 1 tests/python/relax/test_op_vision.py -k 
"all_class_non_max_suppression or get_valid_counts or nms" -v
    python -m pytest tests/python/relax/test_frontend_tflite.py -k "nms_v5" -v
    ```
    
    ## Result:
    - Relax vision tests passed locally
    - TFLite `NON_MAX_SUPPRESSION_V5` coverage added for both hard-NMS and
    soft-NMS paths
---
 include/tvm/relax/attrs/vision.h                   |   8 +-
 .../tvm/relax/frontend/tflite/tflite_frontend.py   |  39 ++-
 python/tvm/relax/op/vision/nms.py                  |  24 +-
 python/tvm/relax/transform/legalize_ops/vision.py  |   2 +
 python/tvm/topi/testing/nms_python.py              |  78 ++++-
 python/tvm/topi/vision/nms.py                      | 326 ++++++++++++++-------
 src/relax/op/vision/nms.cc                         |  29 +-
 src/relax/op/vision/nms.h                          |   4 +-
 tests/python/relax/test_frontend_tflite.py         | 129 +++++++-
 tests/python/relax/test_op_vision.py               | 124 ++++++++
 .../relax/test_tvmscript_parser_op_vision.py       |  70 +++++
 11 files changed, 710 insertions(+), 123 deletions(-)

diff --git a/include/tvm/relax/attrs/vision.h b/include/tvm/relax/attrs/vision.h
index 8971127d76..9dec6fd503 100644
--- a/include/tvm/relax/attrs/vision.h
+++ b/include/tvm/relax/attrs/vision.h
@@ -122,6 +122,8 @@ struct NonMaximumSuppressionAttrs
   int id_index;
   bool return_indices;
   bool invalid_to_bottom;
+  double soft_nms_sigma;
+  double score_threshold;
 
   static void RegisterReflection() {
     namespace refl = tvm::ffi::reflection;
@@ -143,7 +145,11 @@ struct NonMaximumSuppressionAttrs
         .def_ro("return_indices", &NonMaximumSuppressionAttrs::return_indices,
                 "Whether to return box indices in input data.")
         .def_ro("invalid_to_bottom", 
&NonMaximumSuppressionAttrs::invalid_to_bottom,
-                "Whether to move all valid bounding boxes to the top.");
+                "Whether to move all valid bounding boxes to the top.")
+        .def_ro("soft_nms_sigma", &NonMaximumSuppressionAttrs::soft_nms_sigma,
+                "Sigma for soft-NMS; 0.0 means standard hard NMS.")
+        .def_ro("score_threshold", 
&NonMaximumSuppressionAttrs::score_threshold,
+                "Score threshold for soft-NMS validity check; 0.0 when 
unused.");
   }
   TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.NonMaximumSuppressionAttrs",
                                     NonMaximumSuppressionAttrs, BaseAttrsNode);
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index d773d8d7ce..732950ca68 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -3617,10 +3617,6 @@ class OperatorConverter:
         if isinstance(soft_nms_sigma, np.ndarray):
             assert soft_nms_sigma.size == 1, "only one value is expected."
             soft_nms_sigma = float(soft_nms_sigma)
-        if soft_nms_sigma != 0.0:
-            raise tvm.error.OpNotImplemented(
-                "It is soft_nms when soft_nms_sigma != 0, which is not 
supported!"
-            )
 
         scores_expand = relax.op.expand_dims(scores, axis=-1)
         data = relax.op.concat([scores_expand, boxes], axis=-1)
@@ -3646,18 +3642,41 @@ class OperatorConverter:
             id_index=-1,
             return_indices=True,
             invalid_to_bottom=False,
+            soft_nms_sigma=soft_nms_sigma,
+            score_threshold=score_threshold,
         )
 
-        selected_indices = relax.op.squeeze(nms_ret[0], axis=[0])
+        if soft_nms_sigma > 0.0:
+            processed_data = relax.op.squeeze(nms_ret[0], axis=[0])
+            indices_from_nms = nms_ret[1]
+            num_valid_from_nms = nms_ret[2]
+        else:
+            indices_from_nms = nms_ret[0]
+            num_valid_from_nms = nms_ret[1]
+
+        selected_indices = relax.op.squeeze(indices_from_nms, axis=[0])
         selected_indices = relax.op.strided_slice(
             selected_indices, axes=[0], begin=[0], end=[max_output_size]
         )
-        num_valid = relax.op.reshape(nms_ret[1], [])
+        num_valid = relax.op.reshape(num_valid_from_nms, [])
 
-        # Clamp out-of-bound padded indices to prevent take() crash.
-        num_boxes = int(self.get_tensor_shape(input_tensors[0])[0])
-        safe_indices = relax.op.clip(selected_indices, min=0, max=num_boxes - 
1)
-        selected_scores = relax.op.take(scores, safe_indices, axis=0)
+        if soft_nms_sigma > 0.0:
+            # Extract decayed scores from the processed data (score_index=0)
+            selected_scores = relax.op.strided_slice(
+                processed_data, axes=[1], begin=[0], end=[1]
+            )
+            selected_scores = relax.op.squeeze(selected_scores, axis=[1])
+            selected_scores = relax.op.strided_slice(
+                selected_scores, axes=[0], begin=[0], end=[max_output_size]
+            )
+            selected_scores = relax.op.clip(
+                selected_scores, min=0.0, max=float(np.finfo("float32").max)
+            )
+        else:
+            # Clamp out-of-bound padded indices to prevent take() crash.
+            num_boxes = int(self.get_tensor_shape(input_tensors[0])[0])
+            safe_indices = relax.op.clip(selected_indices, min=0, 
max=num_boxes - 1)
+            selected_scores = relax.op.take(scores, safe_indices, axis=0)
 
         out = relax.Tuple([selected_indices, selected_scores, num_valid])
         return out
diff --git a/python/tvm/relax/op/vision/nms.py 
b/python/tvm/relax/op/vision/nms.py
index 4eb3eb7f7a..427d68d113 100644
--- a/python/tvm/relax/op/vision/nms.py
+++ b/python/tvm/relax/op/vision/nms.py
@@ -115,6 +115,8 @@ def non_max_suppression(
     id_index=0,
     return_indices=True,
     invalid_to_bottom=False,
+    soft_nms_sigma=0.0,
+    score_threshold=0.0,
 ):
     """Non-maximum suppression operator for object detection.
 
@@ -160,12 +162,28 @@ def non_max_suppression(
         Whether to move valid bounding boxes to the top of the returned tensor.
         This option only affects the ``return_indices=False`` path.
 
+    soft_nms_sigma : float, optional
+        Sigma for soft-NMS Gaussian penalty. When ``0.0`` (default), standard
+        hard NMS is used. Positive values decay overlapping box scores instead
+        of suppressing them outright.
+
+    score_threshold : float, optional
+        Post-decay minimum score for a box to remain eligible during soft-NMS.
+        Only used when ``soft_nms_sigma > 0``. This is distinct from
+        ``get_valid_counts.score_threshold``, which filters boxes before NMS.
+        Defaults to ``0.0``.
+
     Returns
     -------
     out : relax.Expr
-        If ``return_indices`` is ``True``, returns
-        ``(box_indices, valid_box_count)`` with shapes
+        The return tuple shape depends on ``soft_nms_sigma``.
+        If ``return_indices`` is ``True`` and ``soft_nms_sigma`` is ``0.0``,
+        returns a 2-tuple ``(box_indices, valid_box_count)`` with shapes
         ``[batch_size, num_anchors]`` and ``[batch_size, 1]``.
+        If ``return_indices`` is ``True`` and ``soft_nms_sigma > 0``,
+        returns a 3-tuple ``(out_data, box_indices, valid_box_count)`` where
+        decayed ``out_data`` is prepended and has the same shape as the input
+        data.
         Otherwise returns the modified data tensor.
     """
     return _ffi_api.non_max_suppression(
@@ -181,4 +199,6 @@ def non_max_suppression(
         id_index,
         return_indices,
         invalid_to_bottom,
+        soft_nms_sigma,
+        score_threshold,
     )
diff --git a/python/tvm/relax/transform/legalize_ops/vision.py 
b/python/tvm/relax/transform/legalize_ops/vision.py
index c515fc8fe8..4419549164 100644
--- a/python/tvm/relax/transform/legalize_ops/vision.py
+++ b/python/tvm/relax/transform/legalize_ops/vision.py
@@ -152,6 +152,8 @@ def _non_max_suppression(block_builder: BlockBuilder, call: 
Call) -> Expr:
         id_index=call.attrs.id_index,
         return_indices=call.attrs.return_indices,
         invalid_to_bottom=call.attrs.invalid_to_bottom,
+        soft_nms_sigma=call.attrs.soft_nms_sigma,
+        score_threshold=call.attrs.score_threshold,
     )
 
 
diff --git a/python/tvm/topi/testing/nms_python.py 
b/python/tvm/topi/testing/nms_python.py
index 7c8c20f5b4..c8711c70dd 100644
--- a/python/tvm/topi/testing/nms_python.py
+++ b/python/tvm/topi/testing/nms_python.py
@@ -46,6 +46,8 @@ def non_max_suppression_python(
     id_index=0,
     return_indices=True,
     invalid_to_bottom=False,
+    soft_nms_sigma=0.0,
+    score_threshold=0.0,
 ):
     """Numpy reference for classic non_max_suppression.
 
@@ -62,7 +64,9 @@ def non_max_suppression_python(
 
     Returns
     -------
-    If return_indices is True: (box_indices, valid_box_count)
+    If return_indices is True and soft_nms_sigma == 0.0: (box_indices, 
valid_box_count)
+    If return_indices is True and soft_nms_sigma > 0.0:
+        (out_data, box_indices, valid_box_count)
     Otherwise: modified data tensor
     """
     batch_size, num_anchors, _ = data.shape
@@ -71,6 +75,10 @@ def non_max_suppression_python(
     compacted = np.full((batch_size, num_anchors), -1, dtype="int32")
     valid_box_count = np.zeros((batch_size, 1), dtype="int32")
 
+    is_soft_nms = soft_nms_sigma > 0.0
+    thresh = score_threshold if is_soft_nms else 0.0
+    soft_nms_scale = -0.5 / soft_nms_sigma if is_soft_nms else 0.0
+
     for i in range(batch_size):
         nkeep = int(valid_count[i])
         if 0 < top_k < nkeep:
@@ -86,10 +94,72 @@ def non_max_suppression_python(
             out_data[i, j, :] = data[i, src, :]
             out_box_indices[i, j] = src
 
+        if is_soft_nms:
+            num_selected = 0
+            while num_selected < nkeep and (max_output_size < 0 or 
num_selected < max_output_size):
+                best_idx = -1
+                best_score = thresh
+                for j in range(num_selected, nkeep):
+                    if out_box_indices[i, j] >= 0 and out_data[i, j, 
score_index] > best_score:
+                        best_idx = j
+                        best_score = out_data[i, j, score_index]
+
+                if best_idx < 0:
+                    break
+
+                if best_idx != num_selected:
+                    out_data[i, [num_selected, best_idx], :] = out_data[
+                        i, [best_idx, num_selected], :
+                    ]
+                    out_box_indices[i, [num_selected, best_idx]] = 
out_box_indices[
+                        i, [best_idx, num_selected]
+                    ]
+
+                selected_idx = num_selected
+                for j in range(selected_idx + 1, nkeep):
+                    if out_box_indices[i, j] < 0 or out_data[i, j, 
score_index] <= thresh:
+                        continue
+
+                    do_suppress = False
+                    if force_suppress:
+                        do_suppress = True
+                    elif id_index >= 0:
+                        do_suppress = (
+                            out_data[i, selected_idx, id_index] == out_data[i, 
j, id_index]
+                        )
+                    else:
+                        do_suppress = True
+
+                    if not do_suppress:
+                        continue
+
+                    iou = _iou(out_data[i, selected_idx], out_data[i, j], 
coord_start)
+                    if iou >= iou_threshold:
+                        out_box_indices[i, j] = -1
+                    else:
+                        out_data[i, j, score_index] *= np.exp(soft_nms_scale * 
(iou**2))
+                        if out_data[i, j, score_index] <= thresh:
+                            out_box_indices[i, j] = -1
+
+                num_selected += 1
+
+            valid_box_count[i, 0] = num_selected
+            if return_indices:
+                for j in range(num_selected):
+                    orig_idx = out_box_indices[i, j]
+                    compacted[i, j] = int(indices[i, orig_idx])
+                    out_box_indices[i, j] = compacted[i, j]
+                for j in range(num_selected, num_anchors):
+                    out_data[i, j, :] = -1.0
+                    out_box_indices[i, j] = -1
+            else:
+                out_data[i, num_selected:, :] = -1.0
+            continue
+
         # Greedy NMS
         num_valid = 0
         for j in range(nkeep):
-            if out_data[i, j, score_index] <= 0:
+            if out_data[i, j, score_index] <= thresh:
                 out_data[i, j, :] = -1.0
                 out_box_indices[i, j] = -1
                 continue
@@ -102,7 +172,7 @@ def non_max_suppression_python(
 
             # Suppress overlapping boxes
             for k in range(j + 1, nkeep):
-                if out_data[i, k, score_index] <= 0:
+                if out_data[i, k, score_index] <= thresh:
                     continue
 
                 do_suppress = False
@@ -130,6 +200,8 @@ def non_max_suppression_python(
             valid_box_count[i, 0] = cnt
 
     if return_indices:
+        if is_soft_nms:
+            return [out_data, compacted, valid_box_count]
         return [compacted, valid_box_count]
 
     if invalid_to_bottom:
diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py
index a602527fcc..ad548978a1 100644
--- a/python/tvm/topi/vision/nms.py
+++ b/python/tvm/topi/vision/nms.py
@@ -188,6 +188,8 @@ def _classic_nms_ir(
     out_data,
     out_box_indices,
     out_valid_box_count,
+    soft_nms_sigma=0.0,
+    score_threshold=0.0,
 ):
     """IR for classic single-class non-maximum suppression."""
     with IRBuilder() as ib:
@@ -200,6 +202,10 @@ def _classic_nms_ir(
         if out_valid_box_count is not None:
             out_valid_box_count = T.buffer_proxy(out_valid_box_count)
 
+        is_soft_nms = soft_nms_sigma > 0.0
+        # For hard NMS the historical threshold is 0.0; for soft NMS use 
score_threshold.
+        thresh = tvm.tirx.Cast(data.dtype, T.float32(score_threshold if 
is_soft_nms else 0.0))
+
         with T.parallel(0, batch_size) as i:
             # Step 1: Reorder data by sorted score
             nkeep_buf = T.alloc_buffer((1,), "int32", scope="local")
@@ -226,117 +232,220 @@ def _classic_nms_ir(
             num_valid_boxes_buf = T.alloc_buffer((1,), "int32", scope="local")
             num_valid_boxes = T.buffer_proxy(num_valid_boxes_buf)
             num_valid_boxes[0] = T.int32(0)
+            best_idx_buf = T.alloc_buffer((1,), "int32", scope="local")
+            best_idx = T.buffer_proxy(best_idx_buf)
+            best_score_buf = T.alloc_buffer((1,), data.dtype, scope="local")
+            best_score = T.buffer_proxy(best_score_buf)
+            tmp_idx_buf = T.alloc_buffer((1,), "int32", scope="local")
+            tmp_idx = T.buffer_proxy(tmp_idx_buf)
+            tmp_val_buf = T.alloc_buffer((1,), data.dtype, scope="local")
+            tmp_val = T.buffer_proxy(tmp_val_buf)
+            zero = tvm.tirx.Cast(data.dtype, T.float32(0.0))
+
+            def compute_iou(lhs_idx, rhs_idx):
+                lhs_l = tvm.te.min(
+                    out_data[i, lhs_idx, coord_start],
+                    out_data[i, lhs_idx, coord_start + 2],
+                )
+                lhs_t = tvm.te.min(
+                    out_data[i, lhs_idx, coord_start + 1],
+                    out_data[i, lhs_idx, coord_start + 3],
+                )
+                lhs_r = tvm.te.max(
+                    out_data[i, lhs_idx, coord_start],
+                    out_data[i, lhs_idx, coord_start + 2],
+                )
+                lhs_b = tvm.te.max(
+                    out_data[i, lhs_idx, coord_start + 1],
+                    out_data[i, lhs_idx, coord_start + 3],
+                )
+                rhs_l = tvm.te.min(
+                    out_data[i, rhs_idx, coord_start],
+                    out_data[i, rhs_idx, coord_start + 2],
+                )
+                rhs_t = tvm.te.min(
+                    out_data[i, rhs_idx, coord_start + 1],
+                    out_data[i, rhs_idx, coord_start + 3],
+                )
+                rhs_r = tvm.te.max(
+                    out_data[i, rhs_idx, coord_start],
+                    out_data[i, rhs_idx, coord_start + 2],
+                )
+                rhs_b = tvm.te.max(
+                    out_data[i, rhs_idx, coord_start + 1],
+                    out_data[i, rhs_idx, coord_start + 3],
+                )
+                width = tvm.te.max(zero, tvm.te.min(lhs_r, rhs_r) - 
tvm.te.max(lhs_l, rhs_l))
+                height = tvm.te.max(zero, tvm.te.min(lhs_b, rhs_b) - 
tvm.te.max(lhs_t, rhs_t))
+                intersection = height * width
+                union = (
+                    (lhs_r - lhs_l) * (lhs_b - lhs_t)
+                    + (rhs_r - rhs_l) * (rhs_b - rhs_t)
+                    - intersection
+                )
+                return tvm.tirx.Select(union <= zero, zero, intersection / 
union)
+
+            if is_soft_nms:
+                # LiteRT soft-NMS selects the current highest-score candidate 
each round.
+                soft_nms_scale = tvm.tirx.Cast(data.dtype, T.float32(-0.5 / 
soft_nms_sigma))
 
-            with T.serial(0, nkeep_local[0]) as j:
-                # Check if box j is still valid (score > 0) and within 
max_output_size
-                with T.If(
-                    tvm.tirx.all(
-                        out_data[i, j, score_index] > 
tvm.tirx.Cast(data.dtype, T.float32(0.0)),
+                with T.serial(0, nkeep_local[0]) as _:
+                    with T.If(
                         tvm.tirx.Select(
                             max_output_size > 0,
                             num_valid_boxes[0] < max_output_size,
                             tvm.tirx.const(True),
-                        ),
-                    )
-                ):
-                    with T.Then():
-                        num_valid_boxes[0] = num_valid_boxes[0] + 1
-
-                        # Suppress overlapping boxes
-                        with T.serial(0, nkeep_local[0]) as k:
-                            with T.If(
-                                tvm.tirx.all(
-                                    k > j,
-                                    out_data[i, k, score_index]
-                                    > tvm.tirx.Cast(data.dtype, 
T.float32(0.0)),
-                                )
-                            ):
+                        )
+                    ):
+                        with T.Then():
+                            best_idx[0] = T.int32(-1)
+                            best_score[0] = thresh
+
+                            with T.serial(0, nkeep_local[0]) as j:
+                                with T.If(
+                                    tvm.tirx.all(
+                                        j >= num_valid_boxes[0],
+                                        out_box_indices[i, j] >= 0,
+                                        out_data[i, j, score_index] > 
best_score[0],
+                                    )
+                                ):
+                                    with T.Then():
+                                        best_idx[0] = j
+                                        best_score[0] = out_data[i, j, 
score_index]
+
+                            with T.If(best_idx[0] >= 0):
                                 with T.Then():
-                                    # Check class ID match (or force_suppress)
-                                    do_suppress = tvm.tirx.const(False)
-                                    if force_suppress:
-                                        do_suppress = tvm.tirx.const(True)
-                                    elif id_index >= 0:
-                                        do_suppress = (
-                                            out_data[i, j, id_index] == 
out_data[i, k, id_index]
-                                        )
-                                    else:
-                                        do_suppress = tvm.tirx.const(True)
-
-                                    with T.If(do_suppress):
+                                    with T.If(best_idx[0] != 
num_valid_boxes[0]):
                                         with T.Then():
-                                            # Calculate IoU
-                                            a_l = tvm.te.min(
-                                                out_data[i, j, coord_start],
-                                                out_data[i, j, coord_start + 
2],
-                                            )
-                                            a_t = tvm.te.min(
-                                                out_data[i, j, coord_start + 
1],
-                                                out_data[i, j, coord_start + 
3],
-                                            )
-                                            a_r = tvm.te.max(
-                                                out_data[i, j, coord_start],
-                                                out_data[i, j, coord_start + 
2],
-                                            )
-                                            a_b = tvm.te.max(
-                                                out_data[i, j, coord_start + 
1],
-                                                out_data[i, j, coord_start + 
3],
+                                            tmp_idx[0] = out_box_indices[i, 
num_valid_boxes[0]]
+                                            out_box_indices[
+                                                i, num_valid_boxes[0]
+                                            ] = out_box_indices[i, best_idx[0]]
+                                            out_box_indices[i, best_idx[0]] = 
tmp_idx[0]
+
+                                            with T.serial(0, box_data_length) 
as k:
+                                                tmp_val[0] = out_data[i, 
num_valid_boxes[0], k]
+                                                out_data[i, 
num_valid_boxes[0], k] = out_data[
+                                                    i, best_idx[0], k
+                                                ]
+                                                out_data[i, best_idx[0], k] = 
tmp_val[0]
+
+                                    with T.serial(0, nkeep_local[0]) as j:
+                                        with T.If(
+                                            tvm.tirx.all(
+                                                j > num_valid_boxes[0],
+                                                out_box_indices[i, j] >= 0,
+                                                out_data[i, j, score_index] > 
thresh,
                                             )
+                                        ):
+                                            with T.Then():
+                                                do_suppress = 
tvm.tirx.const(False)
+                                                if force_suppress:
+                                                    do_suppress = 
tvm.tirx.const(True)
+                                                elif id_index >= 0:
+                                                    do_suppress = (
+                                                        out_data[i, 
num_valid_boxes[0], id_index]
+                                                        == out_data[i, j, 
id_index]
+                                                    )
+                                                else:
+                                                    do_suppress = 
tvm.tirx.const(True)
+
+                                                with T.If(do_suppress):
+                                                    with T.Then():
+                                                        iou = 
compute_iou(num_valid_boxes[0], j)
+
+                                                        with T.If(iou >= 
iou_threshold):
+                                                            with T.Then():
+                                                                
out_box_indices[i, j] = T.int32(-1)
+                                                        with T.If(iou < 
iou_threshold):
+                                                            with T.Then():
+                                                                out_data[i, j, 
score_index] = (
+                                                                    
out_data[i, j, score_index]
+                                                                    * 
tvm.tirx.exp(
+                                                                        
soft_nms_scale
+                                                                        * iou
+                                                                        * iou
+                                                                    )
+                                                                )
+                                                                with T.If(
+                                                                    
out_data[i, j, score_index]
+                                                                    <= thresh
+                                                                ):
+                                                                    with 
T.Then():
+                                                                        
out_box_indices[
+                                                                            i, 
j
+                                                                        ] = 
T.int32(-1)
+
+                                    num_valid_boxes[0] = num_valid_boxes[0] + 1
+
+                if return_indices:
+                    out_valid_box_count[i, 0] = num_valid_boxes[0]
 
-                                            b_l = tvm.te.min(
-                                                out_data[i, k, coord_start],
-                                                out_data[i, k, coord_start + 
2],
-                                            )
-                                            b_t = tvm.te.min(
-                                                out_data[i, k, coord_start + 
1],
-                                                out_data[i, k, coord_start + 
3],
-                                            )
-                                            b_r = tvm.te.max(
-                                                out_data[i, k, coord_start],
-                                                out_data[i, k, coord_start + 
2],
-                                            )
-                                            b_b = tvm.te.max(
-                                                out_data[i, k, coord_start + 
1],
-                                                out_data[i, k, coord_start + 
3],
-                                            )
+                    with T.serial(0, num_anchors) as j:
+                        with T.If(j < num_valid_boxes[0]):
+                            with T.Then():
+                                orig_idx = out_box_indices[i, j]
+                                out_box_indices[i, j] = indices[i, orig_idx]
+                        with T.If(j >= num_valid_boxes[0]):
+                            with T.Then():
+                                with T.serial(0, box_data_length) as k:
+                                    out_data[i, j, k] = tvm.tirx.Cast(
+                                        data.dtype, T.float32(-1.0)
+                                    )
+                                out_box_indices[i, j] = T.int32(-1)
+                else:
+                    with T.serial(0, num_anchors) as j:
+                        with T.If(j >= num_valid_boxes[0]):
+                            with T.Then():
+                                with T.serial(0, box_data_length) as k:
+                                    out_data[i, j, k] = 
tvm.tirx.Cast(data.dtype, T.float32(-1.0))
+            else:
+                with T.serial(0, nkeep_local[0]) as j:
+                    with T.If(
+                        tvm.tirx.all(
+                            out_data[i, j, score_index] > thresh,
+                            tvm.tirx.Select(
+                                max_output_size > 0,
+                                num_valid_boxes[0] < max_output_size,
+                                tvm.tirx.const(True),
+                            ),
+                        )
+                    ):
+                        with T.Then():
+                            num_valid_boxes[0] = num_valid_boxes[0] + 1
 
-                                            w = tvm.te.max(
-                                                tvm.tirx.Cast(data.dtype, 
T.float32(0.0)),
-                                                tvm.te.min(a_r, b_r) - 
tvm.te.max(a_l, b_l),
-                                            )
-                                            h = tvm.te.max(
-                                                tvm.tirx.Cast(data.dtype, 
T.float32(0.0)),
-                                                tvm.te.min(a_b, b_b) - 
tvm.te.max(a_t, b_t),
-                                            )
-                                            area = h * w
-                                            u = (
-                                                (a_r - a_l) * (a_b - a_t)
-                                                + (b_r - b_l) * (b_b - b_t)
-                                                - area
-                                            )
-                                            iou = tvm.tirx.Select(
-                                                u <= tvm.tirx.Cast(data.dtype, 
T.float32(0.0)),
-                                                tvm.tirx.Cast(data.dtype, 
T.float32(0.0)),
-                                                area / u,
+                            with T.serial(0, nkeep_local[0]) as k:
+                                with T.If(
+                                    tvm.tirx.all(k > j, out_data[i, k, 
score_index] > thresh)
+                                ):
+                                    with T.Then():
+                                        do_suppress = tvm.tirx.const(False)
+                                        if force_suppress:
+                                            do_suppress = tvm.tirx.const(True)
+                                        elif id_index >= 0:
+                                            do_suppress = (
+                                                out_data[i, j, id_index] == 
out_data[i, k, id_index]
                                             )
+                                        else:
+                                            do_suppress = tvm.tirx.const(True)
 
-                                            with T.If(iou >= iou_threshold):
-                                                with T.Then():
-                                                    out_data[i, k, 
score_index] = tvm.tirx.Cast(
-                                                        data.dtype, 
T.float32(-1.0)
-                                                    )
-                                                    out_box_indices[i, k] = 
T.int32(-1)
+                                        with T.If(do_suppress):
+                                            with T.Then():
+                                                iou = compute_iou(j, k)
 
-                    with T.Else():
-                        # Box suppressed or beyond max_output_size
-                        with T.serial(0, box_data_length) as k:
-                            out_data[i, j, k] = tvm.tirx.Cast(data.dtype, 
T.float32(-1.0))
-                        out_box_indices[i, j] = T.int32(-1)
+                                                with T.If(iou >= 
iou_threshold):
+                                                    with T.Then():
+                                                        out_data[i, k, 
score_index] = tvm.tirx.Cast(
+                                                            data.dtype, 
T.float32(-1.0)
+                                                        )
+                                                        out_box_indices[i, k] 
= T.int32(-1)
+
+                        with T.Else():
+                            with T.serial(0, box_data_length) as k:
+                                out_data[i, j, k] = tvm.tirx.Cast(data.dtype, 
T.float32(-1.0))
+                            out_box_indices[i, j] = T.int32(-1)
 
-            # Step 3: If return_indices, remap to original indices
-            if return_indices:
-                if out_valid_box_count is not None:
-                    # Count valid boxes and remap indices
+                if return_indices:
                     valid_idx_buf = T.alloc_buffer((1,), "int32", 
scope="local")
                     valid_idx = T.buffer_proxy(valid_idx_buf)
                     valid_idx[0] = T.int32(0)
@@ -350,7 +459,6 @@ def _classic_nms_ir(
 
                     out_valid_box_count[i, 0] = valid_idx[0]
 
-                    # Fill remaining with -1
                     with T.serial(0, num_anchors) as j:
                         with T.If(j >= valid_idx[0]):
                             with T.Then():
@@ -372,6 +480,8 @@ def non_max_suppression(
     id_index=0,
     return_indices=True,
     invalid_to_bottom=False,
+    soft_nms_sigma=0.0,
+    score_threshold=0.0,
 ):
     """Non-maximum suppression operator for object detection.
 
@@ -416,10 +526,24 @@ def non_max_suppression(
     invalid_to_bottom : optional, boolean
         Whether to move all valid bounding boxes to the top.
 
+    soft_nms_sigma : optional, float
+        Sigma for soft-NMS Gaussian penalty. 0.0 means standard hard NMS.
+
+    score_threshold : optional, float
+        Post-decay minimum score for a box to remain eligible during soft-NMS.
+        Only used when ``soft_nms_sigma > 0``. This is distinct from
+        ``get_valid_counts.score_threshold``, which filters boxes before NMS.
+
     Returns
     -------
     out : tvm.te.Tensor or tuple of tvm.te.Tensor
-        If return_indices is True, returns a tuple of (box_indices, 
valid_box_count).
+        The return tuple shape depends on ``soft_nms_sigma``.
+        If ``return_indices`` is ``True`` and ``soft_nms_sigma`` is ``0.0``,
+        returns a 2-tuple ``(box_indices, valid_box_count)``.
+        If ``return_indices`` is ``True`` and ``soft_nms_sigma > 0``,
+        returns a 3-tuple ``(out_data, box_indices, valid_box_count)`` where
+        decayed ``out_data`` is prepended and has the same shape as the input
+        data.
         Otherwise returns the modified data tensor.
     """
     batch_size = data.shape[0]
@@ -464,6 +588,7 @@ def non_max_suppression(
                 coord_start, score_index, id_index,
                 return_indices,
                 outs[0], outs[1], outs[2],
+                soft_nms_sigma, score_threshold,
             ),
             dtype=[data.dtype, "int32", "int32"],
             out_buffers=[out_data_buf, out_box_indices_buf, 
out_valid_box_count_buf],
@@ -471,6 +596,8 @@ def non_max_suppression(
             name="non_max_suppression",
             tag="non_max_suppression",
         )
+        if soft_nms_sigma > 0.0:
+            return [out_data, out_box_indices, out_valid_box_count]
         return [out_box_indices, out_valid_box_count]
 
     out_data, out_box_indices = te.extern(
@@ -484,6 +611,7 @@ def non_max_suppression(
             coord_start, score_index, id_index,
             return_indices,
             outs[0], outs[1], None,
+            soft_nms_sigma, score_threshold,
         ),
         dtype=[data.dtype, "int32"],
         out_buffers=[out_data_buf, out_box_indices_buf],
diff --git a/src/relax/op/vision/nms.cc b/src/relax/op/vision/nms.cc
index 97508d7211..9d7144b5d7 100644
--- a/src/relax/op/vision/nms.cc
+++ b/src/relax/op/vision/nms.cc
@@ -196,8 +196,8 @@ TVM_REGISTER_OP("relax.vision.get_valid_counts")
 
 Expr non_max_suppression(Expr data, Expr valid_count, Expr indices, int 
max_output_size,
                          double iou_threshold, bool force_suppress, int top_k, 
int coord_start,
-                         int score_index, int id_index, bool return_indices,
-                         bool invalid_to_bottom) {
+                         int score_index, int id_index, bool return_indices, 
bool invalid_to_bottom,
+                         double soft_nms_sigma, double score_threshold) {
   auto attrs = tvm::ffi::make_object<NonMaximumSuppressionAttrs>();
   attrs->max_output_size = max_output_size;
   attrs->iou_threshold = iou_threshold;
@@ -208,6 +208,8 @@ Expr non_max_suppression(Expr data, Expr valid_count, Expr 
indices, int max_outp
   attrs->id_index = id_index;
   attrs->return_indices = return_indices;
   attrs->invalid_to_bottom = invalid_to_bottom;
+  attrs->soft_nms_sigma = soft_nms_sigma;
+  attrs->score_threshold = score_threshold;
 
   static const Op& op = Op::Get("relax.vision.non_max_suppression");
   return Call(op, {std::move(data), std::move(valid_count), 
std::move(indices)}, Attrs(attrs), {});
@@ -319,7 +321,28 @@ StructInfo InferStructInfoNMS(const Call& call, const 
BlockBuilder& ctx) {
   }
 
   if (attrs->return_indices) {
-    // Returns (box_indices[batch, num_anchors], valid_box_count[batch, 1])
+    if (attrs->soft_nms_sigma > 0.0) {
+      // Soft-NMS returns (out_data[batch, num_anchors, elem_length],
+      //                   box_indices[batch, num_anchors],
+      //                   valid_box_count[batch, 1])
+      if (data_shape == nullptr) {
+        tvm::ffi::Array<StructInfo> fields = {
+            TensorStructInfo(data_sinfo->dtype, /*ndim=*/3, vdev),
+            TensorStructInfo(DataType::Int(32), /*ndim=*/2, vdev),
+            TensorStructInfo(DataType::Int(32), /*ndim=*/2, vdev)};
+        return TupleStructInfo(fields);
+      }
+      auto batch = data_shape->values[0];
+      auto num_anchors = data_shape->values[1];
+      tvm::ffi::Array<StructInfo> fields = {
+          TensorStructInfo(ffi::GetRef<ShapeExpr>(data_shape), 
data_sinfo->dtype, vdev),
+          TensorStructInfo(ShapeExpr({batch, num_anchors}), DataType::Int(32), 
vdev),
+          TensorStructInfo(ShapeExpr({batch, IntImm(DataType::Int(64), 1)}), 
DataType::Int(32),
+                           vdev)};
+      return TupleStructInfo(fields);
+    }
+
+    // Hard NMS returns (box_indices[batch, num_anchors], 
valid_box_count[batch, 1])
     if (data_shape == nullptr) {
       tvm::ffi::Array<StructInfo> fields = {
           TensorStructInfo(DataType::Int(32), /*ndim=*/2, vdev),
diff --git a/src/relax/op/vision/nms.h b/src/relax/op/vision/nms.h
index 3fbd2609e2..83ca5b1bc0 100644
--- a/src/relax/op/vision/nms.h
+++ b/src/relax/op/vision/nms.h
@@ -44,8 +44,8 @@ Expr get_valid_counts(Expr data, double score_threshold, int 
id_index, int score
 /*! \brief Non-maximum suppression for object detection. */
 Expr non_max_suppression(Expr data, Expr valid_count, Expr indices, int 
max_output_size,
                          double iou_threshold, bool force_suppress, int top_k, 
int coord_start,
-                         int score_index, int id_index, bool return_indices,
-                         bool invalid_to_bottom);
+                         int score_index, int id_index, bool return_indices, 
bool invalid_to_bottom,
+                         double soft_nms_sigma = 0.0, double score_threshold = 
0.0);
 
 }  // namespace relax
 }  // namespace tvm
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index 15ca1cacf1..908868faf0 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -1360,7 +1360,7 @@ def test_batch_matmul_adj():
     verify(BatchMatMulAdj, Expected)
 
 
-def _verify_nms_v5(mod, tf_func, boxes_np, scores_np):
+def _verify_nms_v5(mod, tf_func, boxes_np, scores_np, soft_nms_sigma=0.0):
     """E2E verify for NMS: only run on nightly, compare valid outputs only."""
     if "CI_ENV_NIGHTLY" not in os.environ:
         return
@@ -1386,9 +1386,19 @@ def _verify_nms_v5(mod, tf_func, boxes_np, scores_np):
         rtol=1e-5,
         atol=1e-5,
     )
+    if soft_nms_sigma > 0.0:
+        np.testing.assert_allclose(
+            tf_scores.numpy(),
+            tvm_scores.numpy(),
+            rtol=1e-5,
+            atol=1e-5,
+        )
+        np.testing.assert_array_less(-1e-6, tvm_scores.numpy()[n_valid:])
 
 
-def _build_nms_v5_mod(num_boxes, max_output_size, iou_threshold, 
score_threshold):
+def _build_nms_v5_mod(
+    num_boxes, max_output_size, iou_threshold, score_threshold, 
soft_nms_sigma=0.0
+):
     """Convert a NonMaxSuppressionV5 TFLite model to a Relax module.
 
     Scalar params must be Python literals (not tf.constant) so TFLite can
@@ -1409,7 +1419,7 @@ def _build_nms_v5_mod(num_boxes, max_output_size, 
iou_threshold, score_threshold
                 max_output_size=max_output_size,
                 iou_threshold=iou_threshold,
                 score_threshold=score_threshold,
-                soft_nms_sigma=0.0,
+                soft_nms_sigma=soft_nms_sigma,
                 pad_to_max_output_size=True,
             )
             return indices, out_scores, valid
@@ -1647,6 +1657,83 @@ _NMS_V5_CASES = [
 ]
 
 
+_NMS_V5_SOFT_CASES = [
+    pytest.param(
+        6,
+        6,
+        0.5,
+        0.0,
+        0.5,
+        np.array(
+            [
+                [0.0, 0.0, 1.0, 1.0],
+                [0.0, 0.0, 1.0, 1.0],
+                [0.0, 0.1, 1.0, 1.1],
+                [0.0, 0.0, 1.0, 0.9],
+                [0.5, 0.5, 1.5, 1.5],
+                [0.0, 0.0, 0.3, 0.3],
+            ],
+            dtype=np.float32,
+        ),
+        np.array([0.9, 0.75, 0.6, 0.5, 0.4, 0.3], dtype=np.float32),
+        id="soft_nms_basic",
+    ),
+    pytest.param(
+        5,
+        5,
+        0.5,
+        0.0,
+        0.3,
+        np.array(
+            [
+                [0.0, 0.0, 1.0, 1.0],
+                [0.1, 0.1, 1.1, 1.1],
+                [0.2, 0.2, 1.2, 1.2],
+                [0.3, 0.3, 1.3, 1.3],
+                [2.0, 2.0, 3.0, 3.0],
+            ],
+            dtype=np.float32,
+        ),
+        np.array([0.9, 0.8, 0.7, 0.6, 0.5], dtype=np.float32),
+        id="soft_nms_tight_sigma",
+    ),
+    pytest.param(
+        3,
+        3,
+        0.5,
+        0.3,
+        0.1,
+        np.array(
+            [
+                [0.0, 0.0, 1.0, 1.0],
+                [0.2, 0.2, 1.2, 1.2],
+                [2.0, 2.0, 3.0, 3.0],
+            ],
+            dtype=np.float32,
+        ),
+        np.array([0.9, 0.8, 0.75], dtype=np.float32),
+        id="soft_nms_threshold_hole",
+    ),
+    pytest.param(
+        3,
+        3,
+        0.5,
+        0.0,
+        0.1,
+        np.array(
+            [
+                [0.0, 0.0, 1.0, 1.0],
+                [0.2, 0.2, 1.2, 1.2],
+                [2.0, 2.0, 3.0, 3.0],
+            ],
+            dtype=np.float32,
+        ),
+        np.array([0.9, 0.85, 0.8], dtype=np.float32),
+        id="soft_nms_reorder",
+    ),
+]
+
+
 @pytest.mark.parametrize(
     "num_boxes,max_output_size,iou_threshold,score_threshold,boxes,scores",
     _NMS_V5_CASES,
@@ -1657,6 +1744,20 @@ def test_nms_v5(num_boxes, max_output_size, 
iou_threshold, score_threshold, boxe
     _verify_nms_v5(mod, tf_func, boxes, scores)
 
 
[email protected](
+    
"num_boxes,max_output_size,iou_threshold,score_threshold,soft_nms_sigma,boxes,scores",
+    _NMS_V5_SOFT_CASES,
+)
+def test_nms_v5_soft(
+    num_boxes, max_output_size, iou_threshold, score_threshold, 
soft_nms_sigma, boxes, scores
+):
+    """NON_MAX_SUPPRESSION_V5 with soft_nms_sigma: conversion smoke test + E2E 
correctness."""
+    mod, tf_func = _build_nms_v5_mod(
+        num_boxes, max_output_size, iou_threshold, score_threshold, 
soft_nms_sigma
+    )
+    _verify_nms_v5(mod, tf_func, boxes, scores, soft_nms_sigma=soft_nms_sigma)
+
+
 def test_nms_v5_ir():
     """Verify the emitted Relax IR has correct structure for 
NON_MAX_SUPPRESSION_V5."""
     num_boxes = 6
@@ -1681,6 +1782,28 @@ def test_nms_v5_ir():
     assert f"R.Tensor(({max_output_size},)" in ir
 
 
+def test_nms_v5_soft_ir():
+    """Verify the emitted Relax IR passes soft_nms_sigma for 
NON_MAX_SUPPRESSION_V5."""
+    num_boxes = 6
+    max_output_size = 3
+    mod, _ = _build_nms_v5_mod(
+        num_boxes=num_boxes,
+        max_output_size=max_output_size,
+        iou_threshold=0.5,
+        score_threshold=0.0,
+        soft_nms_sigma=0.5,
+    )
+
+    ir = mod.script()
+
+    # soft_nms_sigma must appear in the IR
+    assert "soft_nms_sigma=0.5" in ir
+    # score_threshold must also be forwarded
+    assert "score_threshold=0.0" in ir
+    # Soft-NMS padded scores must be clipped to non-negative values.
+    assert "R.clip(" in ir
+
+
 _DETECTION_POSTPROCESS_SMOKE_CASES = [
     pytest.param(
         {
diff --git a/tests/python/relax/test_op_vision.py 
b/tests/python/relax/test_op_vision.py
index b597b325f4..ef260cf188 100644
--- a/tests/python/relax/test_op_vision.py
+++ b/tests/python/relax/test_op_vision.py
@@ -302,6 +302,26 @@ def test_nms_infer_struct_info_return_indices():
     )
 
 
+def test_nms_infer_struct_info_return_indices_soft_nms():
+    bb = relax.BlockBuilder()
+    data = relax.Var("data", R.Tensor((2, 10, 6), "float32"))
+    valid_count = relax.Var("valid_count", R.Tensor((2,), "int32"))
+    indices = relax.Var("indices", R.Tensor((2, 10), "int32"))
+    _check_inference(
+        bb,
+        relax.op.vision.non_max_suppression(
+            data, valid_count, indices, return_indices=True, soft_nms_sigma=0.5
+        ),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((2, 10, 6), "float32"),
+                relax.TensorStructInfo((2, 10), "int32"),
+                relax.TensorStructInfo((2, 1), "int32"),
+            ]
+        ),
+    )
+
+
 def test_nms_infer_struct_info_return_data():
     bb = relax.BlockBuilder()
     data = relax.Var("data", R.Tensor((2, 10, 6), "float32"))
@@ -457,6 +477,52 @@ def test_nms_legalize():
                 id_index=0,
                 return_indices=True,
                 invalid_to_bottom=False,
+                soft_nms_sigma=0.0,
+                score_threshold=0.0,
+            )
+            return gv
+
+    mod = LegalizeOps()(NMS)
+    _assert_relax_op_legalized(mod, "relax.vision.non_max_suppression")
+    tvm.ir.assert_structural_equal(
+        mod["main"].ret_struct_info,
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((1, 5), "int32"),
+                relax.TensorStructInfo((1, 1), "int32"),
+            ]
+        ),
+    )
+
+
+def test_nms_legalize_soft_nms():
+    @tvm.script.ir_module
+    class NMS:
+        @R.function
+        def main(
+            data: R.Tensor((1, 5, 6), "float32"),
+            valid_count: R.Tensor((1,), "int32"),
+            indices: R.Tensor((1, 5), "int32"),
+        ) -> R.Tuple(
+            R.Tensor((1, 5, 6), "float32"),
+            R.Tensor((1, 5), "int32"),
+            R.Tensor((1, 1), "int32"),
+        ):
+            gv = R.vision.non_max_suppression(
+                data,
+                valid_count,
+                indices,
+                max_output_size=-1,
+                iou_threshold=0.5,
+                force_suppress=False,
+                top_k=-1,
+                coord_start=2,
+                score_index=1,
+                id_index=0,
+                return_indices=True,
+                invalid_to_bottom=False,
+                soft_nms_sigma=0.5,
+                score_threshold=0.0,
             )
             return gv
 
@@ -466,6 +532,7 @@ def test_nms_legalize():
         mod["main"].ret_struct_info,
         relax.TupleStructInfo(
             [
+                relax.TensorStructInfo((1, 5, 6), "float32"),
                 relax.TensorStructInfo((1, 5), "int32"),
                 relax.TensorStructInfo((1, 1), "int32"),
             ]
@@ -495,6 +562,8 @@ def test_nms_legalize_return_data():
                 id_index=0,
                 return_indices=False,
                 invalid_to_bottom=True,
+                soft_nms_sigma=0.0,
+                score_threshold=0.0,
             )
             return gv
 
@@ -577,6 +646,8 @@ def _run_nms_e2e(
     id_index: int = 0,
     return_indices: bool = True,
     invalid_to_bottom: bool = False,
+    soft_nms_sigma: float = 0.0,
+    score_threshold: float = 0.0,
 ):
     """Run classic NMS through legalization and VM execution."""
 
@@ -603,6 +674,8 @@ def _run_nms_e2e(
                 id_index=id_index,
                 return_indices=return_indices,
                 invalid_to_bottom=invalid_to_bottom,
+                soft_nms_sigma=soft_nms_sigma,
+                score_threshold=score_threshold,
             )
         )
         bb.emit_func_output(result)
@@ -660,6 +733,57 @@ def test_nms_e2e_return_indices():
     tvm.testing.assert_allclose(result[1].numpy(), ref_valid_box_count)
 
 
[email protected]_llvm
+def test_nms_e2e_soft_nms_reorders_by_decayed_score():
+    """Soft-NMS should re-rank by decayed scores instead of keeping the 
initial order."""
+
+    raw_data = np.array(
+        [
+            [
+                [0.0, 0.90, 0.0, 0.0, 1.0, 1.0],
+                [0.0, 0.85, 0.2, 0.2, 1.2, 1.2],
+                [0.0, 0.80, 2.0, 2.0, 3.0, 3.0],
+                [-1.0, 0.99, 0.0, 0.0, 1.0, 1.0],
+            ]
+        ],
+        dtype="float32",
+    )
+    valid_count_np, filtered_data_np, filtered_indices_np = 
_prepare_nms_inputs(raw_data)
+    ref_out_data, ref_indices, ref_valid_box_count = 
tvm.topi.testing.non_max_suppression_python(
+        filtered_data_np,
+        valid_count_np,
+        filtered_indices_np,
+        max_output_size=-1,
+        iou_threshold=0.5,
+        force_suppress=True,
+        top_k=-1,
+        coord_start=2,
+        score_index=1,
+        id_index=-1,
+        return_indices=True,
+        invalid_to_bottom=False,
+        soft_nms_sigma=0.1,
+        score_threshold=0.0,
+    )
+    result = _run_nms_e2e(
+        filtered_data_np,
+        valid_count_np,
+        filtered_indices_np,
+        iou_threshold=0.5,
+        force_suppress=True,
+        id_index=-1,
+        return_indices=True,
+        invalid_to_bottom=False,
+        soft_nms_sigma=0.1,
+        score_threshold=0.0,
+    )
+
+    np.testing.assert_array_equal(ref_indices[0, :3], np.array([0, 2, 1], 
dtype="int32"))
+    tvm.testing.assert_allclose(result[0].numpy(), ref_out_data)
+    tvm.testing.assert_allclose(result[1].numpy(), ref_indices)
+    tvm.testing.assert_allclose(result[2].numpy(), ref_valid_box_count)
+
+
 @tvm.testing.requires_llvm
 def test_nms_e2e_return_indices_with_invalid_to_bottom():
     """Validate that invalid_to_bottom is a no-op when returning indices."""
diff --git a/tests/python/relax/test_tvmscript_parser_op_vision.py 
b/tests/python/relax/test_tvmscript_parser_op_vision.py
index 370b68769e..d4755ee367 100644
--- a/tests/python/relax/test_tvmscript_parser_op_vision.py
+++ b/tests/python/relax/test_tvmscript_parser_op_vision.py
@@ -126,6 +126,8 @@ def test_non_max_suppression_return_indices():
                 id_index=0,
                 return_indices=True,
                 invalid_to_bottom=False,
+                soft_nms_sigma=0.0,
+                score_threshold=0.0,
             )
         )
         return gv
@@ -150,6 +152,70 @@ def test_non_max_suppression_return_indices():
                 id_index=0,
                 return_indices=True,
                 invalid_to_bottom=False,
+                soft_nms_sigma=0.0,
+                score_threshold=0.0,
+            )
+        )
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+def test_non_max_suppression_return_indices_soft_nms():
+    @R.function
+    def foo(
+        data: R.Tensor((2, 5, 6), "float32"),
+        valid_count: R.Tensor((2,), "int32"),
+        indices: R.Tensor((2, 5), "int32"),
+    ) -> R.Tuple(
+        R.Tensor((2, 5, 6), "float32"),
+        R.Tensor((2, 5), "int32"),
+        R.Tensor((2, 1), "int32"),
+    ):
+        gv: R.Tuple(
+            R.Tensor((2, 5, 6), "float32"),
+            R.Tensor((2, 5), "int32"),
+            R.Tensor((2, 1), "int32"),
+        ) = R.vision.non_max_suppression(
+            data,
+            valid_count,
+            indices,
+            max_output_size=-1,
+            iou_threshold=0.5,
+            force_suppress=False,
+            top_k=3,
+            coord_start=2,
+            score_index=1,
+            id_index=0,
+            return_indices=True,
+            invalid_to_bottom=False,
+            soft_nms_sigma=0.5,
+            score_threshold=0.0,
+        )
+        return gv
+
+    data = relax.Var("data", R.Tensor((2, 5, 6), "float32"))
+    valid_count = relax.Var("valid_count", R.Tensor((2,), "int32"))
+    indices = relax.Var("indices", R.Tensor((2, 5), "int32"))
+
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [data, valid_count, indices]):
+        gv = bb.emit(
+            relax.op.vision.non_max_suppression(
+                data,
+                valid_count,
+                indices,
+                max_output_size=-1,
+                iou_threshold=0.5,
+                force_suppress=False,
+                top_k=3,
+                coord_start=2,
+                score_index=1,
+                id_index=0,
+                return_indices=True,
+                invalid_to_bottom=False,
+                soft_nms_sigma=0.5,
+                score_threshold=0.0,
             )
         )
         bb.emit_func_output(gv)
@@ -177,6 +243,8 @@ def test_non_max_suppression_return_data():
             id_index=0,
             return_indices=False,
             invalid_to_bottom=True,
+            soft_nms_sigma=0.0,
+            score_threshold=0.0,
         )
         return gv
 
@@ -200,6 +268,8 @@ def test_non_max_suppression_return_data():
                 id_index=0,
                 return_indices=False,
                 invalid_to_bottom=True,
+                soft_nms_sigma=0.0,
+                score_threshold=0.0,
             )
         )
         bb.emit_func_output(gv)

Reply via email to