fhahn created this revision.
fhahn added reviewers: rjmccall, anemet, rsmith, erichkeane.
Herald added subscribers: dexonsmith, tschuett.
fhahn requested review of this revision.
Herald added projects: clang, LLVM.
Herald added a subscriber: llvm-commits.

The matrix extension requires the indices for matrix subscript
expression to be valid and it is UB otherwise.

extract/insertelement produce poison if the index is invalid, which
limits the optimizer to not be bale to scalarize load/extract pairs for
example, which causes very suboptimal code to be generated when using
matrix subscript expressions with variable indices for large matrixes.

This patch updates IRGen to emit assumes to for index expression to
convey the information that the index must be valid.

This also adjusts the order in which operations are emitted slightly, so
indices & assumes are added before the load of the matrix value.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D102478

Files:
  clang/lib/CodeGen/CGExpr.cpp
  clang/lib/CodeGen/CGExprScalar.cpp
  clang/test/CodeGen/matrix-type-operators.c
  clang/test/CodeGenCXX/matrix-type-operators.cpp
  clang/test/CodeGenObjC/matrix-type-operators.m
  llvm/include/llvm/IR/MatrixBuilder.h

Index: llvm/include/llvm/IR/MatrixBuilder.h
===================================================================
--- llvm/include/llvm/IR/MatrixBuilder.h
+++ llvm/include/llvm/IR/MatrixBuilder.h
@@ -231,9 +231,25 @@
                : (IsUnsigned ? B.CreateUDiv(LHS, RHS) : B.CreateSDiv(LHS, RHS));
   }
 
-  /// Extracts the element at (\p RowIdx, \p ColumnIdx) from \p Matrix.
-  Value *CreateExtractElement(Value *Matrix, Value *RowIdx, Value *ColumnIdx,
-                              unsigned NumRows, Twine const &Name = "") {
+  /// Create an assumption that \p Idx is less than \p NumElements.
+  void CreateIndexAssumption(Value *Idx, unsigned NumElements,
+                             Twine const &Name = "") {
+
+    Value *NumElts =
+        B.getIntN(Idx->getType()->getScalarSizeInBits(), NumElements);
+
+    auto *Cmp = B.CreateICmpULT(Idx, NumElts);
+    if (!isa<llvm::ConstantInt>(Cmp)) {
+      Function *TheFn =
+          Intrinsic::getDeclaration(getModule(), Intrinsic::assume);
+      B.CreateCall(TheFn->getFunctionType(), TheFn, {Cmp}, Name);
+    }
+  }
+
+  /// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from
+  /// a matrix with \p NumRows embedded in a vector.
+  Value *CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows,
+                     Twine const &Name = "") {
 
     unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(),
                                  ColumnIdx->getType()->getScalarSizeInBits());
@@ -241,9 +257,7 @@
     RowIdx = B.CreateZExt(RowIdx, IntTy);
     ColumnIdx = B.CreateZExt(ColumnIdx, IntTy);
     Value *NumRowsV = B.getIntN(MaxWidth, NumRows);
-    return B.CreateExtractElement(
-        Matrix, B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx),
-        "matext");
+    return B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx);
   }
 };
 
Index: clang/test/CodeGenObjC/matrix-type-operators.m
===================================================================
--- clang/test/CodeGenObjC/matrix-type-operators.m
+++ clang/test/CodeGenObjC/matrix-type-operators.m
@@ -22,9 +22,11 @@
 // CHECK-NEXT:    [[IV2_PTR:%.*]] = bitcast %0* [[IV2]] to i8*
 // CHECK-NEXT:    [[CALL1:%.*]] = call i32 bitcast (i8* (i8*, i8*, ...)* @objc_msgSend to i32 (i8*, i8*)*)(i8* [[IV2_PTR]], i8* [[SEL2]])
 // CHECK-NEXT:    [[CONV2:%.*]] = sext i32 [[CALL1]] to i64
-// CHECK-NEXT:    [[MAT:%.*]] = load <16 x double>, <16 x double>* {{.*}} align 8
 // CHECK-NEXT:    [[IDX1:%.*]] = mul i64 [[CONV2]], 4
 // CHECK-NEXT:    [[IDX2:%.*]] = add i64 [[IDX1]], [[CONV]]
+// CHECK-NEXT:   [[CMP:%.*]] = icmp ult i64 [[IDX2]], 16
+// CHECK-NEXT:    call void @llvm.assume(i1 [[CMP]])
+// CHECK-NEXT:    [[MAT:%.*]] = load <16 x double>, <16 x double>* {{.*}} align 8
 // CHECK-NEXT:    [[MATEXT:%.*]] = extractelement <16 x double> [[MAT]], i64 [[IDX2]]
 // CHECK-NEXT:    ret double [[MATEXT]]
 //
@@ -49,12 +51,14 @@
 // CHECK-NEXT:    [[IV2_PTR:%.*]] = bitcast %0* [[IV2]] to i8*
 // CHECK-NEXT:    [[CALL1:%.*]] = call i32 bitcast (i8* (i8*, i8*, ...)* @objc_msgSend to i32 (i8*, i8*)*)(i8* [[IV2_PTR]], i8* [[SEL2]])
 // CHECK-NEXT:    [[CONV2:%.*]] = sext i32 [[CALL1]] to i64
+// CHECK-NEXT:    [[IDX1:%.*]] = mul i64 [[CONV2]], 4
+// CHECK-NEXT:    [[IDX2:%.*]] = add i64 [[IDX1]], [[CONV]]
+// CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[IDX2]], 16
+// CHECK-NEXT:    call void @llvm.assume(i1 [[CMP]])
 // CHECK-NEXT:    [[M:%.*]] = load %1*, %1** %m.addr, align 8
 // CHECK-NEXT:    [[SEL3:%.*]] = load i8*, i8** @OBJC_SELECTOR_REFERENCES_, align 8, !invariant.load !7
 // CHECK-NEXT:    [[M_PTR:%.*]] = bitcast %1* [[M]] to i8*
 // CHECK-NEXT:    [[MAT:%.*]] = call <16 x double> bitcast (i8* (i8*, i8*, ...)* @objc_msgSend to <16 x double> (i8*, i8*)*)(i8* [[M_PTR]], i8* [[SEL3]])
-// CHECK-NEXT:    [[IDX1:%.*]] = mul i64 [[CONV2]], 4
-// CHECK-NEXT:    [[IDX2:%.*]] = add i64 [[IDX1]], [[CONV]]
 // CHECK-NEXT:    [[MATEXT:%.*]] = extractelement <16 x double> [[MAT]], i64 [[IDX2]]
 // CHECK-NEXT:    ret double [[MATEXT]]
 //
Index: clang/test/CodeGenCXX/matrix-type-operators.cpp
===================================================================
--- clang/test/CodeGenCXX/matrix-type-operators.cpp
+++ clang/test/CodeGenCXX/matrix-type-operators.cpp
@@ -219,6 +219,8 @@
   // CHECK-NEXT:    [[IDX1:%.*]] = mul i64 [[J_EXT]], 2
   // CHECK-NEXT:    [[IDX2:%.*]] = add i64 [[IDX1]], [[I_EXT]]
   // CHECK-NEXT:    [[MAT_ADDR:%.*]] = bitcast [4 x i32]* {{.*}} to <4 x i32>*
+  // CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[IDX2]], 4
+  // CHECK-NEXT:    call void @llvm.assume(i1 [[CMP]])
   // CHECK-NEXT:    [[MAT:%.*]] = load <4 x i32>, <4 x i32>* [[MAT_ADDR]], align 4
   // CHECK-NEXT:    [[MATINS:%.*]] = insertelement <4 x i32> [[MAT]], i32 [[E]], i64 [[IDX2]]
   // CHECK-NEXT:    store <4 x i32> [[MATINS]], <4 x i32>* [[MAT_ADDR]], align 4
@@ -243,6 +245,8 @@
   // CHECK-NEXT:    [[IDX1:%.*]] = mul i64 [[J_EXT]], 3
   // CHECK-NEXT:    [[IDX2:%.*]] = add i64 [[IDX1]], [[I_EXT]]
   // CHECK-NEXT:    [[MAT_ADDR:%.*]] = bitcast [24 x float]* {{.*}} to <24 x float>*
+  // CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[IDX2]], 24
+  // CHECK-NEXT:    call void @llvm.assume(i1 [[CMP]])
   // CHECK-NEXT:    [[MAT:%.*]] = load <24 x float>, <24 x float>* [[MAT_ADDR]], align 4
   // CHECK-NEXT:    [[MATINS:%.*]] = insertelement <24 x float> [[MAT]], float [[E]], i64 [[IDX2]]
   // CHECK-NEXT:    store <24 x float> [[MATINS]], <24 x float>* [[MAT_ADDR]], align 4
@@ -315,11 +319,13 @@
   // CHECK-NEXT:    [[J:%.*]] = call i32 @_ZN15UnsignedWrappercvjEv(%struct.UnsignedWrapper* {{[^,]*}} %j)
   // CHECK-NEXT:    [[J_SUB:%.*]] = sub i32 [[J]], 1
   // CHECK-NEXT:    [[J_SUB_EXT:%.*]] = zext i32 [[J_SUB]] to i64
+  // CHECK-NEXT:    [[IDX1:%.*]] = mul i64 [[J_SUB_EXT]], 4
+  // CHECK-NEXT:    [[IDX2:%.*]] = add i64 [[IDX1]], [[I_ADD_EXT]]
+  // CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[IDX2]], 16
+  // CHECK-NEXT:    call void @llvm.assume(i1 [[CMP]])
   // CHECK-NEXT:    [[MAT_ADDR:%.*]] = load [16 x double]*, [16 x double]** %m.addr, align 8
   // CHECK-NEXT:    [[MAT_ADDR2:%.*]] = bitcast [16 x double]* [[MAT_ADDR]] to <16 x double>*
   // CHECK-NEXT:    [[MAT:%.*]] = load <16 x double>, <16 x double>* [[MAT_ADDR2]], align 8
-  // CHECK-NEXT:    [[IDX1:%.*]] = mul i64 [[J_SUB_EXT]], 4
-  // CHECK-NEXT:    [[IDX2:%.*]] = add i64 [[IDX1]], [[I_ADD_EXT]]
   // CHECK-NEXT:    [[MATEXT:%.*]]  = extractelement <16 x double> [[MAT]], i64 [[IDX2]]
   // CHECK-NEXT:    ret double [[MATEXT]]
   return m[i + 1][j - 1];
@@ -358,6 +364,8 @@
   // CHECK-NEXT:   [[IDX1:%.*]] = mul i64 [[I2_EXT]], 4
   // CHECK-NEXT:   [[IDX2:%.*]] = add i64 [[IDX1]], [[I_EXT]]
   // CHECK-NEXT:   [[MAT_ADDR:%.*]] = bitcast [16 x float]* %result to <16 x float>*
+  // CHECK-NEXT:   [[CMP:%.*]] = icmp ult i64 [[IDX2]], 16
+  // CHECK-NEXT:   call void @llvm.assume(i1 [[CMP]])
   // CHECK-NEXT:   [[MAT:%.*]] = load <16 x float>, <16 x float>* [[MAT_ADDR]], align 4
   // CHECK-NEXT:   [[MATINS:%.*]] = insertelement <16 x float> [[MAT]], float 1.000000e+00, i64 [[IDX2]]
   // CHECK-NEXT:   store <16 x float> [[MATINS]], <16 x float>* [[MAT_ADDR]], align 4
@@ -386,6 +394,8 @@
   // CHECK-NEXT:   [[IDX1:%.*]] = mul i64 [[I2_EXT]], 5
   // CHECK-NEXT:   [[IDX2:%.*]] = add i64 [[IDX1]], [[I_EXT]]
   // CHECK-NEXT:   [[MAT_ADDR:%.*]] = bitcast [25 x i32]* %result to <25 x i32>*
+  // CHECK-NEXT:   [[CMP:%.*]] = icmp ult i64 [[IDX2]], 25
+  // CHECK-NEXT:   call void @llvm.assume(i1 [[CMP]])
   // CHECK-NEXT:   [[MAT:%.*]] = load <25 x i32>, <25 x i32>* [[MAT_ADDR]], align 4
   // CHECK-NEXT:   [[MATINS:%.*]] = insertelement <25 x i32> [[MAT]], i32 1, i64 [[IDX2]]
   // CHECK-NEXT:   store <25 x i32> [[MATINS]], <25 x i32>* [[MAT_ADDR]], align 4
Index: clang/test/CodeGen/matrix-type-operators.c
===================================================================
--- clang/test/CodeGen/matrix-type-operators.c
+++ clang/test/CodeGen/matrix-type-operators.c
@@ -874,6 +874,8 @@
   // CHECK-NEXT:    [[K_EXT:%.*]] = zext i32 [[K]] to i64
   // CHECK-NEXT:    [[IDX1:%.*]] = mul i64 [[K_EXT]], 2
   // CHECK-NEXT:    [[IDX2:%.*]] = add i64 [[IDX1]], [[J_EXT]]
+  // CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[IDX2]], 6
+  // CHECK-NEXT:    call void @llvm.assume(i1 [[CMP]])
   // CHECK-NEXT:    [[MAT:%.*]] = load <6 x float>, <6 x float>* [[MAT_ADDR:%.*]], align 4
   // CHECK-NEXT:    [[MATINS:%.*]] = insertelement <6 x float> [[MAT]], float [[E]], i64 [[IDX2]]
   // CHECK-NEXT:    store <6 x float> [[MATINS]], <6 x float>* [[MAT_ADDR]], align 4
@@ -890,6 +892,8 @@
   // CHECK-NEXT:    [[K:%.*]] = load i64, i64* %k.addr, align 8
   // CHECK-NEXT:    [[IDX1:%.*]] = mul i64 [[K]], 2
   // CHECK-NEXT:    [[IDX2:%.*]] = add i64 [[IDX1]], [[J_EXT]]
+  // CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[IDX2]], 6
+  // CHECK-NEXT:    call void @llvm.assume(i1 [[CMP]])
   // CHECK-NEXT:    [[MAT:%.*]] = load <6 x float>, <6 x float>* [[MAT_ADDR:%.*]], align 4
   // CHECK-NEXT:    [[MATINS:%.*]] = insertelement <6 x float> [[MAT]], float [[E]], i64 [[IDX2]]
   // CHECK-NEXT:    store <6 x float> [[MATINS]], <6 x float>* [[MAT_ADDR]], align 4
@@ -907,6 +911,8 @@
   // CHECK-NEXT:    [[I2_ADD:%.*]] = add nsw i32 4, [[I2]]
   // CHECK-NEXT:    [[ADD_EXT:%.*]] = sext i32 [[I2_ADD]] to i64
   // CHECK-NEXT:    [[IDX2:%.*]] = add i64 18, [[ADD_EXT]]
+  // CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[IDX2]], 27
+  // CHECK-NEXT:    call void @llvm.assume(i1 [[CMP]])
   // CHECK-NEXT:    [[MAT:%.*]] = load <27 x i32>, <27 x i32>* [[MAT_ADDR:%.*]], align 4
   // CHECK-NEXT:    [[MATINS:%.*]] = insertelement <27 x i32> [[MAT]], i32 [[I1]], i64 [[IDX2]]
   // CHECK-NEXT:    store <27 x i32> [[MATINS]], <27 x i32>* [[MAT_ADDR]], align 4
@@ -980,9 +986,11 @@
   // CHECK-LABEL: @extract_int(
   // CHECK:         [[J1:%.*]] = load i64, i64* %j.addr, align 8
   // CHECK-NEXT:    [[J2:%.*]] = load i64, i64* %j.addr, align 8
-  // CHECK-NEXT:    [[MAT:%.*]] = load <27 x i32>, <27 x i32>* {{.*}}, align 4
   // CHECK-NEXT:    [[IDX1:%.*]] = mul i64 [[J2]], 9
   // CHECK-NEXT:    [[IDX2:%.*]] = add i64 [[IDX1]], [[J1]]
+  // CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[IDX2]], 27
+  // CHECK-NEXT:    call void @llvm.assume(i1 [[CMP]])
+  // CHECK-NEXT:    [[MAT:%.*]] = load <27 x i32>, <27 x i32>* {{.*}}, align 4
   // CHECK-NEXT:    [[MATEXT:%.*]] = extractelement <27 x i32> [[MAT]], i64 [[IDX2]]
   // CHECK-NEXT:    ret i32 [[MATEXT]]
 
@@ -995,13 +1003,15 @@
   // CHECK-LABEL: @test_extract_matrix_pointer1(
   // CHECK:         [[J:%.*]] = load i32, i32* %j.addr, align 4
   // CHECK-NEXT:    [[J_EXT:%.*]] = zext i32 [[J]] to i64
+  // CHECK-NEXT:    [[IDX:%.*]] = add i64 3, [[J_EXT]]
+  // CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[IDX]], 6
+  // CHECK-NEXT:    call void @llvm.assume(i1 [[CMP]])
   // CHECK-NEXT:    [[PTR:%.*]] = load [6 x double]**, [6 x double]*** %ptr.addr, align 8
   // CHECK-NEXT:    [[PTR_IDX:%.*]] = getelementptr inbounds [6 x double]*, [6 x double]** [[PTR]], i64 1
   // CHECK-NEXT:    [[PTR2:%.*]] = load [6 x double]*, [6 x double]** [[PTR_IDX]], align 8
   // CHECK-NEXT:    [[PTR2_IDX:%.*]] = getelementptr inbounds [6 x double], [6 x double]* [[PTR2]], i64 2
   // CHECK-NEXT:    [[MAT_ADDR:%.*]] = bitcast [6 x double]* [[PTR2_IDX]] to <6 x double>*
   // CHECK-NEXT:    [[MAT:%.*]] = load <6 x double>, <6 x double>* [[MAT_ADDR]], align 8
-  // CHECK-NEXT:    [[IDX:%.*]] = add i64 3, [[J_EXT]]
   // CHECK-NEXT:    [[MATEXT:%.*]] = extractelement <6 x double> [[MAT]], i64 [[IDX]]
   // CHECK-NEXT:    ret double [[MATEXT]]
 
@@ -1027,13 +1037,17 @@
   // CHECK-LABEL: @insert_extract(
   // CHECK:         [[K:%.*]] = load i16, i16* %k.addr, align 2
   // CHECK-NEXT:    [[K_EXT:%.*]] = sext i16 [[K]] to i64
-  // CHECK-NEXT:    [[MAT:%.*]] = load <9 x float>, <9 x float>* [[MAT_ADDR:%.*]], align 4
   // CHECK-NEXT:    [[IDX1:%.*]] = mul i64 [[K_EXT]], 3
   // CHECK-NEXT:    [[IDX2:%.*]] = add i64 [[IDX1]], 0
-  // CHECK-NEXT:    [[MATEXT:%.*]] = extractelement <9 x float> [[MAT]], i64 [[IDX]]
+  // CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[IDX2]], 9
+  // CHECK-NEXT:    call void @llvm.assume(i1 [[CMP]])
+  // CHECK-NEXT:    [[MAT:%.*]] = load <9 x float>, <9 x float>* [[MAT_ADDR:%.*]], align 4
+  // CHECK-NEXT:    [[MATEXT:%.*]] = extractelement <9 x float> [[MAT]], i64 [[IDX2]]
   // CHECK-NEXT:    [[J:%.*]] = load i64, i64* %j.addr, align 8
   // CHECK-NEXT:    [[IDX3:%.*]] = mul i64 [[J]], 3
   // CHECK-NEXT:    [[IDX4:%.*]] = add i64 [[IDX3]], 2
+  // CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[IDX4]], 9
+  // CHECK-NEXT:    call void @llvm.assume(i1 [[CMP]])
   // CHECK-NEXT:    [[MAT2:%.*]] = load <9 x float>, <9 x float>* [[MAT_ADDR]], align 4
   // CHECK-NEXT:    [[MATINS:%.*]] = insertelement <9 x float> [[MAT2]], float [[MATEXT]], i64 [[IDX4]]
   // CHECK-NEXT:    store <9 x float> [[MATINS]], <9 x float>* [[MAT_ADDR]], align 4
@@ -1068,9 +1082,13 @@
   // CHECK-NEXT:    [[IDX1:%.*]] = mul i64 [[J_EXT]], 2
   // CHECK-NEXT:    [[IDX2:%.*]] = add i64 [[IDX1]], [[I_EXT]]
   // CHECK-NEXT:    [[MAT_PTR:%.*]] = bitcast [6 x float]* %mat to <6 x float>*
+  // CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[IDX2]], 6
+  // CHECK-NEXT:    call void @llvm.assume(i1 [[CMP]])
   // CHECK-NEXT:    [[MAT:%.*]] = load <6 x float>, <6 x float>* [[MAT_PTR]], align 4
   // CHECK-NEXT:    [[EXT:%.*]] = extractelement <6 x float> [[MAT]], i64 [[IDX2]]
   // CHECK-NEXT:    [[SUM:%.*]] = fadd float [[EXT]], {{.*}}
+  // CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[IDX2]], 6
+  // CHECK-NEXT:    call void @llvm.assume(i1 [[CMP]])
   // CHECK-NEXT:    [[MAT2:%.*]] = load <6 x float>, <6 x float>* [[MAT_PTR]], align 4
   // CHECK-NEXT:    [[INS:%.*]] = insertelement <6 x float> [[MAT2]], float [[SUM]], i64 [[IDX2]]
   // CHECK-NEXT:    store <6 x float> [[INS]], <6 x float>* [[MAT_PTR]], align 4
@@ -1085,23 +1103,29 @@
   // CHECK-NEXT:  [[I1_EXT:%.*]] = sext i32 [[I1]] to i64
   // CHECK-NEXT:  [[J1:%.*]] = load i32, i32* %j.addr, align 4
   // CHECK-NEXT:  [[J1_EXT:%.*]] = sext i32 [[J1]] to i64
-  // CHECK-NEXT:  [[A:%.*]] = load <27 x i32>, <27 x i32>* %0, align 4
   // CHECK-NEXT:  [[IDX1_1:%.*]] = mul i64 [[J1_EXT]], 9
   // CHECK-NEXT:  [[IDX1_2:%.*]] = add i64 [[IDX1_1]], [[I1_EXT]]
+  // CHECK-NEXT:  [[CMP:%.*]] = icmp ult i64 [[IDX1_2]], 27
+  // CHECK-NEXT:  call void @llvm.assume(i1 [[CMP]])
+  // CHECK-NEXT:  [[A:%.*]] = load <27 x i32>, <27 x i32>* %0, align 4
   // CHECK-NEXT:  [[MI1:%.*]] = extractelement <27 x i32> [[A]], i64 [[IDX1_2]]
   // CHECK-NEXT:  [[MI1_EXT:%.*]] = sext i32 [[MI1]] to i64
   // CHECK-NEXT:  [[J2:%.*]] = load i32, i32* %j.addr, align 4
   // CHECK-NEXT:  [[J2_EXT:%.*]] = sext i32 [[J2]] to i64
   // CHECK-NEXT:  [[I2:%.*]] = load i32, i32* %i.addr, align 4
   // CHECK-NEXT:  [[I2_EXT:%.*]] = sext i32 [[I2]] to i64
-  // CHECK-NEXT:  [[A2:%.*]] = load <27 x i32>, <27 x i32>* {{.*}}, align 4
   // CHECK-NEXT:  [[IDX2_1:%.*]] = mul i64 [[I2_EXT]], 9
   // CHECK-NEXT:  [[IDX2_2:%.*]] = add i64 [[IDX2_1]], [[J2_EXT]]
+  // CHECK-NEXT:  [[CMP:%.*]] = icmp ult i64 [[IDX2_2]], 27
+  // CHECK-NEXT:  call void @llvm.assume(i1 [[CMP]])
+  // CHECK-NEXT:  [[A2:%.*]] = load <27 x i32>, <27 x i32>* {{.*}}, align 4
   // CHECK-NEXT:  [[MI2:%.*]] = extractelement <27 x i32> [[A2]], i64 [[IDX2_2]]
   // CHECK-NEXT:  [[MI3:%.*]] = add nsw i32 [[MI2]], 2
   // CHECK-NEXT:  [[MI3_EXT:%.*]] = sext i32 [[MI3]] to i64
   // CHECK-NEXT:  [[IDX3_1:%.*]] = mul i64 [[MI3_EXT]], 5
   // CHECK-NEXT:  [[IDX3_2:%.*]] = add i64 [[IDX3_1]], [[MI1_EXT]]
