https://github.com/Lancern updated https://github.com/llvm/llvm-project/pull/139748
>From 8c6a8c37037634ea48eb94e022cd76c0cececb84 Mon Sep 17 00:00:00 2001 From: Sirui Mu <msrlanc...@gmail.com> Date: Thu, 15 May 2025 23:01:18 +0800 Subject: [PATCH] [CIR] Add support for indirect calls --- .../CIR/Dialect/Builder/CIRBaseBuilder.h | 8 +++ clang/include/clang/CIR/Dialect/IR/CIROps.td | 46 +++++++++++------ clang/include/clang/CIR/MissingFeatures.h | 1 - clang/lib/CIR/CodeGen/CIRGenCall.cpp | 34 ++++++++++--- clang/lib/CIR/CodeGen/CIRGenCall.h | 11 ++++- clang/lib/CIR/CodeGen/CIRGenExpr.cpp | 24 ++++++++- clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h | 1 + clang/lib/CIR/CodeGen/CIRGenTypes.h | 1 - clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 49 ++++++++++++++++--- .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 20 +++++++- clang/test/CIR/CodeGen/call.cpp | 14 ++++++ clang/test/CIR/IR/call.cir | 14 ++++++ 12 files changed, 188 insertions(+), 35 deletions(-) diff --git a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h index a63bf4f8858d0..b680e4162a5ce 100644 --- a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h +++ b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h @@ -225,6 +225,14 @@ class CIRBaseBuilderTy : public mlir::OpBuilder { callee.getFunctionType().getReturnType(), operands); } + cir::CallOp createIndirectCallOp(mlir::Location loc, + mlir::Value indirectTarget, + cir::FuncType funcType, + mlir::ValueRange operands) { + return create<cir::CallOp>(loc, indirectTarget, funcType.getReturnType(), + operands); + } + //===--------------------------------------------------------------------===// // Cast/Conversion Operators //===--------------------------------------------------------------------===// diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index e08f372450285..cf55523802659 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -1798,13 +1798,8 @@ class CIR_CallOpBase<string mnemonic, list<Trait> extra_traits = []> DeclareOpInterfaceMethods<SymbolUserOpInterface>])> { let extraClassDeclaration = [{ /// Get the argument operands to the called function. - mlir::OperandRange getArgOperands() { - return getArgs(); - } - - mlir::MutableOperandRange getArgOperandsMutable() { - return getArgsMutable(); - } + mlir::OperandRange getArgOperands(); + mlir::MutableOperandRange getArgOperandsMutable(); /// Return the callee of this operation mlir::CallInterfaceCallable getCallableForCallee() { @@ -1826,6 +1821,9 @@ class CIR_CallOpBase<string mnemonic, list<Trait> extra_traits = []> ::mlir::Attribute removeArgAttrsAttr() { return {}; } ::mlir::Attribute removeResAttrsAttr() { return {}; } + bool isIndirect() { return !getCallee(); } + mlir::Value getIndirectCall(); + void setArg(unsigned index, mlir::Value value) { setOperand(index, value); } @@ -1839,16 +1837,24 @@ class CIR_CallOpBase<string mnemonic, list<Trait> extra_traits = []> // the upstreaming process moves on. The verifiers is also missing for now, // will add in the future. - dag commonArgs = (ins FlatSymbolRefAttr:$callee, - Variadic<CIR_AnyType>:$args); + dag commonArgs = (ins OptionalAttr<FlatSymbolRefAttr>:$callee, + Variadic<CIR_AnyType>:$args); } def CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> { let summary = "call a function"; let description = [{ - The `cir.call` operation represents a direct call to a function that is - within the same symbol scope as the call. The callee is encoded as a symbol - reference attribute named `callee`. + The `cir.call` operation represents a function call. It could represent + either a direct call or an indirect call. + + If the operation represents a direct call, the callee should be defined + within the same symbol scope as the call. The `callee` attribute contains a + symbol reference to the callee function. All operands of this operation are + arguments to the callee function. + + If the operation represents an indirect call, the `callee` attribute is + empty. The first operand of this operation must be a pointer to the callee + function. All the rest operands are arguments to the callee function. Example: @@ -1861,13 +1867,23 @@ def CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> { let arguments = commonArgs; let builders = [OpBuilder<(ins "mlir::SymbolRefAttr":$callee, - "mlir::Type":$resType, - "mlir::ValueRange":$operands), [{ + "mlir::Type":$resType, + "mlir::ValueRange":$operands), + [{ $_state.addOperands(operands); $_state.addAttribute("callee", callee); if (resType && !isa<VoidType>(resType)) $_state.addTypes(resType); - }]>]; + }]>, + OpBuilder<(ins "mlir::Value":$callee, "mlir::Type":$resType, + "mlir::ValueRange":$operands), + [{ + $_state.addOperands(callee); + $_state.addOperands(operands); + if (resType && !isa<VoidType>(resType)) + $_state.addTypes(resType); + }]>, + ]; } //===----------------------------------------------------------------------===// diff --git a/clang/include/clang/CIR/MissingFeatures.h b/clang/include/clang/CIR/MissingFeatures.h index 484822c351746..e8c8c3f3d78c1 100644 --- a/clang/include/clang/CIR/MissingFeatures.h +++ b/clang/include/clang/CIR/MissingFeatures.h @@ -93,7 +93,6 @@ struct MissingFeatures { static bool opCallChainCall() { return false; } static bool opCallNoPrototypeFunc() { return false; } static bool opCallMustTail() { return false; } - static bool opCallIndirect() { return false; } static bool opCallVirtual() { return false; } static bool opCallInAlloca() { return false; } static bool opCallAttrs() { return false; } diff --git a/clang/lib/CIR/CodeGen/CIRGenCall.cpp b/clang/lib/CIR/CodeGen/CIRGenCall.cpp index 5c65a43641844..41d0501b37bba 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCall.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenCall.cpp @@ -97,6 +97,7 @@ CIRGenTypes::arrangeFreeFunctionCall(const CallArgList &args, static cir::CIRCallOpInterface emitCallLikeOp(CIRGenFunction &cgf, mlir::Location callLoc, + cir::FuncType indirectFuncTy, mlir::Value indirectFuncVal, cir::FuncOp directFuncOp, const SmallVectorImpl<mlir::Value> &cirCallArgs) { CIRGenBuilderTy &builder = cgf.getBuilder(); @@ -105,7 +106,13 @@ emitCallLikeOp(CIRGenFunction &cgf, mlir::Location callLoc, assert(!cir::MissingFeatures::invokeOp()); assert(builder.getInsertionBlock() && "expected valid basic block"); - assert(!cir::MissingFeatures::opCallIndirect()); + + if (indirectFuncTy) { + // TODO(cir): Set calling convention for indirect calls. + assert(!cir::MissingFeatures::opCallCallConv()); + return builder.createIndirectCallOp(callLoc, indirectFuncVal, + indirectFuncTy, cirCallArgs); + } return builder.createCallOp(callLoc, directFuncOp, cirCallArgs); } @@ -134,6 +141,7 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &funcInfo, cir::CIRCallOpInterface *callOp, mlir::Location loc) { QualType retTy = funcInfo.getReturnType(); + cir::FuncType cirFuncTy = getTypes().getFunctionType(funcInfo); SmallVector<mlir::Value, 16> cirCallArgs(args.size()); @@ -185,12 +193,26 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &funcInfo, assert(!cir::MissingFeatures::invokeOp()); - auto directFuncOp = dyn_cast<cir::FuncOp>(calleePtr); - assert(!cir::MissingFeatures::opCallIndirect()); + cir::FuncType indirectFuncTy; + mlir::Value indirectFuncVal; + cir::FuncOp directFuncOp; + if (auto fnOp = dyn_cast<cir::FuncOp>(calleePtr)) + directFuncOp = fnOp; + else { + [[maybe_unused]] auto resultTypes = calleePtr->getResultTypes(); + [[maybe_unused]] auto funcPtrTy = + mlir::dyn_cast<cir::PointerType>(resultTypes.front()); + assert(funcPtrTy && mlir::isa<cir::FuncType>(funcPtrTy.getPointee()) && + "expected pointer to function"); + + indirectFuncTy = cirFuncTy; + indirectFuncVal = calleePtr->getResult(0); + } + assert(!cir::MissingFeatures::opCallAttrs()); - cir::CIRCallOpInterface theCall = - emitCallLikeOp(*this, loc, directFuncOp, cirCallArgs); + cir::CIRCallOpInterface theCall = emitCallLikeOp( + *this, loc, indirectFuncTy, indirectFuncVal, directFuncOp, cirCallArgs); if (callOp) *callOp = theCall; @@ -290,7 +312,7 @@ void CIRGenFunction::emitCallArgs( auto maybeEmitImplicitObjectSize = [&](size_t i, const Expr *arg, RValue emittedArg) { - if (callee.hasFunctionDecl() || i >= callee.getNumParams()) + if (!callee.hasFunctionDecl() || i >= callee.getNumParams()) return; auto *ps = callee.getParamDecl(i)->getAttr<PassObjectSizeAttr>(); if (!ps) diff --git a/clang/lib/CIR/CodeGen/CIRGenCall.h b/clang/lib/CIR/CodeGen/CIRGenCall.h index 2ba1676eb6b97..e4fd9c1c506d8 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCall.h +++ b/clang/lib/CIR/CodeGen/CIRGenCall.h @@ -25,11 +25,20 @@ class CIRGenFunction; /// Abstract information about a function or function prototype. class CIRGenCalleeInfo { + const clang::FunctionProtoType *calleeProtoTy; clang::GlobalDecl calleeDecl; public: - explicit CIRGenCalleeInfo() : calleeDecl() {} + explicit CIRGenCalleeInfo() : calleeProtoTy(nullptr), calleeDecl() {} + CIRGenCalleeInfo(const clang::FunctionProtoType *calleeProtoTy, + clang::GlobalDecl calleeDecl) + : calleeProtoTy(calleeProtoTy), calleeDecl(calleeDecl) {} CIRGenCalleeInfo(clang::GlobalDecl calleeDecl) : calleeDecl(calleeDecl) {} + + const clang::FunctionProtoType *getCalleeFunctionProtoType() const { + return calleeProtoTy; + } + clang::GlobalDecl getCalleeDecl() const { return calleeDecl; } }; class CIRGenCallee { diff --git a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp index 038696182f6c8..64b0c63c73cfd 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp @@ -937,8 +937,28 @@ CIRGenCallee CIRGenFunction::emitCallee(const clang::Expr *e) { return emitDirectCallee(cgm, funcDecl); } - cgm.errorNYI(e->getSourceRange(), "Unsupported callee kind"); - return {}; + assert(!cir::MissingFeatures::opCallPseudoDtor()); + + // Otherwise, we have an indirect reference. + mlir::Value calleePtr; + QualType functionType; + if (const auto *ptrType = e->getType()->getAs<clang::PointerType>()) { + calleePtr = emitScalarExpr(e); + functionType = ptrType->getPointeeType(); + } else { + functionType = e->getType(); + calleePtr = emitLValue(e).getPointer(); + } + assert(functionType->isFunctionType()); + + GlobalDecl gd; + if (const auto *vd = + dyn_cast_or_null<VarDecl>(e->getReferencedDeclOfCallee())) + gd = GlobalDecl(vd); + + CIRGenCalleeInfo calleeInfo(functionType->getAs<FunctionProtoType>(), gd); + CIRGenCallee callee(calleeInfo, calleePtr.getDefiningOp()); + return callee; } RValue CIRGenFunction::emitCallExpr(const clang::CallExpr *e, diff --git a/clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h b/clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h index 0556408fb98d1..87d5131c0b944 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h +++ b/clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h @@ -16,6 +16,7 @@ #define LLVM_CLANG_CIR_CIRGENFUNCTIONINFO_H #include "clang/AST/CanonicalType.h" +#include "clang/CIR/MissingFeatures.h" #include "llvm/ADT/FoldingSet.h" #include "llvm/Support/TrailingObjects.h" diff --git a/clang/lib/CIR/CodeGen/CIRGenTypes.h b/clang/lib/CIR/CodeGen/CIRGenTypes.h index ff8ce3f87f362..625e97002cdd5 100644 --- a/clang/lib/CIR/CodeGen/CIRGenTypes.h +++ b/clang/lib/CIR/CodeGen/CIRGenTypes.h @@ -65,7 +65,6 @@ class CIRGenTypes { /// types will be in this set. llvm::SmallPtrSet<const clang::Type *, 4> recordsBeingLaidOut; - llvm::SmallPtrSet<const CIRGenFunctionInfo *, 4> functionsBeingProcessed; /// Heper for convertType. mlir::Type convertFunctionTypeInternal(clang::QualType ft); diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 779114e09d834..32df7d94bc7a0 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -464,15 +464,35 @@ OpFoldResult cir::CastOp::fold(FoldAdaptor adaptor) { // CallOp //===----------------------------------------------------------------------===// +mlir::OperandRange cir::CallOp::getArgOperands() { + if (isIndirect()) + return getArgs().drop_front(1); + return getArgs(); +} + +mlir::MutableOperandRange cir::CallOp::getArgOperandsMutable() { + mlir::MutableOperandRange args = getArgsMutable(); + if (isIndirect()) + return args.slice(1, args.size() - 1); + return args; +} + +mlir::Value cir::CallOp::getIndirectCall() { + assert(isIndirect()); + return getOperand(0); +} + /// Return the operand at index 'i'. Value cir::CallOp::getArgOperand(unsigned i) { - assert(!cir::MissingFeatures::opCallIndirect()); + if (isIndirect()) + ++i; return getOperand(i); } /// Return the number of operands. unsigned cir::CallOp::getNumArgOperands() { - assert(!cir::MissingFeatures::opCallIndirect()); + if (isIndirect()) + return this->getOperation()->getNumOperands() - 1; return this->getOperation()->getNumOperands(); } @@ -483,9 +503,15 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser, mlir::FlatSymbolRefAttr calleeAttr; llvm::ArrayRef<mlir::Type> allResultTypes; + // If we cannot parse a string callee, it means this is an indirect call. if (!parser.parseOptionalAttribute(calleeAttr, "callee", result.attributes) - .has_value()) - return mlir::failure(); + .has_value()) { + OpAsmParser::UnresolvedOperand indirectVal; + // Do not resolve right now, since we need to figure out the type + if (parser.parseOperand(indirectVal).failed()) + return failure(); + ops.push_back(indirectVal); + } if (parser.parseLParen()) return mlir::failure(); @@ -517,13 +543,21 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser, static void printCallCommon(mlir::Operation *op, mlir::FlatSymbolRefAttr calleeSym, + mlir::Value indirectCallee, mlir::OpAsmPrinter &printer) { printer << ' '; auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op); auto ops = callLikeOp.getArgOperands(); - printer.printAttributeWithoutType(calleeSym); + if (calleeSym) { + // Direct calls + printer.printAttributeWithoutType(calleeSym); + } else { + // Indirect calls + assert(indirectCallee); + printer << indirectCallee; + } printer << "(" << ops << ")"; printer.printOptionalAttrDict(op->getAttrs(), {"callee"}); @@ -539,7 +573,8 @@ mlir::ParseResult cir::CallOp::parse(mlir::OpAsmParser &parser, } void cir::CallOp::print(mlir::OpAsmPrinter &p) { - printCallCommon(*this, getCalleeAttr(), p); + mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr; + printCallCommon(*this, getCalleeAttr(), indirectCallee, p); } static LogicalResult @@ -547,7 +582,7 @@ verifyCallCommInSymbolUses(mlir::Operation *op, SymbolTableCollection &symbolTable) { auto fnAttr = op->getAttrOfType<FlatSymbolRefAttr>("callee"); if (!fnAttr) - return mlir::failure(); + return mlir::success(); auto fn = symbolTable.lookupNearestSymbolFrom<cir::FuncOp>(op, fnAttr); if (!fn) diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 3c85bb4b6b41d..0e1ca27681ece 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -674,8 +674,15 @@ rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands, llvmFnTy = cast<mlir::LLVM::LLVMFunctionType>( converter->convertType(fn.getFunctionType())); } else { // indirect call - assert(!cir::MissingFeatures::opCallIndirect()); - return op->emitError("Indirect calls are NYI"); + assert(!op->getOperands().empty() && + "operands list must no be empty for the indirect call"); + auto calleeTy = op->getOperands().front().getType(); + auto calleePtrTy = cast<cir::PointerType>(calleeTy); + auto calleeFuncTy = cast<cir::FuncType>(calleePtrTy.getPointee()); + calleeFuncTy.dump(); + converter->convertType(calleeFuncTy).dump(); + llvmFnTy = cast<mlir::LLVM::LLVMFunctionType>( + converter->convertType(calleeFuncTy)); } assert(!cir::MissingFeatures::opCallLandingPad()); @@ -1501,6 +1508,15 @@ static void prepareTypeConverter(mlir::LLVMTypeConverter &converter, converter.addConversion([&](cir::BF16Type type) -> mlir::Type { return mlir::BFloat16Type::get(type.getContext()); }); + converter.addConversion([&](cir::FuncType type) -> mlir::Type { + auto result = converter.convertType(type.getReturnType()); + llvm::SmallVector<mlir::Type> arguments; + arguments.reserve(type.getNumInputs()); + if (converter.convertTypes(type.getInputs(), arguments).failed()) + llvm_unreachable("Failed to convert function type parameters"); + auto varArg = type.isVarArg(); + return mlir::LLVM::LLVMFunctionType::get(result, arguments, varArg); + }); converter.addConversion([&](cir::RecordType type) -> mlir::Type { // Convert struct members. llvm::SmallVector<mlir::Type> llvmMembers; diff --git a/clang/test/CIR/CodeGen/call.cpp b/clang/test/CIR/CodeGen/call.cpp index 3b1ab8b5fc498..8b8f1296b5108 100644 --- a/clang/test/CIR/CodeGen/call.cpp +++ b/clang/test/CIR/CodeGen/call.cpp @@ -42,3 +42,17 @@ int f6() { // LLVM-LABEL: define i32 @_Z2f6v() { // LLVM: %{{.+}} = call i32 @_Z2f5iPib(i32 2, ptr %{{.+}}, i1 false) + +int f7(int (*ptr)(int, int)) { + return ptr(1, 2); +} + +// CIR-LABEL: cir.func @_Z2f7PFiiiE +// CIR: %[[#ptr:]] = cir.load %{{.+}} : !cir.ptr<!cir.ptr<!cir.func<(!s32i, !s32i) -> !s32i>>>, !cir.ptr<!cir.func<(!s32i, !s32i) -> !s32i>> +// CIR-NEXT: %[[#a:]] = cir.const #cir.int<1> : !s32i +// CIR-NEXT: %[[#b:]] = cir.const #cir.int<2> : !s32i +// CIR-NEXT: %{{.+}} = cir.call %[[#ptr]](%[[#a]], %[[#b]]) : (!cir.ptr<!cir.func<(!s32i, !s32i) -> !s32i>>, !s32i, !s32i) -> !s32i + +// LLVM-LABEL: define i32 @_Z2f7PFiiiE +// LLVM: %[[#ptr:]] = load ptr, ptr %{{.+}} +// LLVM-NEXT: %{{.+}} = call i32 %[[#ptr]](i32 1, i32 2) diff --git a/clang/test/CIR/IR/call.cir b/clang/test/CIR/IR/call.cir index 8276c0cb9e39d..e35c201b6ed48 100644 --- a/clang/test/CIR/IR/call.cir +++ b/clang/test/CIR/IR/call.cir @@ -43,4 +43,18 @@ cir.func @f6() -> !s32i { // CHECK-NEXT: cir.return %[[#c]] : !s32i // CHECK-NEXT: } +cir.func @f7(%arg0: !cir.ptr<!cir.func<(!s32i, !s32i) -> !s32i>>) -> !s32i { + %0 = cir.const #cir.int<1> : !s32i + %1 = cir.const #cir.int<2> : !s32i + %2 = cir.call %arg0(%0, %1) : (!cir.ptr<!cir.func<(!s32i, !s32i) -> !s32i>>, !s32i, !s32i) -> !s32i + cir.return %2 : !s32i +} + +// CHECK: cir.func @f7(%[[ptr:.+]]: !cir.ptr<!cir.func<(!s32i, !s32i) -> !s32i>>) -> !s32i { +// CHECK-NEXT: %[[#a:]] = cir.const #cir.int<1> : !s32i +// CHECK-NEXT: %[[#b:]] = cir.const #cir.int<2> : !s32i +// CHECK-NEXT: %[[#ret:]] = cir.call %[[ptr]](%[[#a]], %[[#b]]) : (!cir.ptr<!cir.func<(!s32i, !s32i) -> !s32i>>, !s32i, !s32i) -> !s32i +// CHECK-NEXT: cir.return %[[#ret]] : !s32i +// CHECK-NEXT: } + } _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits