https://github.com/dkolsen-pgi created https://github.com/llvm/llvm-project/pull/128787
When a C or C++ function has a return type of `void`, the function type is now represented in MLIR as having no return type rather than having a return type of `!cir.void`. This avoids breaking MLIR invariants that require the number of return types and the number of return values to match. Change the assembly format for `cir::FuncType` from having a leading return type to having a trailing return type. In other words, change ``` !cir.func<!returnType (!argTypes)> ``` to ``` !cir.func<(!argTypes) -> !returnType)> ``` Unless the function returns `void`, in which case change ``` !cir.func<!cir.void (!argTypes)> ``` to ``` !cir.func<(!argTypes)> ``` >From 3e396ab2f69d0dcf98605179f69411f04da68f49 Mon Sep 17 00:00:00 2001 From: David Olsen <dol...@nvidia.com> Date: Tue, 25 Feb 2025 15:21:30 -0800 Subject: [PATCH] [CIR] Function type return type improvements When a C or C++ function has a return type of `void`, the function type is now represented in MLIR as having no return type rather than having a return type of `!cir.void`. This avoids breaking MLIR invariants that require the number of return types and the number of return values to match. Change the assembly format for `cir::FuncType` from having a leading return type to having a trailing return type. In other words, change ``` !cir.func<!returnType (!argTypes)> ``` to ``` !cir.func<(!argTypes) -> !returnType)> ``` Unless the function returns `void`, in which case change ``` !cir.func<!cir.void (!argTypes)> ``` to ``` !cir.func<(!argTypes)> ``` --- .../include/clang/CIR/Dialect/IR/CIRTypes.td | 40 ++++-- clang/lib/CIR/CodeGen/CIRGenTypes.cpp | 2 +- clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 4 + clang/lib/CIR/Dialect/IR/CIRTypes.cpp | 125 +++++++++++++----- clang/test/CIR/IR/func.cir | 8 +- clang/test/CIR/IR/global.cir | 12 +- clang/test/CIR/func-simple.cpp | 4 +- clang/test/CIR/global-var-simple.cpp | 6 +- 8 files changed, 141 insertions(+), 60 deletions(-) diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td index fc8edbcf3e166..c2d45ebeefe63 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td @@ -287,32 +287,43 @@ def CIR_BoolType : def CIR_FuncType : CIR_Type<"Func", "func"> { let summary = "CIR function type"; let description = [{ - The `!cir.func` is a function type. It consists of a single return type, a - list of parameter types and can optionally be variadic. + The `!cir.func` is a function type. It consists of an optional return type, + a list of parameter types and can optionally be variadic. Example: ```mlir - !cir.func<!bool ()> - !cir.func<!s32i (!s8i, !s8i)> - !cir.func<!s32i (!s32i, ...)> + !cir.func<()> + !cir.func<() -> bool> + !cir.func<(!s8i, !s8i)> + !cir.func<(!s8i, !s8i) -> !s32i> + !cir.func<(!s32i, ...) -> !s32i> ``` }]; let parameters = (ins ArrayRefParameter<"mlir::Type">:$inputs, - "mlir::Type":$returnType, "bool":$varArg); + "mlir::Type":$optionalReturnType, "bool":$varArg); + // Use a custom parser to handle the argument types and optional return let assemblyFormat = [{ - `<` $returnType ` ` `(` custom<FuncTypeArgs>($inputs, $varArg) `>` + `<` custom<FuncType>($optionalReturnType, $inputs, $varArg) `>` }]; let builders = [ + // Create a FuncType, converting the return type from C-style to + // MLIR-style. If the given return type is `cir::VoidType`, ignore it + // and create the FuncType with no return type, which is how MLIR + // represents function types. TypeBuilderWithInferredContext<(ins "llvm::ArrayRef<mlir::Type>":$inputs, "mlir::Type":$returnType, CArg<"bool", "false">:$isVarArg), [{ - return $_get(returnType.getContext(), inputs, returnType, isVarArg); + return $_get(returnType.getContext(), inputs, + mlir::isa<cir::VoidType>(returnType) ? nullptr : returnType, + isVarArg); }]> ]; + let genVerifyDecl = 1; + let extraClassDeclaration = [{ /// Returns whether the function is variadic. bool isVarArg() const { return getVarArg(); } @@ -323,12 +334,17 @@ def CIR_FuncType : CIR_Type<"Func", "func"> { /// Returns the number of arguments to the function. unsigned getNumInputs() const { return getInputs().size(); } - /// Returns the result type of the function as an ArrayRef, enabling better - /// integration with generic MLIR utilities. + /// Get the C-style return type of the function, which is !cir.void if the + /// function returns nothing and the actual return type otherwise. + mlir::Type getReturnType() const; + + /// Get the MLIR-style return type of the function, which is an empty + /// ArrayRef if the function returns nothing and a single-element ArrayRef + /// with the actual return type otherwise. llvm::ArrayRef<mlir::Type> getReturnTypes() const; - /// Returns whether the function is returns void. - bool isVoid() const; + /// Does the function type return nothing? + bool hasVoidReturn() const; /// Returns a clone of this function type with the given argument /// and result types. diff --git a/clang/lib/CIR/CodeGen/CIRGenTypes.cpp b/clang/lib/CIR/CodeGen/CIRGenTypes.cpp index 16aec10fda81e..dcfaaedc2ef57 100644 --- a/clang/lib/CIR/CodeGen/CIRGenTypes.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenTypes.cpp @@ -60,7 +60,7 @@ bool CIRGenTypes::isFuncTypeConvertible(const FunctionType *ft) { mlir::Type CIRGenTypes::convertFunctionTypeInternal(QualType qft) { assert(qft.isCanonical()); const FunctionType *ft = cast<FunctionType>(qft.getTypePtr()); - // First, check whether we can build the full fucntion type. If the function + // First, check whether we can build the full function type. If the function // type depends on an incomplete type (e.g. a struct or enum), we cannot lower // the function type. if (!isFuncTypeConvertible(ft)) { diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index bfc74d4373f34..1a0740dea1fa8 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -424,6 +424,10 @@ LogicalResult cir::FuncOp::verifyType() { if (!isa<cir::FuncType>(type)) return emitOpError("requires '" + getFunctionTypeAttrName().str() + "' attribute of function type"); + if (auto rt = type.getReturnTypes(); + !rt.empty() && mlir::isa<cir::VoidType>(rt.front())) + return emitOpError("The return type for a function returning void should " + "be empty instead of an explicit !cir.void"); return success(); } diff --git a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp index d1b143efb955e..67fa6c267cf0f 100644 --- a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp @@ -20,11 +20,12 @@ // CIR Custom Parser/Printer Signatures //===----------------------------------------------------------------------===// -static mlir::ParseResult -parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> ¶ms, - bool &isVarArg); -static void printFuncTypeArgs(mlir::AsmPrinter &p, - mlir::ArrayRef<mlir::Type> params, bool isVarArg); +static mlir::ParseResult parseFuncType(mlir::AsmParser &p, + mlir::Type &optionalReturnTypes, + llvm::SmallVector<mlir::Type> ¶ms, + bool &isVarArg); +static void printFuncType(mlir::AsmPrinter &p, mlir::Type optionalReturnTypes, + mlir::ArrayRef<mlir::Type> params, bool isVarArg); //===----------------------------------------------------------------------===// // Get autogenerated stuff @@ -282,40 +283,55 @@ FuncType FuncType::clone(TypeRange inputs, TypeRange results) const { return get(llvm::to_vector(inputs), results[0], isVarArg()); } -mlir::ParseResult parseFuncTypeArgs(mlir::AsmParser &p, - llvm::SmallVector<mlir::Type> ¶ms, - bool &isVarArg) { +// A special parser is needed for function returning void to handle the missing +// type. +static mlir::ParseResult parseFuncTypeReturn(mlir::AsmParser &p, + mlir::Type &optionalReturnType) { + if (succeeded(p.parseOptionalArrow())) { + // `->` found. It must be followed by the return type. + return p.parseType(optionalReturnType); + } + // Function has `void` return in C++, no return in MLIR. + optionalReturnType = {}; + return success(); +} + +// A special pretty-printer for function returning or not a result. +static void printFuncTypeReturn(mlir::AsmPrinter &p, + mlir::Type optionalReturnType) { + if (optionalReturnType) + p << " -> " << optionalReturnType; +} + +static mlir::ParseResult +parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> ¶ms, + bool &isVarArg) { isVarArg = false; - // `(` `)` - if (succeeded(p.parseOptionalRParen())) + if (failed(p.parseLParen())) + return failure(); + if (succeeded(p.parseOptionalRParen())) { + // `()` empty argument list return mlir::success(); - - // `(` `...` `)` - if (succeeded(p.parseOptionalEllipsis())) { - isVarArg = true; - return p.parseRParen(); } - - // type (`,` type)* (`,` `...`)? - mlir::Type type; - if (p.parseType(type)) - return mlir::failure(); - params.push_back(type); - while (succeeded(p.parseOptionalComma())) { + do { if (succeeded(p.parseOptionalEllipsis())) { + // `...`, which must be the last thing in the list. isVarArg = true; - return p.parseRParen(); + break; + } else { + mlir::Type argType; + if (failed(p.parseType(argType))) + return failure(); + params.push_back(argType); } - if (p.parseType(type)) - return mlir::failure(); - params.push_back(type); - } - + } while (succeeded(p.parseOptionalComma())); return p.parseRParen(); } -void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef<mlir::Type> params, - bool isVarArg) { +static void printFuncTypeArgs(mlir::AsmPrinter &p, + mlir::ArrayRef<mlir::Type> params, + bool isVarArg) { + p << '('; llvm::interleaveComma(params, p, [&p](mlir::Type type) { p.printType(type); }); if (isVarArg) { @@ -326,11 +342,56 @@ void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef<mlir::Type> params, p << ')'; } +// Use a custom parser to handle the optional return and argument types without +// an optional anchor. +static mlir::ParseResult parseFuncType(mlir::AsmParser &p, + mlir::Type &optionalReturnType, + llvm::SmallVector<mlir::Type> ¶ms, + bool &isVarArg) { + if (failed(parseFuncTypeArgs(p, params, isVarArg))) + return failure(); + return parseFuncTypeReturn(p, optionalReturnType); +} + +static void printFuncType(mlir::AsmPrinter &p, mlir::Type optionalReturnType, + mlir::ArrayRef<mlir::Type> params, bool isVarArg) { + printFuncTypeArgs(p, params, isVarArg); + printFuncTypeReturn(p, optionalReturnType); +} + +/// Get the C-style return type of the function, which is !cir.void if the +/// function returns nothing and the actual return type otherwise. +mlir::Type FuncType::getReturnType() const { + if (hasVoidReturn()) + return cir::VoidType::get(getContext()); + return getOptionalReturnType(); +} + +/// Get the MLIR-style return type of the function, which is an empty +/// ArrayRef if the function returns nothing and a single-element ArrayRef +/// with the actual return type otherwise. llvm::ArrayRef<mlir::Type> FuncType::getReturnTypes() const { - return static_cast<detail::FuncTypeStorage *>(getImpl())->returnType; + if (hasVoidReturn()) + return {}; + // Can't use getOptionalReturnType() here because llvm::ArrayRef hold a + // pointer to its elements and doesn't do lifetime extension. That would + // result in returning a pointer to a temporary that has gone out of scope. + return getImpl()->optionalReturnType; } -bool FuncType::isVoid() const { return mlir::isa<VoidType>(getReturnType()); } +// Does the fuction type return nothing? +bool FuncType::hasVoidReturn() const { return !getOptionalReturnType(); } + +mlir::LogicalResult +FuncType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, + llvm::ArrayRef<mlir::Type> argTypes, mlir::Type returnType, + bool isVarArg) { + if (returnType && mlir::isa<cir::VoidType>(returnType)) { + emitError() << "!cir.func cannot have an explicit 'void' return type"; + return mlir::failure(); + } + return mlir::success(); +} //===----------------------------------------------------------------------===// // BoolType diff --git a/clang/test/CIR/IR/func.cir b/clang/test/CIR/IR/func.cir index a32c3e697ed25..4077bd33e0438 100644 --- a/clang/test/CIR/IR/func.cir +++ b/clang/test/CIR/IR/func.cir @@ -2,18 +2,18 @@ module { // void empty() { } -cir.func @empty() -> !cir.void { +cir.func @empty() { cir.return } -// CHECK: cir.func @empty() -> !cir.void { +// CHECK: cir.func @empty() { // CHECK: cir.return // CHECK: } // void voidret() { return; } -cir.func @voidret() -> !cir.void { +cir.func @voidret() { cir.return } -// CHECK: cir.func @voidret() -> !cir.void { +// CHECK: cir.func @voidret() { // CHECK: cir.return // CHECK: } diff --git a/clang/test/CIR/IR/global.cir b/clang/test/CIR/IR/global.cir index 6c68ab0a501ff..9d187686d996c 100644 --- a/clang/test/CIR/IR/global.cir +++ b/clang/test/CIR/IR/global.cir @@ -30,9 +30,9 @@ module attributes {cir.triple = "x86_64-unknown-linux-gnu"} { cir.global @ip = #cir.ptr<null> : !cir.ptr<!cir.int<s, 32>> cir.global @dp : !cir.ptr<!cir.double> cir.global @cpp : !cir.ptr<!cir.ptr<!cir.int<s, 8>>> - cir.global @fp : !cir.ptr<!cir.func<!cir.void ()>> - cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<!cir.int<s, 32> (!cir.int<s, 32>)>> - cir.global @fpvar : !cir.ptr<!cir.func<!cir.void (!cir.int<s, 32>, ...)>> + cir.global @fp : !cir.ptr<!cir.func<()>> + cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<(!cir.int<s, 32>) -> !cir.int<s, 32>>> + cir.global @fpvar : !cir.ptr<!cir.func<(!cir.int<s, 32>, ...)>> } // CHECK: cir.global @c : !cir.int<s, 8> @@ -64,6 +64,6 @@ module attributes {cir.triple = "x86_64-unknown-linux-gnu"} { // CHECK: cir.global @ip = #cir.ptr<null> : !cir.ptr<!cir.int<s, 32>> // CHECK: cir.global @dp : !cir.ptr<!cir.double> // CHECK: cir.global @cpp : !cir.ptr<!cir.ptr<!cir.int<s, 8>>> -// CHECK: cir.global @fp : !cir.ptr<!cir.func<!cir.void ()>> -// CHECK: cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<!cir.int<s, 32> (!cir.int<s, 32>)>> -// CHECK: cir.global @fpvar : !cir.ptr<!cir.func<!cir.void (!cir.int<s, 32>, ...)>> +// CHECK: cir.global @fp : !cir.ptr<!cir.func<()>> +// CHECK: cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<(!cir.int<s, 32>) -> !cir.int<s, 32>>> +// CHECK: cir.global @fpvar : !cir.ptr<!cir.func<(!cir.int<s, 32>, ...)>> diff --git a/clang/test/CIR/func-simple.cpp b/clang/test/CIR/func-simple.cpp index 22c120d3404d3..3947055e300a0 100644 --- a/clang/test/CIR/func-simple.cpp +++ b/clang/test/CIR/func-simple.cpp @@ -2,12 +2,12 @@ // RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o - | FileCheck %s void empty() { } -// CHECK: cir.func @empty() -> !cir.void { +// CHECK: cir.func @empty() { // CHECK: cir.return // CHECK: } void voidret() { return; } -// CHECK: cir.func @voidret() -> !cir.void { +// CHECK: cir.func @voidret() { // CHECK: cir.return // CHECK: } diff --git a/clang/test/CIR/global-var-simple.cpp b/clang/test/CIR/global-var-simple.cpp index dfe8371668e2c..f8e233cd5fe33 100644 --- a/clang/test/CIR/global-var-simple.cpp +++ b/clang/test/CIR/global-var-simple.cpp @@ -92,10 +92,10 @@ char **cpp; // CHECK: cir.global @cpp : !cir.ptr<!cir.ptr<!cir.int<s, 8>>> void (*fp)(); -// CHECK: cir.global @fp : !cir.ptr<!cir.func<!cir.void ()>> +// CHECK: cir.global @fp : !cir.ptr<!cir.func<()>> int (*fpii)(int) = 0; -// CHECK: cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<!cir.int<s, 32> (!cir.int<s, 32>)>> +// CHECK: cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<(!cir.int<s, 32>) -> !cir.int<s, 32>>> void (*fpvar)(int, ...); -// CHECK: cir.global @fpvar : !cir.ptr<!cir.func<!cir.void (!cir.int<s, 32>, ...)>> +// CHECK: cir.global @fpvar : !cir.ptr<!cir.func<(!cir.int<s, 32>, ...)>> _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits