fhahn created this revision. fhahn added reviewers: rjmccall, anemet, rsmith, erichkeane. Herald added subscribers: dexonsmith, tschuett. fhahn requested review of this revision. Herald added projects: clang, LLVM. Herald added a subscriber: llvm-commits.
The matrix extension requires the indices for matrix subscript expression to be valid and it is UB otherwise. extract/insertelement produce poison if the index is invalid, which limits the optimizer to not be bale to scalarize load/extract pairs for example, which causes very suboptimal code to be generated when using matrix subscript expressions with variable indices for large matrixes. This patch updates IRGen to emit assumes to for index expression to convey the information that the index must be valid. This also adjusts the order in which operations are emitted slightly, so indices & assumes are added before the load of the matrix value. Repository: rG LLVM Github Monorepo https://reviews.llvm.org/D102478 Files: clang/lib/CodeGen/CGExpr.cpp clang/lib/CodeGen/CGExprScalar.cpp clang/test/CodeGen/matrix-type-operators.c clang/test/CodeGenCXX/matrix-type-operators.cpp clang/test/CodeGenObjC/matrix-type-operators.m llvm/include/llvm/IR/MatrixBuilder.h
Index: llvm/include/llvm/IR/MatrixBuilder.h =================================================================== --- llvm/include/llvm/IR/MatrixBuilder.h +++ llvm/include/llvm/IR/MatrixBuilder.h @@ -231,9 +231,25 @@ : (IsUnsigned ? B.CreateUDiv(LHS, RHS) : B.CreateSDiv(LHS, RHS)); } - /// Extracts the element at (\p RowIdx, \p ColumnIdx) from \p Matrix. - Value *CreateExtractElement(Value *Matrix, Value *RowIdx, Value *ColumnIdx, - unsigned NumRows, Twine const &Name = "") { + /// Create an assumption that \p Idx is less than \p NumElements. + void CreateIndexAssumption(Value *Idx, unsigned NumElements, + Twine const &Name = "") { + + Value *NumElts = + B.getIntN(Idx->getType()->getScalarSizeInBits(), NumElements); + + auto *Cmp = B.CreateICmpULT(Idx, NumElts); + if (!isa<llvm::ConstantInt>(Cmp)) { + Function *TheFn = + Intrinsic::getDeclaration(getModule(), Intrinsic::assume); + B.CreateCall(TheFn->getFunctionType(), TheFn, {Cmp}, Name); + } + } + + /// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from + /// a matrix with \p NumRows embedded in a vector. + Value *CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows, + Twine const &Name = "") { unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(), ColumnIdx->getType()->getScalarSizeInBits()); @@ -241,9 +257,7 @@ RowIdx = B.CreateZExt(RowIdx, IntTy); ColumnIdx = B.CreateZExt(ColumnIdx, IntTy); Value *NumRowsV = B.getIntN(MaxWidth, NumRows); - return B.CreateExtractElement( - Matrix, B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx), - "matext"); + return B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx); } }; Index: clang/test/CodeGenObjC/matrix-type-operators.m =================================================================== --- clang/test/CodeGenObjC/matrix-type-operators.m +++ clang/test/CodeGenObjC/matrix-type-operators.m @@ -22,9 +22,11 @@ // CHECK-NEXT: [[IV2_PTR:%.*]] = bitcast %0* [[IV2]] to i8* // CHECK-NEXT: [[CALL1:%.*]] = call i32 bitcast (i8* (i8*, i8*, ...)* @objc_msgSend to i32 (i8*, i8*)*)(i8* [[IV2_PTR]], i8* [[SEL2]]) // CHECK-NEXT: [[CONV2:%.*]] = sext i32 [[CALL1]] to i64 -// CHECK-NEXT: [[MAT:%.*]] = load <16 x double>, <16 x double>* {{.*}} align 8 // CHECK-NEXT: [[IDX1:%.*]] = mul i64 [[CONV2]], 4 // CHECK-NEXT: [[IDX2:%.*]] = add i64 [[IDX1]], [[CONV]] +// CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[IDX2]], 16 +// CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) +// CHECK-NEXT: [[MAT:%.*]] = load <16 x double>, <16 x double>* {{.*}} align 8 // CHECK-NEXT: [[MATEXT:%.*]] = extractelement <16 x double> [[MAT]], i64 [[IDX2]] // CHECK-NEXT: ret double [[MATEXT]] // @@ -49,12 +51,14 @@ // CHECK-NEXT: [[IV2_PTR:%.*]] = bitcast %0* [[IV2]] to i8* // CHECK-NEXT: [[CALL1:%.*]] = call i32 bitcast (i8* (i8*, i8*, ...)* @objc_msgSend to i32 (i8*, i8*)*)(i8* [[IV2_PTR]], i8* [[SEL2]]) // CHECK-NEXT: [[CONV2:%.*]] = sext i32 [[CALL1]] to i64 +// CHECK-NEXT: [[IDX1:%.*]] = mul i64 [[CONV2]], 4 +// CHECK-NEXT: [[IDX2:%.*]] = add i64 [[IDX1]], [[CONV]] +// CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[IDX2]], 16 +// CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) // CHECK-NEXT: [[M:%.*]] = load %1*, %1** %m.addr, align 8 // CHECK-NEXT: [[SEL3:%.*]] = load i8*, i8** @OBJC_SELECTOR_REFERENCES_, align 8, !invariant.load !7 // CHECK-NEXT: [[M_PTR:%.*]] = bitcast %1* [[M]] to i8* // CHECK-NEXT: [[MAT:%.*]] = call <16 x double> bitcast (i8* (i8*, i8*, ...)* @objc_msgSend to <16 x double> (i8*, i8*)*)(i8* [[M_PTR]], i8* [[SEL3]]) -// CHECK-NEXT: [[IDX1:%.*]] = mul i64 [[CONV2]], 4 -// CHECK-NEXT: [[IDX2:%.*]] = add i64 [[IDX1]], [[CONV]] // CHECK-NEXT: [[MATEXT:%.*]] = extractelement <16 x double> [[MAT]], i64 [[IDX2]] // CHECK-NEXT: ret double [[MATEXT]] // Index: clang/test/CodeGenCXX/matrix-type-operators.cpp =================================================================== --- clang/test/CodeGenCXX/matrix-type-operators.cpp +++ clang/test/CodeGenCXX/matrix-type-operators.cpp @@ -219,6 +219,8 @@ // CHECK-NEXT: [[IDX1:%.*]] = mul i64 [[J_EXT]], 2 // CHECK-NEXT: [[IDX2:%.*]] = add i64 [[IDX1]], [[I_EXT]] // CHECK-NEXT: [[MAT_ADDR:%.*]] = bitcast [4 x i32]* {{.*}} to <4 x i32>* + // CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[IDX2]], 4 + // CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) // CHECK-NEXT: [[MAT:%.*]] = load <4 x i32>, <4 x i32>* [[MAT_ADDR]], align 4 // CHECK-NEXT: [[MATINS:%.*]] = insertelement <4 x i32> [[MAT]], i32 [[E]], i64 [[IDX2]] // CHECK-NEXT: store <4 x i32> [[MATINS]], <4 x i32>* [[MAT_ADDR]], align 4 @@ -243,6 +245,8 @@ // CHECK-NEXT: [[IDX1:%.*]] = mul i64 [[J_EXT]], 3 // CHECK-NEXT: [[IDX2:%.*]] = add i64 [[IDX1]], [[I_EXT]] // CHECK-NEXT: [[MAT_ADDR:%.*]] = bitcast [24 x float]* {{.*}} to <24 x float>* + // CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[IDX2]], 24 + // CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) // CHECK-NEXT: [[MAT:%.*]] = load <24 x float>, <24 x float>* [[MAT_ADDR]], align 4 // CHECK-NEXT: [[MATINS:%.*]] = insertelement <24 x float> [[MAT]], float [[E]], i64 [[IDX2]] // CHECK-NEXT: store <24 x float> [[MATINS]], <24 x float>* [[MAT_ADDR]], align 4 @@ -315,11 +319,13 @@ // CHECK-NEXT: [[J:%.*]] = call i32 @_ZN15UnsignedWrappercvjEv(%struct.UnsignedWrapper* {{[^,]*}} %j) // CHECK-NEXT: [[J_SUB:%.*]] = sub i32 [[J]], 1 // CHECK-NEXT: [[J_SUB_EXT:%.*]] = zext i32 [[J_SUB]] to i64 + // CHECK-NEXT: [[IDX1:%.*]] = mul i64 [[J_SUB_EXT]], 4 + // CHECK-NEXT: [[IDX2:%.*]] = add i64 [[IDX1]], [[I_ADD_EXT]] + // CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[IDX2]], 16 + // CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) // CHECK-NEXT: [[MAT_ADDR:%.*]] = load [16 x double]*, [16 x double]** %m.addr, align 8 // CHECK-NEXT: [[MAT_ADDR2:%.*]] = bitcast [16 x double]* [[MAT_ADDR]] to <16 x double>* // CHECK-NEXT: [[MAT:%.*]] = load <16 x double>, <16 x double>* [[MAT_ADDR2]], align 8 - // CHECK-NEXT: [[IDX1:%.*]] = mul i64 [[J_SUB_EXT]], 4 - // CHECK-NEXT: [[IDX2:%.*]] = add i64 [[IDX1]], [[I_ADD_EXT]] // CHECK-NEXT: [[MATEXT:%.*]] = extractelement <16 x double> [[MAT]], i64 [[IDX2]] // CHECK-NEXT: ret double [[MATEXT]] return m[i + 1][j - 1]; @@ -358,6 +364,8 @@ // CHECK-NEXT: [[IDX1:%.*]] = mul i64 [[I2_EXT]], 4 // CHECK-NEXT: [[IDX2:%.*]] = add i64 [[IDX1]], [[I_EXT]] // CHECK-NEXT: [[MAT_ADDR:%.*]] = bitcast [16 x float]* %result to <16 x float>* + // CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[IDX2]], 16 + // CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) // CHECK-NEXT: [[MAT:%.*]] = load <16 x float>, <16 x float>* [[MAT_ADDR]], align 4 // CHECK-NEXT: [[MATINS:%.*]] = insertelement <16 x float> [[MAT]], float 1.000000e+00, i64 [[IDX2]] // CHECK-NEXT: store <16 x float> [[MATINS]], <16 x float>* [[MAT_ADDR]], align 4 @@ -386,6 +394,8 @@ // CHECK-NEXT: [[IDX1:%.*]] = mul i64 [[I2_EXT]], 5 // CHECK-NEXT: [[IDX2:%.*]] = add i64 [[IDX1]], [[I_EXT]] // CHECK-NEXT: [[MAT_ADDR:%.*]] = bitcast [25 x i32]* %result to <25 x i32>* + // CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[IDX2]], 25 + // CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) // CHECK-NEXT: [[MAT:%.*]] = load <25 x i32>, <25 x i32>* [[MAT_ADDR]], align 4 // CHECK-NEXT: [[MATINS:%.*]] = insertelement <25 x i32> [[MAT]], i32 1, i64 [[IDX2]] // CHECK-NEXT: store <25 x i32> [[MATINS]], <25 x i32>* [[MAT_ADDR]], align 4 Index: clang/test/CodeGen/matrix-type-operators.c =================================================================== --- clang/test/CodeGen/matrix-type-operators.c +++ clang/test/CodeGen/matrix-type-operators.c @@ -874,6 +874,8 @@ // CHECK-NEXT: [[K_EXT:%.*]] = zext i32 [[K]] to i64 // CHECK-NEXT: [[IDX1:%.*]] = mul i64 [[K_EXT]], 2 // CHECK-NEXT: [[IDX2:%.*]] = add i64 [[IDX1]], [[J_EXT]] + // CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[IDX2]], 6 + // CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) // CHECK-NEXT: [[MAT:%.*]] = load <6 x float>, <6 x float>* [[MAT_ADDR:%.*]], align 4 // CHECK-NEXT: [[MATINS:%.*]] = insertelement <6 x float> [[MAT]], float [[E]], i64 [[IDX2]] // CHECK-NEXT: store <6 x float> [[MATINS]], <6 x float>* [[MAT_ADDR]], align 4 @@ -890,6 +892,8 @@ // CHECK-NEXT: [[K:%.*]] = load i64, i64* %k.addr, align 8 // CHECK-NEXT: [[IDX1:%.*]] = mul i64 [[K]], 2 // CHECK-NEXT: [[IDX2:%.*]] = add i64 [[IDX1]], [[J_EXT]] + // CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[IDX2]], 6 + // CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) // CHECK-NEXT: [[MAT:%.*]] = load <6 x float>, <6 x float>* [[MAT_ADDR:%.*]], align 4 // CHECK-NEXT: [[MATINS:%.*]] = insertelement <6 x float> [[MAT]], float [[E]], i64 [[IDX2]] // CHECK-NEXT: store <6 x float> [[MATINS]], <6 x float>* [[MAT_ADDR]], align 4 @@ -907,6 +911,8 @@ // CHECK-NEXT: [[I2_ADD:%.*]] = add nsw i32 4, [[I2]] // CHECK-NEXT: [[ADD_EXT:%.*]] = sext i32 [[I2_ADD]] to i64 // CHECK-NEXT: [[IDX2:%.*]] = add i64 18, [[ADD_EXT]] + // CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[IDX2]], 27 + // CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) // CHECK-NEXT: [[MAT:%.*]] = load <27 x i32>, <27 x i32>* [[MAT_ADDR:%.*]], align 4 // CHECK-NEXT: [[MATINS:%.*]] = insertelement <27 x i32> [[MAT]], i32 [[I1]], i64 [[IDX2]] // CHECK-NEXT: store <27 x i32> [[MATINS]], <27 x i32>* [[MAT_ADDR]], align 4 @@ -980,9 +986,11 @@ // CHECK-LABEL: @extract_int( // CHECK: [[J1:%.*]] = load i64, i64* %j.addr, align 8 // CHECK-NEXT: [[J2:%.*]] = load i64, i64* %j.addr, align 8 - // CHECK-NEXT: [[MAT:%.*]] = load <27 x i32>, <27 x i32>* {{.*}}, align 4 // CHECK-NEXT: [[IDX1:%.*]] = mul i64 [[J2]], 9 // CHECK-NEXT: [[IDX2:%.*]] = add i64 [[IDX1]], [[J1]] + // CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[IDX2]], 27 + // CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) + // CHECK-NEXT: [[MAT:%.*]] = load <27 x i32>, <27 x i32>* {{.*}}, align 4 // CHECK-NEXT: [[MATEXT:%.*]] = extractelement <27 x i32> [[MAT]], i64 [[IDX2]] // CHECK-NEXT: ret i32 [[MATEXT]] @@ -995,13 +1003,15 @@ // CHECK-LABEL: @test_extract_matrix_pointer1( // CHECK: [[J:%.*]] = load i32, i32* %j.addr, align 4 // CHECK-NEXT: [[J_EXT:%.*]] = zext i32 [[J]] to i64 + // CHECK-NEXT: [[IDX:%.*]] = add i64 3, [[J_EXT]] + // CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[IDX]], 6 + // CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) // CHECK-NEXT: [[PTR:%.*]] = load [6 x double]**, [6 x double]*** %ptr.addr, align 8 // CHECK-NEXT: [[PTR_IDX:%.*]] = getelementptr inbounds [6 x double]*, [6 x double]** [[PTR]], i64 1 // CHECK-NEXT: [[PTR2:%.*]] = load [6 x double]*, [6 x double]** [[PTR_IDX]], align 8 // CHECK-NEXT: [[PTR2_IDX:%.*]] = getelementptr inbounds [6 x double], [6 x double]* [[PTR2]], i64 2 // CHECK-NEXT: [[MAT_ADDR:%.*]] = bitcast [6 x double]* [[PTR2_IDX]] to <6 x double>* // CHECK-NEXT: [[MAT:%.*]] = load <6 x double>, <6 x double>* [[MAT_ADDR]], align 8 - // CHECK-NEXT: [[IDX:%.*]] = add i64 3, [[J_EXT]] // CHECK-NEXT: [[MATEXT:%.*]] = extractelement <6 x double> [[MAT]], i64 [[IDX]] // CHECK-NEXT: ret double [[MATEXT]] @@ -1027,13 +1037,17 @@ // CHECK-LABEL: @insert_extract( // CHECK: [[K:%.*]] = load i16, i16* %k.addr, align 2 // CHECK-NEXT: [[K_EXT:%.*]] = sext i16 [[K]] to i64 - // CHECK-NEXT: [[MAT:%.*]] = load <9 x float>, <9 x float>* [[MAT_ADDR:%.*]], align 4 // CHECK-NEXT: [[IDX1:%.*]] = mul i64 [[K_EXT]], 3 // CHECK-NEXT: [[IDX2:%.*]] = add i64 [[IDX1]], 0 - // CHECK-NEXT: [[MATEXT:%.*]] = extractelement <9 x float> [[MAT]], i64 [[IDX]] + // CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[IDX2]], 9 + // CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) + // CHECK-NEXT: [[MAT:%.*]] = load <9 x float>, <9 x float>* [[MAT_ADDR:%.*]], align 4 + // CHECK-NEXT: [[MATEXT:%.*]] = extractelement <9 x float> [[MAT]], i64 [[IDX2]] // CHECK-NEXT: [[J:%.*]] = load i64, i64* %j.addr, align 8 // CHECK-NEXT: [[IDX3:%.*]] = mul i64 [[J]], 3 // CHECK-NEXT: [[IDX4:%.*]] = add i64 [[IDX3]], 2 + // CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[IDX4]], 9 + // CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) // CHECK-NEXT: [[MAT2:%.*]] = load <9 x float>, <9 x float>* [[MAT_ADDR]], align 4 // CHECK-NEXT: [[MATINS:%.*]] = insertelement <9 x float> [[MAT2]], float [[MATEXT]], i64 [[IDX4]] // CHECK-NEXT: store <9 x float> [[MATINS]], <9 x float>* [[MAT_ADDR]], align 4 @@ -1068,9 +1082,13 @@ // CHECK-NEXT: [[IDX1:%.*]] = mul i64 [[J_EXT]], 2 // CHECK-NEXT: [[IDX2:%.*]] = add i64 [[IDX1]], [[I_EXT]] // CHECK-NEXT: [[MAT_PTR:%.*]] = bitcast [6 x float]* %mat to <6 x float>* + // CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[IDX2]], 6 + // CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) // CHECK-NEXT: [[MAT:%.*]] = load <6 x float>, <6 x float>* [[MAT_PTR]], align 4 // CHECK-NEXT: [[EXT:%.*]] = extractelement <6 x float> [[MAT]], i64 [[IDX2]] // CHECK-NEXT: [[SUM:%.*]] = fadd float [[EXT]], {{.*}} + // CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[IDX2]], 6 + // CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) // CHECK-NEXT: [[MAT2:%.*]] = load <6 x float>, <6 x float>* [[MAT_PTR]], align 4 // CHECK-NEXT: [[INS:%.*]] = insertelement <6 x float> [[MAT2]], float [[SUM]], i64 [[IDX2]] // CHECK-NEXT: store <6 x float> [[INS]], <6 x float>* [[MAT_PTR]], align 4 @@ -1085,23 +1103,29 @@ // CHECK-NEXT: [[I1_EXT:%.*]] = sext i32 [[I1]] to i64 // CHECK-NEXT: [[J1:%.*]] = load i32, i32* %j.addr, align 4 // CHECK-NEXT: [[J1_EXT:%.*]] = sext i32 [[J1]] to i64 - // CHECK-NEXT: [[A:%.*]] = load <27 x i32>, <27 x i32>* %0, align 4 // CHECK-NEXT: [[IDX1_1:%.*]] = mul i64 [[J1_EXT]], 9 // CHECK-NEXT: [[IDX1_2:%.*]] = add i64 [[IDX1_1]], [[I1_EXT]] + // CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[IDX1_2]], 27 + // CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) + // CHECK-NEXT: [[A:%.*]] = load <27 x i32>, <27 x i32>* %0, align 4 // CHECK-NEXT: [[MI1:%.*]] = extractelement <27 x i32> [[A]], i64 [[IDX1_2]] // CHECK-NEXT: [[MI1_EXT:%.*]] = sext i32 [[MI1]] to i64 // CHECK-NEXT: [[J2:%.*]] = load i32, i32* %j.addr, align 4 // CHECK-NEXT: [[J2_EXT:%.*]] = sext i32 [[J2]] to i64 // CHECK-NEXT: [[I2:%.*]] = load i32, i32* %i.addr, align 4 // CHECK-NEXT: [[I2_EXT:%.*]] = sext i32 [[I2]] to i64 - // CHECK-NEXT: [[A2:%.*]] = load <27 x i32>, <27 x i32>* {{.*}}, align 4 // CHECK-NEXT: [[IDX2_1:%.*]] = mul i64 [[I2_EXT]], 9 // CHECK-NEXT: [[IDX2_2:%.*]] = add i64 [[IDX2_1]], [[J2_EXT]] + // CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[IDX2_2]], 27 + // CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) + // CHECK-NEXT: [[A2:%.*]] = load <27 x i32>, <27 x i32>* {{.*}}, align 4 // CHECK-NEXT: [[MI2:%.*]] = extractelement <27 x i32> [[A2]], i64 [[IDX2_2]] // CHECK-NEXT: [[MI3:%.*]] = add nsw i32 [[MI2]], 2 // CHECK-NEXT: [[MI3_EXT:%.*]] = sext i32 [[MI3]] to i64 // CHECK-NEXT: [[IDX3_1:%.*]] = mul i64 [[MI3_EXT]], 5 // CHECK-NEXT: [[IDX3_2:%.*]] = add i64 [[IDX3_1]], [[MI1_EXT]] + // CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[IDX3_2]], 25 + // CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) // CHECK-NEXT: [[B:%.*]] = load <25 x double>, <25 x double>* [[B_PTR:%.*]], align 8 // CHECK-NEXT: [[INS:%.*]] = insertelement <25 x double> [[B]], double 1.500000e+00, i64 [[IDX3_2]] // CHECK-NEXT: store <25 x double> [[INS]], <25 x double>* [[B_PTR]], align 8 Index: clang/lib/CodeGen/CGExprScalar.cpp =================================================================== --- clang/lib/CodeGen/CGExprScalar.cpp +++ clang/lib/CodeGen/CGExprScalar.cpp @@ -1754,13 +1754,17 @@ // integer value. Value *RowIdx = Visit(E->getRowIdx()); Value *ColumnIdx = Visit(E->getColumnIdx()); + + auto *MatrixTy = E->getBase()->getType()->castAs<ConstantMatrixType>(); + unsigned NumRows = MatrixTy->getNumRows(); + llvm::MatrixBuilder<CGBuilderTy> MB(Builder); + Value *Idx = MB.CreateIndex(RowIdx, ColumnIdx, NumRows); + MB.CreateIndexAssumption(Idx, MatrixTy->getNumElementsFlattened()); + Value *Matrix = Visit(E->getBase()); // TODO: Should we emit bounds checks with SanitizerKind::ArrayBounds? - llvm::MatrixBuilder<CGBuilderTy> MB(Builder); - return MB.CreateExtractElement( - Matrix, RowIdx, ColumnIdx, - E->getBase()->getType()->castAs<ConstantMatrixType>()->getNumRows()); + return Builder.CreateExtractElement(Matrix, Idx, "matrixext"); } static int getMaskElt(llvm::ShuffleVectorInst *SVI, unsigned Idx, Index: clang/lib/CodeGen/CGExpr.cpp =================================================================== --- clang/lib/CodeGen/CGExpr.cpp +++ clang/lib/CodeGen/CGExpr.cpp @@ -35,6 +35,7 @@ #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" +#include "llvm/IR/MatrixBuilder.h" #include "llvm/Support/ConvertUTF.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/Path.h" @@ -1939,6 +1940,10 @@ return EmitLoadOfGlobalRegLValue(LV); if (LV.isMatrixElt()) { + llvm::Value *Idx = LV.getMatrixIdx(); + auto MatTy = LV.getType()->getAs<ConstantMatrixType>(); + llvm::MatrixBuilder<CGBuilderTy> MB(Builder); + MB.CreateIndexAssumption(Idx, MatTy->getNumElementsFlattened()); llvm::LoadInst *Load = Builder.CreateLoad(LV.getMatrixAddress(), LV.isVolatileQualified()); return RValue::get( @@ -2080,9 +2085,13 @@ return EmitStoreThroughGlobalRegLValue(Src, Dst); if (Dst.isMatrixElt()) { - llvm::Value *Vec = Builder.CreateLoad(Dst.getMatrixAddress()); - Vec = Builder.CreateInsertElement(Vec, Src.getScalarVal(), - Dst.getMatrixIdx(), "matins"); + llvm::Value *Idx = Dst.getMatrixIdx(); + auto MatTy = Dst.getType()->getAs<ConstantMatrixType>(); + llvm::MatrixBuilder<CGBuilderTy> MB(Builder); + MB.CreateIndexAssumption(Idx, MatTy->getNumElementsFlattened()); + llvm::Instruction *Load = Builder.CreateLoad(Dst.getMatrixAddress()); + llvm::Value *Vec = + Builder.CreateInsertElement(Load, Src.getScalarVal(), Idx, "matins"); Builder.CreateStore(Vec, Dst.getMatrixAddress(), Dst.isVolatileQualified()); return;
_______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits