Author: Kerry McLaughlin Date: 2020-11-25T11:18:22Z New Revision: 603d40da9d532ab4706e32c07aba339e180ed865
URL: https://github.com/llvm/llvm-project/commit/603d40da9d532ab4706e32c07aba339e180ed865 DIFF: https://github.com/llvm/llvm-project/commit/603d40da9d532ab4706e32c07aba339e180ed865.diff LOG: [SVE][CodeGen] Add a DAG combine to extend mscatter indices This patch adds a target-specific DAG combine for mscatter to promote indices with element types i8 or i16 before legalisation, plus various tests with illegal types. Reviewed By: sdesmalen Differential Revision: https://reviews.llvm.org/D90945 Added: llvm/test/CodeGen/AArch64/sve-masked-scatter-legalise.ll Modified: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp Removed: ################################################################################ diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index b92eb1d0e4f6..e4c20cc4e6e3 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -835,6 +835,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, if (Subtarget->supportsAddressTopByteIgnored()) setTargetDAGCombine(ISD::LOAD); + setTargetDAGCombine(ISD::MSCATTER); + setTargetDAGCombine(ISD::MUL); setTargetDAGCombine(ISD::SELECT); @@ -13944,6 +13946,44 @@ static SDValue performSTORECombine(SDNode *N, return SDValue(); } +static SDValue performMSCATTERCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG) { + MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(N); + assert(MSC && "Can only combine scatter store nodes"); + + SDLoc DL(MSC); + SDValue Chain = MSC->getChain(); + SDValue Scale = MSC->getScale(); + SDValue Index = MSC->getIndex(); + SDValue Data = MSC->getValue(); + SDValue Mask = MSC->getMask(); + SDValue BasePtr = MSC->getBasePtr(); + ISD::MemIndexType IndexType = MSC->getIndexType(); + + EVT IdxVT = Index.getValueType(); + + if (DCI.isBeforeLegalize()) { + // SVE gather/scatter requires indices of i32/i64. Promote anything smaller + // prior to legalisation so the result can be split if required. + if ((IdxVT.getVectorElementType() == MVT::i8) || + (IdxVT.getVectorElementType() == MVT::i16)) { + EVT NewIdxVT = IdxVT.changeVectorElementType(MVT::i32); + if (MSC->isIndexSigned()) + Index = DAG.getNode(ISD::SIGN_EXTEND, DL, NewIdxVT, Index); + else + Index = DAG.getNode(ISD::ZERO_EXTEND, DL, NewIdxVT, Index); + + SDValue Ops[] = { Chain, Data, Mask, BasePtr, Index, Scale }; + return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), + MSC->getMemoryVT(), DL, Ops, + MSC->getMemOperand(), IndexType, + MSC->isTruncatingStore()); + } + } + + return SDValue(); +} /// Target-specific DAG combine function for NEON load/store intrinsics /// to merge base address updates. @@ -15136,6 +15176,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, break; case ISD::STORE: return performSTORECombine(N, DCI, DAG, Subtarget); + case ISD::MSCATTER: + return performMSCATTERCombine(N, DCI, DAG); case AArch64ISD::BRCOND: return performBRCONDCombine(N, DCI, DAG); case AArch64ISD::TBNZ: diff --git a/llvm/test/CodeGen/AArch64/sve-masked-scatter-legalise.ll b/llvm/test/CodeGen/AArch64/sve-masked-scatter-legalise.ll new file mode 100644 index 000000000000..c3746a61d875 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-masked-scatter-legalise.ll @@ -0,0 +1,59 @@ +; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s + +; Tests that exercise various type legalisation scenarios for ISD::MSCATTER. + +; Code generate the scenario where the offset vector type is illegal. +define void @masked_scatter_nxv16i8(<vscale x 16 x i8> %data, i8* %base, <vscale x 16 x i8> %offsets, <vscale x 16 x i1> %mask) { +; CHECK-LABEL: masked_scatter_nxv16i8: +; CHECK-DAG: st1b { {{z[0-9]+}}.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw] +; CHECK-DAG: st1b { {{z[0-9]+}}.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw] +; CHECK-DAG: st1b { {{z[0-9]+}}.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw] +; CHECK-DAG: st1b { {{z[0-9]+}}.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw] +; CHECK: ret + %ptrs = getelementptr i8, i8* %base, <vscale x 16 x i8> %offsets + call void @llvm.masked.scatter.nxv16i8(<vscale x 16 x i8> %data, <vscale x 16 x i8*> %ptrs, i32 1, <vscale x 16 x i1> %mask) + ret void +} + +define void @masked_scatter_nxv8i16(<vscale x 8 x i16> %data, i16* %base, <vscale x 8 x i16> %offsets, <vscale x 8 x i1> %mask) { +; CHECK-LABEL: masked_scatter_nxv8i16 +; CHECK-DAG: st1h { {{z[0-9]+}}.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #1] +; CHECK-DAG: st1h { {{z[0-9]+}}.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #1] +; CHECK: ret + %ptrs = getelementptr i16, i16* %base, <vscale x 8 x i16> %offsets + call void @llvm.masked.scatter.nxv8i16(<vscale x 8 x i16> %data, <vscale x 8 x i16*> %ptrs, i32 1, <vscale x 8 x i1> %mask) + ret void +} + +define void @masked_scatter_nxv8f32(<vscale x 8 x float> %data, float* %base, <vscale x 8 x i32> %indexes, <vscale x 8 x i1> %masks) { +; CHECK-LABEL: masked_scatter_nxv8f32 +; CHECK-DAG: st1w { z0.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, uxtw #2] +; CHECK-DAG: st1w { z1.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, uxtw #2] + %ext = zext <vscale x 8 x i32> %indexes to <vscale x 8 x i64> + %ptrs = getelementptr float, float* %base, <vscale x 8 x i64> %ext + call void @llvm.masked.scatter.nxv8f32(<vscale x 8 x float> %data, <vscale x 8 x float*> %ptrs, i32 0, <vscale x 8 x i1> %masks) + ret void +} + +; Code generate the worst case scenario when all vector types are illegal. +define void @masked_scatter_nxv32i32(<vscale x 32 x i32> %data, i32* %base, <vscale x 32 x i32> %offsets, <vscale x 32 x i1> %mask) { +; CHECK-LABEL: masked_scatter_nxv32i32: +; CHECK-NOT: unpkhi +; CHECK-DAG: st1w { z0.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2] +; CHECK-DAG: st1w { z1.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2] +; CHECK-DAG: st1w { z2.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2] +; CHECK-DAG: st1w { z3.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2] +; CHECK-DAG: st1w { z4.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2] +; CHECK-DAG: st1w { z5.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2] +; CHECK-DAG: st1w { z6.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2] +; CHECK-DAG: st1w { z7.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2] +; CHECK: ret + %ptrs = getelementptr i32, i32* %base, <vscale x 32 x i32> %offsets + call void @llvm.masked.scatter.nxv32i32(<vscale x 32 x i32> %data, <vscale x 32 x i32*> %ptrs, i32 4, <vscale x 32 x i1> %mask) + ret void +} + +declare void @llvm.masked.scatter.nxv16i8(<vscale x 16 x i8>, <vscale x 16 x i8*>, i32, <vscale x 16 x i1>) +declare void @llvm.masked.scatter.nxv8i16(<vscale x 8 x i16>, <vscale x 8 x i16*>, i32, <vscale x 8 x i1>) +declare void @llvm.masked.scatter.nxv8f32(<vscale x 8 x float>, <vscale x 8 x float*>, i32, <vscale x 8 x i1>) +declare void @llvm.masked.scatter.nxv32i32(<vscale x 32 x i32>, <vscale x 32 x i32*>, i32, <vscale x 32 x i1>) _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits