gemini-code-assist[bot] commented on code in PR #18626:
URL: https://github.com/apache/tvm/pull/18626#discussion_r2652665467
##########
python/tvm/relax/transform/legalize_ops/statistical.py:
##########
@@ -53,6 +53,42 @@ def _te_variance(x: te.Tensor, axis: List[tir.IntImm],
keepdims: bool) -> te.Ten
# return _te_mean(x * x, axis, keepdims) - mean * mean
+def _te_median(x: te.Tensor, axis: List[tir.IntImm], keepdims: bool) ->
te.Tensor:
Review Comment:

The return type hint for `_te_median` is `te.Tensor`, but the function can
return a `Tuple[te.Tensor, te.Tensor]` when an axis is provided. This should be
updated to reflect the actual return types for better type safety and clarity.
You'll also need to add `from typing import Union, Tuple` at the top of the
file.
```suggestion
def _te_median(x: te.Tensor, axis: List[tir.IntImm], keepdims: bool) ->
Union[te.Tensor, Tuple[te.Tensor, te.Tensor]]:
```
##########
src/relax/op/tensor/statistical.cc:
##########
@@ -180,6 +180,66 @@ StructInfo InferStructInfoScan(const Call& call, const
BlockBuilder& ctx) {
}
}
+StructInfo InferStructInfoStatisticalExtension(const Call& call, const
BlockBuilder& ctx) {
+ TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+ const auto* attrs = call->attrs.as<StatisticalAttrs>();
+
+ std::vector<int> axes;
+ if (!data_sinfo->IsUnknownNdim() && attrs->axis.defined()) {
+ axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axis.value());
+ }
+
+ int out_ndim;
+ if (attrs->keepdims) {
+ out_ndim = data_sinfo->ndim;
+ } else if (!attrs->axis.defined()) {
+ out_ndim = 0;
+ } else if (data_sinfo->IsUnknownNdim()) {
+ out_ndim = kUnknownNDim;
+ } else {
+ out_ndim = data_sinfo->ndim - axes.size();
+ ICHECK_GE(out_ndim, 0);
+ }
+
+ // The inference rule for median operator output shapes:
+ // - axes is None || len(axes) > 1, keepdims is false -> return the
zero-rank shape;
+ // - axes is None || len(axes) > 1, keepdims is true -> return the shape
whose ndim
+ // is the same as input and every value is 1.
+ // - len(axes) == 1, keepdims is false -> the returned shape does not
contain the input axis.
+ // - len(axes) == 1, keepdims is true -> the returned shape has value 1 at
the positions of the
+ // input axis
+ const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
+ if (data_shape == nullptr) {
+ if (!attrs->axis.defined() && attrs->keepdims && out_ndim != kUnknownNDim)
{
+ return TensorStructInfo(
+ ShapeExpr(ffi::Array<PrimExpr>(out_ndim, IntImm(DataType::Int(64),
/*value=*/1))),
+ data_sinfo->dtype, data_sinfo->vdevice);
+ } else {
+ return out_ndim == 0 ?
TensorStructInfo(ShapeExpr(ffi::Array<PrimExpr>()), data_sinfo->dtype,
+ data_sinfo->vdevice)
+ : TensorStructInfo(data_sinfo->dtype, out_ndim,
data_sinfo->vdevice);
+ }
+ }
+
+ ffi::Array<PrimExpr> out_shape;
+ out_shape.reserve(out_ndim);
+ for (int i = 0; i < data_sinfo->ndim; ++i) {
+ if (attrs->axis.defined() && std::find(axes.begin(), axes.end(), i) ==
axes.end()) {
+ out_shape.push_back(data_shape->values[i]);
+ } else if (attrs->keepdims) {
+ out_shape.push_back(IntImm(DataType::Int(64), /*value=*/1));
+ }
+ }
+ ICHECK_EQ(static_cast<int>(out_shape.size()), out_ndim);
+
+ if (!attrs->axis.defined() || axes.size() > 1)
+ return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype,
data_sinfo->vdevice);
+ else
+ return TupleStructInfo({
+ TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype,
data_sinfo->vdevice),
+ TensorStructInfo(ShapeExpr(out_shape), DataType::Int(64),
data_sinfo->vdevice)});
+}
Review Comment:

There's a critical correctness issue in
`InferStructInfoStatisticalExtension`. When the input tensor has an unknown
shape (`data_shape == nullptr`), the function incorrectly returns a
`TensorStructInfo` for `relax.median` even when `axis` is specified with a
single dimension. According to PyTorch's behavior and the op's design, it
should return a `TupleStructInfo` containing both values and indices.
This happens because of an early return within the `if (data_shape ==
nullptr)` block, which doesn't account for the tuple-return case.
To fix this, the logic for returning a `TupleStructInfo` when a single axis
is provided needs to be applied regardless of whether the input shape is known
or not. A potential fix would be to refactor the function to determine if a
tuple should be returned at the beginning, calculate the `value_sinfo`, and
then wrap it in a `TupleStructInfo` if needed before returning.
##########
python/tvm/relax/transform/legalize_ops/statistical.py:
##########
@@ -53,6 +53,42 @@ def _te_variance(x: te.Tensor, axis: List[tir.IntImm],
keepdims: bool) -> te.Ten
# return _te_mean(x * x, axis, keepdims) - mean * mean
+def _te_median(x: te.Tensor, axis: List[tir.IntImm], keepdims: bool) ->
te.Tensor:
+ # currently only supports one axis or no axis ~ same pytorch
+ # todo: support multiple axis ~ same numpy
+ shape_prod = _compute_shape_prod(x, axis)
+ mid_index = shape_prod // 2
+ if shape_prod % 2 == 1:
+ mid_index = mid_index
+ else:
+ mid_index = mid_index - 1
Review Comment:

The logic for calculating `mid_index` can be simplified. The current
implementation uses a verbose if/else block that can be replaced with a single,
more concise expression that handles both odd and even `shape_prod` values
correctly.
```suggestion
mid_index = (shape_prod - 1) // 2
```
##########
tests/python/relax/test_op_statistical.py:
##########
@@ -275,5 +276,145 @@ def
test_scan_opinfer_struct_info_wrong_input_type(scan_op: Callable):
bb.normalize(scan_op(x1, axis=1))
+def
test_test_statistical_ext_infer_struct_info_wrong_input_type_infer_struct_info():
+ bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
+ x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
+ x1 = relax.Var("x", R.Tensor("float32", ndim=4))
+ x2 = relax.Var("x", R.Tensor("float32"))
+ x3 = relax.Var("x", R.Tensor((2, 3, 4, 5)))
+ x4 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32", vdev0))
+
+ _check_inference(
+ bb,
+ relax.op.median(x0, axis=[1]),
+ relax.TupleStructInfo([relax.TensorStructInfo((2, 4, 5), "float32"),
+ relax.TensorStructInfo((2, 4, 5), "int64")]),
+ )
+ _check_inference(
+ bb,
+ relax.op.median(x0, axis=[1], keepdims=True),
+ relax.TupleStructInfo([relax.TensorStructInfo((2, 1, 4, 5), "float32"),
+ relax.TensorStructInfo((2, 1, 4, 5), "int64")]),
+ )
+ _check_inference(
+ bb,
+ relax.op.median(x0, axis=[1]),
+ relax.TupleStructInfo([relax.TensorStructInfo((2, 4, 5), "float32"),
+ relax.TensorStructInfo((2, 4, 5), "int64")]),
+ )
+ _check_inference(
+ bb,
+ relax.op.median(x1, axis=[1], keepdims=True),
+ relax.TensorStructInfo(dtype="float32", ndim=4),
+ )
+ _check_inference(
+ bb, relax.op.median(x2, axis=[1]),
relax.TensorStructInfo(dtype="float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.median(x2, axis=[1], keepdims=True),
+ relax.TensorStructInfo(dtype="float32"),
+ )
+ _check_inference(bb, relax.op.median(x2, axis=None),
relax.TensorStructInfo((), "float32"))
+ _check_inference(
+ bb,
+ relax.op.median(x1, axis=None, keepdims=True),
+ relax.TensorStructInfo((1, 1, 1, 1), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.median(x4, axis=[1]),
+ relax.TupleStructInfo([relax.TensorStructInfo((2, 4, 5), "float32",
vdev0),
+ relax.TensorStructInfo((2, 4, 5), "int64",
vdev0)]),
+ )
+ _check_inference(
+ bb,
+ relax.op.median(x3, axis=[1], keepdims=True),
+ relax.TupleStructInfo([relax.TensorStructInfo((2, 1, 4, 5), dtype=""),
+ relax.TensorStructInfo((2, 1, 4, 5),
dtype="int64")])
+ )
+ _check_inference(bb, relax.op.median(x3, axis=None),
relax.TensorStructInfo((), dtype=""))
+ _check_inference(
+ bb,
+ relax.op.median(x3, axis=None, keepdims=True),
+ relax.TensorStructInfo((1, 1, 1, 1), dtype=""),
+ )
Review Comment:

This new test function for `median` has a few issues:
1. The function name
`test_test_statistical_ext_infer_struct_info_wrong_input_type_infer_struct_info`
seems to be a copy-paste error. It should be renamed to something more
descriptive, like `test_median_infer_struct_info`.
2. There's a duplicated test case for `relax.op.median(x0, axis=[1])`. One
of them can be removed.
3. Several checks are validating incorrect behavior caused by a bug in
`InferStructInfoStatisticalExtension`. When `median` is called with a single
axis, it should return a `TupleStructInfo`, but the tests for `x1` and `x2`
expect a `TensorStructInfo`. These tests should be corrected to expect a
`TupleStructInfo` once the underlying bug is fixed.
```python
def test_median_infer_struct_info():
bb = relax.BlockBuilder()
vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=4))
x2 = relax.Var("x", R.Tensor("float32"))
x3 = relax.Var("x", R.Tensor((2, 3, 4, 5)))
x4 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32", vdev0))
_check_inference(
bb,
relax.op.median(x0, axis=[1]),
relax.TupleStructInfo([relax.TensorStructInfo((2, 4, 5), "float32"),
relax.TensorStructInfo((2, 4, 5), "int64")]),
)
_check_inference(
bb,
relax.op.median(x0, axis=[1], keepdims=True),
relax.TupleStructInfo([relax.TensorStructInfo((2, 1, 4, 5),
"float32"),
relax.TensorStructInfo((2, 1, 4, 5),
"int64")]),
)
_check_inference(
bb,
relax.op.median(x1, axis=[1], keepdims=True),
relax.TupleStructInfo([relax.TensorStructInfo(dtype="float32",
ndim=4),
relax.TensorStructInfo(dtype="int64",
ndim=4)]),
)
_check_inference(
bb, relax.op.median(x2, axis=[1]),
relax.TupleStructInfo([relax.TensorStructInfo(dtype="float32"),
relax.TensorStructInfo(dtype="int64")])
)
_check_inference(
bb,
relax.op.median(x2, axis=[1], keepdims=True),
relax.TupleStructInfo([relax.TensorStructInfo(dtype="float32"),
relax.TensorStructInfo(dtype="int64")])
)
_check_inference(bb, relax.op.median(x2, axis=None),
relax.TensorStructInfo((), "float32"))
_check_inference(
bb,
relax.op.median(x1, axis=None, keepdims=True),
relax.TensorStructInfo((1, 1, 1, 1), "float32"),
)
_check_inference(
bb,
relax.op.median(x4, axis=[1]),
relax.TupleStructInfo([relax.TensorStructInfo((2, 4, 5), "float32",
vdev0),
relax.TensorStructInfo((2, 4, 5), "int64",
vdev0)]),
)
_check_inference(
bb,
relax.op.median(x3, axis=[1], keepdims=True),
relax.TupleStructInfo([relax.TensorStructInfo((2, 1, 4, 5),
dtype=""),
relax.TensorStructInfo((2, 1, 4, 5),
dtype="int64")])
)
_check_inference(bb, relax.op.median(x3, axis=None),
relax.TensorStructInfo((), dtype=""))
_check_inference(
bb,
relax.op.median(x3, axis=None, keepdims=True),
relax.TensorStructInfo((1, 1, 1, 1), dtype=""),
)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]