https://github.com/Lancern created https://github.com/llvm/llvm-project/pull/136810
This PR upstreams support for scalar arguments in `cir.call` operation. Related to #132487 . >From 557cae2daea53723010390cdf545721dd9ad7de4 Mon Sep 17 00:00:00 2001 From: Sirui Mu <msrlanc...@gmail.com> Date: Wed, 23 Apr 2025 12:14:40 +0800 Subject: [PATCH] [CIR] Upstream cir.call with scalar arguments --- .../CIR/Dialect/Builder/CIRBaseBuilder.h | 9 +- clang/include/clang/CIR/Dialect/IR/CIROps.td | 46 ++- .../clang/CIR/Interfaces/CIROpInterfaces.td | 17 +- clang/include/clang/CIR/MissingFeatures.h | 12 +- clang/lib/CIR/CodeGen/CIRGenCall.cpp | 277 +++++++++++++++++- clang/lib/CIR/CodeGen/CIRGenCall.h | 32 +- clang/lib/CIR/CodeGen/CIRGenExpr.cpp | 24 +- clang/lib/CIR/CodeGen/CIRGenFunction.h | 48 ++- clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h | 27 +- clang/lib/CIR/CodeGen/CIRGenTypes.cpp | 15 +- clang/lib/CIR/CodeGen/CIRGenTypes.h | 7 +- clang/lib/CIR/CodeGen/TargetInfo.cpp | 8 +- clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 58 +++- clang/test/CIR/CodeGen/call.cpp | 12 + clang/test/CIR/IR/call.cir | 15 + clang/test/CIR/IR/invalid-call.cir | 27 ++ 16 files changed, 588 insertions(+), 46 deletions(-) diff --git a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h index 539268c6270f4..0a6e47ea43a8c 100644 --- a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h +++ b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h @@ -214,14 +214,15 @@ class CIRBaseBuilderTy : public mlir::OpBuilder { //===--------------------------------------------------------------------===// cir::CallOp createCallOp(mlir::Location loc, mlir::SymbolRefAttr callee, - mlir::Type returnType) { - auto op = create<cir::CallOp>(loc, callee, returnType); + mlir::Type returnType, mlir::ValueRange operands) { + auto op = create<cir::CallOp>(loc, callee, returnType, operands); return op; } - cir::CallOp createCallOp(mlir::Location loc, cir::FuncOp callee) { + cir::CallOp createCallOp(mlir::Location loc, cir::FuncOp callee, + mlir::ValueRange operands) { return createCallOp(loc, mlir::SymbolRefAttr::get(callee), - callee.getFunctionType().getReturnType()); + callee.getFunctionType().getReturnType(), operands); } //===--------------------------------------------------------------------===// diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index bb19de31b4fa5..aa7a9b2de664f 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -1496,6 +1496,10 @@ def FuncOp : CIR_Op<"func", [ return getFunctionType().getReturnTypes(); } + // TODO(cir): this should be an operand attribute, but for now we just hard- + // wire this as a function. Will later add a $no_proto argument to this op. + bool getNoProto() { return false; } + //===------------------------------------------------------------------===// // SymbolOpInterface Methods //===------------------------------------------------------------------===// @@ -1516,6 +1520,41 @@ class CIR_CallOpBase<string mnemonic, list<Trait> extra_traits = []> !listconcat(extra_traits, [DeclareOpInterfaceMethods<CIRCallOpInterface>, DeclareOpInterfaceMethods<SymbolUserOpInterface>])> { + let extraClassDeclaration = [{ + /// Get the argument operands to the called function. + mlir::OperandRange getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + + mlir::MutableOperandRange getArgOperandsMutable() { + llvm_unreachable("NYI"); + } + + /// Return the callee of this operation + mlir::CallInterfaceCallable getCallableForCallee() { + return (*this)->getAttrOfType<mlir::SymbolRefAttr>("callee"); + } + + /// Set the callee for this operation. + void setCalleeFromCallable(::mlir::CallInterfaceCallable callee) { + (*this)->setAttr(getCalleeAttrName(), + mlir::cast<mlir::SymbolRefAttr>(callee)); + } + + ::mlir::ArrayAttr getArgAttrsAttr() { return {}; } + ::mlir::ArrayAttr getResAttrsAttr() { return {}; } + + void setResAttrsAttr(::mlir::ArrayAttr attrs) {} + void setArgAttrsAttr(::mlir::ArrayAttr attrs) {} + + ::mlir::Attribute removeArgAttrsAttr() { return {}; } + ::mlir::Attribute removeResAttrsAttr() { return {}; } + + void setArg(unsigned index, mlir::Value value) { + setOperand(index, value); + } + }]; + let hasCustomAssemblyFormat = 1; let skipDefaultBuilders = 1; let hasVerifier = 0; @@ -1525,7 +1564,8 @@ class CIR_CallOpBase<string mnemonic, list<Trait> extra_traits = []> // the upstreaming process moves on. The verifiers is also missing for now, // will add in the future. - dag commonArgs = (ins FlatSymbolRefAttr:$callee); + dag commonArgs = (ins FlatSymbolRefAttr:$callee, + Variadic<CIR_AnyType>:$args); } def CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> { @@ -1546,7 +1586,9 @@ def CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> { let arguments = commonArgs; let builders = [OpBuilder<(ins "mlir::SymbolRefAttr":$callee, - "mlir::Type":$resType), [{ + "mlir::Type":$resType, + "mlir::ValueRange":$operands), [{ + $_state.addOperands(operands); $_state.addAttribute("callee", callee); if (resType && !isa<VoidType>(resType)) $_state.addTypes(resType); diff --git a/clang/include/clang/CIR/Interfaces/CIROpInterfaces.td b/clang/include/clang/CIR/Interfaces/CIROpInterfaces.td index c6c6356118ac6..8227ce4bea5a3 100644 --- a/clang/include/clang/CIR/Interfaces/CIROpInterfaces.td +++ b/clang/include/clang/CIR/Interfaces/CIROpInterfaces.td @@ -21,9 +21,24 @@ let cppNamespace = "::cir" in { // The CIRCallOpInterface must be used instead of CallOpInterface when looking // at arguments and other bits of CallOp. This creates a level of abstraction // that's useful for handling indirect calls and other details. - def CIRCallOpInterface : OpInterface<"CIRCallOpInterface", []> { + def CIRCallOpInterface : OpInterface<"CIRCallOpInterface", [CallOpInterface]> { // Currently we don't have any methods defined in CIRCallOpInterface. We'll // add more methods as the upstreaming proceeds. + let methods = [ + InterfaceMethod<"", "mlir::Operation::operand_iterator", + "arg_operand_begin", (ins)>, + InterfaceMethod<"", "mlir::Operation::operand_iterator", + "arg_operand_end", (ins)>, + InterfaceMethod< + "Return the operand at index 'i', accounts for indirect call or " + "exception info", + "mlir::Value", "getArgOperand", + (ins "unsigned":$i)>, + InterfaceMethod< + "Return the number of operands, accounts for indirect call or " + "exception info", + "unsigned", "getNumArgOperands", (ins)>, + ]; } def CIRGlobalValueInterface diff --git a/clang/include/clang/CIR/MissingFeatures.h b/clang/include/clang/CIR/MissingFeatures.h index 6bfc1199aea55..370d82d26ebe7 100644 --- a/clang/include/clang/CIR/MissingFeatures.h +++ b/clang/include/clang/CIR/MissingFeatures.h @@ -76,7 +76,13 @@ struct MissingFeatures { // CallOp handling static bool opCallBuiltinFunc() { return false; } static bool opCallPseudoDtor() { return false; } - static bool opCallArgs() { return false; } + static bool opCallAggregateArgs() { return false; } + static bool opCallPaddingArgs() { return false; } + static bool opCallABIExtendArg() { return false; } + static bool opCallABIIndirectArg() { return false; } + static bool opCallWidenArg() { return false; } + static bool opCallBitcastArg() { return false; } + static bool opCallImplicitObjectSizeArgs() { return false; } static bool opCallReturn() { return false; } static bool opCallArgEvaluationOrder() { return false; } static bool opCallCallConv() { return false; } @@ -90,6 +96,9 @@ struct MissingFeatures { static bool opCallAttrs() { return false; } static bool opCallSurroundingTry() { return false; } static bool opCallASTAttr() { return false; } + static bool opCallVariadic() { return false; } + static bool opCallObjCMethod() { return false; } + static bool opCallExtParameterInfo() { return false; } // ScopeOp handling static bool opScopeCleanupRegion() { return false; } @@ -157,6 +166,7 @@ struct MissingFeatures { static bool emitCheckedInBoundsGEP() { return false; } static bool preservedAccessIndexRegion() { return false; } static bool bitfields() { return false; } + static bool msabi() { return false; } static bool typeChecks() { return false; } static bool lambdaFieldToName() { return false; } diff --git a/clang/lib/CIR/CodeGen/CIRGenCall.cpp b/clang/lib/CIR/CodeGen/CIRGenCall.cpp index 69266f79a88a5..bea91f8ec0ec7 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCall.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenCall.cpp @@ -18,15 +18,110 @@ using namespace clang; using namespace clang::CIRGen; -CIRGenFunctionInfo *CIRGenFunctionInfo::create(CanQualType resultType) { - void *buffer = operator new(totalSizeToAlloc<ArgInfo>(1)); +CIRGenFunctionInfo * +CIRGenFunctionInfo::create(CanQualType resultType, + llvm::ArrayRef<CanQualType> argTypes) { + void *buffer = operator new(totalSizeToAlloc<ArgInfo>(argTypes.size() + 1)); CIRGenFunctionInfo *fi = new (buffer) CIRGenFunctionInfo(); + fi->numArgs = argTypes.size(); fi->getArgsBuffer()[0].type = resultType; + for (unsigned i = 0; i < argTypes.size(); ++i) + fi->getArgsBuffer()[i + 1].type = argTypes[i]; return fi; } +namespace { + +/// Encapsulates information about the way function arguments from +/// CIRGenFunctionInfo should be passed to actual CIR function. +class ClangToCIRArgMapping { + static constexpr unsigned invalidIndex = ~0U; + unsigned totalNumCIRArgs; + + /// Arguments of CIR function corresponding to single Clang argument. + struct CIRArgs { + // Argument is expanded to CIR arguments at positions + // [FirstArgIndex, FirstArgIndex + NumberOfArgs). + unsigned firstArgIndex = 0; + unsigned numberOfArgs = 0; + + CIRArgs() : firstArgIndex(invalidIndex), numberOfArgs(0) {} + }; + + SmallVector<CIRArgs, 8> argInfo; + +public: + ClangToCIRArgMapping(const ASTContext &astContext, + const CIRGenFunctionInfo &funcInfo) + : totalNumCIRArgs(0), argInfo(funcInfo.arg_size()) { + construct(astContext, funcInfo); + } + + unsigned totalCIRArgs() const { return totalNumCIRArgs; } + + /// Returns index of first CIR argument corresponding to argNo, and their + /// quantity. + std::pair<unsigned, unsigned> getCIRArgs(unsigned argNo) const { + assert(argNo < argInfo.size()); + return std::make_pair(argInfo[argNo].firstArgIndex, + argInfo[argNo].numberOfArgs); + } + +private: + void construct(const ASTContext &astContext, + const CIRGenFunctionInfo &funcInfo); +}; + +void ClangToCIRArgMapping::construct(const ASTContext &astContext, + const CIRGenFunctionInfo &funcInfo) { + unsigned cirArgNo = 0; + + assert(!cir::MissingFeatures::opCallABIIndirectArg()); + + unsigned argNo = 0; + unsigned numArgs = funcInfo.arg_size(); + for (const auto *i = funcInfo.arg_begin(); argNo < numArgs; ++i, ++argNo) { + assert(i != funcInfo.arg_end()); + const cir::ABIArgInfo &ai = i->info; + // Collect data about CIR arguments corresponding to Clang argument ArgNo. + auto &cirArgs = argInfo[argNo]; + + assert(!cir::MissingFeatures::opCallPaddingArgs()); + + switch (ai.getKind()) { + default: + assert(!cir::MissingFeatures::abiArgInfo()); + // For now we just fall through. More argument kinds will be added later + // as the upstreaming proceeds. + [[fallthrough]]; + case cir::ABIArgInfo::Direct: + // Postpone splitting structs into elements since this makes it way + // more complicated for analysis to obtain information on the original + // arguments. + // + // TODO(cir): a LLVM lowering prepare pass should break this down into + // the appropriated pieces. + assert(!cir::MissingFeatures::opCallABIExtendArg()); + cirArgs.numberOfArgs = 1; + break; + } + + if (cirArgs.numberOfArgs > 0) { + cirArgs.firstArgIndex = cirArgNo; + cirArgNo += cirArgs.numberOfArgs; + } + } + + assert(argNo == argInfo.size()); + assert(!cir::MissingFeatures::opCallInAlloca()); + + totalNumCIRArgs = cirArgNo; +} + +} // namespace + CIRGenCallee CIRGenCallee::prepareConcreteCallee(CIRGenFunction &cgf) const { assert(!cir::MissingFeatures::opCallVirtual()); return *this; @@ -34,6 +129,7 @@ CIRGenCallee CIRGenCallee::prepareConcreteCallee(CIRGenFunction &cgf) const { static const CIRGenFunctionInfo & arrangeFreeFunctionLikeCall(CIRGenTypes &cgt, CIRGenModule &cgm, + const CallArgList &args, const FunctionType *fnType) { if (const auto *proto = dyn_cast<FunctionProtoType>(fnType)) { if (proto->isVariadic()) @@ -44,22 +140,26 @@ arrangeFreeFunctionLikeCall(CIRGenTypes &cgt, CIRGenModule &cgm, cast<FunctionNoProtoType>(fnType))) cgm.errorNYI("call to function without a prototype"); - assert(!cir::MissingFeatures::opCallArgs()); + SmallVector<CanQualType, 16> argTypes; + for (const CallArg &arg : args) + argTypes.push_back(cgt.getASTContext().getCanonicalParamType(arg.ty)); CanQualType retType = fnType->getReturnType() ->getCanonicalTypeUnqualified() .getUnqualifiedType(); - return cgt.arrangeCIRFunctionInfo(retType); + return cgt.arrangeCIRFunctionInfo(retType, argTypes); } const CIRGenFunctionInfo & -CIRGenTypes::arrangeFreeFunctionCall(const FunctionType *fnType) { - return arrangeFreeFunctionLikeCall(*this, cgm, fnType); +CIRGenTypes::arrangeFreeFunctionCall(const CallArgList &args, + const FunctionType *fnType) { + return arrangeFreeFunctionLikeCall(*this, cgm, args, fnType); } -static cir::CIRCallOpInterface emitCallLikeOp(CIRGenFunction &cgf, - mlir::Location callLoc, - cir::FuncOp directFuncOp) { +static cir::CIRCallOpInterface +emitCallLikeOp(CIRGenFunction &cgf, mlir::Location callLoc, + cir::FuncOp directFuncOp, + const SmallVectorImpl<mlir::Value> &cirCallArgs) { CIRGenBuilderTy &builder = cgf.getBuilder(); assert(!cir::MissingFeatures::opCallSurroundingTry()); @@ -68,20 +168,70 @@ static cir::CIRCallOpInterface emitCallLikeOp(CIRGenFunction &cgf, assert(builder.getInsertionBlock() && "expected valid basic block"); assert(!cir::MissingFeatures::opCallIndirect()); - return builder.createCallOp(callLoc, directFuncOp); + return builder.createCallOp(callLoc, directFuncOp, cirCallArgs); } RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &funcInfo, const CIRGenCallee &callee, ReturnValueSlot returnValue, + const CallArgList &args, cir::CIRCallOpInterface *callOp, mlir::Location loc) { QualType retTy = funcInfo.getReturnType(); const cir::ABIArgInfo &retInfo = funcInfo.getReturnInfo(); - assert(!cir::MissingFeatures::opCallArgs()); + ClangToCIRArgMapping cirFuncArgs(cgm.getASTContext(), funcInfo); + SmallVector<mlir::Value, 16> cirCallArgs(cirFuncArgs.totalCIRArgs()); + assert(!cir::MissingFeatures::emitLifetimeMarkers()); + // Translate all of the arguments as necessary to match the CIR lowering. + assert(funcInfo.arg_size() == args.size() && + "Mismatch between function signature & arguments."); + unsigned argNo = 0; + const auto *infoIter = funcInfo.arg_begin(); + for (auto i = args.begin(), e = args.end(); i != e; + ++i, ++infoIter, ++argNo) { + const cir::ABIArgInfo &argInfo = infoIter->info; + + // Insert a padding argument to ensure proper alignment. + assert(!cir::MissingFeatures::opCallPaddingArgs()); + + unsigned firstCIRArg; + unsigned numCIRArgs; + std::tie(firstCIRArg, numCIRArgs) = cirFuncArgs.getCIRArgs(argNo); + + switch (argInfo.getKind()) { + case cir::ABIArgInfo::Direct: { + if (!mlir::isa<cir::RecordType>(argInfo.getCoerceToType()) && + argInfo.getCoerceToType() == convertType(infoIter->type) && + argInfo.getDirectOffset() == 0) { + assert(numCIRArgs == 1); + assert(!cir::MissingFeatures::opCallAggregateArgs()); + mlir::Value v = i->getKnownRValue().getScalarVal(); + + assert(!cir::MissingFeatures::opCallExtParameterInfo()); + + // We might have to widen integers, but we should never truncate. + assert(!cir::MissingFeatures::opCallWidenArg()); + + // If the argument doesn't match, perform a bitcast to coerce it. This + // can happen due to trivial type mismatches. + assert(!cir::MissingFeatures::opCallBitcastArg()); + + cirCallArgs[firstCIRArg] = v; + break; + } + + assert(!cir::MissingFeatures::opCallAggregateArgs()); + cgm.errorNYI("aggregate function call argument"); + break; + } + default: + cgm.errorNYI("unsupported argument kind"); + } + } + const CIRGenCallee &concreteCallee = callee.prepareConcreteCallee(*this); mlir::Operation *calleePtr = concreteCallee.getFunctionPointer(); @@ -102,7 +252,8 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &funcInfo, assert(!cir::MissingFeatures::opCallIndirect()); assert(!cir::MissingFeatures::opCallAttrs()); - cir::CIRCallOpInterface theCall = emitCallLikeOp(*this, loc, directFuncOp); + cir::CIRCallOpInterface theCall = + emitCallLikeOp(*this, loc, directFuncOp, cirCallArgs); if (callOp) *callOp = theCall; @@ -152,3 +303,105 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &funcInfo, return ret; } + +void CIRGenFunction::emitCallArg(CallArgList &args, const clang::Expr *e, + clang::QualType argType) { + assert(argType->isReferenceType() == e->isGLValue() && + "reference binding to unmaterialized r-value!"); + + if (e->isGLValue()) { + assert(e->getObjectKind() == OK_Ordinary); + args.add(emitReferenceBindingToExpr(e), argType); + } + + bool hasAggregateEvalKind = hasAggregateEvaluationKind(argType); + + if (hasAggregateEvalKind) { + assert(!cir::MissingFeatures::opCallAggregateArgs()); + cgm.errorNYI(e->getSourceRange(), "aggregate function call argument"); + } + + args.add(emitAnyExprToTemp(e), argType); +} + +/// Similar to emitAnyExpr(), however, the result will always be accessible +/// even if no aggregate location is provided. +RValue CIRGenFunction::emitAnyExprToTemp(const Expr *e) { + assert(!cir::MissingFeatures::opCallAggregateArgs()); + + if (hasAggregateEvaluationKind(e->getType())) + cgm.errorNYI(e->getSourceRange(), "emit aggregate value to temp"); + + return emitAnyExpr(e); +} + +void CIRGenFunction::emitCallArgs( + CallArgList &args, PrototypeWrapper prototype, + llvm::iterator_range<clang::CallExpr::const_arg_iterator> argRange, + AbstractCallee callee, unsigned paramsToSkip) { + llvm::SmallVector<QualType, 16> argTypes; + + assert(!cir::MissingFeatures::opCallCallConv()); + + // First, if a prototype was provided, use those argument types. + assert(!cir::MissingFeatures::opCallVariadic()); + if (prototype.p) { + assert(!cir::MissingFeatures::opCallObjCMethod()); + + const auto *fpt = cast<const FunctionProtoType *>(prototype.p); + argTypes.assign(fpt->param_type_begin() + paramsToSkip, + fpt->param_type_end()); + } + + // If we still have any arguments, emit them using the type of the argument. + for (auto *a : llvm::drop_begin(argRange, argTypes.size())) + argTypes.push_back(a->getType()); + assert(argTypes.size() == (size_t)(argRange.end() - argRange.begin())); + + // We must evaluate arguments from right to left in the MS C++ ABI, because + // arguments are destroyed left to right in the callee. As a special case, + // there are certain language constructs taht require left-to-right + // evaluation, and in those cases we consider the evaluation order requirement + // to trump the "destruction order is reverse construction order" guarantee. + auto leftToRight = true; + assert(!cir::MissingFeatures::msabi()); + + auto maybeEmitImplicitObjectSize = [&](size_t i, const Expr *arg, + RValue emittedArg) { + if (callee.hasFunctionDecl() || i >= callee.getNumParams()) + return; + auto *ps = callee.getParamDecl(i)->getAttr<PassObjectSizeAttr>(); + if (!ps) + return; + + assert(!cir::MissingFeatures::opCallImplicitObjectSizeArgs()); + cgm.errorNYI("emit implicit object size for call arg"); + }; + + // Evaluate each argument in the appropriate order. + size_t callArgsStart = args.size(); + for (size_t i = 0; i != argTypes.size(); ++i) { + size_t idx = leftToRight ? i : argTypes.size() - i - 1; + CallExpr::const_arg_iterator currentArg = argRange.begin() + idx; + size_t initialArgSize = args.size(); + + emitCallArg(args, *currentArg, argTypes[idx]); + + // In particular, we depend on it being the last arg in Args, and the + // objectsize bits depend on there only being one arg if !LeftToRight. + assert(initialArgSize + 1 == args.size() && + "The code below depends on only adding one arg per emitCallArg"); + (void)initialArgSize; + + // Since pointer argument are never emitted as LValue, it is safe to emit + // non-null argument check for r-value only. + if (!args.back().hasLValue()) { + RValue rvArg = args.back().getKnownRValue(); + assert(!cir::MissingFeatures::sanitizers()); + maybeEmitImplicitObjectSize(idx, *currentArg, rvArg); + } + + if (!leftToRight) + std::reverse(args.begin() + callArgsStart, args.end()); + } +} diff --git a/clang/lib/CIR/CodeGen/CIRGenCall.h b/clang/lib/CIR/CodeGen/CIRGenCall.h index 4427fda863d7e..0e7ab11bfa96c 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCall.h +++ b/clang/lib/CIR/CodeGen/CIRGenCall.h @@ -14,6 +14,7 @@ #ifndef CLANG_LIB_CODEGEN_CIRGENCALL_H #define CLANG_LIB_CODEGEN_CIRGENCALL_H +#include "CIRGenValue.h" #include "mlir/IR/Operation.h" #include "clang/AST/GlobalDecl.h" #include "llvm/ADT/SmallVector.h" @@ -77,9 +78,36 @@ class CIRGenCallee { /// The decl must be either a ParmVarDecl or ImplicitParamDecl. class FunctionArgList : public llvm::SmallVector<const clang::VarDecl *, 16> {}; -struct CallArg {}; +struct CallArg { +private: + union { + RValue rv; + LValue lv; // This argument is semantically a load from this l-value + }; + bool hasLV; + + /// A data-flow flag to make sure getRValue and/or copyInto are not + /// called twice for duplicated IR emission. + mutable bool isUsed; + +public: + clang::QualType ty; + + CallArg(RValue rv, clang::QualType ty) + : rv(rv), hasLV(false), isUsed(false), ty(ty) {} -class CallArgList : public llvm::SmallVector<CallArg, 8> {}; + bool hasLValue() const { return hasLV; } + + RValue getKnownRValue() const { + assert(!hasLV && !isUsed); + return rv; + } +}; + +class CallArgList : public llvm::SmallVector<CallArg, 8> { +public: + void add(RValue rvalue, clang::QualType type) { emplace_back(rvalue, type); } +}; /// Contains the address where the return value of a function can be stored, and /// whether the address is volatile or not. diff --git a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp index 0a518c0fd935d..562daf0e5c897 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp @@ -393,6 +393,17 @@ mlir::Value CIRGenFunction::emitLoadOfScalar(LValue lvalue, return loadOp; } +RValue CIRGenFunction::emitReferenceBindingToExpr(const Expr *e) { + // Emit the expression as an lvalue. + LValue lv = emitLValue(e); + assert(lv.isSimple()); + auto value = lv.getPointer(); + + assert(!cir::MissingFeatures::sanitizers()); + + return RValue::get(value); +} + /// Given an expression that represents a value lvalue, this /// method emits the address of the lvalue, then loads the result as an rvalue, /// returning the rvalue. @@ -855,10 +866,15 @@ RValue CIRGenFunction::emitCall(clang::QualType calleeTy, const auto *fnType = cast<FunctionType>(pointeeTy); assert(!cir::MissingFeatures::sanitizers()); - assert(!cir::MissingFeatures::opCallArgs()); + + CallArgList args; + assert(!cir::MissingFeatures::opCallArgEvaluationOrder()); + + emitCallArgs(args, dyn_cast<FunctionProtoType>(fnType), e->arguments(), + e->getDirectCallee()); const CIRGenFunctionInfo &funcInfo = - cgm.getTypes().arrangeFreeFunctionCall(fnType); + cgm.getTypes().arrangeFreeFunctionCall(args, fnType); assert(!cir::MissingFeatures::opCallNoPrototypeFunc()); assert(!cir::MissingFeatures::opCallChainCall()); @@ -866,8 +882,8 @@ RValue CIRGenFunction::emitCall(clang::QualType calleeTy, assert(!cir::MissingFeatures::opCallMustTail()); cir::CIRCallOpInterface callOp; - RValue callResult = - emitCall(funcInfo, callee, returnValue, &callOp, getLoc(e->getExprLoc())); + RValue callResult = emitCall(funcInfo, callee, returnValue, args, &callOp, + getLoc(e->getExprLoc())); assert(!cir::MissingFeatures::generateDebugInfo()); diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h b/clang/lib/CIR/CodeGen/CIRGenFunction.h index f533d0ab53cd2..04caea04cf303 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunction.h +++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h @@ -93,6 +93,10 @@ class CIRGenFunction : public CIRGenTypeCache { return getEvaluationKind(type) == cir::TEK_Scalar; } + static bool hasAggregateEvaluationKind(clang::QualType type) { + return getEvaluationKind(type) == cir::TEK_Aggregate; + } + CIRGenFunction(CIRGenModule &cgm, CIRGenBuilderTy &builder, bool suppressNewContext = false); ~CIRGenFunction(); @@ -161,6 +165,17 @@ class CIRGenFunction : public CIRGenTypeCache { const clang::LangOptions &getLangOpts() const { return cgm.getLangOpts(); } + // Wrapper for function prototype sources. Wraps either a FunctionProtoType or + // an ObjCMethodDecl. + struct PrototypeWrapper { + llvm::PointerUnion<const clang::FunctionProtoType *, + const clang::ObjCMethodDecl *> + p; + + PrototypeWrapper(const clang::FunctionProtoType *ft) : p(ft) {} + PrototypeWrapper(const clang::ObjCMethodDecl *md) : p(md) {} + }; + /// An abstract representation of regular/ObjC call/message targets. class AbstractCallee { /// The function declaration of the callee. @@ -169,6 +184,23 @@ class CIRGenFunction : public CIRGenTypeCache { public: AbstractCallee() : calleeDecl(nullptr) {} AbstractCallee(const clang::FunctionDecl *fd) : calleeDecl(fd) {} + + bool hasFunctionDecl() const { + return llvm::isa_and_nonnull<clang::FunctionDecl>(calleeDecl); + } + + unsigned getNumParams() const { + if (const auto *fd = llvm::dyn_cast<clang::FunctionDecl>(calleeDecl)) + return fd->getNumParams(); + return llvm::cast<clang::ObjCMethodDecl>(calleeDecl)->param_size(); + } + + const clang::ParmVarDecl *getParamDecl(unsigned I) const { + if (const auto *fd = llvm::dyn_cast<clang::FunctionDecl>(calleeDecl)) + return fd->getParamDecl(I); + return *(llvm::cast<clang::ObjCMethodDecl>(calleeDecl)->param_begin() + + I); + } }; void finishFunction(SourceLocation endLoc); @@ -444,6 +476,10 @@ class CIRGenFunction : public CIRGenTypeCache { /// should be returned. RValue emitAnyExpr(const clang::Expr *e); + /// Similarly to emitAnyExpr(), however, the result will always be accessible + /// even if no aggregate location is provided. + RValue emitAnyExprToTemp(const clang::Expr *e); + LValue emitArraySubscriptExpr(const clang::ArraySubscriptExpr *e); AutoVarEmission emitAutoVarAlloca(const clang::VarDecl &d); @@ -462,9 +498,16 @@ class CIRGenFunction : public CIRGenTypeCache { RValue emitCall(const CIRGenFunctionInfo &funcInfo, const CIRGenCallee &callee, ReturnValueSlot returnValue, - cir::CIRCallOpInterface *callOp, mlir::Location loc); + const CallArgList &args, cir::CIRCallOpInterface *callOp, + mlir::Location loc); RValue emitCall(clang::QualType calleeTy, const CIRGenCallee &callee, const clang::CallExpr *e, ReturnValueSlot returnValue); + void emitCallArg(CallArgList &args, const clang::Expr *e, + clang::QualType argType); + void emitCallArgs( + CallArgList &args, PrototypeWrapper prototype, + llvm::iterator_range<clang::CallExpr::const_arg_iterator> argRange, + AbstractCallee callee = AbstractCallee(), unsigned paramsToSkip = 0); RValue emitCallExpr(const clang::CallExpr *e, ReturnValueSlot returnValue = ReturnValueSlot()); CIRGenCallee emitCallee(const clang::Expr *e); @@ -489,6 +532,9 @@ class CIRGenFunction : public CIRGenTypeCache { mlir::Value emitPromotedScalarExpr(const Expr *e, QualType promotionType); + /// Emits a reference binding to the passed in expression. + RValue emitReferenceBindingToExpr(const Expr *e); + /// Emit the computation of the specified expression of scalar type. mlir::Value emitScalarExpr(const clang::Expr *e); diff --git a/clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h b/clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h index c4a2b238c96ae..4319f7a2be225 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h +++ b/clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h @@ -33,11 +33,14 @@ class CIRGenFunctionInfo final CIRGenFunctionInfoArgInfo> { using ArgInfo = CIRGenFunctionInfoArgInfo; + unsigned numArgs; + ArgInfo *getArgsBuffer() { return getTrailingObjects<ArgInfo>(); } const ArgInfo *getArgsBuffer() const { return getTrailingObjects<ArgInfo>(); } public: - static CIRGenFunctionInfo *create(CanQualType resultType); + static CIRGenFunctionInfo *create(CanQualType resultType, + llvm::ArrayRef<CanQualType> argTypes); void operator delete(void *p) { ::operator delete(p); } @@ -45,14 +48,34 @@ class CIRGenFunctionInfo final // these have to be public. friend class TrailingObjects; + using const_arg_iterator = const ArgInfo *; + using arg_iterator = ArgInfo *; + // This function has to be CamelCase because llvm::FoldingSet requires so. // NOLINTNEXTLINE(readability-identifier-naming) - static void Profile(llvm::FoldingSetNodeID &id, CanQualType resultType) { + static void Profile(llvm::FoldingSetNodeID &id, CanQualType resultType, + llvm::ArrayRef<clang::CanQualType> argTypes) { resultType.Profile(id); + for (auto i : argTypes) + i.Profile(id); } void Profile(llvm::FoldingSetNodeID &id) { getReturnType().Profile(id); } + llvm::MutableArrayRef<ArgInfo> arguments() { + return llvm::MutableArrayRef<ArgInfo>(arg_begin(), numArgs); + } + llvm::ArrayRef<ArgInfo> arguments() const { + return llvm::ArrayRef<ArgInfo>(arg_begin(), numArgs); + } + + const_arg_iterator arg_begin() const { return getArgsBuffer() + 1; } + const_arg_iterator arg_end() const { return getArgsBuffer() + 1 + numArgs; } + arg_iterator arg_begin() { return getArgsBuffer() + 1; } + arg_iterator arg_end() { return getArgsBuffer() + 1 + numArgs; } + + unsigned arg_size() const { return numArgs; } + CanQualType getReturnType() const { return getArgsBuffer()[0].type; } cir::ABIArgInfo &getReturnInfo() { return getArgsBuffer()[0].info; } diff --git a/clang/lib/CIR/CodeGen/CIRGenTypes.cpp b/clang/lib/CIR/CodeGen/CIRGenTypes.cpp index b11f8466607f8..071b8c856c45a 100644 --- a/clang/lib/CIR/CodeGen/CIRGenTypes.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenTypes.cpp @@ -493,11 +493,14 @@ bool CIRGenTypes::isZeroInitializable(clang::QualType t) { return true; } -const CIRGenFunctionInfo & -CIRGenTypes::arrangeCIRFunctionInfo(CanQualType returnType) { +const CIRGenFunctionInfo &CIRGenTypes::arrangeCIRFunctionInfo( + CanQualType returnType, llvm::ArrayRef<clang::CanQualType> argTypes) { + assert(llvm::all_of(argTypes, + [](CanQualType T) { return T.isCanonicalAsParam(); })); + // Lookup or create unique function info. llvm::FoldingSetNodeID id; - CIRGenFunctionInfo::Profile(id, returnType); + CIRGenFunctionInfo::Profile(id, returnType, argTypes); void *insertPos = nullptr; CIRGenFunctionInfo *fi = functionInfos.FindNodeOrInsertPos(id, insertPos); @@ -507,7 +510,7 @@ CIRGenTypes::arrangeCIRFunctionInfo(CanQualType returnType) { assert(!cir::MissingFeatures::opCallCallConv()); // Construction the function info. We co-allocate the ArgInfos. - fi = CIRGenFunctionInfo::create(returnType); + fi = CIRGenFunctionInfo::create(returnType, argTypes); functionInfos.InsertNode(fi, insertPos); bool inserted = functionsBeingProcessed.insert(fi).second; @@ -524,7 +527,9 @@ CIRGenTypes::arrangeCIRFunctionInfo(CanQualType returnType) { if (retInfo.canHaveCoerceToType() && retInfo.getCoerceToType() == nullptr) retInfo.setCoerceToType(convertType(fi->getReturnType())); - assert(!cir::MissingFeatures::opCallArgs()); + for (CIRGenFunctionInfoArgInfo &i : fi->arguments()) + if (i.info.canHaveCoerceToType() && i.info.getCoerceToType() == nullptr) + i.info.setCoerceToType(convertType(i.type)); bool erased = functionsBeingProcessed.erase(fi); (void)erased; diff --git a/clang/lib/CIR/CodeGen/CIRGenTypes.h b/clang/lib/CIR/CodeGen/CIRGenTypes.h index 5b4027601ca3a..38f4b389c8db9 100644 --- a/clang/lib/CIR/CodeGen/CIRGenTypes.h +++ b/clang/lib/CIR/CodeGen/CIRGenTypes.h @@ -121,9 +121,12 @@ class CIRGenTypes { /// LLVM zeroinitializer. bool isZeroInitializable(clang::QualType ty); - const CIRGenFunctionInfo &arrangeFreeFunctionCall(const FunctionType *fnType); + const CIRGenFunctionInfo &arrangeFreeFunctionCall(const CallArgList &args, + const FunctionType *fnType); - const CIRGenFunctionInfo &arrangeCIRFunctionInfo(CanQualType returnType); + const CIRGenFunctionInfo & + arrangeCIRFunctionInfo(CanQualType returnType, + llvm::ArrayRef<clang::CanQualType> argTypes); }; } // namespace clang::CIRGen diff --git a/clang/lib/CIR/CodeGen/TargetInfo.cpp b/clang/lib/CIR/CodeGen/TargetInfo.cpp index 0d0ffb93d4e7e..1af8b9fedac6f 100644 --- a/clang/lib/CIR/CodeGen/TargetInfo.cpp +++ b/clang/lib/CIR/CodeGen/TargetInfo.cpp @@ -32,7 +32,13 @@ void X8664ABIInfo::computeInfo(CIRGenFunctionInfo &funcInfo) const { // Top level CIR has unlimited arguments and return types. Lowering for ABI // specific concerns should happen during a lowering phase. Assume everything // is direct for now. - assert(!cir::MissingFeatures::opCallArgs()); + for (auto it = funcInfo.arg_begin(), ie = funcInfo.arg_end(); it != ie; + ++it) { + if (testIfIsVoidTy(it->type)) + it->info = cir::ABIArgInfo::getIgnore(); + else + it->info = cir::ABIArgInfo::getDirect(cgt.convertType(it->type)); + } CanQualType retTy = funcInfo.getReturnType(); if (testIfIsVoidTy(retTy)) diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 146c91b253f39..3036a354407b1 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -446,8 +446,31 @@ OpFoldResult cir::CastOp::fold(FoldAdaptor adaptor) { // CallOp //===----------------------------------------------------------------------===// +mlir::Operation::operand_iterator cir::CallOp::arg_operand_begin() { + assert(!cir::MissingFeatures::opCallIndirect()); + return operand_begin(); +} + +mlir::Operation::operand_iterator cir::CallOp::arg_operand_end() { + return operand_end(); +} + +/// Return the operand at index 'i'. +Value cir::CallOp::getArgOperand(unsigned i) { + assert(!cir::MissingFeatures::opCallIndirect()); + return getOperand(i); +} + +/// Return the number of operands. +unsigned cir::CallOp::getNumArgOperands() { + assert(!cir::MissingFeatures::opCallIndirect()); + return this->getOperation()->getNumOperands(); +} + static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser, mlir::OperationState &result) { + llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> ops; + llvm::SMLoc opsLoc; mlir::FlatSymbolRefAttr calleeAttr; llvm::ArrayRef<mlir::Type> allResultTypes; @@ -458,9 +481,9 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser, if (parser.parseLParen()) return mlir::failure(); - // TODO(cir): parse argument list here - assert(!cir::MissingFeatures::opCallArgs()); - + opsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(ops)) + return mlir::failure(); if (parser.parseRParen()) return mlir::failure(); @@ -477,6 +500,9 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser, allResultTypes = opsFnTy.getResults(); result.addTypes(allResultTypes); + if (parser.resolveOperands(ops, opsFnTy.getInputs(), opsLoc, result.operands)) + return mlir::failure(); + return mlir::success(); } @@ -485,11 +511,11 @@ static void printCallCommon(mlir::Operation *op, mlir::OpAsmPrinter &printer) { printer << ' '; + auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op); + auto ops = callLikeOp.getArgOperands(); + printer.printAttributeWithoutType(calleeSym); - printer << "("; - // TODO(cir): print call args here - assert(!cir::MissingFeatures::opCallArgs()); - printer << ")"; + printer << "(" << ops << ")"; printer.printOptionalAttrDict(op->getAttrs(), {"callee"}); @@ -525,9 +551,23 @@ verifyCallCommInSymbolUses(mlir::Operation *op, // Verify that the operand and result types match the callee. Note that // argument-checking is disabled for functions without a prototype. auto fnType = fn.getFunctionType(); + if (!fn.getNoProto()) { + unsigned numCallOperands = callIf.getNumArgOperands(); + unsigned numFnOpOperands = fnType.getNumInputs(); + + assert(!cir::MissingFeatures::opCallVariadic()); + + if (numCallOperands != numFnOpOperands) + return op->emitOpError("incorrect number of operands for callee"); + + for (unsigned i = 0, e = numFnOpOperands; i != e; ++i) + if (callIf.getArgOperand(i).getType() != fnType.getInput(i)) + return op->emitOpError("operand type mismatch: expected operand type ") + << fnType.getInput(i) << ", but provided " + << op->getOperand(i).getType() << " for operand number " << i; + } - // TODO(cir): verify function arguments - assert(!cir::MissingFeatures::opCallArgs()); + assert(!cir::MissingFeatures::opCallCallConv()); // Void function must not return any results. if (fnType.hasVoidReturn() && op->getNumResults() != 0) diff --git a/clang/test/CIR/CodeGen/call.cpp b/clang/test/CIR/CodeGen/call.cpp index 9082fbc9f6860..8fec8673cd691 100644 --- a/clang/test/CIR/CodeGen/call.cpp +++ b/clang/test/CIR/CodeGen/call.cpp @@ -17,3 +17,15 @@ int f4() { // CHECK-LABEL: cir.func @f4() -> !s32i // CHECK: %[[#x:]] = cir.call @f3() : () -> !s32i // CHECK-NEXT: cir.store %[[#x]], %{{.+}} : !s32i, !cir.ptr<!s32i> + +int f5(int a, int *b, bool c); +int f6() { + int b = 1; + return f5(2, &b, false); +} + +// CHECK-LABEL: cir.func @f6() -> !s32i +// CHECK: %[[#b:]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["b", init] +// CHECK: %[[#a:]] = cir.const #cir.int<2> : !s32i +// CHECK-NEXT: %[[#c:]] = cir.const #false +// CHECK-NEXT: %5 = cir.call @f5(%[[#a]], %[[#b:]], %[[#c]]) : (!s32i, !cir.ptr<!s32i>, !cir.bool) -> !s32i diff --git a/clang/test/CIR/IR/call.cir b/clang/test/CIR/IR/call.cir index 3c3fbf3d4d987..8276c0cb9e39d 100644 --- a/clang/test/CIR/IR/call.cir +++ b/clang/test/CIR/IR/call.cir @@ -28,4 +28,19 @@ cir.func @f4() -> !s32i { // CHECK-NEXT: cir.return %[[#x]] : !s32i // CHECK-NEXT: } +cir.func @f5(!s32i, !s32i) -> !s32i +cir.func @f6() -> !s32i { + %0 = cir.const #cir.int<1> : !s32i + %1 = cir.const #cir.int<2> : !s32i + %2 = cir.call @f5(%0, %1) : (!s32i, !s32i) -> !s32i + cir.return %2 : !s32i +} + +// CHECK: cir.func @f6() -> !s32i { +// CHECK-NEXT: %[[#a:]] = cir.const #cir.int<1> : !s32i +// CHECK-NEXT: %[[#b:]] = cir.const #cir.int<2> : !s32i +// CHECK-NEXT: %[[#c:]] = cir.call @f5(%[[#a]], %[[#b]]) : (!s32i, !s32i) -> !s32i +// CHECK-NEXT: cir.return %[[#c]] : !s32i +// CHECK-NEXT: } + } diff --git a/clang/test/CIR/IR/invalid-call.cir b/clang/test/CIR/IR/invalid-call.cir index 64b6d56e0fa88..8a584bae70878 100644 --- a/clang/test/CIR/IR/invalid-call.cir +++ b/clang/test/CIR/IR/invalid-call.cir @@ -41,3 +41,30 @@ cir.func @f7() { %0 = cir.call @f6() : () -> !s32i cir.return } + +// ----- + +!s32i = !cir.int<s, 32> +!u32i = !cir.int<u, 32> + +cir.func @f8(!s32i, !s32i) +cir.func @f9() { + %0 = cir.const #cir.int<1> : !s32i + // expected-error @below {{incorrect number of operands for callee}} + cir.call @f8(%0) : (!s32i) -> () + cir.return +} + +// ----- + +!s32i = !cir.int<s, 32> +!u32i = !cir.int<u, 32> + +cir.func @f10(!s32i, !s32i) +cir.func @f11() { + %0 = cir.const #cir.int<1> : !s32i + %1 = cir.const #cir.int<2> : !u32i + // expected-error @below {{operand type mismatch: expected operand type '!cir.int<s, 32>', but provided '!cir.int<u, 32>' for operand number 1}} + cir.call @f10(%0, %1) : (!s32i, !u32i) -> () + cir.return +} _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits