gemini-code-assist[bot] commented on code in PR #18880:
URL: https://github.com/apache/tvm/pull/18880#discussion_r2894270018


##########
python/tvm/topi/math.py:
##########
@@ -18,13 +18,19 @@
 
 # pylint: disable=redefined-builtin,unused-argument
 import tvm
-from tvm import te
+from tvm import DataType, DataTypeCode, te
 from tvm.tir import PrimExpr
 
 from . import cpp, tag
 from .utils import get_const_tuple
 
 
+def _require_float_tensor(op_name, x):
+    if DataType(x.dtype).type_code not in (DataTypeCode.FLOAT, 
DataTypeCode.BFLOAT):
+        raise TypeError(f"topi.{op_name} only supports floating-point inputs, 
but got {x.dtype}")
+    return x

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   This is a useful helper function for enforcing float-only inputs. For 
consistency and to prevent future issues, consider applying this check to other 
inverse trigonometric and hyperbolic functions that currently lack input type 
validation, such as `acosh`, `asinh`, and `atanh`. This would ensure they also 
handle non-floating-point types gracefully by raising a clear error.
   
   For example, `acosh` could be updated as follows:
   ```python
   @tvm.te.tag_scope(tag=tag.ELEMWISE)
   def acosh(x):
       """Take arc cosh of input x.
       ...
       """
       x = _require_float_tensor("acosh", x)
       return te.compute(x.shape, lambda *i: te.acosh(x(*i)))
   ```



##########
tests/python/te/test_te_create_primfunc.py:
##########
@@ -359,6 +359,31 @@ def test_constant():
     tvm.testing.assert_allclose(a_np + 2, c.numpy())
 
 
[email protected]("op_name", ["acos", "asin"])
+def test_topi_float_unary_rejects_integer_input(op_name):
+    x = te.placeholder((1, 8), dtype="int16", name="x")
+    op = getattr(topi, op_name)
+
+    with pytest.raises(
+        TypeError,
+        match=rf"topi\.{op_name} only supports floating-point inputs, but got 
int16",
+    ):
+        op(x)
+
+
[email protected]("op_name", ["acos", "asin"])
+def test_topi_float_unary_accepts_float_input(op_name):
+    x = te.placeholder((1, 8), dtype="float32", name="x")
+    op = getattr(topi, op_name)
+    out = op(x)
+
+    func = te.create_prim_func([x, out]).with_attr("target", 
tvm.target.Target("llvm"))
+    mod = tvm.IRModule({"main": func})
+    compiled = tvm.build(mod, target="llvm")
+
+    assert compiled is not None

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The `_require_float_tensor` function supports both `FLOAT` and `BFLOAT` 
types. To ensure full coverage, it would be beneficial to extend this test to 
also validate `bfloat16` inputs. You can achieve this by parameterizing the 
`dtype`.
   
   You may need to add logic to skip the `bfloat16` test on platforms where 
it's not supported.
   
   ```suggestion
   @pytest.mark.parametrize("op_name", ["acos", "asin"])
   @pytest.mark.parametrize("dtype", ["float32", "bfloat16"])
   def test_topi_float_unary_accepts_float_input(op_name, dtype):
       x = te.placeholder((1, 8), dtype=dtype, name="x")
       op = getattr(topi, op_name)
       out = op(x)
   
       func = te.create_prim_func([x, out]).with_attr("target", 
tvm.target.Target("llvm"))
       mod = tvm.IRModule({"main": func})
       compiled = tvm.build(mod, target="llvm")
   
       assert compiled is not None
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to