Author: Kerry McLaughlin Date: 2020-12-10T13:54:45Z New Revision: abe7775f5a43e5a0d8ec237542274ba3e73937e4
URL: https://github.com/llvm/llvm-project/commit/abe7775f5a43e5a0d8ec237542274ba3e73937e4 DIFF: https://github.com/llvm/llvm-project/commit/abe7775f5a43e5a0d8ec237542274ba3e73937e4.diff LOG: [SVE][CodeGen] Extend index of masked gathers This patch changes performMSCATTERCombine to also promote the indices of masked gathers where the element type is i8 or i16, and adds various tests for gathers with illegal types. Reviewed By: sdesmalen Differential Revision: https://reviews.llvm.org/D91433 Added: Modified: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp llvm/test/CodeGen/AArch64/sve-masked-gather-legalize.ll Removed: ################################################################################ diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 5d9c66e170eab..01301abf10e3d 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -849,6 +849,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, if (Subtarget->supportsAddressTopByteIgnored()) setTargetDAGCombine(ISD::LOAD); + setTargetDAGCombine(ISD::MGATHER); setTargetDAGCombine(ISD::MSCATTER); setTargetDAGCombine(ISD::MUL); @@ -14063,20 +14064,19 @@ static SDValue performSTORECombine(SDNode *N, return SDValue(); } -static SDValue performMSCATTERCombine(SDNode *N, +static SDValue performMaskedGatherScatterCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { - MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(N); - assert(MSC && "Can only combine scatter store nodes"); + MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N); + assert(MGS && "Can only combine gather load or scatter store nodes"); - SDLoc DL(MSC); - SDValue Chain = MSC->getChain(); - SDValue Scale = MSC->getScale(); - SDValue Index = MSC->getIndex(); - SDValue Data = MSC->getValue(); - SDValue Mask = MSC->getMask(); - SDValue BasePtr = MSC->getBasePtr(); - ISD::MemIndexType IndexType = MSC->getIndexType(); + SDLoc DL(MGS); + SDValue Chain = MGS->getChain(); + SDValue Scale = MGS->getScale(); + SDValue Index = MGS->getIndex(); + SDValue Mask = MGS->getMask(); + SDValue BasePtr = MGS->getBasePtr(); + ISD::MemIndexType IndexType = MGS->getIndexType(); EVT IdxVT = Index.getValueType(); @@ -14086,16 +14086,27 @@ static SDValue performMSCATTERCombine(SDNode *N, if ((IdxVT.getVectorElementType() == MVT::i8) || (IdxVT.getVectorElementType() == MVT::i16)) { EVT NewIdxVT = IdxVT.changeVectorElementType(MVT::i32); - if (MSC->isIndexSigned()) + if (MGS->isIndexSigned()) Index = DAG.getNode(ISD::SIGN_EXTEND, DL, NewIdxVT, Index); else Index = DAG.getNode(ISD::ZERO_EXTEND, DL, NewIdxVT, Index); - SDValue Ops[] = { Chain, Data, Mask, BasePtr, Index, Scale }; - return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), - MSC->getMemoryVT(), DL, Ops, - MSC->getMemOperand(), IndexType, - MSC->isTruncatingStore()); + if (auto *MGT = dyn_cast<MaskedGatherSDNode>(MGS)) { + SDValue PassThru = MGT->getPassThru(); + SDValue Ops[] = { Chain, PassThru, Mask, BasePtr, Index, Scale }; + return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other), + PassThru.getValueType(), DL, Ops, + MGT->getMemOperand(), + MGT->getIndexType(), MGT->getExtensionType()); + } else { + auto *MSC = cast<MaskedScatterSDNode>(MGS); + SDValue Data = MSC->getValue(); + SDValue Ops[] = { Chain, Data, Mask, BasePtr, Index, Scale }; + return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), + MSC->getMemoryVT(), DL, Ops, + MSC->getMemOperand(), IndexType, + MSC->isTruncatingStore()); + } } } @@ -15072,9 +15083,6 @@ static SDValue performGatherLoadCombine(SDNode *N, SelectionDAG &DAG, static SDValue performSignExtendInRegCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { - if (DCI.isBeforeLegalizeOps()) - return SDValue(); - SDLoc DL(N); SDValue Src = N->getOperand(0); unsigned Opc = Src->getOpcode(); @@ -15109,6 +15117,9 @@ performSignExtendInRegCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, return DAG.getNode(SOpc, DL, N->getValueType(0), Ext); } + if (DCI.isBeforeLegalizeOps()) + return SDValue(); + if (!EnableCombineMGatherIntrinsics) return SDValue(); @@ -15296,8 +15307,9 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, break; case ISD::STORE: return performSTORECombine(N, DCI, DAG, Subtarget); + case ISD::MGATHER: case ISD::MSCATTER: - return performMSCATTERCombine(N, DCI, DAG); + return performMaskedGatherScatterCombine(N, DCI, DAG); case AArch64ISD::BRCOND: return performBRCONDCombine(N, DCI, DAG); case AArch64ISD::TBNZ: diff --git a/llvm/test/CodeGen/AArch64/sve-masked-gather-legalize.ll b/llvm/test/CodeGen/AArch64/sve-masked-gather-legalize.ll index 076edc1fd86da..4482730a7d74c 100644 --- a/llvm/test/CodeGen/AArch64/sve-masked-gather-legalize.ll +++ b/llvm/test/CodeGen/AArch64/sve-masked-gather-legalize.ll @@ -54,6 +54,19 @@ define <vscale x 2 x i32> @masked_gather_nxv2i32(<vscale x 2 x i32*> %ptrs, <vsc ret <vscale x 2 x i32> %data } +; Code generate the worst case scenario when all vector types are legal. +define <vscale x 16 x i8> @masked_gather_nxv16i8(i8* %base, <vscale x 16 x i8> %indices, <vscale x 16 x i1> %mask) { +; CHECK-LABEL: masked_gather_nxv16i8: +; CHECK-DAG: ld1sb { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, {{z[0-9]+}}.s, sxtw] +; CHECK-DAG: ld1sb { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, {{z[0-9]+}}.s, sxtw] +; CHECK-DAG: ld1sb { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, {{z[0-9]+}}.s, sxtw] +; CHECK-DAG: ld1sb { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, {{z[0-9]+}}.s, sxtw] +; CHECK: ret + %ptrs = getelementptr i8, i8* %base, <vscale x 16 x i8> %indices + %data = call <vscale x 16 x i8> @llvm.masked.gather.nxv16i8(<vscale x 16 x i8*> %ptrs, i32 1, <vscale x 16 x i1> %mask, <vscale x 16 x i8> undef) + ret <vscale x 16 x i8> %data +} + ; Code generate the worst case scenario when all vector types are illegal. define <vscale x 32 x i32> @masked_gather_nxv32i32(i32* %base, <vscale x 32 x i32> %indices, <vscale x 32 x i1> %mask) { ; CHECK-LABEL: masked_gather_nxv32i32: _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits