https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/72647
>From 3dfe86782806f048b130d46afa6293712919f672 Mon Sep 17 00:00:00 2001 From: Florian Hahn <f...@fhahn.com> Date: Fri, 14 Apr 2023 14:33:57 +0100 Subject: [PATCH 1/2] [Matrix] Convert column-vector ops feeding dot product to row-vectors. Generalize the logic used to convert column-vector ops to row-vectors to support converting chains of operations. A potential next step is to further generalize this to convert column-vector ops to row-vector ops in general, not just for operands of dot products. Dot-product handling would then be driven by the general conversion, rather than the other way around. --- .../Scalar/LowerMatrixIntrinsics.cpp | 51 ++++++++++++++----- .../LowerMatrixIntrinsics/dot-product-int.ll | 47 ++++------------- 2 files changed, 47 insertions(+), 51 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 72b9db1e73d73d..c6bb43d3a78cf3 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -1332,8 +1332,8 @@ class LowerMatrixIntrinsics { if (!IsIntVec && !FMF.allowReassoc()) return; - auto CanBeFlattened = [this](Value *Op) { - if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) + auto CanBeFlattened = [](Value *Op) { + if (match(Op, m_BinOp())) return true; return match( Op, m_OneUse(m_CombineOr( @@ -1346,6 +1346,9 @@ class LowerMatrixIntrinsics { // the returned cost is < 0, the argument is cheaper to use in the // dot-product lowering. auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) { + if (ShapeMap.find(Op) == ShapeMap.end()) + return InstructionCost::getInvalid(); + if (!isa<Instruction>(Op)) return InstructionCost(0); @@ -1356,7 +1359,7 @@ class LowerMatrixIntrinsics { InstructionCost EmbedCost(0); // Roughly estimate the cost for embedding the columns into a vector. for (unsigned I = 1; I < N; ++I) - EmbedCost -= + EmbedCost += TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1), std::nullopt, TTI::TCK_RecipThroughput); return EmbedCost; @@ -1378,7 +1381,7 @@ class LowerMatrixIntrinsics { // vector. InstructionCost EmbedCost(0); for (unsigned I = 1; I < N; ++I) - EmbedCost += + EmbedCost -= TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1), std::nullopt, TTI::TCK_RecipThroughput); return EmbedCost; @@ -1391,7 +1394,26 @@ class LowerMatrixIntrinsics { return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) - N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0); }; - auto LHSCost = GetCostForArg(LHS, LShape.NumColumns); + + SmallPtrSet<Value *, 4> Seen; + SmallVector<Value *> WorkList; + SmallVector<Value *> ToFlatten; + WorkList.push_back(LHS); + InstructionCost LHSCost(0); + while (!WorkList.empty()) { + Value *Op = WorkList.pop_back_val(); + if (!Seen.insert(Op).second) + continue; + + InstructionCost OpCost = GetCostForArg(Op, LShape.NumColumns); + if (OpCost + LHSCost >= LHSCost) + continue; + + LHSCost += OpCost; + ToFlatten.push_back(Op); + if (auto *I = dyn_cast<Instruction>(Op)) + WorkList.append(I->op_begin(), I->op_end()); + } // We compare the costs of a vector.reduce.add to sequential add. int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd; @@ -1412,16 +1434,16 @@ class LowerMatrixIntrinsics { FusedInsts.insert(MatMul); IRBuilder<> Builder(MatMul); auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened, - this](Value *Op) -> Value * { + this](Value *Op) { // Matmul must be the only user of loads because we don't use LowerLoad // for row vectors (LowerLoad results in scalar loads and shufflevectors // instead of single vector load). if (!CanBeFlattened(Op)) - return Op; + return; if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) { ShapeMap[Op] = ShapeMap[Op].t(); - return Op; + return; } FusedInsts.insert(cast<Instruction>(Op)); @@ -1432,16 +1454,19 @@ class LowerMatrixIntrinsics { auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg); Op->replaceAllUsesWith(NewLoad); cast<Instruction>(Op)->eraseFromParent(); - return NewLoad; + return; } else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>( m_Value(Arg)))) { ToRemove.push_back(cast<Instruction>(Op)); - return Arg; + Op->replaceAllUsesWith(Arg); + return; } - - return Op; }; - LHS = FlattenArg(LHS); + + for (auto *V : ToFlatten) + FlattenArg(V); + + LHS = MatMul->getArgOperand(0); // Insert mul/fmul and llvm.vector.reduce.fadd Value *Mul = diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int.ll index 7bbd0c50048551..f15dbed1f1f513 100644 --- a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int.ll @@ -119,44 +119,15 @@ entry: define <1 x i32> @add_chain_feeding_dotproduct_i32_v8_1(<8 x i32> %a, <8 x i32> %b, <8 x i32> %c, <8 x i32> %d) { ; CHECK-LABEL: @add_chain_feeding_dotproduct_i32_v8_1( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x i32> [[A:%.*]], <8 x i32> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 1> -; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 2> -; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 3> -; CHECK-NEXT: [[SPLIT4:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 4> -; CHECK-NEXT: [[SPLIT5:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 5> -; CHECK-NEXT: [[SPLIT6:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 6> -; CHECK-NEXT: [[SPLIT7:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 7> -; CHECK-NEXT: [[SPLIT8:%.*]] = shufflevector <8 x i32> [[B:%.*]], <8 x i32> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[SPLIT9:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 1> -; CHECK-NEXT: [[SPLIT10:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 2> -; CHECK-NEXT: [[SPLIT11:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 3> -; CHECK-NEXT: [[SPLIT12:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 4> -; CHECK-NEXT: [[SPLIT13:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 5> -; CHECK-NEXT: [[SPLIT14:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 6> -; CHECK-NEXT: [[SPLIT15:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 7> -; CHECK-NEXT: [[TMP0:%.*]] = add <1 x i32> [[SPLIT]], [[SPLIT8]] -; CHECK-NEXT: [[TMP1:%.*]] = add <1 x i32> [[SPLIT1]], [[SPLIT9]] -; CHECK-NEXT: [[TMP2:%.*]] = add <1 x i32> [[SPLIT2]], [[SPLIT10]] -; CHECK-NEXT: [[TMP3:%.*]] = add <1 x i32> [[SPLIT3]], [[SPLIT11]] -; CHECK-NEXT: [[TMP4:%.*]] = add <1 x i32> [[SPLIT4]], [[SPLIT12]] -; CHECK-NEXT: [[TMP5:%.*]] = add <1 x i32> [[SPLIT5]], [[SPLIT13]] -; CHECK-NEXT: [[TMP6:%.*]] = add <1 x i32> [[SPLIT6]], [[SPLIT14]] -; CHECK-NEXT: [[TMP7:%.*]] = add <1 x i32> [[SPLIT7]], [[SPLIT15]] -; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <1 x i32> [[TMP0]], <1 x i32> [[TMP1]], <2 x i32> <i32 0, i32 1> -; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <1 x i32> [[TMP2]], <1 x i32> [[TMP3]], <2 x i32> <i32 0, i32 1> -; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <1 x i32> [[TMP4]], <1 x i32> [[TMP5]], <2 x i32> <i32 0, i32 1> -; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <1 x i32> [[TMP6]], <1 x i32> [[TMP7]], <2 x i32> <i32 0, i32 1> -; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <2 x i32> [[TMP8]], <2 x i32> [[TMP9]], <4 x i32> <i32 0, i32 1, i32 2, i32 3> -; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <2 x i32> [[TMP10]], <2 x i32> [[TMP11]], <4 x i32> <i32 0, i32 1, i32 2, i32 3> -; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <4 x i32> [[TMP12]], <4 x i32> [[TMP13]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7> -; CHECK-NEXT: [[SPLIT16:%.*]] = shufflevector <8 x i32> [[TMP14]], <8 x i32> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7> -; CHECK-NEXT: [[SPLIT17:%.*]] = shufflevector <8 x i32> [[C:%.*]], <8 x i32> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7> -; CHECK-NEXT: [[TMP15:%.*]] = add <8 x i32> [[SPLIT16]], [[SPLIT17]] -; CHECK-NEXT: [[TMP16:%.*]] = mul <8 x i32> [[TMP15]], [[D:%.*]] -; CHECK-NEXT: [[TMP17:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP16]]) -; CHECK-NEXT: [[TMP18:%.*]] = insertelement <1 x i32> poison, i32 [[TMP17]], i64 0 -; CHECK-NEXT: ret <1 x i32> [[TMP18]] +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x i32> [[A:%.*]], <8 x i32> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x i32> [[B:%.*]], <8 x i32> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7> +; CHECK-NEXT: [[TMP0:%.*]] = add <8 x i32> [[SPLIT]], [[SPLIT1]] +; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <8 x i32> [[C:%.*]], <8 x i32> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7> +; CHECK-NEXT: [[TMP1:%.*]] = add <8 x i32> [[TMP0]], [[SPLIT2]] +; CHECK-NEXT: [[TMP2:%.*]] = mul <8 x i32> [[TMP1]], [[D:%.*]] +; CHECK-NEXT: [[TMP3:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP2]]) +; CHECK-NEXT: [[TMP4:%.*]] = insertelement <1 x i32> poison, i32 [[TMP3]], i64 0 +; CHECK-NEXT: ret <1 x i32> [[TMP4]] ; entry: %add.1 = add <8 x i32> %a, %b >From 5c1a3f2cb7ed28c21e8a1e976754e69e9e3bb452 Mon Sep 17 00:00:00 2001 From: Florian Hahn <f...@fhahn.com> Date: Fri, 12 Jan 2024 09:55:14 +0000 Subject: [PATCH 2/2] !fixup add comment --- llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index c6bb43d3a78cf3..b528762b545659 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -1395,6 +1395,9 @@ class LowerMatrixIntrinsics { N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0); }; + // Iterate over LHS and operations feeding LHS and check if it is profitable + // to flatten the visited ops. For each op, we compute the difference + // between the flattened and matrix versions. SmallPtrSet<Value *, 4> Seen; SmallVector<Value *> WorkList; SmallVector<Value *> ToFlatten; _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits