gemini-code-assist[bot] commented on code in PR #18626:
URL: https://github.com/apache/tvm/pull/18626#discussion_r2652665467


##########
python/tvm/relax/transform/legalize_ops/statistical.py:
##########
@@ -53,6 +53,42 @@ def _te_variance(x: te.Tensor, axis: List[tir.IntImm], 
keepdims: bool) -> te.Ten
     # return _te_mean(x * x, axis, keepdims) - mean * mean
 
 
+def _te_median(x: te.Tensor, axis: List[tir.IntImm], keepdims: bool) -> 
te.Tensor:

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The return type hint for `_te_median` is `te.Tensor`, but the function can 
return a `Tuple[te.Tensor, te.Tensor]` when an axis is provided. This should be 
updated to reflect the actual return types for better type safety and clarity. 
You'll also need to add `from typing import Union, Tuple` at the top of the 
file.
   
   ```suggestion
   def _te_median(x: te.Tensor, axis: List[tir.IntImm], keepdims: bool) -> 
Union[te.Tensor, Tuple[te.Tensor, te.Tensor]]:
   ```



##########
src/relax/op/tensor/statistical.cc:
##########
@@ -180,6 +180,66 @@ StructInfo InferStructInfoScan(const Call& call, const 
BlockBuilder& ctx) {
   }
 }
 
+StructInfo InferStructInfoStatisticalExtension(const Call& call, const 
BlockBuilder& ctx) {
+  TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+  const auto* attrs = call->attrs.as<StatisticalAttrs>();
+
+  std::vector<int> axes;
+  if (!data_sinfo->IsUnknownNdim() && attrs->axis.defined()) {
+    axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axis.value());
+  }
+
+  int out_ndim;
+  if (attrs->keepdims) {
+    out_ndim = data_sinfo->ndim;
+  } else if (!attrs->axis.defined()) {
+    out_ndim = 0;
+  } else if (data_sinfo->IsUnknownNdim()) {
+    out_ndim = kUnknownNDim;
+  } else {
+    out_ndim = data_sinfo->ndim - axes.size();
+    ICHECK_GE(out_ndim, 0);
+  }
+
+  // The inference rule for median operator output shapes:
+  // - axes is None || len(axes) > 1, keepdims is false -> return the 
zero-rank shape;
+  // - axes is None || len(axes) > 1, keepdims is true -> return the shape 
whose ndim
+  // is the same as input and every value is 1.
+  // - len(axes) == 1, keepdims is false -> the returned shape does not 
contain the input axis.
+  // - len(axes) == 1, keepdims is true -> the returned shape has value 1 at 
the positions of the
+  // input axis
+  const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
+  if (data_shape == nullptr) {
+    if (!attrs->axis.defined() && attrs->keepdims && out_ndim != kUnknownNDim) 
{
+      return TensorStructInfo(
+          ShapeExpr(ffi::Array<PrimExpr>(out_ndim, IntImm(DataType::Int(64), 
/*value=*/1))),
+          data_sinfo->dtype, data_sinfo->vdevice);
+    } else {
+      return out_ndim == 0 ? 
TensorStructInfo(ShapeExpr(ffi::Array<PrimExpr>()), data_sinfo->dtype,
+                                              data_sinfo->vdevice)
+                           : TensorStructInfo(data_sinfo->dtype, out_ndim, 
data_sinfo->vdevice);
+    }
+  }
+
+  ffi::Array<PrimExpr> out_shape;
+  out_shape.reserve(out_ndim);
+  for (int i = 0; i < data_sinfo->ndim; ++i) {
+    if (attrs->axis.defined() && std::find(axes.begin(), axes.end(), i) == 
axes.end()) {
+      out_shape.push_back(data_shape->values[i]);
+    } else if (attrs->keepdims) {
+      out_shape.push_back(IntImm(DataType::Int(64), /*value=*/1));
+    }
+  }
+  ICHECK_EQ(static_cast<int>(out_shape.size()), out_ndim);
+
+  if (!attrs->axis.defined() || axes.size() > 1)
+    return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, 
data_sinfo->vdevice);
+  else
+    return TupleStructInfo({
+        TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, 
data_sinfo->vdevice),
+        TensorStructInfo(ShapeExpr(out_shape), DataType::Int(64), 
data_sinfo->vdevice)});
+}

Review Comment:
   ![critical](https://www.gstatic.com/codereviewagent/critical.svg)
   
   There's a critical correctness issue in 
`InferStructInfoStatisticalExtension`. When the input tensor has an unknown 
shape (`data_shape == nullptr`), the function incorrectly returns a 
`TensorStructInfo` for `relax.median` even when `axis` is specified with a 
single dimension. According to PyTorch's behavior and the op's design, it 
should return a `TupleStructInfo` containing both values and indices.
   
   This happens because of an early return within the `if (data_shape == 
nullptr)` block, which doesn't account for the tuple-return case.
   
   To fix this, the logic for returning a `TupleStructInfo` when a single axis 
is provided needs to be applied regardless of whether the input shape is known 
or not. A potential fix would be to refactor the function to determine if a 
tuple should be returned at the beginning, calculate the `value_sinfo`, and 
then wrap it in a `TupleStructInfo` if needed before returning.



##########
python/tvm/relax/transform/legalize_ops/statistical.py:
##########
@@ -53,6 +53,42 @@ def _te_variance(x: te.Tensor, axis: List[tir.IntImm], 
keepdims: bool) -> te.Ten
     # return _te_mean(x * x, axis, keepdims) - mean * mean
 
 
+def _te_median(x: te.Tensor, axis: List[tir.IntImm], keepdims: bool) -> 
te.Tensor:
+    # currently only supports one axis or no axis ~ same pytorch
+    # todo: support multiple axis ~ same numpy
+    shape_prod = _compute_shape_prod(x, axis)
+    mid_index = shape_prod // 2
+    if shape_prod % 2 == 1:
+        mid_index = mid_index
+    else:
+        mid_index = mid_index - 1

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The logic for calculating `mid_index` can be simplified. The current 
implementation uses a verbose if/else block that can be replaced with a single, 
more concise expression that handles both odd and even `shape_prod` values 
correctly.
   
   ```suggestion
       mid_index = (shape_prod - 1) // 2
   ```



##########
tests/python/relax/test_op_statistical.py:
##########
@@ -275,5 +276,145 @@ def 
test_scan_opinfer_struct_info_wrong_input_type(scan_op: Callable):
         bb.normalize(scan_op(x1, axis=1))
 
 
+def 
test_test_statistical_ext_infer_struct_info_wrong_input_type_infer_struct_info():
+    bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
+    x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=4))
+    x2 = relax.Var("x", R.Tensor("float32"))
+    x3 = relax.Var("x", R.Tensor((2, 3, 4, 5)))
+    x4 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32", vdev0))
+
+    _check_inference(
+        bb,
+        relax.op.median(x0, axis=[1]),
+        relax.TupleStructInfo([relax.TensorStructInfo((2, 4, 5), "float32"),
+                               relax.TensorStructInfo((2, 4, 5), "int64")]),
+    )
+    _check_inference(
+        bb,
+        relax.op.median(x0, axis=[1], keepdims=True),
+        relax.TupleStructInfo([relax.TensorStructInfo((2, 1, 4, 5), "float32"),
+                               relax.TensorStructInfo((2, 1, 4, 5), "int64")]),
+    )
+    _check_inference(
+        bb,
+        relax.op.median(x0, axis=[1]),
+        relax.TupleStructInfo([relax.TensorStructInfo((2, 4, 5), "float32"),
+                               relax.TensorStructInfo((2, 4, 5), "int64")]),
+    )
+    _check_inference(
+        bb,
+        relax.op.median(x1, axis=[1], keepdims=True),
+        relax.TensorStructInfo(dtype="float32", ndim=4),
+    )
+    _check_inference(
+        bb, relax.op.median(x2, axis=[1]), 
relax.TensorStructInfo(dtype="float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.median(x2, axis=[1], keepdims=True),
+        relax.TensorStructInfo(dtype="float32"),
+    )
+    _check_inference(bb, relax.op.median(x2, axis=None), 
relax.TensorStructInfo((), "float32"))
+    _check_inference(
+        bb,
+        relax.op.median(x1, axis=None, keepdims=True),
+        relax.TensorStructInfo((1, 1, 1, 1), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.median(x4, axis=[1]),
+        relax.TupleStructInfo([relax.TensorStructInfo((2, 4, 5), "float32", 
vdev0),
+                               relax.TensorStructInfo((2, 4, 5), "int64", 
vdev0)]),
+    )
+    _check_inference(
+        bb,
+        relax.op.median(x3, axis=[1], keepdims=True),
+        relax.TupleStructInfo([relax.TensorStructInfo((2, 1, 4, 5), dtype=""),
+                               relax.TensorStructInfo((2, 1, 4, 5), 
dtype="int64")])
+    )
+    _check_inference(bb, relax.op.median(x3, axis=None), 
relax.TensorStructInfo((), dtype=""))
+    _check_inference(
+        bb,
+        relax.op.median(x3, axis=None, keepdims=True),
+        relax.TensorStructInfo((1, 1, 1, 1), dtype=""),
+    )

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   This new test function for `median` has a few issues:
   1. The function name 
`test_test_statistical_ext_infer_struct_info_wrong_input_type_infer_struct_info`
 seems to be a copy-paste error. It should be renamed to something more 
descriptive, like `test_median_infer_struct_info`.
   2. There's a duplicated test case for `relax.op.median(x0, axis=[1])`. One 
of them can be removed.
   3. Several checks are validating incorrect behavior caused by a bug in 
`InferStructInfoStatisticalExtension`. When `median` is called with a single 
axis, it should return a `TupleStructInfo`, but the tests for `x1` and `x2` 
expect a `TensorStructInfo`. These tests should be corrected to expect a 
`TupleStructInfo` once the underlying bug is fixed.
   
   ```python
   def test_median_infer_struct_info():
       bb = relax.BlockBuilder()
       vdev0 = VDevice("llvm")
       x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
       x1 = relax.Var("x", R.Tensor("float32", ndim=4))
       x2 = relax.Var("x", R.Tensor("float32"))
       x3 = relax.Var("x", R.Tensor((2, 3, 4, 5)))
       x4 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32", vdev0))
   
       _check_inference(
           bb,
           relax.op.median(x0, axis=[1]),
           relax.TupleStructInfo([relax.TensorStructInfo((2, 4, 5), "float32"),
                                  relax.TensorStructInfo((2, 4, 5), "int64")]),
       )
       _check_inference(
           bb,
           relax.op.median(x0, axis=[1], keepdims=True),
           relax.TupleStructInfo([relax.TensorStructInfo((2, 1, 4, 5), 
"float32"),
                                  relax.TensorStructInfo((2, 1, 4, 5), 
"int64")]),
       )
       _check_inference(
           bb,
           relax.op.median(x1, axis=[1], keepdims=True),
           relax.TupleStructInfo([relax.TensorStructInfo(dtype="float32", 
ndim=4),
                                  relax.TensorStructInfo(dtype="int64", 
ndim=4)]),
       )
       _check_inference(
           bb, relax.op.median(x2, axis=[1]),
           relax.TupleStructInfo([relax.TensorStructInfo(dtype="float32"),
                                  relax.TensorStructInfo(dtype="int64")])
       )
       _check_inference(
           bb,
           relax.op.median(x2, axis=[1], keepdims=True),
           relax.TupleStructInfo([relax.TensorStructInfo(dtype="float32"),
                                  relax.TensorStructInfo(dtype="int64")])
       )
       _check_inference(bb, relax.op.median(x2, axis=None), 
relax.TensorStructInfo((), "float32"))
       _check_inference(
           bb,
           relax.op.median(x1, axis=None, keepdims=True),
           relax.TensorStructInfo((1, 1, 1, 1), "float32"),
       )
       _check_inference(
           bb,
           relax.op.median(x4, axis=[1]),
           relax.TupleStructInfo([relax.TensorStructInfo((2, 4, 5), "float32", 
vdev0),
                                  relax.TensorStructInfo((2, 4, 5), "int64", 
vdev0)]),
       )
       _check_inference(
           bb,
           relax.op.median(x3, axis=[1], keepdims=True),
           relax.TupleStructInfo([relax.TensorStructInfo((2, 1, 4, 5), 
dtype=""),
                                  relax.TensorStructInfo((2, 1, 4, 5), 
dtype="int64")])
       )
       _check_inference(bb, relax.op.median(x3, axis=None), 
relax.TensorStructInfo((), dtype=""))
       _check_inference(
           bb,
           relax.op.median(x3, axis=None, keepdims=True),
           relax.TensorStructInfo((1, 1, 1, 1), dtype=""),
       )
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to