https://github.com/llvmbot created https://github.com/llvm/llvm-project/pull/152921
Backport d8b1b46cd39c91830bcf49ed91d80f38f78c2168 Requested by: @dtcxzyw >From 7ef6f5bdc487cd277fcfa4ac3b4f812f657bbb66 Mon Sep 17 00:00:00 2001 From: Yingwei Zheng <dtcxzyw2...@gmail.com> Date: Sun, 10 Aug 2025 22:55:04 +0800 Subject: [PATCH] [IR] Handle fabs LHS in `fcmpImpliesClass` (#152913) Closes https://github.com/llvm/llvm-project/issues/152824. (cherry picked from commit d8b1b46cd39c91830bcf49ed91d80f38f78c2168) --- .../IR/GenericFloatingPointPredicateUtils.h | 24 +++++++------ .../InstSimplify/floating-point-arithmetic.ll | 21 +++++++++-- llvm/unittests/Analysis/ValueTrackingTest.cpp | 36 +++++++++++++++++++ 3 files changed, 69 insertions(+), 12 deletions(-) diff --git a/llvm/include/llvm/IR/GenericFloatingPointPredicateUtils.h b/llvm/include/llvm/IR/GenericFloatingPointPredicateUtils.h index 8aac9d5b49dbb..448a6e913eb86 100644 --- a/llvm/include/llvm/IR/GenericFloatingPointPredicateUtils.h +++ b/llvm/include/llvm/IR/GenericFloatingPointPredicateUtils.h @@ -135,6 +135,12 @@ template <typename ContextT> class GenericFloatingPointPredicateUtils { if (Mode.Input != DenormalMode::IEEE) return {Invalid, fcAllFlags, fcAllFlags}; + auto ExactClass = [IsFabs, Src](FPClassTest Mask) { + if (IsFabs) + Mask = llvm::inverse_fabs(Mask); + return exactClass(Src, Mask); + }; + switch (Pred) { case FCmpInst::FCMP_OEQ: // Match x == 0.0 return exactClass(Src, fcZero); @@ -151,26 +157,24 @@ template <typename ContextT> class GenericFloatingPointPredicateUtils { case FCmpInst::FCMP_UNO: return exactClass(Src, fcNan); case FCmpInst::FCMP_OGT: // x > 0 - return exactClass(Src, fcPosSubnormal | fcPosNormal | fcPosInf); + return ExactClass(fcPosSubnormal | fcPosNormal | fcPosInf); case FCmpInst::FCMP_UGT: // isnan(x) || x > 0 - return exactClass(Src, fcPosSubnormal | fcPosNormal | fcPosInf | fcNan); + return ExactClass(fcPosSubnormal | fcPosNormal | fcPosInf | fcNan); case FCmpInst::FCMP_OGE: // x >= 0 - return exactClass(Src, fcPositive | fcNegZero); + return ExactClass(fcPositive | fcNegZero); case FCmpInst::FCMP_UGE: // isnan(x) || x >= 0 - return exactClass(Src, fcPositive | fcNegZero | fcNan); + return ExactClass(fcPositive | fcNegZero | fcNan); case FCmpInst::FCMP_OLT: // x < 0 - return exactClass(Src, fcNegSubnormal | fcNegNormal | fcNegInf); + return ExactClass(fcNegSubnormal | fcNegNormal | fcNegInf); case FCmpInst::FCMP_ULT: // isnan(x) || x < 0 - return exactClass(Src, fcNegSubnormal | fcNegNormal | fcNegInf | fcNan); + return ExactClass(fcNegSubnormal | fcNegNormal | fcNegInf | fcNan); case FCmpInst::FCMP_OLE: // x <= 0 - return exactClass(Src, fcNegative | fcPosZero); + return ExactClass(fcNegative | fcPosZero); case FCmpInst::FCMP_ULE: // isnan(x) || x <= 0 - return exactClass(Src, fcNegative | fcPosZero | fcNan); + return ExactClass(fcNegative | fcPosZero | fcNan); default: llvm_unreachable("all compare types are handled"); } - - return {Invalid, fcAllFlags, fcAllFlags}; } const bool IsDenormalRHS = (OrigClass & fcSubnormal) == OrigClass; diff --git a/llvm/test/Transforms/InstSimplify/floating-point-arithmetic.ll b/llvm/test/Transforms/InstSimplify/floating-point-arithmetic.ll index ab4448b460bfc..820fff433e9e0 100644 --- a/llvm/test/Transforms/InstSimplify/floating-point-arithmetic.ll +++ b/llvm/test/Transforms/InstSimplify/floating-point-arithmetic.ll @@ -213,7 +213,7 @@ define double @fmul_nnan_ninf_nneg_n0.0_commute(i127 %x) { define float @fmul_ninf_nnan_mul_zero_nsz(float nofpclass(inf nan) %f) { ; CHECK-LABEL: @fmul_ninf_nnan_mul_zero_nsz( -; CHECK-NEXT: ret float 0.000000e+00 +; CHECK-NEXT: ret float 0.000000e+00 ; %r = fmul nsz float %f, 0.0 ret float %r @@ -221,7 +221,7 @@ define float @fmul_ninf_nnan_mul_zero_nsz(float nofpclass(inf nan) %f) { define float @fmul_ninf_nnan_mul_nzero_nsz(float nofpclass(inf nan) %f) { ; CHECK-LABEL: @fmul_ninf_nnan_mul_nzero_nsz( -; CHECK-NEXT: ret float 0.000000e+00 +; CHECK-NEXT: ret float 0.000000e+00 ; %r = fmul nsz float %f, -0.0 ret float %r @@ -1255,3 +1255,20 @@ define i1 @fptrunc_round_unknown_positive(double %unknown) { %cmp = fcmp nnan oge float %op, 0.0 ret i1 %cmp } + +define half @fabs_select_fabs(half noundef %x) { +; CHECK-LABEL: @fabs_select_fabs( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[ABS1:%.*]] = call half @llvm.fabs.f16(half [[X:%.*]]) +; CHECK-NEXT: [[CMP:%.*]] = fcmp ogt half [[ABS1]], 0xH0000 +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], half [[X]], half 0xH0000 +; CHECK-NEXT: [[ABS2:%.*]] = call half @llvm.fabs.f16(half [[SEL]]) +; CHECK-NEXT: ret half [[ABS2]] +; +entry: + %abs1 = call half @llvm.fabs.f16(half %x) + %cmp = fcmp ogt half %abs1, 0xH0000 + %sel = select i1 %cmp, half %x, half 0xH0000 + %abs2 = call half @llvm.fabs.f16(half %sel) + ret half %abs2 +} diff --git a/llvm/unittests/Analysis/ValueTrackingTest.cpp b/llvm/unittests/Analysis/ValueTrackingTest.cpp index 7a48105a1dc99..bf396499e35ca 100644 --- a/llvm/unittests/Analysis/ValueTrackingTest.cpp +++ b/llvm/unittests/Analysis/ValueTrackingTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/ValueTracking.h" +#include "llvm/ADT/FloatingPointMode.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/FloatingPointPredicateUtils.h" #include "llvm/AsmParser/Parser.h" @@ -2208,6 +2209,41 @@ TEST_F(ComputeKnownFPClassTest, Constants) { } } +TEST_F(ComputeKnownFPClassTest, fcmpImpliesClass_fabs_zero) { + parseAssembly("define float @test(float %x) {\n" + " %A = call float @llvm.fabs.f32(float %x)\n" + " ret float %A\n" + "}\n"); + EXPECT_EQ(std::get<1>(fcmpImpliesClass(FCmpInst::FCMP_OEQ, *F, A, fcZero)), + fcZero); + EXPECT_EQ(std::get<1>(fcmpImpliesClass(FCmpInst::FCMP_UEQ, *F, A, fcZero)), + fcZero | fcNan); + EXPECT_EQ(std::get<1>(fcmpImpliesClass(FCmpInst::FCMP_UNE, *F, A, fcZero)), + ~fcZero); + EXPECT_EQ(std::get<1>(fcmpImpliesClass(FCmpInst::FCMP_ONE, *F, A, fcZero)), + ~fcNan & ~fcZero); + EXPECT_EQ(std::get<1>(fcmpImpliesClass(FCmpInst::FCMP_ORD, *F, A, fcZero)), + ~fcNan); + EXPECT_EQ(std::get<1>(fcmpImpliesClass(FCmpInst::FCMP_UNO, *F, A, fcZero)), + fcNan); + EXPECT_EQ(std::get<1>(fcmpImpliesClass(FCmpInst::FCMP_OGT, *F, A, fcZero)), + fcSubnormal | fcNormal | fcInf); + EXPECT_EQ(std::get<1>(fcmpImpliesClass(FCmpInst::FCMP_UGT, *F, A, fcZero)), + fcSubnormal | fcNormal | fcInf | fcNan); + EXPECT_EQ(std::get<1>(fcmpImpliesClass(FCmpInst::FCMP_OGE, *F, A, fcZero)), + ~fcNan); + EXPECT_EQ(std::get<1>(fcmpImpliesClass(FCmpInst::FCMP_UGE, *F, A, fcZero)), + fcAllFlags); + EXPECT_EQ(std::get<1>(fcmpImpliesClass(FCmpInst::FCMP_OLT, *F, A, fcZero)), + fcNone); + EXPECT_EQ(std::get<1>(fcmpImpliesClass(FCmpInst::FCMP_ULT, *F, A, fcZero)), + fcNan); + EXPECT_EQ(std::get<1>(fcmpImpliesClass(FCmpInst::FCMP_OLE, *F, A, fcZero)), + fcZero); + EXPECT_EQ(std::get<1>(fcmpImpliesClass(FCmpInst::FCMP_ULE, *F, A, fcZero)), + fcZero | fcNan); +} + TEST_F(ValueTrackingTest, isNonZeroRecurrence) { parseAssembly(R"( define i1 @test(i8 %n, i8 %r) { _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits