================ @@ -121,12 +137,96 @@ static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) { llvm_unreachable("No terminator in the entry block!"); } -static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) { +class ShapeCalculator { +private: + TargetMachine *TM = nullptr; + + // In AMX intrinsics we let Shape = {Row, Col}, but the + // RealCol = Col / ElementSize. We may use the RealCol + // as a new Row for other new created AMX intrinsics. + std::map<Value *, Value *> Col2Row, Row2Col; + +public: + ShapeCalculator(TargetMachine *TargetM) : TM(TargetM) {} + std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo); + std::pair<Value *, Value *> getShape(PHINode *Phi); + Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity); + Value *getColFromRow(Instruction *II, Value *V, unsigned Granularity); +}; + +Value *ShapeCalculator::getRowFromCol(Instruction *II, Value *V, + unsigned Granularity) { + if (Col2Row.count(V)) + return Col2Row[V]; + IRBuilder<> Builder(II); + Value *RealRow = nullptr; + if (isa<ConstantInt>(V)) + RealRow = + Builder.getInt16((cast<ConstantInt>(V)->getSExtValue()) / Granularity); + else if (isa<Instruction>(V)) { + // When it is not a const value and it is not a function argument, we + // create Row after the definition of V instead of + // before II. For example, II is %118, we try to getshape for %117: + // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x + // i32> %115). + // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16 + // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx + // %117). + // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its + // definition is after its user(new tileload for %117). + // So, the best choice is to create %row right after the definition of + // %106. + Builder.SetInsertPoint(cast<Instruction>(V)); + RealRow = Builder.CreateUDiv(V, Builder.getInt16(4)); + cast<Instruction>(RealRow)->moveAfter(cast<Instruction>(V)); + } else { + // When it is not a const value and it is a function argument, we create + // Row at the entry bb. + IRBuilder<> NewBuilder( + getFirstNonAllocaInTheEntryBlock(*II->getFunction())); + RealRow = NewBuilder.CreateUDiv(V, NewBuilder.getInt16(Granularity)); + } + Col2Row[V] = RealRow; + return RealRow; +} + +Value *ShapeCalculator::getColFromRow(Instruction *II, Value *V, + unsigned Granularity) { + if (Row2Col.count(V)) + return Row2Col[V]; + IRBuilder<> Builder(II); + Value *RealCol = nullptr; + if (isa<ConstantInt>(V)) + RealCol = + Builder.getInt16((cast<ConstantInt>(V)->getSExtValue()) * Granularity); + else if (isa<Instruction>(V)) { + Builder.SetInsertPoint(cast<Instruction>(V)); + RealCol = Builder.CreateNUWMul(V, Builder.getInt16(Granularity)); + cast<Instruction>(RealCol)->moveAfter(cast<Instruction>(V)); + } else { + // When it is not a const value and it is a function argument, we create + // Row at the entry bb. ---------------- phoebewang wrote:
Row is correct. https://github.com/llvm/llvm-project/pull/113532 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits