Author: Juneyoung Lee Date: 2021-01-06T12:10:33+09:00 New Revision: 29f8628d1fc8d96670e13562c4d92fc916bd0ce1
URL: https://github.com/llvm/llvm-project/commit/29f8628d1fc8d96670e13562c4d92fc916bd0ce1 DIFF: https://github.com/llvm/llvm-project/commit/29f8628d1fc8d96670e13562c4d92fc916bd0ce1.diff LOG: [Constant] Add containsPoisonElement This patch - Adds containsPoisonElement that checks existence of poison in constant vector elements, - Renames containsUndefElement to containsUndefOrPoisonElement to clarify its behavior & updates its uses properly With this patch, isGuaranteedNotToBeUndefOrPoison's tests w.r.t constant vectors are added because its analysis is improved. Thanks! Reviewed By: nikic Differential Revision: https://reviews.llvm.org/D94053 Added: Modified: llvm/include/llvm/IR/Constant.h llvm/lib/Analysis/ValueTracking.cpp llvm/lib/IR/ConstantFold.cpp llvm/lib/IR/Constants.cpp llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp llvm/unittests/Analysis/ValueTrackingTest.cpp llvm/unittests/IR/ConstantsTest.cpp Removed: ################################################################################ diff --git a/llvm/include/llvm/IR/Constant.h b/llvm/include/llvm/IR/Constant.h index 97650c2051ca..0190aca27b72 100644 --- a/llvm/include/llvm/IR/Constant.h +++ b/llvm/include/llvm/IR/Constant.h @@ -101,11 +101,15 @@ class Constant : public User { /// lane, the constants still match. bool isElementWiseEqual(Value *Y) const; - /// Return true if this is a vector constant that includes any undefined - /// elements. Since it is impossible to inspect a scalable vector element- - /// wise at compile time, this function returns true only if the entire - /// vector is undef - bool containsUndefElement() const; + /// Return true if this is a vector constant that includes any undef or + /// poison elements. Since it is impossible to inspect a scalable vector + /// element- wise at compile time, this function returns true only if the + /// entire vector is undef or poison. + bool containsUndefOrPoisonElement() const; + + /// Return true if this is a vector constant that includes any poison + /// elements. + bool containsPoisonElement() const; /// Return true if this is a fixed width vector constant that includes /// any constant expressions. diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index e15d4f0e4b07..1c75c5fbd0db 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -4895,7 +4895,8 @@ static bool isGuaranteedNotToBeUndefOrPoison(const Value *V, return true; if (C->getType()->isVectorTy() && !isa<ConstantExpr>(C)) - return (PoisonOnly || !C->containsUndefElement()) && + return (PoisonOnly ? !C->containsPoisonElement() + : !C->containsUndefOrPoisonElement()) && !C->containsConstantExpression(); } @@ -5636,10 +5637,10 @@ static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred, // elements because those can not be back-propagated for analysis. Value *OutputZeroVal = nullptr; if (match(TrueVal, m_AnyZeroFP()) && !match(FalseVal, m_AnyZeroFP()) && - !cast<Constant>(TrueVal)->containsUndefElement()) + !cast<Constant>(TrueVal)->containsUndefOrPoisonElement()) OutputZeroVal = TrueVal; else if (match(FalseVal, m_AnyZeroFP()) && !match(TrueVal, m_AnyZeroFP()) && - !cast<Constant>(FalseVal)->containsUndefElement()) + !cast<Constant>(FalseVal)->containsUndefOrPoisonElement()) OutputZeroVal = FalseVal; if (OutputZeroVal) { diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp index 47745689ba2d..03cb108cc485 100644 --- a/llvm/lib/IR/ConstantFold.cpp +++ b/llvm/lib/IR/ConstantFold.cpp @@ -811,7 +811,7 @@ Constant *llvm::ConstantFoldSelectInstruction(Constant *Cond, return true; if (C->getType()->isVectorTy()) - return !C->containsUndefElement() && !C->containsConstantExpression(); + return !C->containsPoisonElement() && !C->containsConstantExpression(); // TODO: Recursively analyze aggregates or other constants. return false; diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp index a38302d17937..5aa819dda2b3 100644 --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -304,31 +304,42 @@ bool Constant::isElementWiseEqual(Value *Y) const { return isa<UndefValue>(CmpEq) || match(CmpEq, m_One()); } -bool Constant::containsUndefElement() const { - if (auto *VTy = dyn_cast<VectorType>(getType())) { - if (isa<UndefValue>(this)) +static bool +containsUndefinedElement(const Constant *C, + function_ref<bool(const Constant *)> HasFn) { + if (auto *VTy = dyn_cast<VectorType>(C->getType())) { + if (HasFn(C)) return true; - if (isa<ConstantAggregateZero>(this)) + if (isa<ConstantAggregateZero>(C)) return false; - if (isa<ScalableVectorType>(getType())) + if (isa<ScalableVectorType>(C->getType())) return false; for (unsigned i = 0, e = cast<FixedVectorType>(VTy)->getNumElements(); i != e; ++i) - if (isa<UndefValue>(getAggregateElement(i))) + if (HasFn(C->getAggregateElement(i))) return true; } return false; } +bool Constant::containsUndefOrPoisonElement() const { + return containsUndefinedElement( + this, [&](const auto *C) { return isa<UndefValue>(C); }); +} + +bool Constant::containsPoisonElement() const { + return containsUndefinedElement( + this, [&](const auto *C) { return isa<PoisonValue>(C); }); +} + bool Constant::containsConstantExpression() const { if (auto *VTy = dyn_cast<FixedVectorType>(getType())) { for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) if (isa<ConstantExpr>(getAggregateElement(i))) return true; } - return false; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 87d4b40a9a64..08877797c53a 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -3370,7 +3370,7 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I, Type *OpTy = M->getType(); auto *VecC = dyn_cast<Constant>(M); auto *OpVTy = dyn_cast<FixedVectorType>(OpTy); - if (OpVTy && VecC && VecC->containsUndefElement()) { + if (OpVTy && VecC && VecC->containsUndefOrPoisonElement()) { Constant *SafeReplacementConstant = nullptr; for (unsigned i = 0, e = OpVTy->getNumElements(); i != e; ++i) { if (!isa<UndefValue>(VecC->getAggregateElement(i))) { @@ -5259,7 +5259,8 @@ InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, // It may not be safe to change a compare predicate in the presence of // undefined elements, so replace those elements with the first safe constant // that we found. - if (C->containsUndefElement()) { + // TODO: in case of poison, it is safe; let's replace undefs only. + if (C->containsUndefOrPoisonElement()) { assert(SafeReplacementConstant && "Replacement constant not set"); C = Constant::replaceUndefsWith(C, SafeReplacementConstant); } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp index 494c58e706e1..7718c8b0eedd 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp @@ -239,8 +239,8 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) { // While this is normally not behind a use-check, // let's consider division to be special since it's costly. if (auto *Op1C = dyn_cast<Constant>(I->getOperand(1))) { - if (!Op1C->containsUndefElement() && Op1C->isNotMinSignedValue() && - Op1C->isNotOneValue()) { + if (!Op1C->containsUndefOrPoisonElement() && + Op1C->isNotMinSignedValue() && Op1C->isNotOneValue()) { Value *BO = Builder.CreateSDiv(I->getOperand(0), ConstantExpr::getNeg(Op1C), I->getName() + ".neg"); diff --git a/llvm/unittests/Analysis/ValueTrackingTest.cpp b/llvm/unittests/Analysis/ValueTrackingTest.cpp index 0d6577452560..d70fd6eb0ba2 100644 --- a/llvm/unittests/Analysis/ValueTrackingTest.cpp +++ b/llvm/unittests/Analysis/ValueTrackingTest.cpp @@ -888,6 +888,30 @@ TEST_F(ValueTrackingTest, isGuaranteedNotToBeUndefOrPoison) { EXPECT_EQ(isGuaranteedNotToBeUndefOrPoison(PoisonValue::get(IntegerType::get(Context, 8))), false); EXPECT_EQ(isGuaranteedNotToBePoison(UndefValue::get(IntegerType::get(Context, 8))), true); EXPECT_EQ(isGuaranteedNotToBePoison(PoisonValue::get(IntegerType::get(Context, 8))), false); + + Type *Int32Ty = Type::getInt32Ty(Context); + Constant *CU = UndefValue::get(Int32Ty); + Constant *CP = PoisonValue::get(Int32Ty); + Constant *C1 = ConstantInt::get(Int32Ty, 1); + Constant *C2 = ConstantInt::get(Int32Ty, 2); + + { + Constant *V1 = ConstantVector::get({C1, C2}); + EXPECT_TRUE(isGuaranteedNotToBeUndefOrPoison(V1)); + EXPECT_TRUE(isGuaranteedNotToBePoison(V1)); + } + + { + Constant *V2 = ConstantVector::get({C1, CU}); + EXPECT_FALSE(isGuaranteedNotToBeUndefOrPoison(V2)); + EXPECT_TRUE(isGuaranteedNotToBePoison(V2)); + } + + { + Constant *V3 = ConstantVector::get({C1, CP}); + EXPECT_FALSE(isGuaranteedNotToBeUndefOrPoison(V3)); + EXPECT_FALSE(isGuaranteedNotToBePoison(V3)); + } } TEST_F(ValueTrackingTest, isGuaranteedNotToBeUndefOrPoison_assume) { diff --git a/llvm/unittests/IR/ConstantsTest.cpp b/llvm/unittests/IR/ConstantsTest.cpp index 9dd1ba84cc12..afae154cca90 100644 --- a/llvm/unittests/IR/ConstantsTest.cpp +++ b/llvm/unittests/IR/ConstantsTest.cpp @@ -585,6 +585,43 @@ TEST(ConstantsTest, FoldGlobalVariablePtr) { Instruction::And, TheConstantExpr, TheConstant)->isNullValue()); } +// Check that containsUndefOrPoisonElement and containsPoisonElement is working +// great + +TEST(ConstantsTest, containsUndefElemTest) { + LLVMContext Context; + + Type *Int32Ty = Type::getInt32Ty(Context); + Constant *CU = UndefValue::get(Int32Ty); + Constant *CP = PoisonValue::get(Int32Ty); + Constant *C1 = ConstantInt::get(Int32Ty, 1); + Constant *C2 = ConstantInt::get(Int32Ty, 2); + + { + Constant *V1 = ConstantVector::get({C1, C2}); + EXPECT_FALSE(V1->containsUndefOrPoisonElement()); + EXPECT_FALSE(V1->containsPoisonElement()); + } + + { + Constant *V2 = ConstantVector::get({C1, CU}); + EXPECT_TRUE(V2->containsUndefOrPoisonElement()); + EXPECT_FALSE(V2->containsPoisonElement()); + } + + { + Constant *V3 = ConstantVector::get({C1, CP}); + EXPECT_TRUE(V3->containsUndefOrPoisonElement()); + EXPECT_TRUE(V3->containsPoisonElement()); + } + + { + Constant *V4 = ConstantVector::get({CU, CP}); + EXPECT_TRUE(V4->containsUndefOrPoisonElement()); + EXPECT_TRUE(V4->containsPoisonElement()); + } +} + // Check that undefined elements in vector constants are matched // correctly for both integer and floating-point types. Just don't // crash on vectors of pointers (could be handled?). _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits