================ @@ -0,0 +1,778 @@ +//===-- clang-nvlink-wrapper/ClangNVLinkWrapper.cpp - NVIDIA linker util --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// +// +// This tool wraps around the NVIDIA linker called 'nvlink'. The NVIDIA linker +// is required to create NVPTX applications, but does not support common +// features like LTO or archives. This utility wraps around the tool to cover +// its deficiencies. This tool can be removed once NVIDIA improves their linker +// or ports it to `ld.lld`. +// +//===---------------------------------------------------------------------===// + +#include "clang/Basic/Version.h" + +#include "llvm/ADT/StringExtras.h" +#include "llvm/BinaryFormat/Magic.h" +#include "llvm/Bitcode/BitcodeWriter.h" +#include "llvm/CodeGen/CommandFlags.h" +#include "llvm/IR/DiagnosticPrinter.h" +#include "llvm/LTO/LTO.h" +#include "llvm/Object/Archive.h" +#include "llvm/Object/ArchiveWriter.h" +#include "llvm/Object/Binary.h" +#include "llvm/Object/ELFObjectFile.h" +#include "llvm/Object/IRObjectFile.h" +#include "llvm/Object/ObjectFile.h" +#include "llvm/Object/OffloadBinary.h" +#include "llvm/Option/ArgList.h" +#include "llvm/Option/OptTable.h" +#include "llvm/Option/Option.h" +#include "llvm/Remarks/HotnessThresholdParser.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FileOutputBuffer.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/Program.h" +#include "llvm/Support/Signals.h" +#include "llvm/Support/StringSaver.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/WithColor.h" + +using namespace llvm; +using namespace llvm::opt; +using namespace llvm::object; + +// Various tools (e.g., llc and opt) duplicate this series of declarations for +// options related to passes and remarks. +static cl::opt<bool> RemarksWithHotness( + "pass-remarks-with-hotness", + cl::desc("With PGO, include profile count in optimization remarks"), + cl::Hidden); + +static cl::opt<std::optional<uint64_t>, false, remarks::HotnessThresholdParser> + RemarksHotnessThreshold( + "pass-remarks-hotness-threshold", + cl::desc("Minimum profile count required for " + "an optimization remark to be output. " + "Use 'auto' to apply the threshold from profile summary."), + cl::value_desc("N or 'auto'"), cl::init(0), cl::Hidden); + +static cl::opt<std::string> + RemarksFilename("pass-remarks-output", + cl::desc("Output filename for pass remarks"), + cl::value_desc("filename")); + +static cl::opt<std::string> + RemarksPasses("pass-remarks-filter", + cl::desc("Only record optimization remarks from passes whose " + "names match the given regular expression"), + cl::value_desc("regex")); + +static cl::opt<std::string> RemarksFormat( + "pass-remarks-format", + cl::desc("The format used for serializing remarks (default: YAML)"), + cl::value_desc("format"), cl::init("yaml")); + +static cl::list<std::string> + PassPlugins("load-pass-plugin", + cl::desc("Load passes from plugin library")); + +static cl::opt<std::string> PassPipeline( + "passes", + cl::desc( + "A textual description of the pass pipeline. To have analysis passes " + "available before a certain pass, add 'require<foo-analysis>'. " + "'-passes' overrides the pass pipeline (but not all effects) from " + "specifying '--opt-level=O?' (O2 is the default) to " + "clang-linker-wrapper. Be sure to include the corresponding " + "'default<O?>' in '-passes'.")); +static cl::alias PassPipeline2("p", cl::aliasopt(PassPipeline), + cl::desc("Alias for -passes")); + +static void printVersion(raw_ostream &OS) { + OS << clang::getClangToolFullVersion("clang-nvlink-wrapper") << '\n'; +} + +/// The value of `argv[0]` when run. +static const char *Executable; + +/// Temporary files to be cleaned up. +static SmallVector<SmallString<128>> TempFiles; + +/// Codegen flags for LTO backend. +static codegen::RegisterCodeGenFlags CodeGenFlags; + +namespace { +// Must not overlap with llvm::opt::DriverFlag. +enum WrapperFlags { WrapperOnlyOption = (1 << 4) }; + +enum ID { + OPT_INVALID = 0, // This is not an option ID. +#define OPTION(...) LLVM_MAKE_OPT_ID(__VA_ARGS__), +#include "NVLinkOpts.inc" + LastOption +#undef OPTION +}; + +#define PREFIX(NAME, VALUE) \ + static constexpr StringLiteral NAME##_init[] = VALUE; \ + static constexpr ArrayRef<StringLiteral> NAME(NAME##_init, \ + std::size(NAME##_init) - 1); +#include "NVLinkOpts.inc" +#undef PREFIX + +static constexpr OptTable::Info InfoTable[] = { +#define OPTION(...) LLVM_CONSTRUCT_OPT_INFO(__VA_ARGS__), +#include "NVLinkOpts.inc" +#undef OPTION +}; + +class WrapperOptTable : public opt::GenericOptTable { +public: + WrapperOptTable() : opt::GenericOptTable(InfoTable) {} +}; + +const OptTable &getOptTable() { + static const WrapperOptTable *Table = []() { + auto Result = std::make_unique<WrapperOptTable>(); + return Result.release(); + }(); + return *Table; +} + +[[noreturn]] void reportError(Error E) { + outs().flush(); + logAllUnhandledErrors(std::move(E), WithColor::error(errs(), Executable)); + exit(EXIT_FAILURE); +} + +void diagnosticHandler(const DiagnosticInfo &DI) { + std::string ErrStorage; + raw_string_ostream OS(ErrStorage); + DiagnosticPrinterRawOStream DP(OS); + DI.print(DP); + + switch (DI.getSeverity()) { + case DS_Error: + WithColor::error(errs(), Executable) << ErrStorage << "\n"; + break; + case DS_Warning: + WithColor::warning(errs(), Executable) << ErrStorage << "\n"; + break; + case DS_Note: + WithColor::note(errs(), Executable) << ErrStorage << "\n"; + break; + case DS_Remark: + WithColor::remark(errs()) << ErrStorage << "\n"; + break; + } +} + +Expected<StringRef> createTempFile(const ArgList &Args, const Twine &Prefix, + StringRef Extension) { + SmallString<128> OutputFile; + if (Args.hasArg(OPT_save_temps)) { + (Prefix + "." + Extension).toNullTerminatedStringRef(OutputFile); + } else { + if (std::error_code EC = + sys::fs::createTemporaryFile(Prefix, Extension, OutputFile)) + return createFileError(OutputFile, EC); + } + + TempFiles.emplace_back(std::move(OutputFile)); + return TempFiles.back(); +} + +Expected<std::string> findProgram(StringRef Name, ArrayRef<StringRef> Paths) { + ErrorOr<std::string> Path = sys::findProgramByName(Name, Paths); + if (!Path) + Path = sys::findProgramByName(Name); + if (!Path) + return createStringError(Path.getError(), + "Unable to find '" + Name + "' in path"); + return *Path; +} + +std::optional<std::string> findFile(StringRef Dir, StringRef Root, + const Twine &Name) { + SmallString<128> Path; + if (Dir.starts_with("=")) + sys::path::append(Path, Root, Dir.substr(1), Name); + else + sys::path::append(Path, Dir, Name); + + if (sys::fs::exists(Path)) + return static_cast<std::string>(Path); + return std::nullopt; +} + +std::optional<std::string> +findFromSearchPaths(StringRef Name, StringRef Root, + ArrayRef<StringRef> SearchPaths) { + for (StringRef Dir : SearchPaths) + if (std::optional<std::string> File = findFile(Dir, Root, Name)) + return File; + return std::nullopt; +} + +std::optional<std::string> +searchLibraryBaseName(StringRef Name, StringRef Root, + ArrayRef<StringRef> SearchPaths) { + for (StringRef Dir : SearchPaths) + if (std::optional<std::string> File = + findFile(Dir, Root, "lib" + Name + ".a")) + return File; + return std::nullopt; +} + +/// Search for static libraries in the linker's library path given input like +/// `-lfoo` or `-l:libfoo.a`. +std::optional<std::string> searchLibrary(StringRef Input, StringRef Root, + ArrayRef<StringRef> SearchPaths) { + if (Input.starts_with(":")) + return findFromSearchPaths(Input.drop_front(), Root, SearchPaths); + return searchLibraryBaseName(Input, Root, SearchPaths); +} + +void printCommands(ArrayRef<StringRef> CmdArgs) { + if (CmdArgs.empty()) + return; + + llvm::errs() << " \"" << CmdArgs.front() << "\" "; + llvm::errs() << llvm::join(std::next(CmdArgs.begin(), " ") << "\n"; +} + +/// A minimum symbol interface that provides the necessary information to +/// extract archive members and resolve LTO symbols. +struct Symbol { + enum Flags { + None = 0, + Undefined = 1 << 0, + Weak = 1 << 1, + }; + + Symbol() : File(), Flags(None), UsedInRegularObj(false) {} + + Symbol(MemoryBufferRef File, const irsymtab::Reader::SymbolRef Sym) + : File(File), Flags(0), UsedInRegularObj(false) { + if (Sym.isUndefined()) + Flags |= Undefined; + if (Sym.isWeak()) + Flags |= Weak; + } + + Symbol(MemoryBufferRef File, const SymbolRef Sym) + : File(File), Flags(0), UsedInRegularObj(false) { + auto FlagsOrErr = Sym.getFlags(); + if (!FlagsOrErr) + reportError(FlagsOrErr.takeError()); + if (*FlagsOrErr & SymbolRef::SF_Undefined) + Flags |= Undefined; + if (*FlagsOrErr & SymbolRef::SF_Weak) + Flags |= Weak; + + auto NameOrErr = Sym.getName(); + if (!NameOrErr) + reportError(NameOrErr.takeError()); + } + + bool isWeak() const { return Flags & Weak; } + bool isUndefined() const { return Flags & Undefined; } + + MemoryBufferRef File; + uint32_t Flags; + bool UsedInRegularObj; +}; + +Expected<StringRef> runPTXAs(StringRef File, const ArgList &Args) { + std::string CudaPath = Args.getLastArgValue(OPT_cuda_path_EQ).str(); + Expected<std::string> PTXAsPath = findProgram("ptxas", {CudaPath + "/bin"}); + if (!PTXAsPath) + return PTXAsPath.takeError(); + + auto TempFileOrErr = createTempFile( + Args, sys::path::stem(Args.getLastArgValue(OPT_o, "a.out")), "cubin"); + if (!TempFileOrErr) + return TempFileOrErr.takeError(); + + SmallVector<StringRef> AssemblerArgs({*PTXAsPath, "-m64", "-c", File}); + if (Args.hasArg(OPT_verbose)) + AssemblerArgs.push_back("-v"); + if (Args.hasArg(OPT_g)) { + if (Args.hasArg(OPT_O)) + WithColor::warning(errs(), Executable) + << "Optimized debugging not supported, overriding to '-O0'\n"; + AssemblerArgs.push_back("-O0"); + } else + AssemblerArgs.push_back( + Args.MakeArgString("-O" + Args.getLastArgValue(OPT_O, "3"))); + AssemblerArgs.append({"-arch", Args.getLastArgValue(OPT_arch)}); + AssemblerArgs.append({"-o", *TempFileOrErr}); + + if (Args.hasArg(OPT_dry_run) || Args.hasArg(OPT_verbose)) + printCommands(AssemblerArgs); + if (Args.hasArg(OPT_dry_run)) + return Args.MakeArgString(*TempFileOrErr); + if (sys::ExecuteAndWait(*PTXAsPath, AssemblerArgs)) + return createStringError("'" + sys::path::filename(*PTXAsPath) + "'" + + " failed"); + return Args.MakeArgString(*TempFileOrErr); +} + +Expected<std::unique_ptr<lto::LTO>> createLTO(const ArgList &Args) { + const llvm::Triple Triple("nvptx64-nvidia-cuda"); + lto::Config Conf; + lto::ThinBackend Backend; + unsigned Jobs = 0; + if (auto *Arg = Args.getLastArg(OPT_jobs)) + if (!llvm::to_integer(Arg->getValue(), Jobs) || Jobs == 0) + reportError(createStringError("%s: expected a positive integer, got '%s'", + Arg->getSpelling().data(), + Arg->getValue())); + Backend = lto::createInProcessThinBackend( + llvm::heavyweight_hardware_concurrency(Jobs)); + + Conf.CPU = Args.getLastArgValue(OPT_arch); + Conf.Options = codegen::InitTargetOptionsFromCodeGenFlags(Triple); + + Conf.RemarksFilename = RemarksFilename; + Conf.RemarksPasses = RemarksPasses; + Conf.RemarksWithHotness = RemarksWithHotness; + Conf.RemarksHotnessThreshold = RemarksHotnessThreshold; + Conf.RemarksFormat = RemarksFormat; + + Conf.MAttrs = {Args.getLastArgValue(OPT_feature, "").str()}; + std::optional<CodeGenOptLevel> CGOptLevelOrNone = + CodeGenOpt::parseLevel(Args.getLastArgValue(OPT_O, "2")[0]); + assert(CGOptLevelOrNone && "Invalid optimization level"); + Conf.CGOptLevel = *CGOptLevelOrNone; + Conf.OptLevel = Args.getLastArgValue(OPT_O, "2")[0] - '0'; + Conf.DefaultTriple = Triple.getTriple(); + + Conf.OptPipeline = PassPipeline; + Conf.PassPlugins = PassPlugins; + + Conf.DiagHandler = diagnosticHandler; + Conf.CGFileType = CodeGenFileType::AssemblyFile; + + if (Args.hasArg(OPT_lto_emit_llvm)) { + Conf.PreCodeGenModuleHook = [&](size_t, const Module &M) { + std::error_code EC; + raw_fd_ostream LinkedBitcode(Args.getLastArgValue(OPT_o, "a.out"), EC); + if (EC) + reportError(errorCodeToError(EC)); + WriteBitcodeToFile(M, LinkedBitcode); + return false; + }; + } + + if (Args.hasArg(OPT_save_temps)) + if (Error Err = Conf.addSaveTemps( + (Args.getLastArgValue(OPT_o, "a.out") + ".").str())) + return Err; + + unsigned Partitions = 1; + if (auto *Arg = Args.getLastArg(OPT_lto_partitions)) + if (!llvm::to_integer(Arg->getValue(), Partitions) || Partitions == 0) + reportError(createStringError("%s: expected a positive integer, got '%s'", + Arg->getSpelling().data(), + Arg->getValue())); + lto::LTO::LTOKind Kind = Args.hasArg(OPT_thinlto) ? lto::LTO::LTOK_UnifiedThin + : lto::LTO::LTOK_Default; + return std::make_unique<lto::LTO>(std::move(Conf), Backend, Partitions, Kind); +} + +Expected<bool> getSymbolsFromBitcode(MemoryBufferRef Buffer, + StringMap<Symbol> &SymTab, bool IsLazy) { + Expected<IRSymtabFile> IRSymtabOrErr = readIRSymtab(Buffer); + if (!IRSymtabOrErr) + return IRSymtabOrErr.takeError(); + bool Extracted = !IsLazy; + StringMap<Symbol> PendingSymbols; + for (unsigned I = 0; I != IRSymtabOrErr->Mods.size(); ++I) { + for (const auto &IRSym : IRSymtabOrErr->TheReader.module_symbols(I)) { + if (IRSym.isFormatSpecific() || !IRSym.isGlobal()) + continue; + + Symbol &OldSym = !SymTab.count(IRSym.getName()) && IsLazy + ? PendingSymbols[IRSym.getName()] + : SymTab[IRSym.getName()]; + Symbol Sym = Symbol(Buffer, IRSym); + if (OldSym.File.getBuffer().empty()) + OldSym = Sym; + + bool ResolvesReference = + !Sym.isUndefined() && + (OldSym.isUndefined() || (OldSym.isWeak() && !Sym.isWeak())) && + !(OldSym.isWeak() && OldSym.isUndefined() && IsLazy); + Extracted |= ResolvesReference; + + Sym.UsedInRegularObj = OldSym.UsedInRegularObj; + if (ResolvesReference) + OldSym = Sym; + } + } + if (Extracted) + for (auto &[Name, Symbol] : PendingSymbols) + SymTab[Name] = Symbol; + return Extracted; +} + +Expected<bool> getSymbolsFromObject(ObjectFile &ObjFile, + StringMap<Symbol> &SymTab, bool IsLazy) { + bool Extracted = !IsLazy; + StringMap<Symbol> PendingSymbols; + for (SymbolRef ObjSym : ObjFile.symbols()) { + auto NameOrErr = ObjSym.getName(); + if (!NameOrErr) + return NameOrErr.takeError(); + + Symbol &OldSym = !SymTab.count(*NameOrErr) && IsLazy + ? PendingSymbols[*NameOrErr] + : SymTab[*NameOrErr]; + Symbol Sym = Symbol(ObjFile.getMemoryBufferRef(), ObjSym); + if (OldSym.File.getBuffer().empty()) + OldSym = Sym; + + bool ResolvesReference = OldSym.isUndefined() && !Sym.isUndefined() && + (!OldSym.isWeak() || !IsLazy); + Extracted |= ResolvesReference; + + if (ResolvesReference) + OldSym = Sym; + OldSym.UsedInRegularObj = true; + } + if (Extracted) + for (auto &[Name, Symbol] : PendingSymbols) + SymTab[Name] = Symbol; + return Extracted; +} + +Expected<bool> getSymbols(MemoryBufferRef Buffer, StringMap<Symbol> &SymTab, + bool IsLazy) { + switch (identify_magic(Buffer.getBuffer())) { + case file_magic::bitcode: { + return getSymbolsFromBitcode(Buffer, SymTab, IsLazy); + } + case file_magic::elf_relocatable: { + Expected<std::unique_ptr<ObjectFile>> ObjFile = + ObjectFile::createObjectFile(Buffer); + if (!ObjFile) + return ObjFile.takeError(); + return getSymbolsFromObject(**ObjFile, SymTab, IsLazy); + } + default: + return createStringError("Unsupported file type"); + } +} + +Expected<SmallVector<StringRef>> getInput(const ArgList &Args) { + SmallVector<StringRef> LibraryPaths; + for (const opt::Arg *Arg : Args.filtered(OPT_library_path)) + LibraryPaths.push_back(Arg->getValue()); + + bool WholeArchive = false; + SmallVector<std::pair<std::unique_ptr<MemoryBuffer>, bool>> InputFiles; + for (const opt::Arg *Arg : Args.filtered( + OPT_INPUT, OPT_library, OPT_whole_archive, OPT_no_whole_archive)) { + if (Arg->getOption().matches(OPT_whole_archive) || + Arg->getOption().matches(OPT_no_whole_archive)) { + WholeArchive = Arg->getOption().matches(OPT_whole_archive); + continue; + } + + std::optional<std::string> Filename = + Arg->getOption().matches(OPT_library) + ? searchLibrary(Arg->getValue(), /*Root=*/"", LibraryPaths) + : std::string(Arg->getValue()); + + if (!Filename && Arg->getOption().matches(OPT_library)) + return createStringError("unable to find library -l%s", Arg->getValue()); + + if (!Filename || !sys::fs::exists(*Filename) || + sys::fs::is_directory(*Filename)) + continue; + + ErrorOr<std::unique_ptr<MemoryBuffer>> BufferOrErr = + MemoryBuffer::getFileOrSTDIN(*Filename); + if (std::error_code EC = BufferOrErr.getError()) + return createFileError(*Filename, EC); + + MemoryBufferRef Buffer = **BufferOrErr; + switch (identify_magic(Buffer.getBuffer())) { + case file_magic::bitcode: + case file_magic::elf_relocatable: + InputFiles.emplace_back(std::move(*BufferOrErr), /*IsLazy=*/false); + break; + case file_magic::archive: { + Expected<std::unique_ptr<llvm::object::Archive>> LibFile = + object::Archive::create(Buffer); + if (!LibFile) + return LibFile.takeError(); + Error Err = Error::success(); + for (auto Child : (*LibFile)->children(Err)) { + auto ChildBufferOrErr = Child.getMemoryBufferRef(); + if (!ChildBufferOrErr) + return ChildBufferOrErr.takeError(); + std::unique_ptr<MemoryBuffer> ChildBuffer = + MemoryBuffer::getMemBufferCopy( + ChildBufferOrErr->getBuffer(), + ChildBufferOrErr->getBufferIdentifier()); + InputFiles.emplace_back(std::move(ChildBuffer), !WholeArchive); + } + if (Err) + return Err; + break; + } + default: + return createStringError("Unsupported file type"); + } + } + + bool Extracted = true; + StringMap<Symbol> SymTab; + SmallVector<std::unique_ptr<MemoryBuffer>> LinkerInput; + while (Extracted) { + Extracted = false; + for (auto &[Input, IsLazy] : InputFiles) { + if (!Input) + continue; + + // Archive members only extract if they define needed symbols. We will + // re-scan all the inputs if any files were extracted for the link job. + Expected<bool> ExtractOrErr = getSymbols(*Input, SymTab, IsLazy); + if (!ExtractOrErr) + return ExtractOrErr.takeError(); + + Extracted |= *ExtractOrErr; + if (!*ExtractOrErr) + continue; + + LinkerInput.emplace_back(std::move(Input)); + } + } + InputFiles.clear(); + + // Extract any bitcode files to be passed to the LTO pipeline. + SmallVector<std::unique_ptr<MemoryBuffer>> BitcodeFiles; + for (auto &Input : LinkerInput) + if (identify_magic(Input->getBuffer()) == file_magic::bitcode) + BitcodeFiles.emplace_back(std::move(Input)); + llvm::erase_if(LinkerInput, [](const auto &F) { return !F; }); + + // Run the LTO pipeline on the extracted inputs. + SmallVector<StringRef> Files; + if (!BitcodeFiles.empty()) { + auto LTOBackendOrErr = createLTO(Args); + if (!LTOBackendOrErr) + return LTOBackendOrErr.takeError(); + lto::LTO <OBackend = **LTOBackendOrErr; + for (auto &BitcodeFile : BitcodeFiles) { + Expected<std::unique_ptr<lto::InputFile>> BitcodeFileOrErr = + llvm::lto::InputFile::create(*BitcodeFile); + if (!BitcodeFileOrErr) + return BitcodeFileOrErr.takeError(); + + const auto Symbols = (*BitcodeFileOrErr)->symbols(); + SmallVector<lto::SymbolResolution, 16> Resolutions(Symbols.size()); + size_t Idx = 0; + for (auto &Sym : Symbols) { + lto::SymbolResolution &Res = Resolutions[Idx++]; + Symbol ObjSym = SymTab[Sym.getName()]; + // We will use this as the prevailing symbol in LTO if it is not + // undefined and it is from the file that contained the canonical + // definition. + Res.Prevailing = !Sym.isUndefined() && ObjSym.File == *BitcodeFile; + + // We need LTO to preseve the following global symbols: + // 1) Symbols used in regular objects. + // 2) Prevailing symbols that are needed visible to the gpu runtime. + Res.VisibleToRegularObj = + ObjSym.UsedInRegularObj || + (Res.Prevailing && + (Sym.getVisibility() != GlobalValue::HiddenVisibility && + !Sym.canBeOmittedFromSymbolTable())); + + // Identify symbols that must be exported dynamically and can be + // referenced by other files, (i.e. the runtime). + Res.ExportDynamic = + Sym.getVisibility() != GlobalValue::HiddenVisibility && + !Sym.canBeOmittedFromSymbolTable(); + + // The NVIDIA platform does not support any symbol preemption. + Res.FinalDefinitionInLinkageUnit = true; + + // We do not support linker redefined symbols (e.g. --wrap) for device + // image linking, so the symbols will not be changed after LTO. + Res.LinkerRedefined = false; + } + + // Add the bitcode file with its resolved symbols to the LTO job. + if (Error Err = LTOBackend.add(std::move(*BitcodeFileOrErr), Resolutions)) + return Err; + } + + // Run the LTO job to compile the bitcode. + size_t MaxTasks = LTOBackend.getMaxTasks(); + SmallVector<StringRef> LTOFiles(MaxTasks); + auto AddStream = + [&](size_t Task, + const Twine &ModuleName) -> std::unique_ptr<CachedFileStream> { + int FD = -1; + auto &TempFile = LTOFiles[Task]; + if (Args.hasArg(OPT_lto_emit_asm)) + TempFile = Args.getLastArgValue(OPT_o, "a.out"); + else { + auto TempFileOrErr = createTempFile( + Args, sys::path::stem(Args.getLastArgValue(OPT_o, "a.out")), "s"); + if (!TempFileOrErr) + reportError(TempFileOrErr.takeError()); + TempFile = Args.MakeArgString(*TempFileOrErr); + } + if (std::error_code EC = sys::fs::openFileForWrite(TempFile, FD)) + reportError(errorCodeToError(EC)); + return std::make_unique<CachedFileStream>( + std::make_unique<llvm::raw_fd_ostream>(FD, true)); + }; + + if (Error Err = LTOBackend.run(AddStream)) + return Err; + + if (Args.hasArg(OPT_lto_emit_llvm) || Args.hasArg(OPT_lto_emit_asm)) + return Files; + + for (StringRef LTOFile : LTOFiles) { + auto FileOrErr = runPTXAs(LTOFile, Args); + if (!FileOrErr) + return FileOrErr.takeError(); + Files.emplace_back(*FileOrErr); + } + } + + // Copy all of the input files to a new file ending in `.cubin`. The 'nvlink' + // linker requires all NVPTX inputs to have this extension for some reason. + for (auto &Input : LinkerInput) { + auto TempFileOrErr = createTempFile( + Args, sys::path::stem(Input->getBufferIdentifier()), "cubin"); + if (!TempFileOrErr) + return TempFileOrErr.takeError(); + Expected<std::unique_ptr<FileOutputBuffer>> OutputOrErr = + FileOutputBuffer::create(*TempFileOrErr, Input->getBuffer().size()); + if (!OutputOrErr) + return OutputOrErr.takeError(); + std::unique_ptr<FileOutputBuffer> Output = std::move(*OutputOrErr); + llvm::copy(Input->getBuffer(), Output->getBufferStart()); + if (Error E = Output->commit()) + return E; + Files.emplace_back(Args.MakeArgString(*TempFileOrErr)); + } + + return Files; +} + +Error runNVLink(ArrayRef<StringRef> Files, const ArgList &Args) { + if (Args.hasArg(OPT_lto_emit_asm) || Args.hasArg(OPT_lto_emit_llvm)) + return Error::success(); + + std::string CudaPath = Args.getLastArgValue(OPT_cuda_path_EQ).str(); + Expected<std::string> NVLinkPath = findProgram("nvlink", {CudaPath + "/bin"}); + if (!NVLinkPath) + return NVLinkPath.takeError(); + + ArgStringList NewLinkerArgs; + for (const opt::Arg *Arg : Args) { + // Do not forward arguments only intended for the linker wrapper. + if (Arg->getOption().hasFlag(WrapperOnlyOption)) + continue; + + // Do not forward any inputs that we have processed. + if (Arg->getOption().matches(OPT_INPUT) || + Arg->getOption().matches(OPT_library)) + continue; + + Arg->render(Args, NewLinkerArgs); + } + + llvm::transform(Files, std::back_inserter(NewLinkerArgs), + [&](StringRef Arg) { return Args.MakeArgString(Arg); }); + + SmallVector<StringRef> LinkerArgs({*NVLinkPath}); + if (!Args.hasArg(OPT_o)) + LinkerArgs.append({"-o", "a.out"}); + for (StringRef Arg : NewLinkerArgs) + LinkerArgs.push_back(Arg); + + if (Args.hasArg(OPT_dry_run) || Args.hasArg(OPT_verbose)) + printCommands(LinkerArgs); + if (Args.hasArg(OPT_dry_run)) + return Error::success(); + if (sys::ExecuteAndWait(*NVLinkPath, LinkerArgs)) + return createStringError("'" + sys::path::filename(*NVLinkPath) + "'" + + " failed"); + return Error::success(); +} + +} // namespace + +int main(int argc, char **argv) { + InitLLVM X(argc, argv); + InitializeAllTargetInfos(); + InitializeAllTargets(); + InitializeAllTargetMCs(); + InitializeAllAsmParsers(); + InitializeAllAsmPrinters(); + + Executable = argv[0]; + sys::PrintStackTraceOnErrorSignal(argv[0]); + + const OptTable &Tbl = getOptTable(); + BumpPtrAllocator Alloc; + StringSaver Saver(Alloc); + auto Args = Tbl.parseArgs(argc, argv, OPT_INVALID, Saver, [&](StringRef Err) { + reportError(createStringError(inconvertibleErrorCode(), Err)); + }); + + if (Args.hasArg(OPT_help) || Args.hasArg(OPT_help_hidden)) { + Tbl.printHelp( + outs(), "clang-nvlink-wrapper [options] <options to passed to nvlink>", + "A utility that wraps around the NVIDIA 'nvlink' linker.\n" + "This enables static linking and LTO handling for NVPTX targets.", + Args.hasArg(OPT_help_hidden), Args.hasArg(OPT_help_hidden)); + return EXIT_SUCCESS; + } + + if (Args.hasArg(OPT_version)) ---------------- Artem-B wrote:
Should we exit here? I think that's the 'standard' expectation. At least clang does so. https://github.com/llvm/llvm-project/pull/96561 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits