Issue |
136117
|
Summary |
Issues enabling ND mesh resharding in Spmdization pass: incorrect axis comparison and resharding assertion failures
|
Labels |
new issue
|
Assignees |
|
Reporter |
zhangdianchen
|
When trying to enable ND mesh resharding in the Spmdization pass of MLIR, I encountered several issues that cause incorrect behavior or assertion failures. Below is a detailed breakdown:
### 1. Incorrect detection logic in detectMoveLastSplitAxisInResharding
**Problem Reproduction:**
When executing the following resharding sequence:
```
%sharding = mesh.sharding @mesh_3d split_axes = [[0, 1], [2]] : !mesh.sharding
%in1_sharded1 = mesh.shard %in1 to %sharding : tensor<8x16xi8>
%sharding = mesh.sharding @mesh_3d split_axes = [[0], [1, 2]] : !mesh.sharding
%in1_sharded2 = mesh.shard %in1_sharded1 to %sharding annotate_for_users : tensor<8x16xi8>
```
The pass is expected to detect a valid last-axis movement and insert a mesh.all_to_all operation. However, instead it crashes with the following assertion:
`mlir::TypedValue<mlir::ShapedType> mlir::sharding::reshardOn1DGrid(...): Assertion `targetShard && "Did not find any pattern to apply."' failed.
`
**Root Cause:**
In detectMoveLastSplitAxisInResharding, the logic checks:
```
if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().back())
continue;
```
In the example [[0, 1], [2]] -> [[0], [1, 2]], this compares:
- source.split_axes[0][1] = 1 vs. target.split_axes[1][1] = 2 — incorrect
It should instead compare source.back() with target.front():
`sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().front()
`
**Additional incorrect check:**
```
if (!llvm::equal(
llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().begin(),
sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().end() - 1),
llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().begin(),
targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().end() - 1)))
continue;
```
This incorrectly compares the wrong slices. In the example, it ends up comparing [0] and [1]. Instead, it should skip the first of the target and compare:
```
if (llvm::equal(
llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().begin(),
sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().end() - 1),
llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().begin() + 1,
targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().end())))
continue;
```
This now compares [0] with [2] — which is correct.
### 2.Incorrect ShardingTarget construction in targetShardingInMoveLastAxis
Skipping the above assert leads to another failure:
`mlir::TypedValue<mlir::ShapedType> mlir::sharding::reshardOn1DGrid(...): Assertion `actualTargetSharding == targetSharding' failed.`
Root cause: in targetShardingInMoveLastAxis, the targetShardingSplitAxes are incorrectly ordered.
Current result:
```
actualTargetSharding: split_axes = [[0], [2, 1]]
targetSharding: split_axes = [[0], [1, 2]]
```
Fix: Instead of:
`targetSplitAxes.push_back(gridAxis);
`
Use:
```
targetSplitAxes.insert(targetSplitAxes.begin(), gridAxis);
```
### 3. Bug in handlePartialAxesDuringResharding
In the following snippet:
```
llvm::SmallVector<GridAxis> remainingPartialAxes;
llvm::copy_if(sourceShardingPartialAxesSet,
std::back_inserter(allReduceGridAxes),
[&targetShardingPartialAxesSet](Axis a) {
return targetShardingPartialAxesSet.contains(a);
});
```
It should be writing to remainingPartialAxes, not allReduceGridAxes. Corrected version:
```
llvm::copy_if(sourceShardingPartialAxesSet,
std::back_inserter(remainingPartialAxes),
[&targetShardingPartialAxesSet](Axis a) {
return targetShardingPartialAxesSet.contains(a);
});
```
Please let me know if a patch is desired — I'm happy to contribute a PR for these changes.
_______________________________________________
llvm-bugs mailing list
llvm-bugs@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-bugs