================
@@ -2564,20 +2564,24 @@ static Value *EmitHLSLElementwiseCast(CodeGenFunction 
&CGF, LValue SrcVal,
            "Flattened type on RHS must have the same number or more elements "
            "than vector on LHS.");
 
+    bool IsRowMajor = CGF.getLangOpts().getDefaultMatrixMemoryLayout() ==
+                      LangOptions::MatrixMemoryLayout::MatrixRowMajor;
+
     llvm::Value *V = CGF.Builder.CreateLoad(
         CGF.CreateIRTempWithoutCast(DestTy, "flatcast.tmp"));
-    // V is an allocated temporary to build the truncated matrix into.
-    for (unsigned I = 0, E = MatTy->getNumElementsFlattened(); I < E; I++) {
-      unsigned ColMajorIndex =
-          (I % MatTy->getNumRows()) * MatTy->getNumColumns() +
-          (I / MatTy->getNumRows());
-      RValue RVal = CGF.EmitLoadOfLValue(LoadList[ColMajorIndex], Loc);
-      assert(RVal.isScalar() &&
-             "All flattened source values should be scalars.");
-      llvm::Value *Cast = CGF.EmitScalarConversion(
-          RVal.getScalarVal(), LoadList[ColMajorIndex].getType(),
-          MatTy->getElementType(), Loc);
-      V = CGF.Builder.CreateInsertElement(V, Cast, I);
+    // V is an allocated temporary for constructing the matrix.
+    for (unsigned Row = 0, RE = MatTy->getNumRows(); Row < RE; Row++) {
+      for (unsigned Col = 0, CE = MatTy->getNumColumns(); Col < CE; Col++) {
+        unsigned LoadIdx = MatTy->getRowMajorFlattenedIndex(Row, Col);
----------------
Icohedron wrote:

The `LoadList` is **always** row-major order regardless of default matrix 
memory layout. So I use the index calculation `Row * MatTy->getNumColumns() + 
Col` to index into `LoadList`, which is equivalent to getting the row-major 
flat index of `MatTy`.
Alternatively I could just flatten the loops and iterate over `LoadList` 
without any Row/Col arithmetic except in computing the index into the 
destination matrix `V` which is affected by the default matrix memory layout, 
but I would have to reintroduce the `getRowCol()` helper functions to convert 
`I` into Row/Col for `V` and then make a flattened index for `V` from that -- 
which is a pattern you didn't like from the matrix truncation PR.

https://github.com/llvm/llvm-project/pull/184429
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to