================ @@ -119,6 +123,142 @@ class OpLowerer { }); } + /// Create a cast between a `target("dx")` type and `dx.types.Handle`, which + /// is intended to be removed by the end of lowering. This is used to allow + /// lowering of ops which need to change their return or argument types in a + /// piecemeal way - we can add the casts in to avoid updating all of the uses + /// or defs, and by the end all of the casts will be redundant. + Value *createTmpHandleCast(Value *V, Type *Ty) { + Function *CastFn = Intrinsic::getDeclaration(&M, Intrinsic::dx_cast_handle, + {Ty, V->getType()}); + CallInst *Cast = OpBuilder.getIRB().CreateCall(CastFn, {V}); + CleanupCasts.push_back(Cast); + return Cast; + } + + void cleanupHandleCasts() { + SmallVector<CallInst *> ToRemove; + SmallVector<Function *> CastFns; + + for (CallInst *Cast : CleanupCasts) { + // These casts were only put in to ease the move from `target("dx")` types + // to `dx.types.Handle in a piecemeal way. At this point, all of the + // non-cast uses should now be `dx.types.Handle`, and remaining casts + // should all form pairs to and from the now unused `target("dx")` type. + CastFns.push_back(Cast->getCalledFunction()); + + // If the cast is not to `dx.types.Handle`, it should be the first part of + // the pair. Keep track so we can remove it once it has no more uses. + if (Cast->getType() != OpBuilder.getHandleType()) { + ToRemove.push_back(Cast); + continue; + } + // Otherwise, we're the second handle in a pair. Forward the arguments and + // remove the (second) cast. + CallInst *Def = cast<CallInst>(Cast->getOperand(0)); + assert(Def->getIntrinsicID() == Intrinsic::dx_cast_handle && + "Unbalanced pair of temporary handle casts"); + Cast->replaceAllUsesWith(Def->getOperand(0)); + Cast->eraseFromParent(); + } + for (CallInst *Cast : ToRemove) { + assert(Cast->user_empty() && "Temporary handle cast still has users"); + Cast->eraseFromParent(); + } + + // Deduplicate the cast functions so that we only erase each one once. + llvm::sort(CastFns); + CastFns.erase(llvm::unique(CastFns), CastFns.end()); + for (Function *F : CastFns) + F->eraseFromParent(); + + CleanupCasts.clear(); + } + + void lowerToCreateHandle(Function &F) { + IRBuilder<> &IRB = OpBuilder.getIRB(); + Type *Int8Ty = IRB.getInt8Ty(); + Type *Int32Ty = IRB.getInt32Ty(); + + replaceFunction(F, [&](CallInst *CI) -> Error { + IRB.SetInsertPoint(CI); + + auto *It = DRM.find(CI); + assert(It != DRM.end() && "Resource not in map?"); + dxil::ResourceInfo &RI = *It; + const auto &Binding = RI.getBinding(); + + std::array<Value *, 4> Args{ + ConstantInt::get(Int8Ty, llvm::to_underlying(RI.getResourceClass())), + ConstantInt::get(Int32Ty, Binding.RecordID), CI->getArgOperand(3), + CI->getArgOperand(4)}; + Expected<CallInst *> OpCall = + OpBuilder.tryCreateOp(OpCode::CreateHandle, Args); + if (Error E = OpCall.takeError()) + return E; + + Value *Cast = createTmpHandleCast(*OpCall, CI->getType()); + + CI->replaceAllUsesWith(Cast); + CI->eraseFromParent(); + return Error::success(); + }); + } + + void lowerToBindAndAnnotateHandle(Function &F) { + IRBuilder<> &IRB = OpBuilder.getIRB(); + + replaceFunction(F, [&](CallInst *CI) -> Error { + IRB.SetInsertPoint(CI); + + auto *It = DRM.find(CI); + assert(It != DRM.end() && "Resource not in map?"); + dxil::ResourceInfo &RI = *It; + + const auto &Binding = RI.getBinding(); + std::pair<uint32_t, uint32_t> Props = RI.getAnnotateProps(); + + // For `CreateHandleFromBinding` we need the upper bound rather than the + // size, so we need to be careful about the difference for "unbounded". + uint32_t Unbounded = std::numeric_limits<uint32_t>::max(); ---------------- bogner wrote:
I personally think it's clearer to spell out that it's the maximum uint32_t rather than have to chase through some global in a header to figure out the sigil for something simple like this. https://github.com/llvm/llvm-project/pull/104251 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits