https://github.com/llvmbot created https://github.com/llvm/llvm-project/pull/116797
Backport 2f4572f5e7e2d7f4626e825404c11f07d191fb05 c727b48287cc96888f9e262f23d53cf635cf3b3d Requested by: @dtcxzyw >From 0dcfce5c6d5501d8358045a5eee109a35bfcdab1 Mon Sep 17 00:00:00 2001 From: Craig Topper <craig.top...@sifive.com> Date: Sat, 16 Nov 2024 20:55:33 -0800 Subject: [PATCH 1/2] [Mips] Change vsplat_imm_eq_1 to a ComplexPattern. (#116471) Resolves a FIXME and avoids needing to workaround #116075. Adding parentheses around the (vsplat_imm_eq_1) fixes the error cited in the FIXME by changing the ComplexPattern from a leaf node to an operator. (cherry picked from commit 2f4572f5e7e2d7f4626e825404c11f07d191fb05) --- llvm/lib/Target/Mips/MipsISelDAGToDAG.cpp | 4 ++ llvm/lib/Target/Mips/MipsISelDAGToDAG.h | 3 ++ llvm/lib/Target/Mips/MipsMSAInstrInfo.td | 59 ++++++++------------- llvm/lib/Target/Mips/MipsSEISelDAGToDAG.cpp | 12 +++++ llvm/lib/Target/Mips/MipsSEISelDAGToDAG.h | 3 ++ 5 files changed, 43 insertions(+), 38 deletions(-) diff --git a/llvm/lib/Target/Mips/MipsISelDAGToDAG.cpp b/llvm/lib/Target/Mips/MipsISelDAGToDAG.cpp index f6f32fde3b7778..a9ffd2bedf21e0 100644 --- a/llvm/lib/Target/Mips/MipsISelDAGToDAG.cpp +++ b/llvm/lib/Target/Mips/MipsISelDAGToDAG.cpp @@ -220,6 +220,10 @@ bool MipsDAGToDAGISel::selectVSplatMaskR(SDValue N, SDValue &Imm) const { return false; } +bool MipsDAGToDAGISel::selectVSplatImmEq1(SDValue N) const { + llvm_unreachable("Unimplemented function."); +} + /// Convert vector addition with vector subtraction if that allows to encode /// constant as an immediate and thus avoid extra 'ldi' instruction. /// add X, <-1, -1...> --> sub X, <1, 1...> diff --git a/llvm/lib/Target/Mips/MipsISelDAGToDAG.h b/llvm/lib/Target/Mips/MipsISelDAGToDAG.h index 6135f968078542..3485300a782c94 100644 --- a/llvm/lib/Target/Mips/MipsISelDAGToDAG.h +++ b/llvm/lib/Target/Mips/MipsISelDAGToDAG.h @@ -120,6 +120,9 @@ class MipsDAGToDAGISel : public SelectionDAGISel { /// starting at bit zero. virtual bool selectVSplatMaskR(SDValue N, SDValue &Imm) const; + /// Select constant vector splats whose value is 1. + virtual bool selectVSplatImmEq1(SDValue N) const; + /// Convert vector addition with vector subtraction if that allows to encode /// constant as an immediate and thus avoid extra 'ldi' instruction. /// add X, <-1, -1...> --> sub X, <1, 1...> diff --git a/llvm/lib/Target/Mips/MipsMSAInstrInfo.td b/llvm/lib/Target/Mips/MipsMSAInstrInfo.td index c4abccb24c6f35..f4c32c9dcd4212 100644 --- a/llvm/lib/Target/Mips/MipsMSAInstrInfo.td +++ b/llvm/lib/Target/Mips/MipsMSAInstrInfo.td @@ -198,14 +198,8 @@ def vsplati32 : PatFrag<(ops node:$e0), (v4i32 (build_vector node:$e0, node:$e0, node:$e0, node:$e0))>; -def vsplati64_imm_eq_1 : PatLeaf<(bitconvert (v4i32 (build_vector))), [{ - APInt Imm; - SDNode *BV = N->getOperand(0).getNode(); - EVT EltTy = N->getValueType(0).getVectorElementType(); - - return selectVSplat(BV, Imm, EltTy.getSizeInBits()) && - Imm.getBitWidth() == EltTy.getSizeInBits() && Imm == 1; -}]>; +// Any build_vector that is a constant splat with a value that equals 1 +def vsplat_imm_eq_1 : ComplexPattern<vAny, 0, "selectVSplatImmEq1">; def vsplati64 : PatFrag<(ops node:$e0), (v2i64 (build_vector node:$e0, node:$e0))>; @@ -217,7 +211,7 @@ def vsplati64_splat_d : PatFrag<(ops node:$e0), node:$e0, node:$e0, node:$e0)), - vsplati64_imm_eq_1))))>; + (vsplat_imm_eq_1)))))>; def vsplatf32 : PatFrag<(ops node:$e0), (v4f32 (build_vector node:$e0, node:$e0, @@ -352,46 +346,35 @@ def vsplat_maskr_bits_uimm6 : SplatComplexPattern<vsplat_uimm6, vAny, 1, "selectVSplatMaskR", [build_vector, bitconvert]>; -// Any build_vector that is a constant splat with a value that equals 1 -// FIXME: These should be a ComplexPattern but we can't use them because the -// ISel generator requires the uses to have a name, but providing a name -// causes other errors ("used in pattern but not operand list") -def vsplat_imm_eq_1 : PatLeaf<(build_vector), [{ - APInt Imm; - EVT EltTy = N->getValueType(0).getVectorElementType(); - - return selectVSplat(N, Imm, EltTy.getSizeInBits()) && - Imm.getBitWidth() == EltTy.getSizeInBits() && Imm == 1; -}]>; def vbclr_b : PatFrag<(ops node:$ws, node:$wt), - (and node:$ws, (vnot (shl vsplat_imm_eq_1, node:$wt)))>; + (and node:$ws, (vnot (shl (vsplat_imm_eq_1), node:$wt)))>; def vbclr_h : PatFrag<(ops node:$ws, node:$wt), - (and node:$ws, (vnot (shl vsplat_imm_eq_1, node:$wt)))>; + (and node:$ws, (vnot (shl (vsplat_imm_eq_1), node:$wt)))>; def vbclr_w : PatFrag<(ops node:$ws, node:$wt), - (and node:$ws, (vnot (shl vsplat_imm_eq_1, node:$wt)))>; + (and node:$ws, (vnot (shl (vsplat_imm_eq_1), node:$wt)))>; def vbclr_d : PatFrag<(ops node:$ws, node:$wt), - (and node:$ws, (vnot (shl (v2i64 vsplati64_imm_eq_1), + (and node:$ws, (vnot (shl (v2i64 (vsplat_imm_eq_1)), node:$wt)))>; def vbneg_b : PatFrag<(ops node:$ws, node:$wt), - (xor node:$ws, (shl vsplat_imm_eq_1, node:$wt))>; + (xor node:$ws, (shl (vsplat_imm_eq_1), node:$wt))>; def vbneg_h : PatFrag<(ops node:$ws, node:$wt), - (xor node:$ws, (shl vsplat_imm_eq_1, node:$wt))>; + (xor node:$ws, (shl (vsplat_imm_eq_1), node:$wt))>; def vbneg_w : PatFrag<(ops node:$ws, node:$wt), - (xor node:$ws, (shl vsplat_imm_eq_1, node:$wt))>; + (xor node:$ws, (shl (vsplat_imm_eq_1), node:$wt))>; def vbneg_d : PatFrag<(ops node:$ws, node:$wt), - (xor node:$ws, (shl (v2i64 vsplati64_imm_eq_1), + (xor node:$ws, (shl (v2i64 (vsplat_imm_eq_1)), node:$wt))>; def vbset_b : PatFrag<(ops node:$ws, node:$wt), - (or node:$ws, (shl vsplat_imm_eq_1, node:$wt))>; + (or node:$ws, (shl (vsplat_imm_eq_1), node:$wt))>; def vbset_h : PatFrag<(ops node:$ws, node:$wt), - (or node:$ws, (shl vsplat_imm_eq_1, node:$wt))>; + (or node:$ws, (shl (vsplat_imm_eq_1), node:$wt))>; def vbset_w : PatFrag<(ops node:$ws, node:$wt), - (or node:$ws, (shl vsplat_imm_eq_1, node:$wt))>; + (or node:$ws, (shl (vsplat_imm_eq_1), node:$wt))>; def vbset_d : PatFrag<(ops node:$ws, node:$wt), - (or node:$ws, (shl (v2i64 vsplati64_imm_eq_1), + (or node:$ws, (shl (v2i64 (vsplat_imm_eq_1)), node:$wt))>; def muladd : PatFrag<(ops node:$wd, node:$ws, node:$wt), @@ -3842,7 +3825,7 @@ class MSAShiftPat<SDNode Node, ValueType VT, MSAInst Insn, dag Vec> : (VT (Insn VT:$ws, VT:$wt))>; class MSABitPat<SDNode Node, ValueType VT, MSAInst Insn, PatFrag Frag> : - MSAPat<(VT (Node VT:$ws, (shl vsplat_imm_eq_1, (Frag VT:$wt)))), + MSAPat<(VT (Node VT:$ws, (shl (vsplat_imm_eq_1), (Frag VT:$wt)))), (VT (Insn VT:$ws, VT:$wt))>; multiclass MSAShiftPats<SDNode Node, string Insn> { @@ -3861,7 +3844,7 @@ multiclass MSABitPats<SDNode Node, string Insn> { def : MSABitPat<Node, v16i8, !cast<MSAInst>(Insn#_B), vsplati8imm7>; def : MSABitPat<Node, v8i16, !cast<MSAInst>(Insn#_H), vsplati16imm15>; def : MSABitPat<Node, v4i32, !cast<MSAInst>(Insn#_W), vsplati32imm31>; - def : MSAPat<(Node v2i64:$ws, (shl (v2i64 vsplati64_imm_eq_1), + def : MSAPat<(Node v2i64:$ws, (shl (v2i64 (vsplat_imm_eq_1)), (vsplati64imm63 v2i64:$wt))), (v2i64 (!cast<MSAInst>(Insn#_D) v2i64:$ws, v2i64:$wt))>; } @@ -3872,16 +3855,16 @@ defm : MSAShiftPats<sra, "SRA">; defm : MSABitPats<xor, "BNEG">; defm : MSABitPats<or, "BSET">; -def : MSAPat<(and v16i8:$ws, (vnot (shl vsplat_imm_eq_1, +def : MSAPat<(and v16i8:$ws, (vnot (shl (vsplat_imm_eq_1), (vsplati8imm7 v16i8:$wt)))), (v16i8 (BCLR_B v16i8:$ws, v16i8:$wt))>; -def : MSAPat<(and v8i16:$ws, (vnot (shl vsplat_imm_eq_1, +def : MSAPat<(and v8i16:$ws, (vnot (shl (vsplat_imm_eq_1), (vsplati16imm15 v8i16:$wt)))), (v8i16 (BCLR_H v8i16:$ws, v8i16:$wt))>; -def : MSAPat<(and v4i32:$ws, (vnot (shl vsplat_imm_eq_1, +def : MSAPat<(and v4i32:$ws, (vnot (shl (vsplat_imm_eq_1), (vsplati32imm31 v4i32:$wt)))), (v4i32 (BCLR_W v4i32:$ws, v4i32:$wt))>; -def : MSAPat<(and v2i64:$ws, (vnot (shl (v2i64 vsplati64_imm_eq_1), +def : MSAPat<(and v2i64:$ws, (vnot (shl (v2i64 (vsplat_imm_eq_1)), (vsplati64imm63 v2i64:$wt)))), (v2i64 (BCLR_D v2i64:$ws, v2i64:$wt))>; diff --git a/llvm/lib/Target/Mips/MipsSEISelDAGToDAG.cpp b/llvm/lib/Target/Mips/MipsSEISelDAGToDAG.cpp index 7ad300c6cccd45..66c034a889c600 100644 --- a/llvm/lib/Target/Mips/MipsSEISelDAGToDAG.cpp +++ b/llvm/lib/Target/Mips/MipsSEISelDAGToDAG.cpp @@ -730,6 +730,18 @@ bool MipsSEDAGToDAGISel::selectVSplatUimmInvPow2(SDValue N, return false; } +// Select const vector splat of 1. +bool MipsSEDAGToDAGISel::selectVSplatImmEq1(SDValue N) const { + APInt ImmValue; + EVT EltTy = N->getValueType(0).getVectorElementType(); + + if (N->getOpcode() == ISD::BITCAST) + N = N->getOperand(0); + + return selectVSplat(N.getNode(), ImmValue, EltTy.getSizeInBits()) && + ImmValue.getBitWidth() == EltTy.getSizeInBits() && ImmValue == 1; +} + bool MipsSEDAGToDAGISel::trySelect(SDNode *Node) { unsigned Opcode = Node->getOpcode(); SDLoc DL(Node); diff --git a/llvm/lib/Target/Mips/MipsSEISelDAGToDAG.h b/llvm/lib/Target/Mips/MipsSEISelDAGToDAG.h index 7b843b0e0b2552..22d8e924ac534f 100644 --- a/llvm/lib/Target/Mips/MipsSEISelDAGToDAG.h +++ b/llvm/lib/Target/Mips/MipsSEISelDAGToDAG.h @@ -124,6 +124,9 @@ class MipsSEDAGToDAGISel : public MipsDAGToDAGISel { /// starting at bit zero. bool selectVSplatMaskR(SDValue N, SDValue &Imm) const override; + /// Select constant vector splats whose value is 1. + bool selectVSplatImmEq1(SDValue N) const override; + bool trySelect(SDNode *Node) override; // Emits proper ABI for _mcount profiling calls. >From 53f85626daab1ec9e128e486501b9f5bc8d8a27c Mon Sep 17 00:00:00 2001 From: Yingwei Zheng <dtcxzyw2...@gmail.com> Date: Tue, 19 Nov 2024 21:24:40 +0800 Subject: [PATCH 2/2] [SDAG][ISel][TableGen][LoongArch] Report error for trivial bitcasts when there are predicate calls (#116075) On loongarch64 with lsx extension, we select `VBITREV_W` for `v4i32 (xor X, (shl splat(1), Y))`: https://github.com/llvm/llvm-project/blob/8e6630391699116641cf390a10476295b7d4b95c/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td#L1583-L1584 And `vsplat_imm_eq_1` is defined as: https://github.com/llvm/llvm-project/blob/8e6630391699116641cf390a10476295b7d4b95c/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td#L77-L87 For the `(bitconvert (v4i32 (build_vector)))` case, the pattern is expected to be: ``` PATTERN: (xor:{ *:[v4i32] } v4i32:{ *:[v4i32] }:$vj, (shl:{ *:[v4i32] } (bitconvert:{ *:[v4i32] } (build_vector:{ *:[v4i32] }))<<P:Predicate_vsplat_imm_eq_1>>, v4i32:{ *:[v4i32] }:$vk)) RESULT: (VBITREV_W:{ *:[v4i32] } v4i32:{ *:[v4i32] }:$vj, v4i32:{ *:[v4i32] }:$vk) ``` However, `simplifyTree` drops the `bitconvert` node and its predicates: https://github.com/llvm/llvm-project/blob/8e6630391699116641cf390a10476295b7d4b95c/llvm/utils/TableGen/Common/CodeGenDAGPatterns.cpp#L3036-L3062 Then llvm will match `vsplat_imm_eq_1` for any v4i32 splats and cause a miscompilation: ``` PATTERN: (xor:{ *:[v4i32] } v4i32:{ *:[v4i32] }:$vj, (shl:{ *:[v4i32] } (build_vector:{ *:[v4i32] }), v4i32:{ *:[v4i32] }:$vk)) RESULT: (VBITREV_W:{ *:[v4i32] } v4i32:{ *:[v4i32] }:$vj, v4i32:{ *:[v4i32] }:$vk) ``` This patch adds additional checks for predicates associated with the trivial bitconvert node. Unused patterns in the LoongArch target are also removed. Fixes https://github.com/llvm/llvm-project/issues/116008. (cherry picked from commit c727b48287cc96888f9e262f23d53cf635cf3b3d) --- .../Target/LoongArch/LoongArchLSXInstrInfo.td | 6 ++---- llvm/test/CodeGen/LoongArch/lsx/pr116008.ll | 17 +++++++++++++++++ .../TableGen/Common/CodeGenDAGPatterns.cpp | 8 ++++++++ 3 files changed, 27 insertions(+), 4 deletions(-) create mode 100644 llvm/test/CodeGen/LoongArch/lsx/pr116008.ll diff --git a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td index 0580683c3ce303..0233baecf6dd9c 100644 --- a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td +++ b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td @@ -67,8 +67,7 @@ class VecCond<SDPatternOperator OpNode, ValueType TyNode, let usesCustomInserter = 1; } -def vsplat_imm_eq_1 : PatFrags<(ops), [(build_vector), - (bitconvert (v4i32 (build_vector)))], [{ +def vsplat_imm_eq_1 : PatFrags<(ops), [(build_vector)], [{ APInt Imm; EVT EltTy = N->getValueType(0).getVectorElementType(); @@ -109,8 +108,7 @@ def vsplati32_imm_eq_31 : PatFrags<(ops), [(build_vector)], [{ return selectVSplat(N, Imm, EltTy.getSizeInBits()) && Imm.getBitWidth() == EltTy.getSizeInBits() && Imm == 31; }]>; -def vsplati64_imm_eq_63 : PatFrags<(ops), [(build_vector), - (bitconvert (v4i32 (build_vector)))], [{ +def vsplati64_imm_eq_63 : PatFrags<(ops), [(build_vector)], [{ APInt Imm; EVT EltTy = N->getValueType(0).getVectorElementType(); diff --git a/llvm/test/CodeGen/LoongArch/lsx/pr116008.ll b/llvm/test/CodeGen/LoongArch/lsx/pr116008.ll new file mode 100644 index 00000000000000..ba8ffc34931893 --- /dev/null +++ b/llvm/test/CodeGen/LoongArch/lsx/pr116008.ll @@ -0,0 +1,17 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc --mtriple=loongarch64 --mattr=+lsx < %s | FileCheck %s + +define <4 x i32> @xor_shl_splat_vec_one(i32 %x, <4 x i32> %y) nounwind { +; CHECK-LABEL: xor_shl_splat_vec_one: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vreplgr2vr.w $vr1, $a0 +; CHECK-NEXT: vsll.w $vr0, $vr1, $vr0 +; CHECK-NEXT: vbitrevi.w $vr0, $vr0, 0 +; CHECK-NEXT: ret +entry: + %ins = insertelement <4 x i32> poison, i32 %x, i64 0 + %splat = shufflevector <4 x i32> %ins, <4 x i32> poison, <4 x i32> zeroinitializer + %shl = shl <4 x i32> %splat, %y + %xor = xor <4 x i32> %shl, splat (i32 1) + ret <4 x i32> %xor +} diff --git a/llvm/utils/TableGen/Common/CodeGenDAGPatterns.cpp b/llvm/utils/TableGen/Common/CodeGenDAGPatterns.cpp index a8cecca0d4a54f..ca71569008d5ec 100644 --- a/llvm/utils/TableGen/Common/CodeGenDAGPatterns.cpp +++ b/llvm/utils/TableGen/Common/CodeGenDAGPatterns.cpp @@ -3042,6 +3042,14 @@ static bool SimplifyTree(TreePatternNodePtr &N) { !N->getExtType(0).empty() && N->getExtType(0) == N->getChild(0).getExtType(0) && N->getName().empty()) { + if (!N->getPredicateCalls().empty()) { + std::string Str; + raw_string_ostream OS(Str); + OS << *N + << "\n trivial bitconvert node should not have predicate calls\n"; + PrintFatalError(Str); + return false; + } N = N->getChildShared(0); SimplifyTree(N); return true; _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits