https://github.com/llvmbot created 
https://github.com/llvm/llvm-project/pull/149778

Backport 8a307ae61963a3f967052f7ea3c89aafa56934cf

Requested by: @heiher

>From e781da54605c515b1d061a1368aba3e73c8a6bd5 Mon Sep 17 00:00:00 2001
From: hev <wang...@loongson.cn>
Date: Mon, 21 Jul 2025 16:36:49 +0800
Subject: [PATCH] [LoongArch] Fix failure to widen operand for
 `[X]VMSK{LT,GE,NE}Z` (#149442)

Reported-by: tangyan <tangya...@loongson.cn>
(cherry picked from commit 8a307ae61963a3f967052f7ea3c89aafa56934cf)
---
 .../LoongArch/LoongArchISelLowering.cpp       | 221 ++++++++++--------
 llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll   |  15 ++
 2 files changed, 139 insertions(+), 97 deletions(-)

diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp 
b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
index c47987fbf683b..12cf04bbbab56 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
@@ -4563,6 +4563,80 @@ static SDValue signExtendBitcastSrcVector(SelectionDAG 
&DAG, EVT SExtVT,
   llvm_unreachable("Unexpected node type for vXi1 sign extension");
 }
 
+static SDValue
+performSETCC_BITCASTCombine(SDNode *N, SelectionDAG &DAG,
+                            TargetLowering::DAGCombinerInfo &DCI,
+                            const LoongArchSubtarget &Subtarget) {
+  SDLoc DL(N);
+  EVT VT = N->getValueType(0);
+  SDValue Src = N->getOperand(0);
+  EVT SrcVT = Src.getValueType();
+
+  if (Src.getOpcode() != ISD::SETCC || !Src.hasOneUse())
+    return SDValue();
+
+  bool UseLASX;
+  unsigned Opc = ISD::DELETED_NODE;
+  EVT CmpVT = Src.getOperand(0).getValueType();
+  EVT EltVT = CmpVT.getVectorElementType();
+
+  if (Subtarget.hasExtLSX() && CmpVT.getSizeInBits() == 128)
+    UseLASX = false;
+  else if (Subtarget.has32S() && Subtarget.hasExtLASX() &&
+           CmpVT.getSizeInBits() == 256)
+    UseLASX = true;
+  else
+    return SDValue();
+
+  SDValue SrcN1 = Src.getOperand(1);
+  switch (cast<CondCodeSDNode>(Src.getOperand(2))->get()) {
+  default:
+    break;
+  case ISD::SETEQ:
+    // x == 0 => not (vmsknez.b x)
+    if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
+      Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
+    break;
+  case ISD::SETGT:
+    // x > -1 => vmskgez.b x
+    if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && EltVT == MVT::i8)
+      Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
+    break;
+  case ISD::SETGE:
+    // x >= 0 => vmskgez.b x
+    if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
+      Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
+    break;
+  case ISD::SETLT:
+    // x < 0 => vmskltz.{b,h,w,d} x
+    if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) &&
+        (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
+         EltVT == MVT::i64))
+      Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
+    break;
+  case ISD::SETLE:
+    // x <= -1 => vmskltz.{b,h,w,d} x
+    if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) &&
+        (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
+         EltVT == MVT::i64))
+      Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
+    break;
+  case ISD::SETNE:
+    // x != 0 => vmsknez.b x
+    if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
+      Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
+    break;
+  }
+
+  if (Opc == ISD::DELETED_NODE)
+    return SDValue();
+
+  SDValue V = DAG.getNode(Opc, DL, MVT::i64, Src.getOperand(0));
+  EVT T = EVT::getIntegerVT(*DAG.getContext(), SrcVT.getVectorNumElements());
+  V = DAG.getZExtOrTrunc(V, DL, T);
+  return DAG.getBitcast(VT, V);
+}
+
 static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG,
                                      TargetLowering::DAGCombinerInfo &DCI,
                                      const LoongArchSubtarget &Subtarget) {
@@ -4577,110 +4651,63 @@ static SDValue performBITCASTCombine(SDNode *N, 
SelectionDAG &DAG,
   if (!SrcVT.isSimple() || SrcVT.getScalarType() != MVT::i1)
     return SDValue();
 
-  unsigned Opc = ISD::DELETED_NODE;
   // Combine SETCC and BITCAST into [X]VMSK{LT,GE,NE} when possible
+  SDValue Res = performSETCC_BITCASTCombine(N, DAG, DCI, Subtarget);
+  if (Res)
+    return Res;
+
+  // Generate vXi1 using [X]VMSKLTZ
+  MVT SExtVT;
+  unsigned Opc;
+  bool UseLASX = false;
+  bool PropagateSExt = false;
+
   if (Src.getOpcode() == ISD::SETCC && Src.hasOneUse()) {
-    bool UseLASX;
     EVT CmpVT = Src.getOperand(0).getValueType();
-    EVT EltVT = CmpVT.getVectorElementType();
-
-    if (Subtarget.hasExtLSX() && CmpVT.getSizeInBits() <= 128)
-      UseLASX = false;
-    else if (Subtarget.has32S() && Subtarget.hasExtLASX() &&
-             CmpVT.getSizeInBits() <= 256)
-      UseLASX = true;
-    else
+    if (CmpVT.getSizeInBits() > 256)
       return SDValue();
-
-    SDValue SrcN1 = Src.getOperand(1);
-    switch (cast<CondCodeSDNode>(Src.getOperand(2))->get()) {
-    default:
-      break;
-    case ISD::SETEQ:
-      // x == 0 => not (vmsknez.b x)
-      if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
-        Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
-      break;
-    case ISD::SETGT:
-      // x > -1 => vmskgez.b x
-      if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && EltVT == MVT::i8)
-        Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
-      break;
-    case ISD::SETGE:
-      // x >= 0 => vmskgez.b x
-      if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
-        Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
-      break;
-    case ISD::SETLT:
-      // x < 0 => vmskltz.{b,h,w,d} x
-      if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) &&
-          (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
-           EltVT == MVT::i64))
-        Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
-      break;
-    case ISD::SETLE:
-      // x <= -1 => vmskltz.{b,h,w,d} x
-      if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) &&
-          (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
-           EltVT == MVT::i64))
-        Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
-      break;
-    case ISD::SETNE:
-      // x != 0 => vmsknez.b x
-      if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
-        Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
-      break;
-    }
   }
 
-  // Generate vXi1 using [X]VMSKLTZ
-  if (Opc == ISD::DELETED_NODE) {
-    MVT SExtVT;
-    bool UseLASX = false;
-    bool PropagateSExt = false;
-    switch (SrcVT.getSimpleVT().SimpleTy) {
-    default:
-      return SDValue();
-    case MVT::v2i1:
-      SExtVT = MVT::v2i64;
-      break;
-    case MVT::v4i1:
-      SExtVT = MVT::v4i32;
-      if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
-        SExtVT = MVT::v4i64;
-        UseLASX = true;
-        PropagateSExt = true;
-      }
-      break;
-    case MVT::v8i1:
-      SExtVT = MVT::v8i16;
-      if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
-        SExtVT = MVT::v8i32;
-        UseLASX = true;
-        PropagateSExt = true;
-      }
-      break;
-    case MVT::v16i1:
-      SExtVT = MVT::v16i8;
-      if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
-        SExtVT = MVT::v16i16;
-        UseLASX = true;
-        PropagateSExt = true;
-      }
-      break;
-    case MVT::v32i1:
-      SExtVT = MVT::v32i8;
+  switch (SrcVT.getSimpleVT().SimpleTy) {
+  default:
+    return SDValue();
+  case MVT::v2i1:
+    SExtVT = MVT::v2i64;
+    break;
+  case MVT::v4i1:
+    SExtVT = MVT::v4i32;
+    if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
+      SExtVT = MVT::v4i64;
       UseLASX = true;
-      break;
-    };
-    if (UseLASX && !Subtarget.has32S() && !Subtarget.hasExtLASX())
-      return SDValue();
-    Src = PropagateSExt ? signExtendBitcastSrcVector(DAG, SExtVT, Src, DL)
-                        : DAG.getNode(ISD::SIGN_EXTEND, DL, SExtVT, Src);
-    Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
-  } else {
-    Src = Src.getOperand(0);
-  }
+      PropagateSExt = true;
+    }
+    break;
+  case MVT::v8i1:
+    SExtVT = MVT::v8i16;
+    if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
+      SExtVT = MVT::v8i32;
+      UseLASX = true;
+      PropagateSExt = true;
+    }
+    break;
+  case MVT::v16i1:
+    SExtVT = MVT::v16i8;
+    if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
+      SExtVT = MVT::v16i16;
+      UseLASX = true;
+      PropagateSExt = true;
+    }
+    break;
+  case MVT::v32i1:
+    SExtVT = MVT::v32i8;
+    UseLASX = true;
+    break;
+  };
+  if (UseLASX && !(Subtarget.has32S() && Subtarget.hasExtLASX()))
+    return SDValue();
+  Src = PropagateSExt ? signExtendBitcastSrcVector(DAG, SExtVT, Src, DL)
+                      : DAG.getNode(ISD::SIGN_EXTEND, DL, SExtVT, Src);
+  Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
 
   SDValue V = DAG.getNode(Opc, DL, MVT::i64, Src);
   EVT T = EVT::getIntegerVT(*DAG.getContext(), SrcVT.getVectorNumElements());
diff --git a/llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll 
b/llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll
index 0ee30120f77a6..ad57bbf9ee5c0 100644
--- a/llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll
+++ b/llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll
@@ -588,3 +588,18 @@ define i2 @vmsk_trunc_i64(<2 x i64> %a) {
   %res = bitcast <2 x i1> %y to i2
   ret i2 %res
 }
+
+define i4 @vmsk_eq_allzeros_v4i8(<4 x i8> %a) {
+; CHECK-LABEL: vmsk_eq_allzeros_v4i8:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vseqi.b $vr0, $vr0, 0
+; CHECK-NEXT:    vilvl.b $vr0, $vr0, $vr0
+; CHECK-NEXT:    vilvl.h $vr0, $vr0, $vr0
+; CHECK-NEXT:    vslli.w $vr0, $vr0, 24
+; CHECK-NEXT:    vmskltz.w $vr0, $vr0
+; CHECK-NEXT:    vpickve2gr.hu $a0, $vr0, 0
+; CHECK-NEXT:    ret
+  %1 = icmp eq <4 x i8> %a, zeroinitializer
+  %2 = bitcast <4 x i1> %1 to i4
+  ret i4 %2
+}

_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to