gemini-code-assist[bot] commented on code in PR #18880:
URL: https://github.com/apache/tvm/pull/18880#discussion_r2894270018
##########
python/tvm/topi/math.py:
##########
@@ -18,13 +18,19 @@
# pylint: disable=redefined-builtin,unused-argument
import tvm
-from tvm import te
+from tvm import DataType, DataTypeCode, te
from tvm.tir import PrimExpr
from . import cpp, tag
from .utils import get_const_tuple
+def _require_float_tensor(op_name, x):
+ if DataType(x.dtype).type_code not in (DataTypeCode.FLOAT,
DataTypeCode.BFLOAT):
+ raise TypeError(f"topi.{op_name} only supports floating-point inputs,
but got {x.dtype}")
+ return x
Review Comment:

This is a useful helper function for enforcing float-only inputs. For
consistency and to prevent future issues, consider applying this check to other
inverse trigonometric and hyperbolic functions that currently lack input type
validation, such as `acosh`, `asinh`, and `atanh`. This would ensure they also
handle non-floating-point types gracefully by raising a clear error.
For example, `acosh` could be updated as follows:
```python
@tvm.te.tag_scope(tag=tag.ELEMWISE)
def acosh(x):
"""Take arc cosh of input x.
...
"""
x = _require_float_tensor("acosh", x)
return te.compute(x.shape, lambda *i: te.acosh(x(*i)))
```
##########
tests/python/te/test_te_create_primfunc.py:
##########
@@ -359,6 +359,31 @@ def test_constant():
tvm.testing.assert_allclose(a_np + 2, c.numpy())
[email protected]("op_name", ["acos", "asin"])
+def test_topi_float_unary_rejects_integer_input(op_name):
+ x = te.placeholder((1, 8), dtype="int16", name="x")
+ op = getattr(topi, op_name)
+
+ with pytest.raises(
+ TypeError,
+ match=rf"topi\.{op_name} only supports floating-point inputs, but got
int16",
+ ):
+ op(x)
+
+
[email protected]("op_name", ["acos", "asin"])
+def test_topi_float_unary_accepts_float_input(op_name):
+ x = te.placeholder((1, 8), dtype="float32", name="x")
+ op = getattr(topi, op_name)
+ out = op(x)
+
+ func = te.create_prim_func([x, out]).with_attr("target",
tvm.target.Target("llvm"))
+ mod = tvm.IRModule({"main": func})
+ compiled = tvm.build(mod, target="llvm")
+
+ assert compiled is not None
Review Comment:

The `_require_float_tensor` function supports both `FLOAT` and `BFLOAT`
types. To ensure full coverage, it would be beneficial to extend this test to
also validate `bfloat16` inputs. You can achieve this by parameterizing the
`dtype`.
You may need to add logic to skip the `bfloat16` test on platforms where
it's not supported.
```suggestion
@pytest.mark.parametrize("op_name", ["acos", "asin"])
@pytest.mark.parametrize("dtype", ["float32", "bfloat16"])
def test_topi_float_unary_accepts_float_input(op_name, dtype):
x = te.placeholder((1, 8), dtype=dtype, name="x")
op = getattr(topi, op_name)
out = op(x)
func = te.create_prim_func([x, out]).with_attr("target",
tvm.target.Target("llvm"))
mod = tvm.IRModule({"main": func})
compiled = tvm.build(mod, target="llvm")
assert compiled is not None
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]