+  // CHECK-NEXT:  [[CMP:%.*]] = icmp ult i64 [[IDX3_2]], 25
+  // CHECK-NEXT:  call void @llvm.assume(i1 [[CMP]])
   // CHECK-NEXT:  [[B:%.*]] = load <25 x double>, <25 x double>* [[B_PTR:%.*]], align 8
   // CHECK-NEXT:  [[INS:%.*]] = insertelement <25 x double> [[B]], double 1.500000e+00, i64 [[IDX3_2]]
   // CHECK-NEXT:  store <25 x double> [[INS]], <25 x double>* [[B_PTR]], align 8
Index: clang/lib/CodeGen/CGExprScalar.cpp
===================================================================
--- clang/lib/CodeGen/CGExprScalar.cpp
+++ clang/lib/CodeGen/CGExprScalar.cpp
@@ -1754,13 +1754,17 @@
   // integer value.
   Value *RowIdx = Visit(E->getRowIdx());
   Value *ColumnIdx = Visit(E->getColumnIdx());
+
+  auto *MatrixTy = E->getBase()->getType()->castAs<ConstantMatrixType>();
+  unsigned NumRows = MatrixTy->getNumRows();
+  llvm::MatrixBuilder<CGBuilderTy> MB(Builder);
+  Value *Idx = MB.CreateIndex(RowIdx, ColumnIdx, NumRows);
+  MB.CreateIndexAssumption(Idx, MatrixTy->getNumElementsFlattened());
+
   Value *Matrix = Visit(E->getBase());
 
   // TODO: Should we emit bounds checks with SanitizerKind::ArrayBounds?
-  llvm::MatrixBuilder<CGBuilderTy> MB(Builder);
-  return MB.CreateExtractElement(
-      Matrix, RowIdx, ColumnIdx,
-      E->getBase()->getType()->castAs<ConstantMatrixType>()->getNumRows());
+  return Builder.CreateExtractElement(Matrix, Idx, "matrixext");
 }
 
 static int getMaskElt(llvm::ShuffleVectorInst *SVI, unsigned Idx,
Index: clang/lib/CodeGen/CGExpr.cpp
===================================================================
--- clang/lib/CodeGen/CGExpr.cpp
+++ clang/lib/CodeGen/CGExpr.cpp
@@ -35,6 +35,7 @@
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/MDBuilder.h"
+#include "llvm/IR/MatrixBuilder.h"
 #include "llvm/Support/ConvertUTF.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/Path.h"
@@ -1939,6 +1940,10 @@
     return EmitLoadOfGlobalRegLValue(LV);
 
   if (LV.isMatrixElt()) {
+    llvm::Value *Idx = LV.getMatrixIdx();
+    auto MatTy = LV.getType()->getAs<ConstantMatrixType>();
+    llvm::MatrixBuilder<CGBuilderTy> MB(Builder);
+    MB.CreateIndexAssumption(Idx, MatTy->getNumElementsFlattened());
     llvm::LoadInst *Load =
         Builder.CreateLoad(LV.getMatrixAddress(), LV.isVolatileQualified());
     return RValue::get(
@@ -2080,9 +2085,13 @@
       return EmitStoreThroughGlobalRegLValue(Src, Dst);
 
     if (Dst.isMatrixElt()) {
-      llvm::Value *Vec = Builder.CreateLoad(Dst.getMatrixAddress());
-      Vec = Builder.CreateInsertElement(Vec, Src.getScalarVal(),
-                                        Dst.getMatrixIdx(), "matins");
+      llvm::Value *Idx = Dst.getMatrixIdx();
+      auto MatTy = Dst.getType()->getAs<ConstantMatrixType>();
+      llvm::MatrixBuilder<CGBuilderTy> MB(Builder);
+      MB.CreateIndexAssumption(Idx, MatTy->getNumElementsFlattened());
+      llvm::Instruction *Load = Builder.CreateLoad(Dst.getMatrixAddress());
+      llvm::Value *Vec =
+          Builder.CreateInsertElement(Load, Src.getScalarVal(), Idx, "matins");
       Builder.CreateStore(Vec, Dst.getMatrixAddress(),
                           Dst.isVolatileQualified());
       return;
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to