This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b915cac0bf [Relax][Frontend][TFLite] Add soft-NMS support for TFLite
NON_MAX_SUPPRESSION_V5 (#19426)
b915cac0bf is described below
commit b915cac0bf5071fd23e8ef1643c7932826e2e05e
Author: HoYi <[email protected]>
AuthorDate: Tue Apr 28 04:42:34 2026 +0800
[Relax][Frontend][TFLite] Add soft-NMS support for TFLite
NON_MAX_SUPPRESSION_V5 (#19426)
## Summary
This PR completes the TFLite `NON_MAX_SUPPRESSION_V5` implementation in
Relax by adding support for `soft_nms_sigma != 0`.
It extends `relax.vision.non_max_suppression` with soft-NMS attributes,
updates the TFLite frontend to consume the soft-NMS outputs correctly,
and aligns the TOPI implementation with LiteRT's reference behavior.
Relates to #19412.
## Changes
1. **Relax / TOPI soft-NMS support**
- Extend `NonMaximumSuppressionAttrs` with `soft_nms_sigma` and
`score_threshold`.
- Thread the new attributes through Relax op registration, Python
wrapper, and legalization.
- Add soft-NMS handling to TOPI classic NMS so
`relax.vision.non_max_suppression` can represent the
`NON_MAX_SUPPRESSION_V5` behavior.
2. **TFLite frontend support for `NON_MAX_SUPPRESSION_V5`**
- Remove the previous `soft_nms_sigma != 0` unsupported-path guard in
the TFLite frontend.
- Forward `soft_nms_sigma` and `score_threshold` into
`relax.vision.non_max_suppression`.
- Handle the soft-NMS return path explicitly so the frontend reads
decayed scores from the processed NMS output instead of re-reading the
original score tensor.
3. **Soft-NMS correctness fixes**
- Fix the soft-NMS path so boxes whose scores fall below the threshold
after decay are invalidated consistently.
- Keep returned indices and decayed scores aligned in both the TOPI TIR
implementation and the NumPy reference implementation.
- Update the soft-NMS candidate selection logic to re-pick the current
best candidate after each decay step, matching LiteRT's
reference behavior.
- Align the Gaussian decay formula with LiteRT.
4. **Test coverage**
- Add Relax tests for soft-NMS struct-info inference and legalization.
- Add Relax E2E tests covering reordered outputs after score decay and
other soft-NMS follow-up cases.
- Add TFLite frontend tests for `NON_MAX_SUPPRESSION_V5` with
`soft_nms_sigma != 0`.
- Add IR checks to verify that `soft_nms_sigma` and `score_threshold`
are forwarded correctly.
## Testing
```bash
python -m pytest -n 1 tests/python/relax/test_op_vision.py -k
"all_class_non_max_suppression or get_valid_counts or nms" -v
python -m pytest tests/python/relax/test_frontend_tflite.py -k "nms_v5" -v
```
## Result:
- Relax vision tests passed locally
- TFLite `NON_MAX_SUPPRESSION_V5` coverage added for both hard-NMS and
soft-NMS paths
---
include/tvm/relax/attrs/vision.h | 8 +-
.../tvm/relax/frontend/tflite/tflite_frontend.py | 39 ++-
python/tvm/relax/op/vision/nms.py | 24 +-
python/tvm/relax/transform/legalize_ops/vision.py | 2 +
python/tvm/topi/testing/nms_python.py | 78 ++++-
python/tvm/topi/vision/nms.py | 326 ++++++++++++++-------
src/relax/op/vision/nms.cc | 29 +-
src/relax/op/vision/nms.h | 4 +-
tests/python/relax/test_frontend_tflite.py | 129 +++++++-
tests/python/relax/test_op_vision.py | 124 ++++++++
.../relax/test_tvmscript_parser_op_vision.py | 70 +++++
11 files changed, 710 insertions(+), 123 deletions(-)
diff --git a/include/tvm/relax/attrs/vision.h b/include/tvm/relax/attrs/vision.h
index 8971127d76..9dec6fd503 100644
--- a/include/tvm/relax/attrs/vision.h
+++ b/include/tvm/relax/attrs/vision.h
@@ -122,6 +122,8 @@ struct NonMaximumSuppressionAttrs
int id_index;
bool return_indices;
bool invalid_to_bottom;
+ double soft_nms_sigma;
+ double score_threshold;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
@@ -143,7 +145,11 @@ struct NonMaximumSuppressionAttrs
.def_ro("return_indices", &NonMaximumSuppressionAttrs::return_indices,
"Whether to return box indices in input data.")
.def_ro("invalid_to_bottom",
&NonMaximumSuppressionAttrs::invalid_to_bottom,
- "Whether to move all valid bounding boxes to the top.");
+ "Whether to move all valid bounding boxes to the top.")
+ .def_ro("soft_nms_sigma", &NonMaximumSuppressionAttrs::soft_nms_sigma,
+ "Sigma for soft-NMS; 0.0 means standard hard NMS.")
+ .def_ro("score_threshold",
&NonMaximumSuppressionAttrs::score_threshold,
+ "Score threshold for soft-NMS validity check; 0.0 when
unused.");
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.NonMaximumSuppressionAttrs",
NonMaximumSuppressionAttrs, BaseAttrsNode);
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index d773d8d7ce..732950ca68 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -3617,10 +3617,6 @@ class OperatorConverter:
if isinstance(soft_nms_sigma, np.ndarray):
assert soft_nms_sigma.size == 1, "only one value is expected."
soft_nms_sigma = float(soft_nms_sigma)
- if soft_nms_sigma != 0.0:
- raise tvm.error.OpNotImplemented(
- "It is soft_nms when soft_nms_sigma != 0, which is not
supported!"
- )
scores_expand = relax.op.expand_dims(scores, axis=-1)
data = relax.op.concat([scores_expand, boxes], axis=-1)
@@ -3646,18 +3642,41 @@ class OperatorConverter:
id_index=-1,
return_indices=True,
invalid_to_bottom=False,
+ soft_nms_sigma=soft_nms_sigma,
+ score_threshold=score_threshold,
)
- selected_indices = relax.op.squeeze(nms_ret[0], axis=[0])
+ if soft_nms_sigma > 0.0:
+ processed_data = relax.op.squeeze(nms_ret[0], axis=[0])
+ indices_from_nms = nms_ret[1]
+ num_valid_from_nms = nms_ret[2]
+ else:
+ indices_from_nms = nms_ret[0]
+ num_valid_from_nms = nms_ret[1]
+
+ selected_indices = relax.op.squeeze(indices_from_nms, axis=[0])
selected_indices = relax.op.strided_slice(
selected_indices, axes=[0], begin=[0], end=[max_output_size]
)
- num_valid = relax.op.reshape(nms_ret[1], [])
+ num_valid = relax.op.reshape(num_valid_from_nms, [])
- # Clamp out-of-bound padded indices to prevent take() crash.
- num_boxes = int(self.get_tensor_shape(input_tensors[0])[0])
- safe_indices = relax.op.clip(selected_indices, min=0, max=num_boxes -
1)
- selected_scores = relax.op.take(scores, safe_indices, axis=0)
+ if soft_nms_sigma > 0.0:
+ # Extract decayed scores from the processed data (score_index=0)
+ selected_scores = relax.op.strided_slice(
+ processed_data, axes=[1], begin=[0], end=[1]
+ )
+ selected_scores = relax.op.squeeze(selected_scores, axis=[1])
+ selected_scores = relax.op.strided_slice(
+ selected_scores, axes=[0], begin=[0], end=[max_output_size]
+ )
+ selected_scores = relax.op.clip(
+ selected_scores, min=0.0, max=float(np.finfo("float32").max)
+ )
+ else:
+ # Clamp out-of-bound padded indices to prevent take() crash.
+ num_boxes = int(self.get_tensor_shape(input_tensors[0])[0])
+ safe_indices = relax.op.clip(selected_indices, min=0,
max=num_boxes - 1)
+ selected_scores = relax.op.take(scores, safe_indices, axis=0)
out = relax.Tuple([selected_indices, selected_scores, num_valid])
return out
diff --git a/python/tvm/relax/op/vision/nms.py
b/python/tvm/relax/op/vision/nms.py
index 4eb3eb7f7a..427d68d113 100644
--- a/python/tvm/relax/op/vision/nms.py
+++ b/python/tvm/relax/op/vision/nms.py
@@ -115,6 +115,8 @@ def non_max_suppression(
id_index=0,
return_indices=True,
invalid_to_bottom=False,
+ soft_nms_sigma=0.0,
+ score_threshold=0.0,
):
"""Non-maximum suppression operator for object detection.
@@ -160,12 +162,28 @@ def non_max_suppression(
Whether to move valid bounding boxes to the top of the returned tensor.
This option only affects the ``return_indices=False`` path.
+ soft_nms_sigma : float, optional
+ Sigma for soft-NMS Gaussian penalty. When ``0.0`` (default), standard
+ hard NMS is used. Positive values decay overlapping box scores instead
+ of suppressing them outright.
+
+ score_threshold : float, optional
+ Post-decay minimum score for a box to remain eligible during soft-NMS.
+ Only used when ``soft_nms_sigma > 0``. This is distinct from
+ ``get_valid_counts.score_threshold``, which filters boxes before NMS.
+ Defaults to ``0.0``.
+
Returns
-------
out : relax.Expr
- If ``return_indices`` is ``True``, returns
- ``(box_indices, valid_box_count)`` with shapes
+ The return tuple shape depends on ``soft_nms_sigma``.
+ If ``return_indices`` is ``True`` and ``soft_nms_sigma`` is ``0.0``,
+ returns a 2-tuple ``(box_indices, valid_box_count)`` with shapes
``[batch_size, num_anchors]`` and ``[batch_size, 1]``.
+ If ``return_indices`` is ``True`` and ``soft_nms_sigma > 0``,
+ returns a 3-tuple ``(out_data, box_indices, valid_box_count)`` where
+ decayed ``out_data`` is prepended and has the same shape as the input
+ data.
Otherwise returns the modified data tensor.
"""
return _ffi_api.non_max_suppression(
@@ -181,4 +199,6 @@ def non_max_suppression(
id_index,
return_indices,
invalid_to_bottom,
+ soft_nms_sigma,
+ score_threshold,
)
diff --git a/python/tvm/relax/transform/legalize_ops/vision.py
b/python/tvm/relax/transform/legalize_ops/vision.py
index c515fc8fe8..4419549164 100644
--- a/python/tvm/relax/transform/legalize_ops/vision.py
+++ b/python/tvm/relax/transform/legalize_ops/vision.py
@@ -152,6 +152,8 @@ def _non_max_suppression(block_builder: BlockBuilder, call:
Call) -> Expr:
id_index=call.attrs.id_index,
return_indices=call.attrs.return_indices,
invalid_to_bottom=call.attrs.invalid_to_bottom,
+ soft_nms_sigma=call.attrs.soft_nms_sigma,
+ score_threshold=call.attrs.score_threshold,
)
diff --git a/python/tvm/topi/testing/nms_python.py
b/python/tvm/topi/testing/nms_python.py
index 7c8c20f5b4..c8711c70dd 100644
--- a/python/tvm/topi/testing/nms_python.py
+++ b/python/tvm/topi/testing/nms_python.py
@@ -46,6 +46,8 @@ def non_max_suppression_python(
id_index=0,
return_indices=True,
invalid_to_bottom=False,
+ soft_nms_sigma=0.0,
+ score_threshold=0.0,
):
"""Numpy reference for classic non_max_suppression.
@@ -62,7 +64,9 @@ def non_max_suppression_python(
Returns
-------
- If return_indices is True: (box_indices, valid_box_count)
+ If return_indices is True and soft_nms_sigma == 0.0: (box_indices,
valid_box_count)
+ If return_indices is True and soft_nms_sigma > 0.0:
+ (out_data, box_indices, valid_box_count)
Otherwise: modified data tensor
"""
batch_size, num_anchors, _ = data.shape
@@ -71,6 +75,10 @@ def non_max_suppression_python(
compacted = np.full((batch_size, num_anchors), -1, dtype="int32")
valid_box_count = np.zeros((batch_size, 1), dtype="int32")
+ is_soft_nms = soft_nms_sigma > 0.0
+ thresh = score_threshold if is_soft_nms else 0.0
+ soft_nms_scale = -0.5 / soft_nms_sigma if is_soft_nms else 0.0
+
for i in range(batch_size):
nkeep = int(valid_count[i])
if 0 < top_k < nkeep:
@@ -86,10 +94,72 @@ def non_max_suppression_python(
out_data[i, j, :] = data[i, src, :]
out_box_indices[i, j] = src
+ if is_soft_nms:
+ num_selected = 0
+ while num_selected < nkeep and (max_output_size < 0 or
num_selected < max_output_size):
+ best_idx = -1
+ best_score = thresh
+ for j in range(num_selected, nkeep):
+ if out_box_indices[i, j] >= 0 and out_data[i, j,
score_index] > best_score:
+ best_idx = j
+ best_score = out_data[i, j, score_index]
+
+ if best_idx < 0:
+ break
+
+ if best_idx != num_selected:
+ out_data[i, [num_selected, best_idx], :] = out_data[
+ i, [best_idx, num_selected], :
+ ]
+ out_box_indices[i, [num_selected, best_idx]] =
out_box_indices[
+ i, [best_idx, num_selected]
+ ]
+
+ selected_idx = num_selected
+ for j in range(selected_idx + 1, nkeep):
+ if out_box_indices[i, j] < 0 or out_data[i, j,
score_index] <= thresh:
+ continue
+
+ do_suppress = False
+ if force_suppress:
+ do_suppress = True
+ elif id_index >= 0:
+ do_suppress = (
+ out_data[i, selected_idx, id_index] == out_data[i,
j, id_index]
+ )
+ else:
+ do_suppress = True
+
+ if not do_suppress:
+ continue
+
+ iou = _iou(out_data[i, selected_idx], out_data[i, j],
coord_start)
+ if iou >= iou_threshold:
+ out_box_indices[i, j] = -1
+ else:
+ out_data[i, j, score_index] *= np.exp(soft_nms_scale *
(iou**2))
+ if out_data[i, j, score_index] <= thresh:
+ out_box_indices[i, j] = -1
+
+ num_selected += 1
+
+ valid_box_count[i, 0] = num_selected
+ if return_indices:
+ for j in range(num_selected):
+ orig_idx = out_box_indices[i, j]
+ compacted[i, j] = int(indices[i, orig_idx])
+ out_box_indices[i, j] = compacted[i, j]
+ for j in range(num_selected, num_anchors):
+ out_data[i, j, :] = -1.0
+ out_box_indices[i, j] = -1
+ else:
+ out_data[i, num_selected:, :] = -1.0
+ continue
+
# Greedy NMS
num_valid = 0
for j in range(nkeep):
- if out_data[i, j, score_index] <= 0:
+ if out_data[i, j, score_index] <= thresh:
out_data[i, j, :] = -1.0
out_box_indices[i, j] = -1
continue
@@ -102,7 +172,7 @@ def non_max_suppression_python(
# Suppress overlapping boxes
for k in range(j + 1, nkeep):
- if out_data[i, k, score_index] <= 0:
+ if out_data[i, k, score_index] <= thresh:
continue
do_suppress = False
@@ -130,6 +200,8 @@ def non_max_suppression_python(
valid_box_count[i, 0] = cnt
if return_indices:
+ if is_soft_nms:
+ return [out_data, compacted, valid_box_count]
return [compacted, valid_box_count]
if invalid_to_bottom:
diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py
index a602527fcc..ad548978a1 100644
--- a/python/tvm/topi/vision/nms.py
+++ b/python/tvm/topi/vision/nms.py
@@ -188,6 +188,8 @@ def _classic_nms_ir(
out_data,
out_box_indices,
out_valid_box_count,
+ soft_nms_sigma=0.0,
+ score_threshold=0.0,
):
"""IR for classic single-class non-maximum suppression."""
with IRBuilder() as ib:
@@ -200,6 +202,10 @@ def _classic_nms_ir(
if out_valid_box_count is not None:
out_valid_box_count = T.buffer_proxy(out_valid_box_count)
+ is_soft_nms = soft_nms_sigma > 0.0
+ # For hard NMS the historical threshold is 0.0; for soft NMS use
score_threshold.
+ thresh = tvm.tirx.Cast(data.dtype, T.float32(score_threshold if
is_soft_nms else 0.0))
+
with T.parallel(0, batch_size) as i:
# Step 1: Reorder data by sorted score
nkeep_buf = T.alloc_buffer((1,), "int32", scope="local")
@@ -226,117 +232,220 @@ def _classic_nms_ir(
num_valid_boxes_buf = T.alloc_buffer((1,), "int32", scope="local")
num_valid_boxes = T.buffer_proxy(num_valid_boxes_buf)
num_valid_boxes[0] = T.int32(0)
+ best_idx_buf = T.alloc_buffer((1,), "int32", scope="local")
+ best_idx = T.buffer_proxy(best_idx_buf)
+ best_score_buf = T.alloc_buffer((1,), data.dtype, scope="local")
+ best_score = T.buffer_proxy(best_score_buf)
+ tmp_idx_buf = T.alloc_buffer((1,), "int32", scope="local")
+ tmp_idx = T.buffer_proxy(tmp_idx_buf)
+ tmp_val_buf = T.alloc_buffer((1,), data.dtype, scope="local")
+ tmp_val = T.buffer_proxy(tmp_val_buf)
+ zero = tvm.tirx.Cast(data.dtype, T.float32(0.0))
+
+ def compute_iou(lhs_idx, rhs_idx):
+ lhs_l = tvm.te.min(
+ out_data[i, lhs_idx, coord_start],
+ out_data[i, lhs_idx, coord_start + 2],
+ )
+ lhs_t = tvm.te.min(
+ out_data[i, lhs_idx, coord_start + 1],
+ out_data[i, lhs_idx, coord_start + 3],
+ )
+ lhs_r = tvm.te.max(
+ out_data[i, lhs_idx, coord_start],
+ out_data[i, lhs_idx, coord_start + 2],
+ )
+ lhs_b = tvm.te.max(
+ out_data[i, lhs_idx, coord_start + 1],
+ out_data[i, lhs_idx, coord_start + 3],
+ )
+ rhs_l = tvm.te.min(
+ out_data[i, rhs_idx, coord_start],
+ out_data[i, rhs_idx, coord_start + 2],
+ )
+ rhs_t = tvm.te.min(
+ out_data[i, rhs_idx, coord_start + 1],
+ out_data[i, rhs_idx, coord_start + 3],
+ )
+ rhs_r = tvm.te.max(
+ out_data[i, rhs_idx, coord_start],
+ out_data[i, rhs_idx, coord_start + 2],
+ )
+ rhs_b = tvm.te.max(
+ out_data[i, rhs_idx, coord_start + 1],
+ out_data[i, rhs_idx, coord_start + 3],
+ )
+ width = tvm.te.max(zero, tvm.te.min(lhs_r, rhs_r) -
tvm.te.max(lhs_l, rhs_l))
+ height = tvm.te.max(zero, tvm.te.min(lhs_b, rhs_b) -
tvm.te.max(lhs_t, rhs_t))
+ intersection = height * width
+ union = (
+ (lhs_r - lhs_l) * (lhs_b - lhs_t)
+ + (rhs_r - rhs_l) * (rhs_b - rhs_t)
+ - intersection
+ )
+ return tvm.tirx.Select(union <= zero, zero, intersection /
union)
+
+ if is_soft_nms:
+ # LiteRT soft-NMS selects the current highest-score candidate
each round.
+ soft_nms_scale = tvm.tirx.Cast(data.dtype, T.float32(-0.5 /
soft_nms_sigma))
- with T.serial(0, nkeep_local[0]) as j:
- # Check if box j is still valid (score > 0) and within
max_output_size
- with T.If(
- tvm.tirx.all(
- out_data[i, j, score_index] >
tvm.tirx.Cast(data.dtype, T.float32(0.0)),
+ with T.serial(0, nkeep_local[0]) as _:
+ with T.If(
tvm.tirx.Select(
max_output_size > 0,
num_valid_boxes[0] < max_output_size,
tvm.tirx.const(True),
- ),
- )
- ):
- with T.Then():
- num_valid_boxes[0] = num_valid_boxes[0] + 1
-
- # Suppress overlapping boxes
- with T.serial(0, nkeep_local[0]) as k:
- with T.If(
- tvm.tirx.all(
- k > j,
- out_data[i, k, score_index]
- > tvm.tirx.Cast(data.dtype,
T.float32(0.0)),
- )
- ):
+ )
+ ):
+ with T.Then():
+ best_idx[0] = T.int32(-1)
+ best_score[0] = thresh
+
+ with T.serial(0, nkeep_local[0]) as j:
+ with T.If(
+ tvm.tirx.all(
+ j >= num_valid_boxes[0],
+ out_box_indices[i, j] >= 0,
+ out_data[i, j, score_index] >
best_score[0],
+ )
+ ):
+ with T.Then():
+ best_idx[0] = j
+ best_score[0] = out_data[i, j,
score_index]
+
+ with T.If(best_idx[0] >= 0):
with T.Then():
- # Check class ID match (or force_suppress)
- do_suppress = tvm.tirx.const(False)
- if force_suppress:
- do_suppress = tvm.tirx.const(True)
- elif id_index >= 0:
- do_suppress = (
- out_data[i, j, id_index] ==
out_data[i, k, id_index]
- )
- else:
- do_suppress = tvm.tirx.const(True)
-
- with T.If(do_suppress):
+ with T.If(best_idx[0] !=
num_valid_boxes[0]):
with T.Then():
- # Calculate IoU
- a_l = tvm.te.min(
- out_data[i, j, coord_start],
- out_data[i, j, coord_start +
2],
- )
- a_t = tvm.te.min(
- out_data[i, j, coord_start +
1],
- out_data[i, j, coord_start +
3],
- )
- a_r = tvm.te.max(
- out_data[i, j, coord_start],
- out_data[i, j, coord_start +
2],
- )
- a_b = tvm.te.max(
- out_data[i, j, coord_start +
1],
- out_data[i, j, coord_start +
3],
+ tmp_idx[0] = out_box_indices[i,
num_valid_boxes[0]]
+ out_box_indices[
+ i, num_valid_boxes[0]
+ ] = out_box_indices[i, best_idx[0]]
+ out_box_indices[i, best_idx[0]] =
tmp_idx[0]
+
+ with T.serial(0, box_data_length)
as k:
+ tmp_val[0] = out_data[i,
num_valid_boxes[0], k]
+ out_data[i,
num_valid_boxes[0], k] = out_data[
+ i, best_idx[0], k
+ ]
+ out_data[i, best_idx[0], k] =
tmp_val[0]
+
+ with T.serial(0, nkeep_local[0]) as j:
+ with T.If(
+ tvm.tirx.all(
+ j > num_valid_boxes[0],
+ out_box_indices[i, j] >= 0,
+ out_data[i, j, score_index] >
thresh,
)
+ ):
+ with T.Then():
+ do_suppress =
tvm.tirx.const(False)
+ if force_suppress:
+ do_suppress =
tvm.tirx.const(True)
+ elif id_index >= 0:
+ do_suppress = (
+ out_data[i,
num_valid_boxes[0], id_index]
+ == out_data[i, j,
id_index]
+ )
+ else:
+ do_suppress =
tvm.tirx.const(True)
+
+ with T.If(do_suppress):
+ with T.Then():
+ iou =
compute_iou(num_valid_boxes[0], j)
+
+ with T.If(iou >=
iou_threshold):
+ with T.Then():
+
out_box_indices[i, j] = T.int32(-1)
+ with T.If(iou <
iou_threshold):
+ with T.Then():
+ out_data[i, j,
score_index] = (
+
out_data[i, j, score_index]
+ *
tvm.tirx.exp(
+
soft_nms_scale
+ * iou
+ * iou
+ )
+ )
+ with T.If(
+
out_data[i, j, score_index]
+ <= thresh
+ ):
+ with
T.Then():
+
out_box_indices[
+ i,
j
+ ] =
T.int32(-1)
+
+ num_valid_boxes[0] = num_valid_boxes[0] + 1
+
+ if return_indices:
+ out_valid_box_count[i, 0] = num_valid_boxes[0]
- b_l = tvm.te.min(
- out_data[i, k, coord_start],
- out_data[i, k, coord_start +
2],
- )
- b_t = tvm.te.min(
- out_data[i, k, coord_start +
1],
- out_data[i, k, coord_start +
3],
- )
- b_r = tvm.te.max(
- out_data[i, k, coord_start],
- out_data[i, k, coord_start +
2],
- )
- b_b = tvm.te.max(
- out_data[i, k, coord_start +
1],
- out_data[i, k, coord_start +
3],
- )
+ with T.serial(0, num_anchors) as j:
+ with T.If(j < num_valid_boxes[0]):
+ with T.Then():
+ orig_idx = out_box_indices[i, j]
+ out_box_indices[i, j] = indices[i, orig_idx]
+ with T.If(j >= num_valid_boxes[0]):
+ with T.Then():
+ with T.serial(0, box_data_length) as k:
+ out_data[i, j, k] = tvm.tirx.Cast(
+ data.dtype, T.float32(-1.0)
+ )
+ out_box_indices[i, j] = T.int32(-1)
+ else:
+ with T.serial(0, num_anchors) as j:
+ with T.If(j >= num_valid_boxes[0]):
+ with T.Then():
+ with T.serial(0, box_data_length) as k:
+ out_data[i, j, k] =
tvm.tirx.Cast(data.dtype, T.float32(-1.0))
+ else:
+ with T.serial(0, nkeep_local[0]) as j:
+ with T.If(
+ tvm.tirx.all(
+ out_data[i, j, score_index] > thresh,
+ tvm.tirx.Select(
+ max_output_size > 0,
+ num_valid_boxes[0] < max_output_size,
+ tvm.tirx.const(True),
+ ),
+ )
+ ):
+ with T.Then():
+ num_valid_boxes[0] = num_valid_boxes[0] + 1
- w = tvm.te.max(
- tvm.tirx.Cast(data.dtype,
T.float32(0.0)),
- tvm.te.min(a_r, b_r) -
tvm.te.max(a_l, b_l),
- )
- h = tvm.te.max(
- tvm.tirx.Cast(data.dtype,
T.float32(0.0)),
- tvm.te.min(a_b, b_b) -
tvm.te.max(a_t, b_t),
- )
- area = h * w
- u = (
- (a_r - a_l) * (a_b - a_t)
- + (b_r - b_l) * (b_b - b_t)
- - area
- )
- iou = tvm.tirx.Select(
- u <= tvm.tirx.Cast(data.dtype,
T.float32(0.0)),
- tvm.tirx.Cast(data.dtype,
T.float32(0.0)),
- area / u,
+ with T.serial(0, nkeep_local[0]) as k:
+ with T.If(
+ tvm.tirx.all(k > j, out_data[i, k,
score_index] > thresh)
+ ):
+ with T.Then():
+ do_suppress = tvm.tirx.const(False)
+ if force_suppress:
+ do_suppress = tvm.tirx.const(True)
+ elif id_index >= 0:
+ do_suppress = (
+ out_data[i, j, id_index] ==
out_data[i, k, id_index]
)
+ else:
+ do_suppress = tvm.tirx.const(True)
- with T.If(iou >= iou_threshold):
- with T.Then():
- out_data[i, k,
score_index] = tvm.tirx.Cast(
- data.dtype,
T.float32(-1.0)
- )
- out_box_indices[i, k] =
T.int32(-1)
+ with T.If(do_suppress):
+ with T.Then():
+ iou = compute_iou(j, k)
- with T.Else():
- # Box suppressed or beyond max_output_size
- with T.serial(0, box_data_length) as k:
- out_data[i, j, k] = tvm.tirx.Cast(data.dtype,
T.float32(-1.0))
- out_box_indices[i, j] = T.int32(-1)
+ with T.If(iou >=
iou_threshold):
+ with T.Then():
+ out_data[i, k,
score_index] = tvm.tirx.Cast(
+ data.dtype,
T.float32(-1.0)
+ )
+ out_box_indices[i, k]
= T.int32(-1)
+
+ with T.Else():
+ with T.serial(0, box_data_length) as k:
+ out_data[i, j, k] = tvm.tirx.Cast(data.dtype,
T.float32(-1.0))
+ out_box_indices[i, j] = T.int32(-1)
- # Step 3: If return_indices, remap to original indices
- if return_indices:
- if out_valid_box_count is not None:
- # Count valid boxes and remap indices
+ if return_indices:
valid_idx_buf = T.alloc_buffer((1,), "int32",
scope="local")
valid_idx = T.buffer_proxy(valid_idx_buf)
valid_idx[0] = T.int32(0)
@@ -350,7 +459,6 @@ def _classic_nms_ir(
out_valid_box_count[i, 0] = valid_idx[0]
- # Fill remaining with -1
with T.serial(0, num_anchors) as j:
with T.If(j >= valid_idx[0]):
with T.Then():
@@ -372,6 +480,8 @@ def non_max_suppression(
id_index=0,
return_indices=True,
invalid_to_bottom=False,
+ soft_nms_sigma=0.0,
+ score_threshold=0.0,
):
"""Non-maximum suppression operator for object detection.
@@ -416,10 +526,24 @@ def non_max_suppression(
invalid_to_bottom : optional, boolean
Whether to move all valid bounding boxes to the top.
+ soft_nms_sigma : optional, float
+ Sigma for soft-NMS Gaussian penalty. 0.0 means standard hard NMS.
+
+ score_threshold : optional, float
+ Post-decay minimum score for a box to remain eligible during soft-NMS.
+ Only used when ``soft_nms_sigma > 0``. This is distinct from
+ ``get_valid_counts.score_threshold``, which filters boxes before NMS.
+
Returns
-------
out : tvm.te.Tensor or tuple of tvm.te.Tensor
- If return_indices is True, returns a tuple of (box_indices,
valid_box_count).
+ The return tuple shape depends on ``soft_nms_sigma``.
+ If ``return_indices`` is ``True`` and ``soft_nms_sigma`` is ``0.0``,
+ returns a 2-tuple ``(box_indices, valid_box_count)``.
+ If ``return_indices`` is ``True`` and ``soft_nms_sigma > 0``,
+ returns a 3-tuple ``(out_data, box_indices, valid_box_count)`` where
+ decayed ``out_data`` is prepended and has the same shape as the input
+ data.
Otherwise returns the modified data tensor.
"""
batch_size = data.shape[0]
@@ -464,6 +588,7 @@ def non_max_suppression(
coord_start, score_index, id_index,
return_indices,
outs[0], outs[1], outs[2],
+ soft_nms_sigma, score_threshold,
),
dtype=[data.dtype, "int32", "int32"],
out_buffers=[out_data_buf, out_box_indices_buf,
out_valid_box_count_buf],
@@ -471,6 +596,8 @@ def non_max_suppression(
name="non_max_suppression",
tag="non_max_suppression",
)
+ if soft_nms_sigma > 0.0:
+ return [out_data, out_box_indices, out_valid_box_count]
return [out_box_indices, out_valid_box_count]
out_data, out_box_indices = te.extern(
@@ -484,6 +611,7 @@ def non_max_suppression(
coord_start, score_index, id_index,
return_indices,
outs[0], outs[1], None,
+ soft_nms_sigma, score_threshold,
),
dtype=[data.dtype, "int32"],
out_buffers=[out_data_buf, out_box_indices_buf],
diff --git a/src/relax/op/vision/nms.cc b/src/relax/op/vision/nms.cc
index 97508d7211..9d7144b5d7 100644
--- a/src/relax/op/vision/nms.cc
+++ b/src/relax/op/vision/nms.cc
@@ -196,8 +196,8 @@ TVM_REGISTER_OP("relax.vision.get_valid_counts")
Expr non_max_suppression(Expr data, Expr valid_count, Expr indices, int
max_output_size,
double iou_threshold, bool force_suppress, int top_k,
int coord_start,
- int score_index, int id_index, bool return_indices,
- bool invalid_to_bottom) {
+ int score_index, int id_index, bool return_indices,
bool invalid_to_bottom,
+ double soft_nms_sigma, double score_threshold) {
auto attrs = tvm::ffi::make_object<NonMaximumSuppressionAttrs>();
attrs->max_output_size = max_output_size;
attrs->iou_threshold = iou_threshold;
@@ -208,6 +208,8 @@ Expr non_max_suppression(Expr data, Expr valid_count, Expr
indices, int max_outp
attrs->id_index = id_index;
attrs->return_indices = return_indices;
attrs->invalid_to_bottom = invalid_to_bottom;
+ attrs->soft_nms_sigma = soft_nms_sigma;
+ attrs->score_threshold = score_threshold;
static const Op& op = Op::Get("relax.vision.non_max_suppression");
return Call(op, {std::move(data), std::move(valid_count),
std::move(indices)}, Attrs(attrs), {});
@@ -319,7 +321,28 @@ StructInfo InferStructInfoNMS(const Call& call, const
BlockBuilder& ctx) {
}
if (attrs->return_indices) {
- // Returns (box_indices[batch, num_anchors], valid_box_count[batch, 1])
+ if (attrs->soft_nms_sigma > 0.0) {
+ // Soft-NMS returns (out_data[batch, num_anchors, elem_length],
+ // box_indices[batch, num_anchors],
+ // valid_box_count[batch, 1])
+ if (data_shape == nullptr) {
+ tvm::ffi::Array<StructInfo> fields = {
+ TensorStructInfo(data_sinfo->dtype, /*ndim=*/3, vdev),
+ TensorStructInfo(DataType::Int(32), /*ndim=*/2, vdev),
+ TensorStructInfo(DataType::Int(32), /*ndim=*/2, vdev)};
+ return TupleStructInfo(fields);
+ }
+ auto batch = data_shape->values[0];
+ auto num_anchors = data_shape->values[1];
+ tvm::ffi::Array<StructInfo> fields = {
+ TensorStructInfo(ffi::GetRef<ShapeExpr>(data_shape),
data_sinfo->dtype, vdev),
+ TensorStructInfo(ShapeExpr({batch, num_anchors}), DataType::Int(32),
vdev),
+ TensorStructInfo(ShapeExpr({batch, IntImm(DataType::Int(64), 1)}),
DataType::Int(32),
+ vdev)};
+ return TupleStructInfo(fields);
+ }
+
+ // Hard NMS returns (box_indices[batch, num_anchors],
valid_box_count[batch, 1])
if (data_shape == nullptr) {
tvm::ffi::Array<StructInfo> fields = {
TensorStructInfo(DataType::Int(32), /*ndim=*/2, vdev),
diff --git a/src/relax/op/vision/nms.h b/src/relax/op/vision/nms.h
index 3fbd2609e2..83ca5b1bc0 100644
--- a/src/relax/op/vision/nms.h
+++ b/src/relax/op/vision/nms.h
@@ -44,8 +44,8 @@ Expr get_valid_counts(Expr data, double score_threshold, int
id_index, int score
/*! \brief Non-maximum suppression for object detection. */
Expr non_max_suppression(Expr data, Expr valid_count, Expr indices, int
max_output_size,
double iou_threshold, bool force_suppress, int top_k,
int coord_start,
- int score_index, int id_index, bool return_indices,
- bool invalid_to_bottom);
+ int score_index, int id_index, bool return_indices,
bool invalid_to_bottom,
+ double soft_nms_sigma = 0.0, double score_threshold =
0.0);
} // namespace relax
} // namespace tvm
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index 15ca1cacf1..908868faf0 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -1360,7 +1360,7 @@ def test_batch_matmul_adj():
verify(BatchMatMulAdj, Expected)
-def _verify_nms_v5(mod, tf_func, boxes_np, scores_np):
+def _verify_nms_v5(mod, tf_func, boxes_np, scores_np, soft_nms_sigma=0.0):
"""E2E verify for NMS: only run on nightly, compare valid outputs only."""
if "CI_ENV_NIGHTLY" not in os.environ:
return
@@ -1386,9 +1386,19 @@ def _verify_nms_v5(mod, tf_func, boxes_np, scores_np):
rtol=1e-5,
atol=1e-5,
)
+ if soft_nms_sigma > 0.0:
+ np.testing.assert_allclose(
+ tf_scores.numpy(),
+ tvm_scores.numpy(),
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ np.testing.assert_array_less(-1e-6, tvm_scores.numpy()[n_valid:])
-def _build_nms_v5_mod(num_boxes, max_output_size, iou_threshold,
score_threshold):
+def _build_nms_v5_mod(
+ num_boxes, max_output_size, iou_threshold, score_threshold,
soft_nms_sigma=0.0
+):
"""Convert a NonMaxSuppressionV5 TFLite model to a Relax module.
Scalar params must be Python literals (not tf.constant) so TFLite can
@@ -1409,7 +1419,7 @@ def _build_nms_v5_mod(num_boxes, max_output_size,
iou_threshold, score_threshold
max_output_size=max_output_size,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
- soft_nms_sigma=0.0,
+ soft_nms_sigma=soft_nms_sigma,
pad_to_max_output_size=True,
)
return indices, out_scores, valid
@@ -1647,6 +1657,83 @@ _NMS_V5_CASES = [
]
+_NMS_V5_SOFT_CASES = [
+ pytest.param(
+ 6,
+ 6,
+ 0.5,
+ 0.0,
+ 0.5,
+ np.array(
+ [
+ [0.0, 0.0, 1.0, 1.0],
+ [0.0, 0.0, 1.0, 1.0],
+ [0.0, 0.1, 1.0, 1.1],
+ [0.0, 0.0, 1.0, 0.9],
+ [0.5, 0.5, 1.5, 1.5],
+ [0.0, 0.0, 0.3, 0.3],
+ ],
+ dtype=np.float32,
+ ),
+ np.array([0.9, 0.75, 0.6, 0.5, 0.4, 0.3], dtype=np.float32),
+ id="soft_nms_basic",
+ ),
+ pytest.param(
+ 5,
+ 5,
+ 0.5,
+ 0.0,
+ 0.3,
+ np.array(
+ [
+ [0.0, 0.0, 1.0, 1.0],
+ [0.1, 0.1, 1.1, 1.1],
+ [0.2, 0.2, 1.2, 1.2],
+ [0.3, 0.3, 1.3, 1.3],
+ [2.0, 2.0, 3.0, 3.0],
+ ],
+ dtype=np.float32,
+ ),
+ np.array([0.9, 0.8, 0.7, 0.6, 0.5], dtype=np.float32),
+ id="soft_nms_tight_sigma",
+ ),
+ pytest.param(
+ 3,
+ 3,
+ 0.5,
+ 0.3,
+ 0.1,
+ np.array(
+ [
+ [0.0, 0.0, 1.0, 1.0],
+ [0.2, 0.2, 1.2, 1.2],
+ [2.0, 2.0, 3.0, 3.0],
+ ],
+ dtype=np.float32,
+ ),
+ np.array([0.9, 0.8, 0.75], dtype=np.float32),
+ id="soft_nms_threshold_hole",
+ ),
+ pytest.param(
+ 3,
+ 3,
+ 0.5,
+ 0.0,
+ 0.1,
+ np.array(
+ [
+ [0.0, 0.0, 1.0, 1.0],
+ [0.2, 0.2, 1.2, 1.2],
+ [2.0, 2.0, 3.0, 3.0],
+ ],
+ dtype=np.float32,
+ ),
+ np.array([0.9, 0.85, 0.8], dtype=np.float32),
+ id="soft_nms_reorder",
+ ),
+]
+
+
@pytest.mark.parametrize(
"num_boxes,max_output_size,iou_threshold,score_threshold,boxes,scores",
_NMS_V5_CASES,
@@ -1657,6 +1744,20 @@ def test_nms_v5(num_boxes, max_output_size,
iou_threshold, score_threshold, boxe
_verify_nms_v5(mod, tf_func, boxes, scores)
[email protected](
+
"num_boxes,max_output_size,iou_threshold,score_threshold,soft_nms_sigma,boxes,scores",
+ _NMS_V5_SOFT_CASES,
+)
+def test_nms_v5_soft(
+ num_boxes, max_output_size, iou_threshold, score_threshold,
soft_nms_sigma, boxes, scores
+):
+ """NON_MAX_SUPPRESSION_V5 with soft_nms_sigma: conversion smoke test + E2E
correctness."""
+ mod, tf_func = _build_nms_v5_mod(
+ num_boxes, max_output_size, iou_threshold, score_threshold,
soft_nms_sigma
+ )
+ _verify_nms_v5(mod, tf_func, boxes, scores, soft_nms_sigma=soft_nms_sigma)
+
+
def test_nms_v5_ir():
"""Verify the emitted Relax IR has correct structure for
NON_MAX_SUPPRESSION_V5."""
num_boxes = 6
@@ -1681,6 +1782,28 @@ def test_nms_v5_ir():
assert f"R.Tensor(({max_output_size},)" in ir
+def test_nms_v5_soft_ir():
+ """Verify the emitted Relax IR passes soft_nms_sigma for
NON_MAX_SUPPRESSION_V5."""
+ num_boxes = 6
+ max_output_size = 3
+ mod, _ = _build_nms_v5_mod(
+ num_boxes=num_boxes,
+ max_output_size=max_output_size,
+ iou_threshold=0.5,
+ score_threshold=0.0,
+ soft_nms_sigma=0.5,
+ )
+
+ ir = mod.script()
+
+ # soft_nms_sigma must appear in the IR
+ assert "soft_nms_sigma=0.5" in ir
+ # score_threshold must also be forwarded
+ assert "score_threshold=0.0" in ir
+ # Soft-NMS padded scores must be clipped to non-negative values.
+ assert "R.clip(" in ir
+
+
_DETECTION_POSTPROCESS_SMOKE_CASES = [
pytest.param(
{
diff --git a/tests/python/relax/test_op_vision.py
b/tests/python/relax/test_op_vision.py
index b597b325f4..ef260cf188 100644
--- a/tests/python/relax/test_op_vision.py
+++ b/tests/python/relax/test_op_vision.py
@@ -302,6 +302,26 @@ def test_nms_infer_struct_info_return_indices():
)
+def test_nms_infer_struct_info_return_indices_soft_nms():
+ bb = relax.BlockBuilder()
+ data = relax.Var("data", R.Tensor((2, 10, 6), "float32"))
+ valid_count = relax.Var("valid_count", R.Tensor((2,), "int32"))
+ indices = relax.Var("indices", R.Tensor((2, 10), "int32"))
+ _check_inference(
+ bb,
+ relax.op.vision.non_max_suppression(
+ data, valid_count, indices, return_indices=True, soft_nms_sigma=0.5
+ ),
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo((2, 10, 6), "float32"),
+ relax.TensorStructInfo((2, 10), "int32"),
+ relax.TensorStructInfo((2, 1), "int32"),
+ ]
+ ),
+ )
+
+
def test_nms_infer_struct_info_return_data():
bb = relax.BlockBuilder()
data = relax.Var("data", R.Tensor((2, 10, 6), "float32"))
@@ -457,6 +477,52 @@ def test_nms_legalize():
id_index=0,
return_indices=True,
invalid_to_bottom=False,
+ soft_nms_sigma=0.0,
+ score_threshold=0.0,
+ )
+ return gv
+
+ mod = LegalizeOps()(NMS)
+ _assert_relax_op_legalized(mod, "relax.vision.non_max_suppression")
+ tvm.ir.assert_structural_equal(
+ mod["main"].ret_struct_info,
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo((1, 5), "int32"),
+ relax.TensorStructInfo((1, 1), "int32"),
+ ]
+ ),
+ )
+
+
+def test_nms_legalize_soft_nms():
+ @tvm.script.ir_module
+ class NMS:
+ @R.function
+ def main(
+ data: R.Tensor((1, 5, 6), "float32"),
+ valid_count: R.Tensor((1,), "int32"),
+ indices: R.Tensor((1, 5), "int32"),
+ ) -> R.Tuple(
+ R.Tensor((1, 5, 6), "float32"),
+ R.Tensor((1, 5), "int32"),
+ R.Tensor((1, 1), "int32"),
+ ):
+ gv = R.vision.non_max_suppression(
+ data,
+ valid_count,
+ indices,
+ max_output_size=-1,
+ iou_threshold=0.5,
+ force_suppress=False,
+ top_k=-1,
+ coord_start=2,
+ score_index=1,
+ id_index=0,
+ return_indices=True,
+ invalid_to_bottom=False,
+ soft_nms_sigma=0.5,
+ score_threshold=0.0,
)
return gv
@@ -466,6 +532,7 @@ def test_nms_legalize():
mod["main"].ret_struct_info,
relax.TupleStructInfo(
[
+ relax.TensorStructInfo((1, 5, 6), "float32"),
relax.TensorStructInfo((1, 5), "int32"),
relax.TensorStructInfo((1, 1), "int32"),
]
@@ -495,6 +562,8 @@ def test_nms_legalize_return_data():
id_index=0,
return_indices=False,
invalid_to_bottom=True,
+ soft_nms_sigma=0.0,
+ score_threshold=0.0,
)
return gv
@@ -577,6 +646,8 @@ def _run_nms_e2e(
id_index: int = 0,
return_indices: bool = True,
invalid_to_bottom: bool = False,
+ soft_nms_sigma: float = 0.0,
+ score_threshold: float = 0.0,
):
"""Run classic NMS through legalization and VM execution."""
@@ -603,6 +674,8 @@ def _run_nms_e2e(
id_index=id_index,
return_indices=return_indices,
invalid_to_bottom=invalid_to_bottom,
+ soft_nms_sigma=soft_nms_sigma,
+ score_threshold=score_threshold,
)
)
bb.emit_func_output(result)
@@ -660,6 +733,57 @@ def test_nms_e2e_return_indices():
tvm.testing.assert_allclose(result[1].numpy(), ref_valid_box_count)
[email protected]_llvm
+def test_nms_e2e_soft_nms_reorders_by_decayed_score():
+ """Soft-NMS should re-rank by decayed scores instead of keeping the
initial order."""
+
+ raw_data = np.array(
+ [
+ [
+ [0.0, 0.90, 0.0, 0.0, 1.0, 1.0],
+ [0.0, 0.85, 0.2, 0.2, 1.2, 1.2],
+ [0.0, 0.80, 2.0, 2.0, 3.0, 3.0],
+ [-1.0, 0.99, 0.0, 0.0, 1.0, 1.0],
+ ]
+ ],
+ dtype="float32",
+ )
+ valid_count_np, filtered_data_np, filtered_indices_np =
_prepare_nms_inputs(raw_data)
+ ref_out_data, ref_indices, ref_valid_box_count =
tvm.topi.testing.non_max_suppression_python(
+ filtered_data_np,
+ valid_count_np,
+ filtered_indices_np,
+ max_output_size=-1,
+ iou_threshold=0.5,
+ force_suppress=True,
+ top_k=-1,
+ coord_start=2,
+ score_index=1,
+ id_index=-1,
+ return_indices=True,
+ invalid_to_bottom=False,
+ soft_nms_sigma=0.1,
+ score_threshold=0.0,
+ )
+ result = _run_nms_e2e(
+ filtered_data_np,
+ valid_count_np,
+ filtered_indices_np,
+ iou_threshold=0.5,
+ force_suppress=True,
+ id_index=-1,
+ return_indices=True,
+ invalid_to_bottom=False,
+ soft_nms_sigma=0.1,
+ score_threshold=0.0,
+ )
+
+ np.testing.assert_array_equal(ref_indices[0, :3], np.array([0, 2, 1],
dtype="int32"))
+ tvm.testing.assert_allclose(result[0].numpy(), ref_out_data)
+ tvm.testing.assert_allclose(result[1].numpy(), ref_indices)
+ tvm.testing.assert_allclose(result[2].numpy(), ref_valid_box_count)
+
+
@tvm.testing.requires_llvm
def test_nms_e2e_return_indices_with_invalid_to_bottom():
"""Validate that invalid_to_bottom is a no-op when returning indices."""
diff --git a/tests/python/relax/test_tvmscript_parser_op_vision.py
b/tests/python/relax/test_tvmscript_parser_op_vision.py
index 370b68769e..d4755ee367 100644
--- a/tests/python/relax/test_tvmscript_parser_op_vision.py
+++ b/tests/python/relax/test_tvmscript_parser_op_vision.py
@@ -126,6 +126,8 @@ def test_non_max_suppression_return_indices():
id_index=0,
return_indices=True,
invalid_to_bottom=False,
+ soft_nms_sigma=0.0,
+ score_threshold=0.0,
)
)
return gv
@@ -150,6 +152,70 @@ def test_non_max_suppression_return_indices():
id_index=0,
return_indices=True,
invalid_to_bottom=False,
+ soft_nms_sigma=0.0,
+ score_threshold=0.0,
+ )
+ )
+ bb.emit_func_output(gv)
+
+ _check(foo, bb.get()["foo"])
+
+
+def test_non_max_suppression_return_indices_soft_nms():
+ @R.function
+ def foo(
+ data: R.Tensor((2, 5, 6), "float32"),
+ valid_count: R.Tensor((2,), "int32"),
+ indices: R.Tensor((2, 5), "int32"),
+ ) -> R.Tuple(
+ R.Tensor((2, 5, 6), "float32"),
+ R.Tensor((2, 5), "int32"),
+ R.Tensor((2, 1), "int32"),
+ ):
+ gv: R.Tuple(
+ R.Tensor((2, 5, 6), "float32"),
+ R.Tensor((2, 5), "int32"),
+ R.Tensor((2, 1), "int32"),
+ ) = R.vision.non_max_suppression(
+ data,
+ valid_count,
+ indices,
+ max_output_size=-1,
+ iou_threshold=0.5,
+ force_suppress=False,
+ top_k=3,
+ coord_start=2,
+ score_index=1,
+ id_index=0,
+ return_indices=True,
+ invalid_to_bottom=False,
+ soft_nms_sigma=0.5,
+ score_threshold=0.0,
+ )
+ return gv
+
+ data = relax.Var("data", R.Tensor((2, 5, 6), "float32"))
+ valid_count = relax.Var("valid_count", R.Tensor((2,), "int32"))
+ indices = relax.Var("indices", R.Tensor((2, 5), "int32"))
+
+ bb = relax.BlockBuilder()
+ with bb.function("foo", [data, valid_count, indices]):
+ gv = bb.emit(
+ relax.op.vision.non_max_suppression(
+ data,
+ valid_count,
+ indices,
+ max_output_size=-1,
+ iou_threshold=0.5,
+ force_suppress=False,
+ top_k=3,
+ coord_start=2,
+ score_index=1,
+ id_index=0,
+ return_indices=True,
+ invalid_to_bottom=False,
+ soft_nms_sigma=0.5,
+ score_threshold=0.0,
)
)
bb.emit_func_output(gv)
@@ -177,6 +243,8 @@ def test_non_max_suppression_return_data():
id_index=0,
return_indices=False,
invalid_to_bottom=True,
+ soft_nms_sigma=0.0,
+ score_threshold=0.0,
)
return gv
@@ -200,6 +268,8 @@ def test_non_max_suppression_return_data():
id_index=0,
return_indices=False,
invalid_to_bottom=True,
+ soft_nms_sigma=0.0,
+ score_threshold=0.0,
)
)
bb.emit_func_output(gv)