simoll updated this revision to Diff 296135.
simoll added a comment.
Herald added subscribers: llvm-commits, nikic, pengfei, hiraditya, mgorny.
Herald added a project: LLVM.
fixed for privatized ElementCount members.
Repository:
rG LLVM Github Monorepo
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D81083/new/
https://reviews.llvm.org/D81083
Files:
llvm/include/llvm/Analysis/TargetTransformInfo.h
llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
llvm/include/llvm/CodeGen/ExpandVectorPredication.h
llvm/include/llvm/CodeGen/Passes.h
llvm/include/llvm/IR/IntrinsicInst.h
llvm/include/llvm/InitializePasses.h
llvm/lib/Analysis/TargetTransformInfo.cpp
llvm/lib/CodeGen/CMakeLists.txt
llvm/lib/CodeGen/ExpandVectorPredication.cpp
llvm/lib/CodeGen/TargetPassConfig.cpp
llvm/lib/IR/IntrinsicInst.cpp
llvm/test/CodeGen/AArch64/O0-pipeline.ll
llvm/test/CodeGen/AArch64/O3-pipeline.ll
llvm/test/CodeGen/ARM/O3-pipeline.ll
llvm/test/CodeGen/Generic/expand-vp.ll
llvm/test/CodeGen/X86/O0-pipeline.ll
llvm/tools/llc/llc.cpp
llvm/tools/opt/opt.cpp
Index: llvm/tools/opt/opt.cpp
===================================================================
--- llvm/tools/opt/opt.cpp
+++ llvm/tools/opt/opt.cpp
@@ -578,6 +578,7 @@
initializePostInlineEntryExitInstrumenterPass(Registry);
initializeUnreachableBlockElimLegacyPassPass(Registry);
initializeExpandReductionsPass(Registry);
+ initializeExpandVectorPredicationPass(Registry);
initializeWasmEHPreparePass(Registry);
initializeWriteBitcodePassPass(Registry);
initializeHardwareLoopsPass(Registry);
Index: llvm/tools/llc/llc.cpp
===================================================================
--- llvm/tools/llc/llc.cpp
+++ llvm/tools/llc/llc.cpp
@@ -318,6 +318,7 @@
initializeVectorization(*Registry);
initializeScalarizeMaskedMemIntrinPass(*Registry);
initializeExpandReductionsPass(*Registry);
+ initializeExpandVectorPredicationPass(*Registry);
initializeHardwareLoopsPass(*Registry);
initializeTransformUtils(*Registry);
Index: llvm/test/CodeGen/X86/O0-pipeline.ll
===================================================================
--- llvm/test/CodeGen/X86/O0-pipeline.ll
+++ llvm/test/CodeGen/X86/O0-pipeline.ll
@@ -24,6 +24,7 @@
; CHECK-NEXT: Lower constant intrinsics
; CHECK-NEXT: Remove unreachable blocks from the CFG
; CHECK-NEXT: Instrument function entry/exit with calls to e.g. mcount() (post inlining)
+; CHECK-NEXT: Expand vector predication intrinsics
; CHECK-NEXT: Scalarize Masked Memory Intrinsics
; CHECK-NEXT: Expand reduction intrinsics
; CHECK-NEXT: Expand indirectbr instructions
Index: llvm/test/CodeGen/Generic/expand-vp.ll
===================================================================
--- /dev/null
+++ llvm/test/CodeGen/Generic/expand-vp.ll
@@ -0,0 +1,84 @@
+; RUN: opt --expand-vec-pred -S < %s | FileCheck %s
+
+; All VP intrinsics have to be lowered into non-VP ops
+; CHECK-NOT: {{call.* @llvm.vp.add}}
+; CHECK-NOT: {{call.* @llvm.vp.sub}}
+; CHECK-NOT: {{call.* @llvm.vp.mul}}
+; CHECK-NOT: {{call.* @llvm.vp.sdiv}}
+; CHECK-NOT: {{call.* @llvm.vp.srem}}
+; CHECK-NOT: {{call.* @llvm.vp.udiv}}
+; CHECK-NOT: {{call.* @llvm.vp.urem}}
+; CHECK-NOT: {{call.* @llvm.vp.and}}
+; CHECK-NOT: {{call.* @llvm.vp.or}}
+; CHECK-NOT: {{call.* @llvm.vp.xor}}
+; CHECK-NOT: {{call.* @llvm.vp.ashr}}
+; CHECK-NOT: {{call.* @llvm.vp.lshr}}
+; CHECK-NOT: {{call.* @llvm.vp.shl}}
+
+define void @test_vp_int_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x i32> %i2, <8 x i32> %f3, <8 x i1> %m, i32 %n) {
+ %r0 = call <8 x i32> @llvm.vp.add.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
+ %r1 = call <8 x i32> @llvm.vp.sub.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
+ %r2 = call <8 x i32> @llvm.vp.mul.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
+ %r3 = call <8 x i32> @llvm.vp.sdiv.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
+ %r4 = call <8 x i32> @llvm.vp.srem.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
+ %r5 = call <8 x i32> @llvm.vp.udiv.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
+ %r6 = call <8 x i32> @llvm.vp.urem.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
+ %r7 = call <8 x i32> @llvm.vp.and.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
+ %r8 = call <8 x i32> @llvm.vp.or.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
+ %r9 = call <8 x i32> @llvm.vp.xor.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
+ %rA = call <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
+ %rB = call <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
+ %rC = call <8 x i32> @llvm.vp.shl.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
+ ret void
+}
+
+; fixed-width vectors
+; integer arith
+declare <8 x i32> @llvm.vp.add.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32)
+declare <8 x i32> @llvm.vp.sub.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32)
+declare <8 x i32> @llvm.vp.mul.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32)
+declare <8 x i32> @llvm.vp.sdiv.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32)
+declare <8 x i32> @llvm.vp.srem.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32)
+declare <8 x i32> @llvm.vp.udiv.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32)
+declare <8 x i32> @llvm.vp.urem.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32)
+; bit arith
+declare <8 x i32> @llvm.vp.and.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32)
+declare <8 x i32> @llvm.vp.xor.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32)
+declare <8 x i32> @llvm.vp.or.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32)
+declare <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32)
+declare <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32)
+declare <8 x i32> @llvm.vp.shl.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32)
+
+define void @test_vp_int_vscale(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i32> %i2, <vscale x 4 x i32> %f3, <vscale x 4 x i1> %m, i32 %n) {
+ %r0 = call <vscale x 4 x i32> @llvm.vp.add.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
+ %r1 = call <vscale x 4 x i32> @llvm.vp.sub.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
+ %r2 = call <vscale x 4 x i32> @llvm.vp.mul.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
+ %r3 = call <vscale x 4 x i32> @llvm.vp.sdiv.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
+ %r4 = call <vscale x 4 x i32> @llvm.vp.srem.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
+ %r5 = call <vscale x 4 x i32> @llvm.vp.udiv.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
+ %r6 = call <vscale x 4 x i32> @llvm.vp.urem.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
+ %r7 = call <vscale x 4 x i32> @llvm.vp.and.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
+ %r8 = call <vscale x 4 x i32> @llvm.vp.or.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
+ %r9 = call <vscale x 4 x i32> @llvm.vp.xor.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
+ %rA = call <vscale x 4 x i32> @llvm.vp.ashr.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
+ %rB = call <vscale x 4 x i32> @llvm.vp.lshr.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
+ %rC = call <vscale x 4 x i32> @llvm.vp.shl.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
+ ret void
+}
+
+; scalable-width vectors
+; integer arith
+declare <vscale x 4 x i32> @llvm.vp.add.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32)
+declare <vscale x 4 x i32> @llvm.vp.sub.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32)
+declare <vscale x 4 x i32> @llvm.vp.mul.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32)
+declare <vscale x 4 x i32> @llvm.vp.sdiv.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32)
+declare <vscale x 4 x i32> @llvm.vp.srem.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32)
+declare <vscale x 4 x i32> @llvm.vp.udiv.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32)
+declare <vscale x 4 x i32> @llvm.vp.urem.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32)
+; bit arith
+declare <vscale x 4 x i32> @llvm.vp.and.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32)
+declare <vscale x 4 x i32> @llvm.vp.xor.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32)
+declare <vscale x 4 x i32> @llvm.vp.or.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32)
+declare <vscale x 4 x i32> @llvm.vp.ashr.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32)
+declare <vscale x 4 x i32> @llvm.vp.lshr.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32)
+declare <vscale x 4 x i32> @llvm.vp.shl.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32)
Index: llvm/test/CodeGen/ARM/O3-pipeline.ll
===================================================================
--- llvm/test/CodeGen/ARM/O3-pipeline.ll
+++ llvm/test/CodeGen/ARM/O3-pipeline.ll
@@ -37,6 +37,7 @@
; CHECK-NEXT: Constant Hoisting
; CHECK-NEXT: Partially inline calls to library functions
; CHECK-NEXT: Instrument function entry/exit with calls to e.g. mcount() (post inlining)
+; CHECK-NEXT: Expand vector predication intrinsics
; CHECK-NEXT: Scalarize Masked Memory Intrinsics
; CHECK-NEXT: Expand reduction intrinsics
; CHECK-NEXT: Dominator Tree Construction
Index: llvm/test/CodeGen/AArch64/O3-pipeline.ll
===================================================================
--- llvm/test/CodeGen/AArch64/O3-pipeline.ll
+++ llvm/test/CodeGen/AArch64/O3-pipeline.ll
@@ -57,6 +57,7 @@
; CHECK-NEXT: Constant Hoisting
; CHECK-NEXT: Partially inline calls to library functions
; CHECK-NEXT: Instrument function entry/exit with calls to e.g. mcount() (post inlining)
+; CHECK-NEXT: Expand vector predication intrinsics
; CHECK-NEXT: Scalarize Masked Memory Intrinsics
; CHECK-NEXT: Expand reduction intrinsics
; CHECK-NEXT: Stack Safety Analysis
Index: llvm/test/CodeGen/AArch64/O0-pipeline.ll
===================================================================
--- llvm/test/CodeGen/AArch64/O0-pipeline.ll
+++ llvm/test/CodeGen/AArch64/O0-pipeline.ll
@@ -22,6 +22,7 @@
; CHECK-NEXT: Lower constant intrinsics
; CHECK-NEXT: Remove unreachable blocks from the CFG
; CHECK-NEXT: Instrument function entry/exit with calls to e.g. mcount() (post inlining)
+; CHECK-NEXT: Expand vector predication intrinsics
; CHECK-NEXT: Scalarize Masked Memory Intrinsics
; CHECK-NEXT: Expand reduction intrinsics
; CHECK-NEXT: AArch64 Stack Tagging
Index: llvm/lib/IR/IntrinsicInst.cpp
===================================================================
--- llvm/lib/IR/IntrinsicInst.cpp
+++ llvm/lib/IR/IntrinsicInst.cpp
@@ -196,6 +196,12 @@
return nullptr;
}
+void VPIntrinsic::setMaskParam(Value *NewMask) {
+ auto MaskPos = GetMaskParamPos(getIntrinsicID());
+ assert(MaskPos.hasValue());
+ setArgOperand(MaskPos.getValue(), NewMask);
+}
+
Value *VPIntrinsic::getVectorLengthParam() const {
auto vlenPos = GetVectorLengthParamPos(getIntrinsicID());
if (vlenPos)
@@ -203,6 +209,12 @@
return nullptr;
}
+void VPIntrinsic::setVectorLengthParam(Value *NewEVL) {
+ auto EVLPos = GetVectorLengthParamPos(getIntrinsicID());
+ assert(EVLPos.hasValue());
+ setArgOperand(EVLPos.getValue(), NewEVL);
+}
+
Optional<int> VPIntrinsic::GetMaskParamPos(Intrinsic::ID IntrinsicID) {
switch (IntrinsicID) {
default:
Index: llvm/lib/CodeGen/TargetPassConfig.cpp
===================================================================
--- llvm/lib/CodeGen/TargetPassConfig.cpp
+++ llvm/lib/CodeGen/TargetPassConfig.cpp
@@ -702,6 +702,11 @@
// Instrument function entry and exit, e.g. with calls to mcount().
addPass(createPostInlineEntryExitInstrumenterPass());
+ // Expand vector predication intrinsics into standard IR instructions.
+ // This pass has to run before ScalarizeMaskedMemIntrin and ExpandReduction
+ // passes since it emits those kinds of intrinsics.
+ addPass(createExpandVectorPredicationPass());
+
// Add scalarization of target's unsupported masked memory intrinsics pass.
// the unsupported intrinsic will be replaced with a chain of basic blocks,
// that stores/loads element one-by-one if the appropriate mask bit is set.
Index: llvm/lib/CodeGen/ExpandVectorPredication.cpp
===================================================================
--- /dev/null
+++ llvm/lib/CodeGen/ExpandVectorPredication.cpp
@@ -0,0 +1,456 @@
+//===--- CodeGen/ExpandVectorPredication.cpp - Expand VP intrinsics -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass implements IR expansion for vector predication intrinsics, allowing
+// targets to enable vector predication until just before codegen.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGen/ExpandVectorPredication.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/CodeGen/Passes.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/Module.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
+
+using namespace llvm;
+
+using VPLegalization = TargetTransformInfo::VPLegalization;
+
+#define DEBUG_TYPE "expand-vec-pred"
+
+STATISTIC(NumFoldedVL, "Number of folded vector length params");
+STATISTIC(NumLoweredVPOps, "Number of folded vector predication operations");
+
+///// Helpers {
+
+/// \returns Whether the vector mask \p MaskVal has all lane bits set.
+static bool isAllTrueMask(Value *MaskVal) {
+ auto *ConstVec = dyn_cast<ConstantVector>(MaskVal);
+ if (!ConstVec)
+ return false;
+ return ConstVec->isAllOnesValue();
+}
+
+/// Computes the smallest integer bit width to hold the step vector <0, ..,
+/// NumVectorElements - 1>
+static unsigned getLeastLaneBitsForStepVector(unsigned NumVectorElements) {
+ unsigned MostSignificantOne =
+ llvm::countLeadingZeros<uint64_t>(NumVectorElements, ZB_Undefined);
+ return std::max<unsigned>(IntegerType::MIN_INT_BITS, 64 - MostSignificantOne);
+}
+
+/// \returns A non-excepting divisor constant for this type.
+static Constant *getSafeDivisor(Type *DivTy) {
+ assert(DivTy->isIntOrIntVectorTy());
+ return ConstantInt::get(DivTy, 1u, false);
+}
+
+/// Transfer operation properties from \p OldVPI to \p NewVal.
+static void transferDecorations(Value &NewVal, VPIntrinsic &VPI) {
+ auto *NewInst = dyn_cast<Instruction>(&NewVal);
+ if (!NewInst || !isa<FPMathOperator>(NewVal))
+ return;
+
+ auto *OldFMOp = dyn_cast<FPMathOperator>(&VPI);
+ if (!OldFMOp)
+ return;
+
+ NewInst->setFastMathFlags(OldFMOp->getFastMathFlags());
+}
+
+/// Transfer all properties from \p OldOp to \p NewOp and replace all uses.
+/// OldVP gets erased.
+static void replaceOperation(Value &NewOp, VPIntrinsic &OldOp) {
+ transferDecorations(NewOp, OldOp);
+ OldOp.replaceAllUsesWith(&NewOp);
+ OldOp.eraseFromParent();
+}
+
+//// } Helpers
+
+namespace {
+
+// Expansion pass state at function scope.
+struct CachingVPExpander {
+ Function &F;
+ const TargetTransformInfo &TTI;
+
+ /// \returns A (fixed length) vector with ascending integer indices
+ /// (<0, 1, ..., NumElems-1>).
+ Value *createStepVector(IRBuilder<> &Builder, int32_t ElemBits,
+ int32_t NumElems);
+
+ /// \returns A bitmask that is true where the lane position is less-than \p
+ /// EVLParam
+ ///
+ /// \p Builder
+ /// Used for instruction creation.
+ /// \p VLParam
+ /// The explicit vector length parameter to test against the lane
+ /// positions.
+ /// \p ElemCount
+ /// Static (potentially scalable) number of vector elements
+ Value *convertEVLToMask(IRBuilder<> &Builder, Value *EVLParam,
+ ElementCount ElemCount);
+
+ Value *foldEVLIntoMask(VPIntrinsic &VPI);
+
+ /// "Remove" the %evl parameter of \p PI by setting it to the static vector
+ /// length of the operation.
+ void discardEVLParameter(VPIntrinsic &PI);
+
+ /// \brief Lower this VP binary operator to a non-VP binary operator.
+ Value *expandPredicationInBinaryOperator(IRBuilder<> &Builder,
+ VPIntrinsic &PI);
+
+ /// \brief query TTI and expand the vector predication in \p P accordingly.
+ Value *expandPredication(VPIntrinsic &PI);
+
+ /// \brief return a good (for fast icmp) integer bit width to expand
+ /// the EVL comparison against the stepvector in.
+ std::map<unsigned, unsigned> StaticVLToBitsCache; // TODO 'SmallMap' class
+ unsigned getLaneBitsForEVLCompare(unsigned StaticVL);
+
+public:
+ CachingVPExpander(Function &F, const TargetTransformInfo &TTI)
+ : F(F), TTI(TTI) {}
+
+ // expand VP ops in \p F according to \p TTI.
+ bool expandVectorPredication();
+};
+
+//// CachingVPExpander {
+
+unsigned CachingVPExpander::getLaneBitsForEVLCompare(unsigned StaticVL) {
+ auto ItCached = StaticVLToBitsCache.find(StaticVL);
+ if (ItCached != StaticVLToBitsCache.end())
+ return ItCached->second;
+
+ // The smallest integer to hold <0, .., ElemCount.Min -1>
+ // Cannot choose less bits than this or the expansion will be invalid.
+ unsigned MinLaneBits = getLeastLaneBitsForStepVector(StaticVL);
+ LLVM_DEBUG(dbgs() << "Least lane bits for " << StaticVL << " is "
+ << MinLaneBits << "\n";);
+
+ // If the EVL compare will be expanded into scalar code, choose the
+ // smallest integer type.
+ if (TTI.getRegisterBitWidth(/* Vector */ true) == 0)
+ return MinLaneBits;
+
+ // Otw, the generated vector operation will likely map to vector instructions.
+ // The largest bit width to fit the EVL expansion in one vector register.
+ unsigned MaxLaneBits = std::min<unsigned>(
+ IntegerType::MAX_INT_BITS, TTI.getRegisterBitWidth(true) / StaticVL);
+
+ // Many SIMD instruction are restricted in their supported lane bit widths.
+ // We choose the bit width that gives us the cheapest vector compare.
+ int Cheapest = std::numeric_limits<int>::max();
+ auto &Ctx = F.getContext();
+ unsigned CheapestLaneBits = MinLaneBits;
+ for (auto LaneBits = MinLaneBits; LaneBits < MaxLaneBits; ++LaneBits) {
+ int VecCmpCost = TTI.getCmpSelInstrCost(
+ Instruction::ICmp, VectorType::get(Type::getIntNTy(Ctx, LaneBits),
+ StaticVL, /* Scalable */ false));
+ if (VecCmpCost < Cheapest) {
+ Cheapest = VecCmpCost;
+ CheapestLaneBits = LaneBits;
+ }
+ }
+
+ StaticVLToBitsCache[StaticVL] = CheapestLaneBits;
+ return CheapestLaneBits;
+}
+
+Value *CachingVPExpander::createStepVector(IRBuilder<> &Builder,
+ int32_t ElemBits, int32_t NumElems) {
+ // TODO add caching
+ SmallVector<Constant *, 16> ConstElems;
+
+ Type *LaneTy = Builder.getIntNTy(ElemBits);
+
+ for (int32_t Idx = 0; Idx < NumElems; ++Idx) {
+ ConstElems.push_back(ConstantInt::get(LaneTy, Idx, false));
+ }
+
+ return ConstantVector::get(ConstElems);
+}
+
+Value *CachingVPExpander::convertEVLToMask(IRBuilder<> &Builder,
+ Value *EVLParam,
+ ElementCount ElemCount) {
+ // TODO add caching
+ if (ElemCount.isScalable()) {
+ auto *M = Builder.GetInsertBlock()->getModule();
+ auto *BoolVecTy = VectorType::get(Builder.getInt1Ty(), ElemCount);
+ auto *ActiveMaskFunc = Intrinsic::getDeclaration(
+ M, Intrinsic::get_active_lane_mask, {BoolVecTy, EVLParam->getType()});
+ // `get_active_lane_mask` performs an implicit less-than comparison.
+ auto *ConstZero = Builder.getInt32(0);
+ return Builder.CreateCall(ActiveMaskFunc, {ConstZero, EVLParam});
+ }
+
+ unsigned NumElems = ElemCount.getFixedValue();
+ unsigned ElemBits = getLaneBitsForEVLCompare(NumElems);
+
+ Type *LaneTy = Builder.getIntNTy(ElemBits);
+
+ auto *ExtVLParam = Builder.CreateZExtOrTrunc(EVLParam, LaneTy);
+ auto *VLSplat = Builder.CreateVectorSplat(NumElems, ExtVLParam);
+
+ auto *IdxVec = createStepVector(Builder, ElemBits, NumElems);
+
+ return Builder.CreateICmp(CmpInst::ICMP_ULT, IdxVec, VLSplat);
+}
+
+Value *
+CachingVPExpander::expandPredicationInBinaryOperator(IRBuilder<> &Builder,
+ VPIntrinsic &VPI) {
+ assert(VPI.canIgnoreVectorLengthParam());
+
+ auto OC = static_cast<Instruction::BinaryOps>(VPI.getFunctionalOpcode());
+ assert(Instruction::isBinaryOp(OC));
+
+ auto *FirstOp = VPI.getOperand(0);
+ auto *SndOp = VPI.getOperand(1);
+
+ auto *Mask = VPI.getMaskParam();
+
+ // Blend in safe operands
+ if (Mask && !isAllTrueMask(Mask)) {
+ switch (OC) {
+ default:
+ // can safely ignore the predicate
+ break;
+
+ // Division operators need a safe divisor on masked-off lanes (1)
+ case Instruction::UDiv:
+ case Instruction::SDiv:
+ case Instruction::URem:
+ case Instruction::SRem:
+ // 2nd operand must not be zero
+ auto *SafeDivisor = getSafeDivisor(VPI.getType());
+ SndOp = Builder.CreateSelect(Mask, SndOp, SafeDivisor);
+ }
+ }
+
+ auto *NewBinOp = Builder.CreateBinOp(OC, FirstOp, SndOp, VPI.getName());
+
+ replaceOperation(*NewBinOp, VPI);
+ return NewBinOp;
+}
+
+void CachingVPExpander::discardEVLParameter(VPIntrinsic &VPI) {
+ LLVM_DEBUG(dbgs() << "Discard EVL parameter in " << VPI << "\n");
+
+ if (VPI.canIgnoreVectorLengthParam())
+ return;
+
+ Value *EVLParam = VPI.getVectorLengthParam();
+ if (!EVLParam)
+ return;
+
+ ElementCount StaticElemCount = VPI.getStaticVectorLength();
+ Value *MaxEVL = nullptr;
+ auto *Int32Ty = Type::getInt32Ty(VPI.getContext());
+ if (StaticElemCount.isScalable()) {
+ // TODO add caching
+ auto *M = VPI.getModule();
+ auto *VScaleFunc = Intrinsic::getDeclaration(M, Intrinsic::vscale, Int32Ty);
+ IRBuilder<> Builder(VPI.getParent(), VPI.getIterator());
+ auto *FactorConst = Builder.getInt32(StaticElemCount.getKnownMinValue());
+ auto *VScale = Builder.CreateCall(VScaleFunc, {}, "vscale");
+ MaxEVL = Builder.CreateMul(VScale, FactorConst, "scalable_size",
+ /*NUW*/ true, /*NSW*/ false);
+ } else {
+ MaxEVL = ConstantInt::get(Int32Ty, StaticElemCount.getFixedValue(), false);
+ }
+ VPI.setVectorLengthParam(MaxEVL);
+}
+
+Value *CachingVPExpander::foldEVLIntoMask(VPIntrinsic &VPI) {
+ LLVM_DEBUG(dbgs() << "Folding vlen for " << VPI << '\n');
+
+ IRBuilder<> Builder(&VPI);
+
+ // No %evl parameter and so nothing to do here
+ if (VPI.canIgnoreVectorLengthParam()) {
+ return &VPI;
+ }
+
+ // Only VP intrinsics can have a %evl parameter
+ Value *OldMaskParam = VPI.getMaskParam();
+ Value *OldEVLParam = VPI.getVectorLengthParam();
+ assert(OldMaskParam && "no mask param to fold the vl param into");
+ assert(OldEVLParam && "no EVL param to fold away");
+
+ LLVM_DEBUG(dbgs() << "OLD evl: " << *OldEVLParam << '\n');
+ LLVM_DEBUG(dbgs() << "OLD mask: " << *OldMaskParam << '\n');
+
+ // Convert the %evl predication into vector mask predication.
+ ElementCount ElemCount = VPI.getStaticVectorLength();
+ auto *VLMask = convertEVLToMask(Builder, OldEVLParam, ElemCount);
+ auto *NewMaskParam = Builder.CreateAnd(VLMask, OldMaskParam);
+ VPI.setMaskParam(NewMaskParam);
+
+ // Drop the EVl parameter
+ discardEVLParameter(VPI);
+ assert(VPI.canIgnoreVectorLengthParam() &&
+ "transformation did not render the evl param ineffective!");
+
+ // re-asses the modified instruction
+ return &VPI;
+}
+
+Value *CachingVPExpander::expandPredication(VPIntrinsic &VPI) {
+ LLVM_DEBUG(dbgs() << "Lowering to unpredicated op: " << VPI << '\n');
+
+ IRBuilder<> Builder(&VPI);
+
+ // Try lowering to a LLVM instruction first.
+ unsigned OC = VPI.getFunctionalOpcode();
+#define FIRST_BINARY_INST(X) unsigned FirstBinOp = X;
+#define LAST_BINARY_INST(X) unsigned LastBinOp = X;
+#include "llvm/IR/Instruction.def"
+
+ if (FirstBinOp <= OC && OC <= LastBinOp) {
+ return expandPredicationInBinaryOperator(Builder, VPI);
+ }
+
+ return &VPI;
+}
+
+//// } CachingVPExpander
+
+struct TransformJob {
+ VPIntrinsic *PI;
+ TargetTransformInfo::VPLegalization Strategy;
+ TransformJob(VPIntrinsic *PI, TargetTransformInfo::VPLegalization InitStrat)
+ : PI(PI), Strategy(InitStrat) {}
+
+ bool isDone() const { return Strategy.doNothing(); }
+};
+
+void sanitizeStrategy(Instruction &I, VPLegalization &LegalizeStrat) {
+ // Speculatable instructions do not strictle need predication.
+ if (isSafeToSpeculativelyExecute(&I))
+ return;
+
+ // Preserve the predication effect of the EVL parameter by folding
+ // it into the predicate.
+ if (LegalizeStrat.EVLParamStrategy == VPLegalization::Discard) {
+ LegalizeStrat.EVLParamStrategy = VPLegalization::Convert;
+ }
+}
+
+/// \brief Expand llvm.vp.* intrinsics as requested by \p TTI.
+bool CachingVPExpander::expandVectorPredication() {
+ // Holds all vector-predicated ops with an effective vector length param that
+ SmallVector<TransformJob, 16> Worklist;
+
+ for (auto &I : instructions(F)) {
+ auto *VPI = dyn_cast<VPIntrinsic>(&I);
+ if (!VPI)
+ continue;
+ auto VPStrat = TTI.getVPLegalizationStrategy(*VPI);
+ sanitizeStrategy(I, VPStrat);
+ if (!VPStrat.doNothing()) {
+ Worklist.emplace_back(VPI, VPStrat);
+ }
+ }
+ if (Worklist.empty())
+ return false;
+
+ LLVM_DEBUG(dbgs() << "\n:::: Transforming instructions. ::::\n");
+ for (TransformJob Job : Worklist) {
+ // Transform the EVL parameter
+ switch (Job.Strategy.EVLParamStrategy) {
+ case VPLegalization::Legal:
+ break;
+ case VPLegalization::Discard: {
+ discardEVLParameter(*Job.PI);
+ } break;
+ case VPLegalization::Convert: {
+ if (foldEVLIntoMask(*Job.PI)) {
+ ++NumFoldedVL;
+ }
+ } break;
+ }
+ Job.Strategy.EVLParamStrategy = VPLegalization::Legal;
+
+ // Replace the operator
+ switch (Job.Strategy.OpStrategy) {
+ case VPLegalization::Legal:
+ break;
+ case VPLegalization::Discard:
+ llvm_unreachable("Invalid strategy for operators.");
+ case VPLegalization::Convert: {
+ expandPredication(*Job.PI);
+ ++NumLoweredVPOps;
+ } break;
+ }
+ Job.Strategy.OpStrategy = VPLegalization::Legal;
+
+ assert(Job.isDone() && "incomplete transformation");
+ }
+
+ return true;
+}
+class ExpandVectorPredication : public FunctionPass {
+public:
+ static char ID;
+ ExpandVectorPredication() : FunctionPass(ID) {
+ initializeExpandVectorPredicationPass(*PassRegistry::getPassRegistry());
+ }
+
+ bool runOnFunction(Function &F) override {
+ const auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
+ CachingVPExpander VPExpander(F, *TTI);
+ return VPExpander.expandVectorPredication();
+ }
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addRequired<TargetTransformInfoWrapperPass>();
+ AU.setPreservesCFG();
+ }
+};
+} // namespace
+
+char ExpandVectorPredication::ID;
+INITIALIZE_PASS_BEGIN(ExpandVectorPredication, "expand-vec-pred",
+ "Expand vector predication intrinsics", false, false)
+INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_END(ExpandVectorPredication, "expand-vec-pred",
+ "Expand vector predication intrinsics", false, false)
+
+FunctionPass *llvm::createExpandVectorPredicationPass() {
+ return new ExpandVectorPredication();
+}
+
+PreservedAnalyses
+ExpandVectorPredicationPass::run(Function &F, FunctionAnalysisManager &AM) {
+ const auto &TTI = AM.getResult<TargetIRAnalysis>(F);
+ CachingVPExpander VPExpander(F, TTI);
+ if (!VPExpander.expandVectorPredication())
+ return PreservedAnalyses::all();
+ PreservedAnalyses PA;
+ PA.preserveSet<CFGAnalyses>();
+ return PA;
+}
Index: llvm/lib/CodeGen/CMakeLists.txt
===================================================================
--- llvm/lib/CodeGen/CMakeLists.txt
+++ llvm/lib/CodeGen/CMakeLists.txt
@@ -27,6 +27,7 @@
ExpandMemCmp.cpp
ExpandPostRAPseudos.cpp
ExpandReductions.cpp
+ ExpandVectorPredication.cpp
FaultMaps.cpp
FEntryInserter.cpp
FinalizeISel.cpp
Index: llvm/lib/Analysis/TargetTransformInfo.cpp
===================================================================
--- llvm/lib/Analysis/TargetTransformInfo.cpp
+++ llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1024,6 +1024,11 @@
return TTIImpl->preferPredicatedReductionSelect(Opcode, Ty, Flags);
}
+TargetTransformInfo::VPLegalization
+TargetTransformInfo::getVPLegalizationStrategy(const VPIntrinsic &VPI) const {
+ return TTIImpl->getVPLegalizationStrategy(VPI);
+}
+
bool TargetTransformInfo::shouldExpandReduction(const IntrinsicInst *II) const {
return TTIImpl->shouldExpandReduction(II);
}
Index: llvm/include/llvm/InitializePasses.h
===================================================================
--- llvm/include/llvm/InitializePasses.h
+++ llvm/include/llvm/InitializePasses.h
@@ -150,6 +150,7 @@
void initializeExpandMemCmpPassPass(PassRegistry&);
void initializeExpandPostRAPass(PassRegistry&);
void initializeExpandReductionsPass(PassRegistry&);
+void initializeExpandVectorPredicationPass(PassRegistry &);
void initializeMakeGuardsExplicitLegacyPassPass(PassRegistry&);
void initializeExternalAAWrapperPassPass(PassRegistry&);
void initializeFEntryInserterPass(PassRegistry&);
Index: llvm/include/llvm/IR/IntrinsicInst.h
===================================================================
--- llvm/include/llvm/IR/IntrinsicInst.h
+++ llvm/include/llvm/IR/IntrinsicInst.h
@@ -255,9 +255,11 @@
/// \return the mask parameter or nullptr.
Value *getMaskParam() const;
+ void setMaskParam(Value *);
/// \return the vector length parameter or nullptr.
Value *getVectorLengthParam() const;
+ void setVectorLengthParam(Value *);
/// \return whether the vector length param can be ignored.
bool canIgnoreVectorLengthParam() const;
Index: llvm/include/llvm/CodeGen/Passes.h
===================================================================
--- llvm/include/llvm/CodeGen/Passes.h
+++ llvm/include/llvm/CodeGen/Passes.h
@@ -456,6 +456,11 @@
/// shuffles.
FunctionPass *createExpandReductionsPass();
+ /// This pass expands the vector predication intrinsics into unpredicated
+ /// instructions with selects or just the explicit vector length into the
+ /// predicate mask.
+ FunctionPass *createExpandVectorPredicationPass();
+
// This pass expands memcmp() to load/stores.
FunctionPass *createExpandMemCmpPass();
Index: llvm/include/llvm/CodeGen/ExpandVectorPredication.h
===================================================================
--- /dev/null
+++ llvm/include/llvm/CodeGen/ExpandVectorPredication.h
@@ -0,0 +1,23 @@
+//===-- ExpandVectorPredication.h - Expand vector predication ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CODEGEN_EXPANDVECTORPREDICATION_H
+#define LLVM_CODEGEN_EXPANDVECTORPREDICATION_H
+
+#include "llvm/IR/PassManager.h"
+
+namespace llvm {
+
+class ExpandVectorPredicationPass
+ : public PassInfoMixin<ExpandVectorPredicationPass> {
+public:
+ PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+};
+} // end namespace llvm
+
+#endif // LLVM_CODEGEN_EXPANDVECTORPREDICATION_H
Index: llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
===================================================================
--- llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -677,6 +677,13 @@
bool hasActiveVectorLength() const { return false; }
+ TargetTransformInfo::VPLegalization
+ getVPLegalizationStrategy(const VPIntrinsic &PI) const {
+ return TargetTransformInfo::VPLegalization(
+ /* EVLParamStrategy */ TargetTransformInfo::VPLegalization::Discard,
+ /* OperatorStrategy */ TargetTransformInfo::VPLegalization::Convert);
+ }
+
protected:
// Obtain the minimum required size to hold the value (without the sign)
// In case of a vector it returns the min required size for one element.
Index: llvm/include/llvm/Analysis/TargetTransformInfo.h
===================================================================
--- llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -57,6 +57,7 @@
class Type;
class User;
class Value;
+class VPIntrinsic;
struct KnownBits;
template <typename T> class Optional;
@@ -1326,6 +1327,38 @@
/// Intrinsics") Use of %evl is discouraged when that is not the case.
bool hasActiveVectorLength() const;
+ struct VPLegalization {
+ enum VPTransform {
+ // keep the predicating parameter
+ Legal = 0,
+ // where legal, discard the predicate parameter
+ Discard = 1,
+ // transform into something else that is also predicating
+ Convert = 2
+ };
+
+ // How to transform the EVL parameter.
+ // Legal: keep the EVL parameter as it is.
+ // Discard: Ignore the EVL parameter where it is safe to do so.
+ // Convert: Fold the EVL into the mask parameter.
+ VPTransform EVLParamStrategy;
+
+ // How to transform the operator.
+ // Legal: The target supports this operator.
+ // Convert: Convert this to a non-VP operation.
+ // The 'Discard' strategy is invalid.
+ VPTransform OpStrategy;
+
+ bool doNothing() const {
+ return (EVLParamStrategy == Legal) && (OpStrategy == Legal);
+ }
+ VPLegalization(VPTransform EVLParamStrategy, VPTransform OpStrategy)
+ : EVLParamStrategy(EVLParamStrategy), OpStrategy(OpStrategy) {}
+ };
+
+ /// \returns How the target needs this vector-predicated operation to be
+ /// transformed.
+ VPLegalization getVPLegalizationStrategy(const VPIntrinsic &PI) const;
/// @}
/// @}
@@ -1609,6 +1642,8 @@
virtual bool shouldExpandReduction(const IntrinsicInst *II) const = 0;
virtual unsigned getGISelRematGlobalCost() const = 0;
virtual bool hasActiveVectorLength() const = 0;
+ virtual VPLegalization
+ getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
virtual int getInstructionLatency(const Instruction *I) = 0;
};
@@ -2127,6 +2162,11 @@
return Impl.hasActiveVectorLength();
}
+ VPLegalization
+ getVPLegalizationStrategy(const VPIntrinsic &PI) const override {
+ return Impl.getVPLegalizationStrategy(PI);
+ }
+
int getInstructionLatency(const Instruction *I) override {
return Impl.getInstructionLatency(I);
}
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits