Author: Alexey Bataev
Date: 2024-12-02T09:32:35+01:00
New Revision: 9f72c9837c553063ab0cbacc1a472a73c0ec2a4b

URL: 
https://github.com/llvm/llvm-project/commit/9f72c9837c553063ab0cbacc1a472a73c0ec2a4b
DIFF: 
https://github.com/llvm/llvm-project/commit/9f72c9837c553063ab0cbacc1a472a73c0ec2a4b.diff

LOG: [SLP]Check that operand of abs does not overflow before making it part of 
minbitwidth transformation

Need to check that the operand of the abs intrinsic can be safely
truncated before making it part of the minbitwidth transformation.

Fixes #112577

(cherry picked from commit 709abacdc350d63c61888607edb28ce272daa0a0)

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
    llvm/test/Transforms/SLPVectorizer/abs-overflow-incorrect-minbws.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp 
b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index ab2b96cdc42db8..746ba51a981fe0 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -15440,9 +15440,25 @@ bool BoUpSLP::collectValuesToDemote(
                 MaskedValueIsZero(I->getOperand(1), Mask, SimplifyQuery(*DL)));
       });
     };
+    auto AbsChecker = [&](unsigned BitWidth, unsigned OrigBitWidth) {
+      assert(BitWidth <= OrigBitWidth && "Unexpected bitwidths!");
+      return all_of(E.Scalars, [&](Value *V) {
+        auto *I = cast<Instruction>(V);
+        unsigned SignBits = OrigBitWidth - BitWidth;
+        APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth - 1);
+        unsigned Op0SignBits =
+            ComputeNumSignBits(I->getOperand(0), *DL, 0, AC, nullptr, DT);
+        return SignBits <= Op0SignBits &&
+               ((SignBits != Op0SignBits &&
+                 !isKnownNonNegative(I->getOperand(0), SimplifyQuery(*DL))) ||
+                MaskedValueIsZero(I->getOperand(0), Mask, SimplifyQuery(*DL)));
+      });
+    };
     if (ID != Intrinsic::abs) {
       Operands.push_back(getOperandEntry(&E, 1));
       CallChecker = CompChecker;
+    } else {
+      CallChecker = AbsChecker;
     }
     InstructionCost BestCost =
         std::numeric_limits<InstructionCost::CostType>::max();

diff  --git 
a/llvm/test/Transforms/SLPVectorizer/abs-overflow-incorrect-minbws.ll 
b/llvm/test/Transforms/SLPVectorizer/abs-overflow-incorrect-minbws.ll
index a936b076138d07..51b635837d3b59 100644
--- a/llvm/test/Transforms/SLPVectorizer/abs-overflow-incorrect-minbws.ll
+++ b/llvm/test/Transforms/SLPVectorizer/abs-overflow-incorrect-minbws.ll
@@ -8,8 +8,10 @@ define i32 @test(i32 %n) {
 ; CHECK-NEXT:    [[TMP0:%.*]] = insertelement <2 x i32> poison, i32 [[N]], i32 0
 ; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <2 x i32> [[TMP0]], <2 x i32> 
poison, <2 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP2:%.*]] = add <2 x i32> [[TMP1]], <i32 1, i32 2>
-; CHECK-NEXT:    [[TMP3:%.*]] = mul <2 x i32> [[TMP2]], <i32 273837369, i32 
273837369>
-; CHECK-NEXT:    [[TMP4:%.*]] = call <2 x i32> @llvm.abs.v2i32(<2 x i32> 
[[TMP3]], i1 false)
+; CHECK-NEXT:    [[TMP3:%.*]] = zext <2 x i32> [[TMP2]] to <2 x i64>
+; CHECK-NEXT:    [[TMP7:%.*]] = mul nuw nsw <2 x i64> [[TMP3]], <i64 
273837369, i64 273837369>
+; CHECK-NEXT:    [[TMP8:%.*]] = call <2 x i64> @llvm.abs.v2i64(<2 x i64> 
[[TMP7]], i1 true)
+; CHECK-NEXT:    [[TMP4:%.*]] = trunc <2 x i64> [[TMP8]] to <2 x i32>
 ; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <2 x i32> [[TMP4]], i32 0
 ; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <2 x i32> [[TMP4]], i32 1
 ; CHECK-NEXT:    [[RES1:%.*]] = add i32 [[TMP5]], [[TMP6]]


        
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to