================
@@ -3085,9 +3089,114 @@ SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, 
SelectionDAG &DAG) const {
                       MachinePointerInfo(SV));
 }
 
+/// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
+static std::optional<std::pair<SDValue, SDValue>>
+replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
+  LoadSDNode *LD = cast<LoadSDNode>(N);
+  const EVT ResVT = LD->getValueType(0);
+  const EVT MemVT = LD->getMemoryVT();
+
+  // If we're doing sign/zero extension as part of the load, avoid lowering to
+  // a LoadV node. TODO: consider relaxing this restriction.
+  if (ResVT != MemVT)
+    return std::nullopt;
+
+  const auto NumEltsAndEltVT =
+      getVectorLoweringShape(ResVT, STI, LD->getAddressSpace());
+  if (!NumEltsAndEltVT)
+    return std::nullopt;
+  const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
+
+  Align Alignment = LD->getAlign();
+  const auto &TD = DAG.getDataLayout();
+  Align PrefAlign = 
TD.getPrefTypeAlign(MemVT.getTypeForEVT(*DAG.getContext()));
+  if (Alignment < PrefAlign) {
+    // This load is not sufficiently aligned, so bail out and let this vector
+    // load be scalarized.  Note that we may still be able to emit smaller
+    // vector loads.  For example, if we are loading a <4 x float> with an
+    // alignment of 8, this check will fail but the legalizer will try again
+    // with 2 x <2 x float>, which will succeed with an alignment of 8.
+    return std::nullopt;
+  }
+
+  // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
+  // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
+  // loaded type to i16 and propagate the "real" type as the memory type.
+  const MVT LoadEltVT = (EltVT.getSizeInBits() < 16) ? MVT::i16 : EltVT;
+
+  unsigned Opcode;
+  switch (NumElts) {
+  default:
+    return std::nullopt;
+  case 2:
+    Opcode = NVPTXISD::LoadV2;
+    break;
+  case 4:
+    Opcode = NVPTXISD::LoadV4;
+    break;
+  case 8:
+    Opcode = NVPTXISD::LoadV8;
+    break;
+  }
+  auto ListVTs = SmallVector<EVT, 9>(NumElts, LoadEltVT);
+  ListVTs.push_back(MVT::Other);
+  SDVTList LdResVTs = DAG.getVTList(ListVTs);
+
+  SDLoc DL(LD);
+
+  // Copy regular operands
+  SmallVector<SDValue, 8> OtherOps(LD->ops());
+
+  // The select routine does not have access to the LoadSDNode instance, so
+  // pass along the extension information
+  OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType(), DL));
+
+  SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps, 
MemVT,
+                                          LD->getMemOperand());
+
+  SmallVector<SDValue> ScalarRes;
+  if (EltVT.isVector()) {
+    assert(EVT(EltVT.getVectorElementType()) == ResVT.getVectorElementType());
+    assert(NumElts * EltVT.getVectorNumElements() ==
+           ResVT.getVectorNumElements());
+    // Generate EXTRACT_VECTOR_ELTs to split v2[i,f,bf]16/v4i8 subvectors back
+    // into individual elements.
+    for (const unsigned I : llvm::seq(NumElts)) {
+      SDValue SubVector = NewLD.getValue(I);
+      DAG.ExtractVectorElements(SubVector, ScalarRes);
+    }
+  } else {
+    for (const unsigned I : llvm::seq(NumElts)) {
+      SDValue Res = NewLD.getValue(I);
+      if (LoadEltVT != EltVT)
+        Res = DAG.getNode(ISD::TRUNCATE, DL, EltVT, Res);
+      ScalarRes.push_back(Res);
+    }
+  }
+
+  SDValue LoadChain = NewLD.getValue(NumElts);
+
+  const MVT BuildVecVT =
+      MVT::getVectorVT(EltVT.getScalarType(), ScalarRes.size());
+  SDValue BuildVec = DAG.getBuildVector(BuildVecVT, DL, ScalarRes);
+  SDValue LoadValue = DAG.getBitcast(ResVT, BuildVec);
+
+  return {{LoadValue, LoadChain}};
+}
+
 static void replaceLoadVector(SDNode *N, SelectionDAG &DAG,
                               SmallVectorImpl<SDValue> &Results,
-                              const NVPTXSubtarget &STI);
+                              const NVPTXSubtarget &STI) {
+  if (auto Res = replaceLoadVector(N, DAG, STI))
+    Results.append({Res->first, Res->second});
+}
+
+static SDValue lowerLoadVector(SDNode *N, SelectionDAG &DAG,
+                               const NVPTXSubtarget &STI) {
+  if (auto Res = replaceLoadVector(N, DAG, STI))
+    return DAG.getMergeValues({Res->first, Res->second}, SDLoc(N));
+  return SDValue();
----------------
dakersnar wrote:

In what circumstances will we return an empty SDValue?

https://github.com/llvm/llvm-project/pull/155198
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to