Author: Joshua Cao Date: 2023-04-10T19:40:52-07:00 New Revision: 585742cbfccd734b19c75dff9709b20367506668
URL: https://github.com/llvm/llvm-project/commit/585742cbfccd734b19c75dff9709b20367506668 DIFF: https://github.com/llvm/llvm-project/commit/585742cbfccd734b19c75dff9709b20367506668.diff LOG: [SCEV] When computing trip count, only zext if necessary This patch improves on https://reviews.llvm.org/D110587. To summarize the patch, given backedge-taken count BC, trip count TC is `BC + 1`. However, we don't know if BC we might overflow. So the patch modifies TC computation to `1 + zext(BC)`. This patch only adds the zext if necessary by looking at the constant range. If we can determine that BC cannot be the max value for its bitwidth, then we know adding 1 will not overflow, and the zext is not needed. We apply loop guards before computing TC to get more data. The primary motivation is to support my work on more precise trip multiples in https://reviews.llvm.org/D141823. For example: ``` void test(unsigned n) __builtin_assume(n % 6 == 0); for (unsigned i = 0; i < n; ++i) foo(); ``` Prior to this patch, we had `TC = 1 + zext(-1 + 6 * ((6 umax %n) /u 6))<nuw>`. SCEV range computation is able to determine that the BC cannot be the max value, so the zext is not needed. The result is `TC -> (6 * ((6 umax %n) /u 6))<nuw>`. From here, we would be able to determine that %n is a multiple of 6. There was one change in LoopCacheAnalysis/LoopInterchange required. Before this patch, if a loop has BC = false, it would compute `TC -> 1 + zext(false) -> 1`, which was fine. After this patch, it computes `TC -> 1 + false = true`. CacheAnalysis would then sign extend the `true`, which was not the intended the behavior. I modified CacheAnalysis such that it would only zero extend trip counts. This patch is not NFC, but also does not change any SCEV outputs. I would like to get this patch out first to make work with trip multiples easier. Differential Revision: https://reviews.llvm.org/D147117 Added: Modified: llvm/lib/Analysis/LoopCacheAnalysis.cpp llvm/lib/Analysis/ScalarEvolution.cpp Removed: ################################################################################ diff --git a/llvm/lib/Analysis/LoopCacheAnalysis.cpp b/llvm/lib/Analysis/LoopCacheAnalysis.cpp index 46198f78b6433..c3a56639b5c8f 100644 --- a/llvm/lib/Analysis/LoopCacheAnalysis.cpp +++ b/llvm/lib/Analysis/LoopCacheAnalysis.cpp @@ -297,7 +297,7 @@ CacheCostTy IndexedReference::computeRefCost(const Loop &L, Type *WiderType = SE.getWiderType(Stride->getType(), TripCount->getType()); const SCEV *CacheLineSize = SE.getConstant(WiderType, CLS); Stride = SE.getNoopOrAnyExtend(Stride, WiderType); - TripCount = SE.getNoopOrAnyExtend(TripCount, WiderType); + TripCount = SE.getNoopOrZeroExtend(TripCount, WiderType); const SCEV *Numerator = SE.getMulExpr(Stride, TripCount); RefCost = SE.getUDivExpr(Numerator, CacheLineSize); @@ -323,8 +323,8 @@ CacheCostTy IndexedReference::computeRefCost(const Loop &L, const SCEV *TripCount = computeTripCount(*AR->getLoop(), *Sizes.back(), SE); Type *WiderType = SE.getWiderType(RefCost->getType(), TripCount->getType()); - RefCost = SE.getMulExpr(SE.getNoopOrAnyExtend(RefCost, WiderType), - SE.getNoopOrAnyExtend(TripCount, WiderType)); + RefCost = SE.getMulExpr(SE.getNoopOrZeroExtend(RefCost, WiderType), + SE.getNoopOrZeroExtend(TripCount, WiderType)); } LLVM_DEBUG(dbgs().indent(4) @@ -334,7 +334,7 @@ CacheCostTy IndexedReference::computeRefCost(const Loop &L, // Attempt to fold RefCost into a constant. if (auto ConstantCost = dyn_cast<SCEVConstant>(RefCost)) - return ConstantCost->getValue()->getSExtValue(); + return ConstantCost->getValue()->getZExtValue(); LLVM_DEBUG(dbgs().indent(4) << "RefCost is not a constant! Setting to RefCost=InvalidCost " diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 174eea8d364ab..52bd161cf9ddf 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -8045,6 +8045,12 @@ const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount, if (!Extend) return getAddExpr(ExitCount, getOne(ExitCountType)); + ConstantRange ExitCountRange = + getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED); + if (!ExitCountRange.contains( + APInt::getMaxValue(ExitCountRange.getBitWidth()))) + return getAddExpr(ExitCount, getOne(ExitCountType)); + auto *WiderType = Type::getIntNTy(ExitCountType->getContext(), 1 + ExitCountType->getScalarSizeInBits()); return getAddExpr(getNoopOrZeroExtend(ExitCount, WiderType), @@ -8227,15 +8233,14 @@ unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L, return 1; // Get the trip count - const SCEV *TCExpr = getTripCountFromExitCount(ExitCount); + const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L)); const SCEVConstant *TC = dyn_cast<SCEVConstant>(TCExpr); if (!TC) // Attempt to factor more general cases. Returns the greatest power of // two divisor. If overflow happens, the trip count expression is still // divisible by the greatest power of 2 divisor returned. - return 1U << std::min((uint32_t)31, - GetMinTrailingZeros(applyLoopGuards(TCExpr, L))); + return 1U << std::min((uint32_t)31, GetMinTrailingZeros(TCExpr)); ConstantInt *Result = TC->getValue(); _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits