gemini-code-assist[bot] commented on code in PR #18869:
URL: https://github.com/apache/tvm/pull/18869#discussion_r2881702073
##########
include/tvm/s_tir/data_layout.h:
##########
@@ -158,6 +160,22 @@ class Layout : public ObjectRef {
return undef;
}
+ /*!
+ * \brief Packs the Given Array of IterVars into a Single IterVar. Each
IterVar in the Array
+ * should represent either a single primal axis or one or more
subordinate axis
+ * \param iters Array of iter vars to be packed
+ * \return A packed iter var
+ */
Review Comment:

The documentation for `PackIterVar` states that it can pack a "single primal
axis or one or more subordinate axis". However, the implementation in
`src/s_tir/data_layout.cc` only supports packing subordinate axes and includes
a check `Packed Axis can contain only Subordinate Axes`. This discrepancy
should be resolved. Please update the documentation to reflect the
implementation's behavior.
##########
src/s_tir/data_layout.cc:
##########
@@ -218,63 +306,120 @@ inline bool GetStoreRule(ffi::Array<PrimExpr>*
index_rule, ffi::Array<PrimExpr>*
return false;
}
- for (size_t i = 0; i < dst_layout.ndim(); ++i) {
- const auto& store_axis = dst_layout[i];
- const IterVar& store_axis_impl = dst_layout->axes[i];
- PrimExpr index_store(0);
-
- for (size_t j = 0; j < src_layout.ndim(); ++j) {
- const auto& orig_axis = src_layout[j];
- const IterVar& orig_axis_impl = src_layout->axes[j];
- if (store_axis.ToPrimal() == orig_axis.ToPrimal()) {
- if (orig_axis.IsPrimal()) {
- PrimExpr orig_var = orig_axis_impl->var;
- const int32_t factor = src_layout.FactorOf(orig_axis);
- if (factor > 0) {
- orig_var = orig_var * factor;
- }
- index_store = index_store + orig_var;
- } else {
- PrimExpr factor(1);
- for (size_t k = j + 1; k < src_layout.ndim(); ++k) {
- if (LayoutAxis::Get(orig_axis_impl) ==
LayoutAxis::Get(src_layout->axes[k])) {
- factor = factor * src_layout->axes[k]->dom->extent;
+ std::vector<bool> exists(128, false);
+ PrimExpr norm_indexes[128];
+ for (auto& it : norm_indexes) it = PrimExpr(0);
+
+ for (size_t i = 0; i < src_layout.ndim(); i++) {
+ auto factor = src_layout.PackedAxisAt(i)->dom->extent;
+ auto src_unpacked_axes = Layout::UnpackIterVar(src_layout.PackedAxisAt(i));
+
+ if (src_unpacked_axes.size() == 1 &&
LayoutAxis::Get(src_unpacked_axes[0]).IsPrimal()) {
+ const auto& prim_axis = LayoutAxis::Get(src_unpacked_axes[0]);
+ int64_t offset = src_layout.FactorOf(prim_axis);
+ if (offset == -1)
+ norm_indexes[prim_axis.name()[0] - 'A'] =
+ norm_indexes[prim_axis.name()[0] - 'A'] +
src_layout.PackedAxisAt(i);
+ else
+ norm_indexes[prim_axis.name()[0] - 'A'] =
+ norm_indexes[prim_axis.name()[0] - 'A'] +
+ src_layout.PackedAxisAt(i) * src_layout.FactorOf(prim_axis);
+ exists[prim_axis.name()[0]] = true;
+ } else {
+ int64_t value = 1;
+ std::vector<int> index_divs(src_unpacked_axes.size());
+ for (size_t j = 0; j < src_unpacked_axes.size(); j++) {
+ index_divs[j] = value;
+ const auto* extent =
src_unpacked_axes[j]->dom->extent.as<IntImmNode>();
+ TVM_FFI_ICHECK(extent) << "Expected Integer Extents for Offset
Calculation";
+ index_divs.push_back(value);
+ value = value * extent->value;
+ }
Review Comment:

There appears to be a bug in this loop. The vector `index_divs` is
initialized with a specific size at line 330. Inside the loop, both
`index_divs[j] = value;` and `index_divs.push_back(value);` are called. This
will result in a vector of double the intended size with incorrect contents,
which will likely cause incorrect behavior in layout transformations. The
`push_back` call should probably be removed.
```c
for (size_t j = 0; j < src_unpacked_axes.size(); j++) {
index_divs[j] = value;
const auto* extent =
src_unpacked_axes[j]->dom->extent.as<IntImmNode>();
TVM_FFI_ICHECK(extent) << "Expected Integer Extents for Offset
Calculation";
value = value * extent->value;
}
```
##########
python/tvm/s_tir/data_layout.py:
##########
@@ -41,7 +41,8 @@ def __len__(self):
return _ffi_api.LayoutNdim(self) # type: ignore
def __contains__(self, axis):
- return len(axis) == 1 and axis[0].isalpha() and axis[0] in self.name
+ # Note: We do a weaker check for packed axis assuming layout is valid
+ return not any(bkt in axis for bkt in "[]") and axis in self.name
Review Comment:

The current implementation of `__contains__` using a substring check (`axis
in self.name`) can lead to incorrect results and is inconsistent with
`index_of`. For example, for a layout `NCHW16c`, `__contains__("16")` would
return `True`, which is incorrect as "16" is not an axis. Similarly, for
`OIHW[4o4i]`, `"i" in layout` is `True` while `layout.index_of("i")` is `-1`.
A more robust implementation should check against the actual axes of the
layout. The current substring check is too broad and gives false positives.
##########
src/s_tir/data_layout.cc:
##########
@@ -218,63 +306,120 @@ inline bool GetStoreRule(ffi::Array<PrimExpr>*
index_rule, ffi::Array<PrimExpr>*
return false;
}
- for (size_t i = 0; i < dst_layout.ndim(); ++i) {
- const auto& store_axis = dst_layout[i];
- const IterVar& store_axis_impl = dst_layout->axes[i];
- PrimExpr index_store(0);
-
- for (size_t j = 0; j < src_layout.ndim(); ++j) {
- const auto& orig_axis = src_layout[j];
- const IterVar& orig_axis_impl = src_layout->axes[j];
- if (store_axis.ToPrimal() == orig_axis.ToPrimal()) {
- if (orig_axis.IsPrimal()) {
- PrimExpr orig_var = orig_axis_impl->var;
- const int32_t factor = src_layout.FactorOf(orig_axis);
- if (factor > 0) {
- orig_var = orig_var * factor;
- }
- index_store = index_store + orig_var;
- } else {
- PrimExpr factor(1);
- for (size_t k = j + 1; k < src_layout.ndim(); ++k) {
- if (LayoutAxis::Get(orig_axis_impl) ==
LayoutAxis::Get(src_layout->axes[k])) {
- factor = factor * src_layout->axes[k]->dom->extent;
+ std::vector<bool> exists(128, false);
+ PrimExpr norm_indexes[128];
+ for (auto& it : norm_indexes) it = PrimExpr(0);
+
+ for (size_t i = 0; i < src_layout.ndim(); i++) {
+ auto factor = src_layout.PackedAxisAt(i)->dom->extent;
+ auto src_unpacked_axes = Layout::UnpackIterVar(src_layout.PackedAxisAt(i));
+
+ if (src_unpacked_axes.size() == 1 &&
LayoutAxis::Get(src_unpacked_axes[0]).IsPrimal()) {
+ const auto& prim_axis = LayoutAxis::Get(src_unpacked_axes[0]);
+ int64_t offset = src_layout.FactorOf(prim_axis);
+ if (offset == -1)
+ norm_indexes[prim_axis.name()[0] - 'A'] =
+ norm_indexes[prim_axis.name()[0] - 'A'] +
src_layout.PackedAxisAt(i);
+ else
+ norm_indexes[prim_axis.name()[0] - 'A'] =
+ norm_indexes[prim_axis.name()[0] - 'A'] +
+ src_layout.PackedAxisAt(i) * src_layout.FactorOf(prim_axis);
+ exists[prim_axis.name()[0]] = true;
+ } else {
+ int64_t value = 1;
+ std::vector<int> index_divs(src_unpacked_axes.size());
+ for (size_t j = 0; j < src_unpacked_axes.size(); j++) {
+ index_divs[j] = value;
+ const auto* extent =
src_unpacked_axes[j]->dom->extent.as<IntImmNode>();
+ TVM_FFI_ICHECK(extent) << "Expected Integer Extents for Offset
Calculation";
+ index_divs.push_back(value);
+ value = value * extent->value;
+ }
+ std::reverse(index_divs.begin(), index_divs.end());
+
+ for (size_t j = 0; j < src_unpacked_axes.size(); j++) {
+ const int extent =
src_unpacked_axes[j]->dom->extent.as<IntImmNode>()->value;
+ const LayoutAxis& store_axis_impl =
LayoutAxis::Get(src_unpacked_axes[j]);
+ const LayoutAxis& sub_axis = store_axis_impl.ToSubordinate(); /* Not
Needed */
Review Comment:

The comment `/* Not Needed */` suggests this variable `sub_axis` is unused.
If it's indeed not needed, it should be removed to improve code clarity and
avoid confusion.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]