https://github.com/dkolsen-pgi created 
https://github.com/llvm/llvm-project/pull/128787

When a C or C++ function has a return type of `void`, the function type is now 
represented in MLIR as having no return type rather than having a return type 
of `!cir.void`.  This avoids breaking MLIR invariants that require the number 
of return types and the number of return values to match.

Change the assembly format for `cir::FuncType` from having a leading return 
type to having a trailing return type.  In other words, change
```
!cir.func<!returnType (!argTypes)>
```
to
```
!cir.func<(!argTypes) -> !returnType)>
```
Unless the function returns `void`, in which case change
```
!cir.func<!cir.void (!argTypes)>
```
to
```
!cir.func<(!argTypes)>
```

>From 3e396ab2f69d0dcf98605179f69411f04da68f49 Mon Sep 17 00:00:00 2001
From: David Olsen <dol...@nvidia.com>
Date: Tue, 25 Feb 2025 15:21:30 -0800
Subject: [PATCH] [CIR] Function type return type improvements

When a C or C++ function has a return type of `void`, the function type
is now represented in MLIR as having no return type rather than having a
return type of `!cir.void`.  This avoids breaking MLIR invariants that
require the number of return types and the number of return values to
match.

Change the assembly format for `cir::FuncType` from having a leading
return type to having a trailing return type.  In other words, change
```
!cir.func<!returnType (!argTypes)>
```
to
```
!cir.func<(!argTypes) -> !returnType)>
```
Unless the function returns `void`, in which case change
```
!cir.func<!cir.void (!argTypes)>
```
to
```
!cir.func<(!argTypes)>
```
---
 .../include/clang/CIR/Dialect/IR/CIRTypes.td  |  40 ++++--
 clang/lib/CIR/CodeGen/CIRGenTypes.cpp         |   2 +-
 clang/lib/CIR/Dialect/IR/CIRDialect.cpp       |   4 +
 clang/lib/CIR/Dialect/IR/CIRTypes.cpp         | 125 +++++++++++++-----
 clang/test/CIR/IR/func.cir                    |   8 +-
 clang/test/CIR/IR/global.cir                  |  12 +-
 clang/test/CIR/func-simple.cpp                |   4 +-
 clang/test/CIR/global-var-simple.cpp          |   6 +-
 8 files changed, 141 insertions(+), 60 deletions(-)

diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td 
b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td
index fc8edbcf3e166..c2d45ebeefe63 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td
@@ -287,32 +287,43 @@ def CIR_BoolType :
 def CIR_FuncType : CIR_Type<"Func", "func"> {
   let summary = "CIR function type";
   let description = [{
-    The `!cir.func` is a function type. It consists of a single return type, a
-    list of parameter types and can optionally be variadic.
+    The `!cir.func` is a function type. It consists of an optional return type,
+    a list of parameter types and can optionally be variadic.
 
     Example:
 
     ```mlir
-    !cir.func<!bool ()>
-    !cir.func<!s32i (!s8i, !s8i)>
-    !cir.func<!s32i (!s32i, ...)>
+    !cir.func<()>
+    !cir.func<() -> bool>
+    !cir.func<(!s8i, !s8i)>
+    !cir.func<(!s8i, !s8i) -> !s32i>
+    !cir.func<(!s32i, ...) -> !s32i>
     ```
   }];
 
   let parameters = (ins ArrayRefParameter<"mlir::Type">:$inputs,
-                        "mlir::Type":$returnType, "bool":$varArg);
+                        "mlir::Type":$optionalReturnType, "bool":$varArg);
+  // Use a custom parser to handle the argument types and optional return
   let assemblyFormat = [{
-    `<` $returnType ` ` `(` custom<FuncTypeArgs>($inputs, $varArg) `>`
+    `<` custom<FuncType>($optionalReturnType, $inputs, $varArg) `>`
   }];
 
   let builders = [
+    // Create a FuncType, converting the return type from C-style to
+    // MLIR-style.  If the given return type is `cir::VoidType`, ignore it
+    // and create the FuncType with no return type, which is how MLIR
+    // represents function types.
     TypeBuilderWithInferredContext<(ins
       "llvm::ArrayRef<mlir::Type>":$inputs, "mlir::Type":$returnType,
       CArg<"bool", "false">:$isVarArg), [{
-      return $_get(returnType.getContext(), inputs, returnType, isVarArg);
+        return $_get(returnType.getContext(), inputs,
+                     mlir::isa<cir::VoidType>(returnType) ? nullptr : 
returnType,
+                     isVarArg);
     }]>
   ];
 
+  let genVerifyDecl = 1;
+
   let extraClassDeclaration = [{
     /// Returns whether the function is variadic.
     bool isVarArg() const { return getVarArg(); }
@@ -323,12 +334,17 @@ def CIR_FuncType : CIR_Type<"Func", "func"> {
     /// Returns the number of arguments to the function.
     unsigned getNumInputs() const { return getInputs().size(); }
 
-    /// Returns the result type of the function as an ArrayRef, enabling better
-    /// integration with generic MLIR utilities.
+    /// Get the C-style return type of the function, which is !cir.void if the
+    /// function returns nothing and the actual return type otherwise.
+    mlir::Type getReturnType() const;
+
+    /// Get the MLIR-style return type of the function, which is an empty
+    /// ArrayRef if the function returns nothing and a single-element ArrayRef
+    /// with the actual return type otherwise.
     llvm::ArrayRef<mlir::Type> getReturnTypes() const;
 
-    /// Returns whether the function is returns void.
-    bool isVoid() const;
+    /// Does the function type return nothing?
+    bool hasVoidReturn() const;
 
     /// Returns a clone of this function type with the given argument
     /// and result types.
diff --git a/clang/lib/CIR/CodeGen/CIRGenTypes.cpp 
b/clang/lib/CIR/CodeGen/CIRGenTypes.cpp
index 16aec10fda81e..dcfaaedc2ef57 100644
--- a/clang/lib/CIR/CodeGen/CIRGenTypes.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenTypes.cpp
@@ -60,7 +60,7 @@ bool CIRGenTypes::isFuncTypeConvertible(const FunctionType 
*ft) {
 mlir::Type CIRGenTypes::convertFunctionTypeInternal(QualType qft) {
   assert(qft.isCanonical());
   const FunctionType *ft = cast<FunctionType>(qft.getTypePtr());
-  // First, check whether we can build the full fucntion type. If the function
+  // First, check whether we can build the full function type. If the function
   // type depends on an incomplete type (e.g. a struct or enum), we cannot 
lower
   // the function type.
   if (!isFuncTypeConvertible(ft)) {
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp 
b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index bfc74d4373f34..1a0740dea1fa8 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -424,6 +424,10 @@ LogicalResult cir::FuncOp::verifyType() {
   if (!isa<cir::FuncType>(type))
     return emitOpError("requires '" + getFunctionTypeAttrName().str() +
                        "' attribute of function type");
+  if (auto rt = type.getReturnTypes();
+      !rt.empty() && mlir::isa<cir::VoidType>(rt.front()))
+    return emitOpError("The return type for a function returning void should "
+                       "be empty instead of an explicit !cir.void");
   return success();
 }
 
diff --git a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp 
b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp
index d1b143efb955e..67fa6c267cf0f 100644
--- a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp
@@ -20,11 +20,12 @@
 // CIR Custom Parser/Printer Signatures
 
//===----------------------------------------------------------------------===//
 
-static mlir::ParseResult
-parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
-                  bool &isVarArg);
-static void printFuncTypeArgs(mlir::AsmPrinter &p,
-                              mlir::ArrayRef<mlir::Type> params, bool 
isVarArg);
+static mlir::ParseResult parseFuncType(mlir::AsmParser &p,
+                                       mlir::Type &optionalReturnTypes,
+                                       llvm::SmallVector<mlir::Type> &params,
+                                       bool &isVarArg);
+static void printFuncType(mlir::AsmPrinter &p, mlir::Type optionalReturnTypes,
+                          mlir::ArrayRef<mlir::Type> params, bool isVarArg);
 
 
//===----------------------------------------------------------------------===//
 // Get autogenerated stuff
@@ -282,40 +283,55 @@ FuncType FuncType::clone(TypeRange inputs, TypeRange 
results) const {
   return get(llvm::to_vector(inputs), results[0], isVarArg());
 }
 
-mlir::ParseResult parseFuncTypeArgs(mlir::AsmParser &p,
-                                    llvm::SmallVector<mlir::Type> &params,
-                                    bool &isVarArg) {
+// A special parser is needed for function returning void to handle the missing
+// type.
+static mlir::ParseResult parseFuncTypeReturn(mlir::AsmParser &p,
+                                             mlir::Type &optionalReturnType) {
+  if (succeeded(p.parseOptionalArrow())) {
+    // `->` found. It must be followed by the return type.
+    return p.parseType(optionalReturnType);
+  }
+  // Function has `void` return in C++, no return in MLIR.
+  optionalReturnType = {};
+  return success();
+}
+
+// A special pretty-printer for function returning or not a result.
+static void printFuncTypeReturn(mlir::AsmPrinter &p,
+                                mlir::Type optionalReturnType) {
+  if (optionalReturnType)
+    p << " -> " << optionalReturnType;
+}
+
+static mlir::ParseResult
+parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
+                  bool &isVarArg) {
   isVarArg = false;
-  // `(` `)`
-  if (succeeded(p.parseOptionalRParen()))
+  if (failed(p.parseLParen()))
+    return failure();
+  if (succeeded(p.parseOptionalRParen())) {
+    // `()` empty argument list
     return mlir::success();
-
-  // `(` `...` `)`
-  if (succeeded(p.parseOptionalEllipsis())) {
-    isVarArg = true;
-    return p.parseRParen();
   }
-
-  // type (`,` type)* (`,` `...`)?
-  mlir::Type type;
-  if (p.parseType(type))
-    return mlir::failure();
-  params.push_back(type);
-  while (succeeded(p.parseOptionalComma())) {
+  do {
     if (succeeded(p.parseOptionalEllipsis())) {
+      // `...`, which must be the last thing in the list.
       isVarArg = true;
-      return p.parseRParen();
+      break;
+    } else {
+      mlir::Type argType;
+      if (failed(p.parseType(argType)))
+        return failure();
+      params.push_back(argType);
     }
-    if (p.parseType(type))
-      return mlir::failure();
-    params.push_back(type);
-  }
-
+  } while (succeeded(p.parseOptionalComma()));
   return p.parseRParen();
 }
 
-void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef<mlir::Type> params,
-                       bool isVarArg) {
+static void printFuncTypeArgs(mlir::AsmPrinter &p,
+                              mlir::ArrayRef<mlir::Type> params,
+                              bool isVarArg) {
+  p << '(';
   llvm::interleaveComma(params, p,
                         [&p](mlir::Type type) { p.printType(type); });
   if (isVarArg) {
@@ -326,11 +342,56 @@ void printFuncTypeArgs(mlir::AsmPrinter &p, 
mlir::ArrayRef<mlir::Type> params,
   p << ')';
 }
 
+// Use a custom parser to handle the optional return and argument types without
+// an optional anchor.
+static mlir::ParseResult parseFuncType(mlir::AsmParser &p,
+                                       mlir::Type &optionalReturnType,
+                                       llvm::SmallVector<mlir::Type> &params,
+                                       bool &isVarArg) {
+  if (failed(parseFuncTypeArgs(p, params, isVarArg)))
+    return failure();
+  return parseFuncTypeReturn(p, optionalReturnType);
+}
+
+static void printFuncType(mlir::AsmPrinter &p, mlir::Type optionalReturnType,
+                          mlir::ArrayRef<mlir::Type> params, bool isVarArg) {
+  printFuncTypeArgs(p, params, isVarArg);
+  printFuncTypeReturn(p, optionalReturnType);
+}
+
+/// Get the C-style return type of the function, which is !cir.void if the
+/// function returns nothing and the actual return type otherwise.
+mlir::Type FuncType::getReturnType() const {
+  if (hasVoidReturn())
+    return cir::VoidType::get(getContext());
+  return getOptionalReturnType();
+}
+
+/// Get the MLIR-style return type of the function, which is an empty
+/// ArrayRef if the function returns nothing and a single-element ArrayRef
+/// with the actual return type otherwise.
 llvm::ArrayRef<mlir::Type> FuncType::getReturnTypes() const {
-  return static_cast<detail::FuncTypeStorage *>(getImpl())->returnType;
+  if (hasVoidReturn())
+    return {};
+  // Can't use getOptionalReturnType() here because llvm::ArrayRef hold a
+  // pointer to its elements and doesn't do lifetime extension.  That would
+  // result in returning a pointer to a temporary that has gone out of scope.
+  return getImpl()->optionalReturnType;
 }
 
-bool FuncType::isVoid() const { return mlir::isa<VoidType>(getReturnType()); }
+// Does the fuction type return nothing?
+bool FuncType::hasVoidReturn() const { return !getOptionalReturnType(); }
+
+mlir::LogicalResult
+FuncType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
+                 llvm::ArrayRef<mlir::Type> argTypes, mlir::Type returnType,
+                 bool isVarArg) {
+  if (returnType && mlir::isa<cir::VoidType>(returnType)) {
+    emitError() << "!cir.func cannot have an explicit 'void' return type";
+    return mlir::failure();
+  }
+  return mlir::success();
+}
 
 
//===----------------------------------------------------------------------===//
 // BoolType
diff --git a/clang/test/CIR/IR/func.cir b/clang/test/CIR/IR/func.cir
index a32c3e697ed25..4077bd33e0438 100644
--- a/clang/test/CIR/IR/func.cir
+++ b/clang/test/CIR/IR/func.cir
@@ -2,18 +2,18 @@
 
 module {
 // void empty() { }
-cir.func @empty() -> !cir.void {
+cir.func @empty() {
   cir.return
 }
-// CHECK: cir.func @empty() -> !cir.void {
+// CHECK: cir.func @empty() {
 // CHECK:   cir.return
 // CHECK: }
 
 // void voidret() { return; }
-cir.func @voidret() -> !cir.void {
+cir.func @voidret() {
   cir.return
 }
-// CHECK: cir.func @voidret() -> !cir.void {
+// CHECK: cir.func @voidret() {
 // CHECK:   cir.return
 // CHECK: }
 
diff --git a/clang/test/CIR/IR/global.cir b/clang/test/CIR/IR/global.cir
index 6c68ab0a501ff..9d187686d996c 100644
--- a/clang/test/CIR/IR/global.cir
+++ b/clang/test/CIR/IR/global.cir
@@ -30,9 +30,9 @@ module attributes {cir.triple = "x86_64-unknown-linux-gnu"} {
   cir.global @ip = #cir.ptr<null> : !cir.ptr<!cir.int<s, 32>>
   cir.global @dp : !cir.ptr<!cir.double>
   cir.global @cpp : !cir.ptr<!cir.ptr<!cir.int<s, 8>>>
-  cir.global @fp : !cir.ptr<!cir.func<!cir.void ()>>
-  cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<!cir.int<s, 32> 
(!cir.int<s, 32>)>>
-  cir.global @fpvar : !cir.ptr<!cir.func<!cir.void (!cir.int<s, 32>, ...)>>
+  cir.global @fp : !cir.ptr<!cir.func<()>>
+  cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<(!cir.int<s, 32>) -> 
!cir.int<s, 32>>>
+  cir.global @fpvar : !cir.ptr<!cir.func<(!cir.int<s, 32>, ...)>>
 }
 
 // CHECK: cir.global @c : !cir.int<s, 8>
@@ -64,6 +64,6 @@ module attributes {cir.triple = "x86_64-unknown-linux-gnu"} {
 // CHECK: cir.global @ip = #cir.ptr<null> : !cir.ptr<!cir.int<s, 32>>
 // CHECK: cir.global @dp : !cir.ptr<!cir.double>
 // CHECK: cir.global @cpp : !cir.ptr<!cir.ptr<!cir.int<s, 8>>>
-// CHECK: cir.global @fp : !cir.ptr<!cir.func<!cir.void ()>>
-// CHECK: cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<!cir.int<s, 
32> (!cir.int<s, 32>)>>
-// CHECK: cir.global @fpvar : !cir.ptr<!cir.func<!cir.void (!cir.int<s, 32>, 
...)>>
+// CHECK: cir.global @fp : !cir.ptr<!cir.func<()>>
+// CHECK: cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<(!cir.int<s, 
32>) -> !cir.int<s, 32>>>
+// CHECK: cir.global @fpvar : !cir.ptr<!cir.func<(!cir.int<s, 32>, ...)>>
diff --git a/clang/test/CIR/func-simple.cpp b/clang/test/CIR/func-simple.cpp
index 22c120d3404d3..3947055e300a0 100644
--- a/clang/test/CIR/func-simple.cpp
+++ b/clang/test/CIR/func-simple.cpp
@@ -2,12 +2,12 @@
 // RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -fclangir 
-emit-cir %s -o -  | FileCheck %s
 
 void empty() { }
-// CHECK: cir.func @empty() -> !cir.void {
+// CHECK: cir.func @empty() {
 // CHECK:   cir.return
 // CHECK: }
 
 void voidret() { return; }
-// CHECK: cir.func @voidret() -> !cir.void {
+// CHECK: cir.func @voidret() {
 // CHECK:   cir.return
 // CHECK: }
 
diff --git a/clang/test/CIR/global-var-simple.cpp 
b/clang/test/CIR/global-var-simple.cpp
index dfe8371668e2c..f8e233cd5fe33 100644
--- a/clang/test/CIR/global-var-simple.cpp
+++ b/clang/test/CIR/global-var-simple.cpp
@@ -92,10 +92,10 @@ char **cpp;
 // CHECK: cir.global @cpp : !cir.ptr<!cir.ptr<!cir.int<s, 8>>>
 
 void (*fp)();
-// CHECK: cir.global @fp : !cir.ptr<!cir.func<!cir.void ()>>
+// CHECK: cir.global @fp : !cir.ptr<!cir.func<()>>
 
 int (*fpii)(int) = 0;
-// CHECK: cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<!cir.int<s, 
32> (!cir.int<s, 32>)>>
+// CHECK: cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<(!cir.int<s, 
32>) -> !cir.int<s, 32>>>
 
 void (*fpvar)(int, ...);
-// CHECK: cir.global @fpvar : !cir.ptr<!cir.func<!cir.void (!cir.int<s, 32>, 
...)>>
+// CHECK: cir.global @fpvar : !cir.ptr<!cir.func<(!cir.int<s, 32>, ...)>>

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to