gemini-code-assist[bot] commented on code in PR #18623:
URL: https://github.com/apache/tvm/pull/18623#discussion_r2650339047
##########
python/tvm/dlight/gpu/matmul.py:
##########
@@ -1040,8 +1052,12 @@ def is_inner_reduction(block_stmt, iter_infos):
return tensorize_sch
elif target.kind.name == "metal":
try:
- return MetalMatmul().apply(func, target, _)
- except: # pylint: disable=bare-except
+ result = MetalMatmul().apply(func, target, _)
+ if result is not None:
+ return result
+ except Exception as e: # pylint: disable=broad-except
+ import warnings
+ warnings.warn(f"MetalMatmul failed: {e}, falling back to
generic Matmul")
Review Comment:

Capturing the specific exception `Exception as e` is a good improvement over
a bare `except`. The warning message is also very helpful for debugging. This
makes the fallback mechanism much more robust and informative.
##########
python/tvm/relax/frontend/nn/llm/position_embedding.py:
##########
@@ -197,8 +197,8 @@ def yarn_find_correction_range(
max_position_embeddings: int,
):
"""Find the correction range based on the number of rotations"""
- low = tir.floor(yarn_find_correction_dim(low_rot, d, theta,
max_position_embeddings))
- high = tir.ceil(yarn_find_correction_dim(high_rot, d, theta,
max_position_embeddings))
+ low = yarn_find_correction_dim(low_rot, d, theta, max_position_embeddings)
+ high = yarn_find_correction_dim(high_rot, d, theta,
max_position_embeddings)
Review Comment:

The `tir.floor` and `tir.ceil` calls have been removed from the calculation
of `low` and `high`. While the subsequent calculations can handle
floating-point values for `low` and `high`, this change deviates from reference
implementations of YaRN (e.g., in HuggingFace Transformers) which explicitly
use `floor` and `ceil` to define the correction range. This could potentially
alter the scaling behavior. Could you clarify if this change is intentional? If
it's an oversight, it would be safer to restore the `floor` and `ceil` to align
with the reference algorithm.
```suggestion
low = tir.floor(yarn_find_correction_dim(low_rot, d, theta,
max_position_embeddings))
high = tir.ceil(yarn_find_correction_dim(high_rot, d, theta,
max_position_embeddings))
```
##########
src/runtime/metal/metal_module.mm:
##########
@@ -227,8 +240,12 @@ void operator()(ffi::PackedArgs args, ffi::Any* rv, const
ArgUnion64* pack_args)
// launch
MTLSize dimGrid = MTLSizeMake(wl.grid_dim(0), wl.grid_dim(1),
wl.grid_dim(2));
MTLSize dimBlock = MTLSizeMake(wl.block_dim(0), wl.block_dim(1),
wl.block_dim(2));
+
+ // LOG(DEBUG) << "Dispatching kernel...";
[encoder dispatchThreadgroups:dimGrid threadsPerThreadgroup:dimBlock];
+ // LOG(DEBUG) << "Kernel dispatched, encoding...";
[encoder endEncoding];
+ // LOG(DEBUG) << "Encoding complete";
Review Comment:

This commented-out log appears to be for debugging. Please remove it.
##########
src/runtime/metal/metal_module.mm:
##########
@@ -227,8 +240,12 @@ void operator()(ffi::PackedArgs args, ffi::Any* rv, const
ArgUnion64* pack_args)
// launch
MTLSize dimGrid = MTLSizeMake(wl.grid_dim(0), wl.grid_dim(1),
wl.grid_dim(2));
MTLSize dimBlock = MTLSizeMake(wl.block_dim(0), wl.block_dim(1),
wl.block_dim(2));
+
+ // LOG(DEBUG) << "Dispatching kernel...";
Review Comment:

This blank line and commented-out log appear to be for debugging. Please
remove them.
##########
src/runtime/metal/metal_module.mm:
##########
@@ -227,8 +240,12 @@ void operator()(ffi::PackedArgs args, ffi::Any* rv, const
ArgUnion64* pack_args)
// launch
MTLSize dimGrid = MTLSizeMake(wl.grid_dim(0), wl.grid_dim(1),
wl.grid_dim(2));
MTLSize dimBlock = MTLSizeMake(wl.block_dim(0), wl.block_dim(1),
wl.block_dim(2));
+
+ // LOG(DEBUG) << "Dispatching kernel...";
[encoder dispatchThreadgroups:dimGrid threadsPerThreadgroup:dimBlock];
+ // LOG(DEBUG) << "Kernel dispatched, encoding...";
Review Comment:

This commented-out log appears to be for debugging. Please remove it.
##########
src/runtime/metal/metal_module.mm:
##########
@@ -209,6 +216,12 @@ void operator()(ffi::PackedArgs args, ffi::Any* rv, const
ArgUnion64* pack_args)
ThreadWorkLoad wl = launch_param_config_.Extract(args);
int blockSize = wl.block_dim(0) * wl.block_dim(1) * wl.block_dim(2);
auto maxTotalThreadsPerThreadgroup =
scache_[device_id].maxTotalThreadsPerThreadgroup;
+
+ // LOG(INFO) << "Launching " << func_name_
+ // << " with grid(" << wl.grid_dim(0) << "," << wl.grid_dim(1)
<< "," << wl.grid_dim(2) << ")"
+ // << " block(" << wl.block_dim(0) << "," << wl.block_dim(1)
<< "," << wl.block_dim(2) << ")"
+ // << " blockSize=" << blockSize << " max=" <<
maxTotalThreadsPerThreadgroup;
+
Review Comment:

These commented-out logging statements and surrounding newlines appear to be
for debugging. Please remove them before merging to keep the codebase clean.
##########
src/target/source/codegen_metal.cc:
##########
@@ -423,10 +441,21 @@ void CodeGenMetal::VisitExpr_(const FloatImmNode* op,
std::ostream& os) { // NO
temp << "NAN";
} else {
temp << std::scientific << op->value;
- if (op->dtype.bits() == 32)
+ if (op->dtype.bits() == 32) {
temp << 'f';
- else if (op->dtype.bits() == 16)
- temp << 'h';
+ } else if (op->dtype.bits() == 16) {
+ if (op->dtype.is_bfloat16()) {
+ // bfloat16 requires explicit cast, can't use 'h' suffix
+ std::ostringstream temp2;
+ temp2 << "bfloat(" << temp.str() << "f)";
+ MarkConst(temp2.str());
+ os << temp2.str();
+ return;
+ } else {
+ // half (float16) uses 'h' suffix
+ temp << 'h';
+ }
+ }
Review Comment:

Handling `bfloat16` literals by generating an explicit `bfloat(...)` cast is
the correct approach, since Metal does not support a literal suffix for this
type. This is a necessary change for bf16 codegen.
##########
src/target/source/codegen_metal.cc:
##########
@@ -378,9 +393,12 @@ void CodeGenMetal::VisitExpr_(const CallNode* op,
std::ostream& os) { // NOLINT
<< "Cannot find variable allocation for simdgroup: " << var;
const std::string& dtype_str = it->second;
f_check_simdgroup_shape(op->args[3], op->args[4]);
+ // Cast the fill value to the correct dtype for Metal's
make_filled_simdgroup_matrix
+ // Metal requires the value to be convertible to the matrix element type
(e.g., bfloat)
+ // Without the cast, float literals like 0.0f fail with bfloat matrices
os << PrintExpr(var) << "[" << PrintExpr(op->args[1]) << "] =
make_filled_simdgroup_matrix<"
<< dtype_str << ", " << PrintExpr(op->args[3]) << ", " <<
PrintExpr(op->args[4]) << ">("
- << PrintExpr(op->args[2]) << ")";
+ << dtype_str << "(" << PrintExpr(op->args[2]) << "))";
Review Comment:

Explicitly casting the fill value to the matrix element type is a crucial
fix for `bfloat16` support, as Metal requires this for
`make_filled_simdgroup_matrix`. This prevents compilation errors with float
literals. Good catch!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]