pengfei updated this revision to Diff 416945.
pengfei added a comment.

Address Yuanke's comment.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D122104/new/

https://reviews.llvm.org/D122104

Files:
  clang/include/clang/CodeGen/CGFunctionInfo.h
  clang/lib/CodeGen/CGCall.cpp
  clang/lib/CodeGen/CodeGenFunction.cpp
  clang/lib/CodeGen/TargetInfo.cpp
  clang/test/CodeGen/aarch64-neon-tbl.c
  clang/test/CodeGen/regcall2.c

Index: clang/test/CodeGen/regcall2.c
===================================================================
--- /dev/null
+++ clang/test/CodeGen/regcall2.c
@@ -0,0 +1,28 @@
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py
+// RUN: %clang_cc1 -emit-llvm %s -o - -ffreestanding -target-feature +avx512vl -triple=x86_64-pc-win32     | FileCheck %s --check-prefix=Win
+// RUN: %clang_cc1 -emit-llvm %s -o - -ffreestanding -target-feature +avx512vl -triple=x86_64-pc-linux-gnu | FileCheck %s --check-prefix=Lin
+
+#include <immintrin.h>
+
+typedef struct {
+  __m512d r1[4];
+  __m512 r2[4];
+} __sVector;
+__sVector A;
+
+__sVector __regcall foo(int a) {
+  return A;
+}
+
+double __regcall bar(__sVector a) {
+  return a.r1[0][4];
+}
+
+// FIXME: Do we need to change for Windows?
+// Win: define dso_local x86_regcallcc void @__regcall3__foo(%struct.__sVector* noalias sret(%struct.__sVector) align 64 %agg.result, i32 noundef %a) #0
+// Win: define dso_local x86_regcallcc double @__regcall3__bar(%struct.__sVector* noundef %a) #0
+// Win: attributes #0 = { noinline nounwind optnone "frame-pointer"="none" "min-legal-vector-width"="0" "no-builtins" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-features"="+avx,+avx2,+avx512f,+avx512vl,+crc32,+cx8,+f16c,+fma,+mmx,+popcnt,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+ssse3,+x87,+xsave" }
+
+// Lin: define dso_local x86_regcallcc %struct.__sVector @__regcall3__foo(i32 noundef %a) #0
+// Lin: define dso_local x86_regcallcc double @__regcall3__bar([4 x <8 x double>] %a.coerce0, [4 x <16 x float>] %a.coerce1) #0
+// Lin: attributes #0 = { noinline nounwind optnone "frame-pointer"="none" "min-legal-vector-width"="512" "no-builtins" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-features"="+avx,+avx2,+avx512f,+avx512vl,+crc32,+cx8,+f16c,+fma,+mmx,+popcnt,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+ssse3,+x87,+xsave" }
Index: clang/test/CodeGen/aarch64-neon-tbl.c
===================================================================
--- clang/test/CodeGen/aarch64-neon-tbl.c
+++ clang/test/CodeGen/aarch64-neon-tbl.c
@@ -42,7 +42,7 @@
   return vtbl2_s8(a, b);
 }
 
-// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbl2_s8([2 x <16 x i8>] %a.coerce, <8 x i8> noundef %b) #0 {
+// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbl2_s8([2 x <16 x i8>] %a.coerce, <8 x i8> noundef %b) #1 {
 // CHECK:   [[__P0_I:%.*]] = alloca %struct.int8x16x2_t, align 16
 // CHECK:   [[A:%.*]] = alloca %struct.int8x16x2_t, align 16
 // CHECK:   [[COERCE_DIVE:%.*]] = getelementptr inbounds %struct.int8x16x2_t, %struct.int8x16x2_t* [[A]], i32 0, i32 0
@@ -89,7 +89,7 @@
   return vtbl3_s8(a, b);
 }
 
-// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbl3_s8([3 x <16 x i8>] %a.coerce, <8 x i8> noundef %b) #0 {
+// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbl3_s8([3 x <16 x i8>] %a.coerce, <8 x i8> noundef %b) #1 {
 // CHECK:   [[__P0_I:%.*]] = alloca %struct.int8x16x3_t, align 16
 // CHECK:   [[A:%.*]] = alloca %struct.int8x16x3_t, align 16
 // CHECK:   [[COERCE_DIVE:%.*]] = getelementptr inbounds %struct.int8x16x3_t, %struct.int8x16x3_t* [[A]], i32 0, i32 0
@@ -142,7 +142,7 @@
   return vtbl4_s8(a, b);
 }
 
-// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbl4_s8([4 x <16 x i8>] %a.coerce, <8 x i8> noundef %b) #0 {
+// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbl4_s8([4 x <16 x i8>] %a.coerce, <8 x i8> noundef %b) #1 {
 // CHECK:   [[__P0_I:%.*]] = alloca %struct.int8x16x4_t, align 16
 // CHECK:   [[A:%.*]] = alloca %struct.int8x16x4_t, align 16
 // CHECK:   [[COERCE_DIVE:%.*]] = getelementptr inbounds %struct.int8x16x4_t, %struct.int8x16x4_t* [[A]], i32 0, i32 0
@@ -352,7 +352,7 @@
   return vqtbx1_s8(a, b, c);
 }
 
-// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbx2_s8(<8 x i8> noundef %a, [2 x <16 x i8>] %b.coerce, <8 x i8> noundef %c) #0 {
+// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbx2_s8(<8 x i8> noundef %a, [2 x <16 x i8>] %b.coerce, <8 x i8> noundef %c) #1 {
 // CHECK:   [[__P1_I:%.*]] = alloca %struct.int8x16x2_t, align 16
 // CHECK:   [[B:%.*]] = alloca %struct.int8x16x2_t, align 16
 // CHECK:   [[COERCE_DIVE:%.*]] = getelementptr inbounds %struct.int8x16x2_t, %struct.int8x16x2_t* [[B]], i32 0, i32 0
@@ -373,7 +373,7 @@
   return vqtbx2_s8(a, b, c);
 }
 
-// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbx3_s8(<8 x i8> noundef %a, [3 x <16 x i8>] %b.coerce, <8 x i8> noundef %c) #0 {
+// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbx3_s8(<8 x i8> noundef %a, [3 x <16 x i8>] %b.coerce, <8 x i8> noundef %c) #1 {
 // CHECK:   [[__P1_I:%.*]] = alloca %struct.int8x16x3_t, align 16
 // CHECK:   [[B:%.*]] = alloca %struct.int8x16x3_t, align 16
 // CHECK:   [[COERCE_DIVE:%.*]] = getelementptr inbounds %struct.int8x16x3_t, %struct.int8x16x3_t* [[B]], i32 0, i32 0
@@ -397,7 +397,7 @@
   return vqtbx3_s8(a, b, c);
 }
 
-// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbx4_s8(<8 x i8> noundef %a, [4 x <16 x i8>] %b.coerce, <8 x i8> noundef %c) #0 {
+// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbx4_s8(<8 x i8> noundef %a, [4 x <16 x i8>] %b.coerce, <8 x i8> noundef %c) #1 {
 // CHECK:   [[__P1_I:%.*]] = alloca %struct.int8x16x4_t, align 16
 // CHECK:   [[B:%.*]] = alloca %struct.int8x16x4_t, align 16
 // CHECK:   [[COERCE_DIVE:%.*]] = getelementptr inbounds %struct.int8x16x4_t, %struct.int8x16x4_t* [[B]], i32 0, i32 0
@@ -540,7 +540,7 @@
   return vtbl2_u8(a, b);
 }
 
-// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbl2_u8([2 x <16 x i8>] %a.coerce, <8 x i8> noundef %b) #0 {
+// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbl2_u8([2 x <16 x i8>] %a.coerce, <8 x i8> noundef %b) #1 {
 // CHECK:   [[__P0_I:%.*]] = alloca %struct.uint8x16x2_t, align 16
 // CHECK:   [[A:%.*]] = alloca %struct.uint8x16x2_t, align 16
 // CHECK:   [[COERCE_DIVE:%.*]] = getelementptr inbounds %struct.uint8x16x2_t, %struct.uint8x16x2_t* [[A]], i32 0, i32 0
@@ -587,7 +587,7 @@
   return vtbl3_u8(a, b);
 }
 
-// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbl3_u8([3 x <16 x i8>] %a.coerce, <8 x i8> noundef %b) #0 {
+// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbl3_u8([3 x <16 x i8>] %a.coerce, <8 x i8> noundef %b) #1 {
 // CHECK:   [[__P0_I:%.*]] = alloca %struct.uint8x16x3_t, align 16
 // CHECK:   [[A:%.*]] = alloca %struct.uint8x16x3_t, align 16
 // CHECK:   [[COERCE_DIVE:%.*]] = getelementptr inbounds %struct.uint8x16x3_t, %struct.uint8x16x3_t* [[A]], i32 0, i32 0
@@ -640,7 +640,7 @@
   return vtbl4_u8(a, b);
 }
 
-// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbl4_u8([4 x <16 x i8>] %a.coerce, <8 x i8> noundef %b) #0 {
+// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbl4_u8([4 x <16 x i8>] %a.coerce, <8 x i8> noundef %b) #1 {
 // CHECK:   [[__P0_I:%.*]] = alloca %struct.uint8x16x4_t, align 16
 // CHECK:   [[A:%.*]] = alloca %struct.uint8x16x4_t, align 16
 // CHECK:   [[COERCE_DIVE:%.*]] = getelementptr inbounds %struct.uint8x16x4_t, %struct.uint8x16x4_t* [[A]], i32 0, i32 0
@@ -850,7 +850,7 @@
   return vqtbx1_u8(a, b, c);
 }
 
-// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbx2_u8(<8 x i8> noundef %a, [2 x <16 x i8>] %b.coerce, <8 x i8> noundef %c) #0 {
+// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbx2_u8(<8 x i8> noundef %a, [2 x <16 x i8>] %b.coerce, <8 x i8> noundef %c) #1 {
 // CHECK:   [[__P1_I:%.*]] = alloca %struct.uint8x16x2_t, align 16
 // CHECK:   [[B:%.*]] = alloca %struct.uint8x16x2_t, align 16
 // CHECK:   [[COERCE_DIVE:%.*]] = getelementptr inbounds %struct.uint8x16x2_t, %struct.uint8x16x2_t* [[B]], i32 0, i32 0
@@ -871,7 +871,7 @@
   return vqtbx2_u8(a, b, c);
 }
 
-// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbx3_u8(<8 x i8> noundef %a, [3 x <16 x i8>] %b.coerce, <8 x i8> noundef %c) #0 {
+// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbx3_u8(<8 x i8> noundef %a, [3 x <16 x i8>] %b.coerce, <8 x i8> noundef %c) #1 {
 // CHECK:   [[__P1_I:%.*]] = alloca %struct.uint8x16x3_t, align 16
 // CHECK:   [[B:%.*]] = alloca %struct.uint8x16x3_t, align 16
 // CHECK:   [[COERCE_DIVE:%.*]] = getelementptr inbounds %struct.uint8x16x3_t, %struct.uint8x16x3_t* [[B]], i32 0, i32 0
@@ -895,7 +895,7 @@
   return vqtbx3_u8(a, b, c);
 }
 
-// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbx4_u8(<8 x i8> noundef %a, [4 x <16 x i8>] %b.coerce, <8 x i8> noundef %c) #0 {
+// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbx4_u8(<8 x i8> noundef %a, [4 x <16 x i8>] %b.coerce, <8 x i8> noundef %c) #1 {
 // CHECK:   [[__P1_I:%.*]] = alloca %struct.uint8x16x4_t, align 16
 // CHECK:   [[B:%.*]] = alloca %struct.uint8x16x4_t, align 16
 // CHECK:   [[COERCE_DIVE:%.*]] = getelementptr inbounds %struct.uint8x16x4_t, %struct.uint8x16x4_t* [[B]], i32 0, i32 0
@@ -1038,7 +1038,7 @@
   return vtbl2_p8(a, b);
 }
 
-// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbl2_p8([2 x <16 x i8>] %a.coerce, <8 x i8> noundef %b) #0 {
+// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbl2_p8([2 x <16 x i8>] %a.coerce, <8 x i8> noundef %b) #1 {
 // CHECK:   [[__P0_I:%.*]] = alloca %struct.poly8x16x2_t, align 16
 // CHECK:   [[A:%.*]] = alloca %struct.poly8x16x2_t, align 16
 // CHECK:   [[COERCE_DIVE:%.*]] = getelementptr inbounds %struct.poly8x16x2_t, %struct.poly8x16x2_t* [[A]], i32 0, i32 0
@@ -1085,7 +1085,7 @@
   return vtbl3_p8(a, b);
 }
 
-// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbl3_p8([3 x <16 x i8>] %a.coerce, <8 x i8> noundef %b) #0 {
+// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbl3_p8([3 x <16 x i8>] %a.coerce, <8 x i8> noundef %b) #1 {
 // CHECK:   [[__P0_I:%.*]] = alloca %struct.poly8x16x3_t, align 16
 // CHECK:   [[A:%.*]] = alloca %struct.poly8x16x3_t, align 16
 // CHECK:   [[COERCE_DIVE:%.*]] = getelementptr inbounds %struct.poly8x16x3_t, %struct.poly8x16x3_t* [[A]], i32 0, i32 0
@@ -1138,7 +1138,7 @@
   return vtbl4_p8(a, b);
 }
 
-// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbl4_p8([4 x <16 x i8>] %a.coerce, <8 x i8> noundef %b) #0 {
+// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbl4_p8([4 x <16 x i8>] %a.coerce, <8 x i8> noundef %b) #1 {
 // CHECK:   [[__P0_I:%.*]] = alloca %struct.poly8x16x4_t, align 16
 // CHECK:   [[A:%.*]] = alloca %struct.poly8x16x4_t, align 16
 // CHECK:   [[COERCE_DIVE:%.*]] = getelementptr inbounds %struct.poly8x16x4_t, %struct.poly8x16x4_t* [[A]], i32 0, i32 0
@@ -1348,7 +1348,7 @@
   return vqtbx1_p8(a, b, c);
 }
 
-// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbx2_p8(<8 x i8> noundef %a, [2 x <16 x i8>] %b.coerce, <8 x i8> noundef %c) #0 {
+// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbx2_p8(<8 x i8> noundef %a, [2 x <16 x i8>] %b.coerce, <8 x i8> noundef %c) #1 {
 // CHECK:   [[__P1_I:%.*]] = alloca %struct.poly8x16x2_t, align 16
 // CHECK:   [[B:%.*]] = alloca %struct.poly8x16x2_t, align 16
 // CHECK:   [[COERCE_DIVE:%.*]] = getelementptr inbounds %struct.poly8x16x2_t, %struct.poly8x16x2_t* [[B]], i32 0, i32 0
@@ -1369,7 +1369,7 @@
   return vqtbx2_p8(a, b, c);
 }
 
-// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbx3_p8(<8 x i8> noundef %a, [3 x <16 x i8>] %b.coerce, <8 x i8> noundef %c) #0 {
+// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbx3_p8(<8 x i8> noundef %a, [3 x <16 x i8>] %b.coerce, <8 x i8> noundef %c) #1 {
 // CHECK:   [[__P1_I:%.*]] = alloca %struct.poly8x16x3_t, align 16
 // CHECK:   [[B:%.*]] = alloca %struct.poly8x16x3_t, align 16
 // CHECK:   [[COERCE_DIVE:%.*]] = getelementptr inbounds %struct.poly8x16x3_t, %struct.poly8x16x3_t* [[B]], i32 0, i32 0
@@ -1393,7 +1393,7 @@
   return vqtbx3_p8(a, b, c);
 }
 
-// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbx4_p8(<8 x i8> noundef %a, [4 x <16 x i8>] %b.coerce, <8 x i8> noundef %c) #0 {
+// CHECK-LABEL: define{{.*}} <8 x i8> @test_vqtbx4_p8(<8 x i8> noundef %a, [4 x <16 x i8>] %b.coerce, <8 x i8> noundef %c) #1 {
 // CHECK:   [[__P1_I:%.*]] = alloca %struct.poly8x16x4_t, align 16
 // CHECK:   [[B:%.*]] = alloca %struct.poly8x16x4_t, align 16
 // CHECK:   [[COERCE_DIVE:%.*]] = getelementptr inbounds %struct.poly8x16x4_t, %struct.poly8x16x4_t* [[B]], i32 0, i32 0
Index: clang/lib/CodeGen/TargetInfo.cpp
===================================================================
--- clang/lib/CodeGen/TargetInfo.cpp
+++ clang/lib/CodeGen/TargetInfo.cpp
@@ -2300,7 +2300,7 @@
   /// If the \arg Lo class is ComplexX87, then the \arg Hi class will
   /// also be ComplexX87.
   void classify(QualType T, uint64_t OffsetBase, Class &Lo, Class &Hi,
-                bool isNamedArg) const;
+                bool isNamedArg, bool IsRegCall = false) const;
 
   llvm::Type *GetByteVectorType(QualType Ty) const;
   llvm::Type *GetSSETypeAtOffset(llvm::Type *IRType,
@@ -2325,13 +2325,16 @@
 
   ABIArgInfo classifyArgumentType(QualType Ty, unsigned freeIntRegs,
                                   unsigned &neededInt, unsigned &neededSSE,
-                                  bool isNamedArg) const;
+                                  bool isNamedArg,
+                                  bool IsRegCall = false) const;
 
   ABIArgInfo classifyRegCallStructType(QualType Ty, unsigned &NeededInt,
-                                       unsigned &NeededSSE) const;
+                                       unsigned &NeededSSE,
+                                       unsigned &MaxSSEWidth) const;
 
   ABIArgInfo classifyRegCallStructTypeImpl(QualType Ty, unsigned &NeededInt,
-                                           unsigned &NeededSSE) const;
+                                           unsigned &NeededSSE,
+                                           unsigned &MaxSSEWidth) const;
 
   bool IsIllegalVectorType(QualType Ty) const;
 
@@ -2826,8 +2829,8 @@
   return SSE;
 }
 
-void X86_64ABIInfo::classify(QualType Ty, uint64_t OffsetBase,
-                             Class &Lo, Class &Hi, bool isNamedArg) const {
+void X86_64ABIInfo::classify(QualType Ty, uint64_t OffsetBase, Class &Lo,
+                             Class &Hi, bool isNamedArg, bool IsRegCall) const {
   // FIXME: This code can be simplified by introducing a simple value class for
   // Class pairs with appropriate constructor methods for the various
   // situations.
@@ -3025,7 +3028,7 @@
 
     // AMD64-ABI 3.2.3p2: Rule 1. If the size of an object is larger
     // than eight eightbytes, ..., it has class MEMORY.
-    if (Size > 512)
+    if (!IsRegCall && Size > 512)
       return;
 
     // AMD64-ABI 3.2.3p2: Rule 1. If ..., or it contains unaligned
@@ -3734,13 +3737,13 @@
 
 ABIArgInfo X86_64ABIInfo::classifyArgumentType(
   QualType Ty, unsigned freeIntRegs, unsigned &neededInt, unsigned &neededSSE,
-  bool isNamedArg)
+  bool isNamedArg, bool IsRegCall)
   const
 {
   Ty = useFirstFieldIfTransparentUnion(Ty);
 
   X86_64ABIInfo::Class Lo, Hi;
-  classify(Ty, 0, Lo, Hi, isNamedArg);
+  classify(Ty, 0, Lo, Hi, isNamedArg, IsRegCall);
 
   // Check some invariants.
   // FIXME: Enforce these by construction.
@@ -3863,7 +3866,8 @@
 
 ABIArgInfo
 X86_64ABIInfo::classifyRegCallStructTypeImpl(QualType Ty, unsigned &NeededInt,
-                                             unsigned &NeededSSE) const {
+                                             unsigned &NeededSSE,
+                                             unsigned &MaxSSEWidth) const {
   auto RT = Ty->getAs<RecordType>();
   assert(RT && "classifyRegCallStructType only valid with struct types");
 
@@ -3878,7 +3882,8 @@
     }
 
     for (const auto &I : CXXRD->bases())
-      if (classifyRegCallStructTypeImpl(I.getType(), NeededInt, NeededSSE)
+      if (classifyRegCallStructTypeImpl(I.getType(), NeededInt, NeededSSE,
+                                        MaxSSEWidth)
               .isIndirect()) {
         NeededInt = NeededSSE = 0;
         return getIndirectReturnResult(Ty);
@@ -3887,20 +3892,26 @@
 
   // Sum up members
   for (const auto *FD : RT->getDecl()->fields()) {
-    if (FD->getType()->isRecordType() && !FD->getType()->isUnionType()) {
-      if (classifyRegCallStructTypeImpl(FD->getType(), NeededInt, NeededSSE)
+    QualType MTy = FD->getType();
+    if (MTy->isRecordType() && !MTy->isUnionType()) {
+      if (classifyRegCallStructTypeImpl(MTy, NeededInt, NeededSSE, MaxSSEWidth)
               .isIndirect()) {
         NeededInt = NeededSSE = 0;
         return getIndirectReturnResult(Ty);
       }
     } else {
       unsigned LocalNeededInt, LocalNeededSSE;
-      if (classifyArgumentType(FD->getType(), UINT_MAX, LocalNeededInt,
-                               LocalNeededSSE, true)
+      if (classifyArgumentType(MTy, UINT_MAX, LocalNeededInt, LocalNeededSSE,
+                               true, true)
               .isIndirect()) {
         NeededInt = NeededSSE = 0;
         return getIndirectReturnResult(Ty);
       }
+      if (const auto *AT = getContext().getAsConstantArrayType(MTy))
+        MTy = AT->getElementType();
+      if (const auto *VT = MTy->getAs<VectorType>())
+        if (getContext().getTypeSize(VT) > MaxSSEWidth)
+          MaxSSEWidth = getContext().getTypeSize(VT);
       NeededInt += LocalNeededInt;
       NeededSSE += LocalNeededSSE;
     }
@@ -3909,14 +3920,16 @@
   return ABIArgInfo::getDirect();
 }
 
-ABIArgInfo X86_64ABIInfo::classifyRegCallStructType(QualType Ty,
-                                                    unsigned &NeededInt,
-                                                    unsigned &NeededSSE) const {
+ABIArgInfo
+X86_64ABIInfo::classifyRegCallStructType(QualType Ty, unsigned &NeededInt,
+                                         unsigned &NeededSSE,
+                                         unsigned &MaxSSEWidth) const {
 
   NeededInt = 0;
   NeededSSE = 0;
+  MaxSSEWidth = 0;
 
-  return classifyRegCallStructTypeImpl(Ty, NeededInt, NeededSSE);
+  return classifyRegCallStructTypeImpl(Ty, NeededInt, NeededSSE, MaxSSEWidth);
 }
 
 void X86_64ABIInfo::computeInfo(CGFunctionInfo &FI) const {
@@ -3936,13 +3949,13 @@
   // Keep track of the number of assigned registers.
   unsigned FreeIntRegs = IsRegCall ? 11 : 6;
   unsigned FreeSSERegs = IsRegCall ? 16 : 8;
-  unsigned NeededInt, NeededSSE;
+  unsigned NeededInt, NeededSSE, MaxSSEWidth = 0;
 
   if (!::classifyReturnType(getCXXABI(), FI, *this)) {
     if (IsRegCall && FI.getReturnType()->getTypePtr()->isRecordType() &&
         !FI.getReturnType()->getTypePtr()->isUnionType()) {
-      FI.getReturnInfo() =
-          classifyRegCallStructType(FI.getReturnType(), NeededInt, NeededSSE);
+      FI.getReturnInfo() = classifyRegCallStructType(
+          FI.getReturnType(), NeededInt, NeededSSE, MaxSSEWidth);
       if (FreeIntRegs >= NeededInt && FreeSSERegs >= NeededSSE) {
         FreeIntRegs -= NeededInt;
         FreeSSERegs -= NeededSSE;
@@ -3965,6 +3978,8 @@
   // integer register.
   if (FI.getReturnInfo().isIndirect())
     --FreeIntRegs;
+  else if (NeededSSE)
+    FI.setMaxVectorWidth(MaxSSEWidth);
 
   // The chain argument effectively gives us another free register.
   if (FI.isChainCall())
@@ -3979,7 +3994,8 @@
     bool IsNamedArg = ArgNo < NumRequiredArgs;
 
     if (IsRegCall && it->type->isStructureOrClassType())
-      it->info = classifyRegCallStructType(it->type, NeededInt, NeededSSE);
+      it->info = classifyRegCallStructType(it->type, NeededInt, NeededSSE,
+                                           MaxSSEWidth);
     else
       it->info = classifyArgumentType(it->type, FreeIntRegs, NeededInt,
                                       NeededSSE, IsNamedArg);
@@ -3991,6 +4007,8 @@
     if (FreeIntRegs >= NeededInt && FreeSSERegs >= NeededSSE) {
       FreeIntRegs -= NeededInt;
       FreeSSERegs -= NeededSSE;
+      if (MaxSSEWidth > FI.getMaxVectorWidth())
+        FI.setMaxVectorWidth(MaxSSEWidth);
     } else {
       it->info = getIndirectResult(it->type, FreeIntRegs);
     }
Index: clang/lib/CodeGen/CodeGenFunction.cpp
===================================================================
--- clang/lib/CodeGen/CodeGenFunction.cpp
+++ clang/lib/CodeGen/CodeGenFunction.cpp
@@ -485,6 +485,9 @@
         std::max((uint64_t)LargestVectorWidth,
                  VT->getPrimitiveSizeInBits().getKnownMinSize());
 
+  if (CurFnInfo->getMaxVectorWidth() > LargestVectorWidth)
+    LargestVectorWidth = CurFnInfo->getMaxVectorWidth();
+
   // Add the required-vector-width attribute. This contains the max width from:
   // 1. min-vector-width attribute used in the source program.
   // 2. Any builtins used that have a vector width specified.
Index: clang/lib/CodeGen/CGCall.cpp
===================================================================
--- clang/lib/CodeGen/CGCall.cpp
+++ clang/lib/CodeGen/CGCall.cpp
@@ -833,6 +833,7 @@
   FI->NumArgs = argTypes.size();
   FI->HasExtParameterInfos = !paramInfos.empty();
   FI->getArgsBuffer()[0].type = resultType;
+  FI->MaxVectorWidth = 0;
   for (unsigned i = 0, e = argTypes.size(); i != e; ++i)
     FI->getArgsBuffer()[i + 1].type = argTypes[i];
   for (unsigned i = 0, e = paramInfos.size(); i != e; ++i)
@@ -4669,6 +4670,19 @@
 
 } // namespace
 
+unsigned getMaxVectorWidth(const llvm::Type *Ty) {
+  if (auto *VT = dyn_cast<llvm::VectorType>(Ty))
+    return VT->getPrimitiveSizeInBits().getKnownMinSize();
+  if (auto *AT = dyn_cast<llvm::ArrayType>(Ty))
+    return getMaxVectorWidth(AT->getElementType());
+
+  unsigned MaxVectorWidth = 0;
+  if (auto *ST = dyn_cast<llvm::StructType>(Ty))
+    for (auto *I : ST->elements())
+      MaxVectorWidth = std::max(MaxVectorWidth, getMaxVectorWidth(I));
+  return MaxVectorWidth;
+}
+
 RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
                                  const CGCallee &Callee,
                                  ReturnValueSlot ReturnValue,
@@ -5220,12 +5234,9 @@
 #endif
 
   // Update the largest vector width if any arguments have vector types.
-  for (unsigned i = 0; i < IRCallArgs.size(); ++i) {
-    if (auto *VT = dyn_cast<llvm::VectorType>(IRCallArgs[i]->getType()))
-      LargestVectorWidth =
-          std::max((uint64_t)LargestVectorWidth,
-                   VT->getPrimitiveSizeInBits().getKnownMinSize());
-  }
+  for (unsigned i = 0; i < IRCallArgs.size(); ++i)
+    LargestVectorWidth = std::max(LargestVectorWidth,
+                                  getMaxVectorWidth(IRCallArgs[i]->getType()));
 
   // Compute the calling convention and attributes.
   unsigned CallingConv;
@@ -5347,10 +5358,8 @@
     CI->setName("call");
 
   // Update largest vector width from the return type.
-  if (auto *VT = dyn_cast<llvm::VectorType>(CI->getType()))
-    LargestVectorWidth =
-        std::max((uint64_t)LargestVectorWidth,
-                 VT->getPrimitiveSizeInBits().getKnownMinSize());
+  LargestVectorWidth =
+      std::max(LargestVectorWidth, getMaxVectorWidth(CI->getType()));
 
   // Insert instrumentation or attach profile metadata at indirect call sites.
   // For more details, see the comment before the definition of
Index: clang/include/clang/CodeGen/CGFunctionInfo.h
===================================================================
--- clang/include/clang/CodeGen/CGFunctionInfo.h
+++ clang/include/clang/CodeGen/CGFunctionInfo.h
@@ -586,6 +586,9 @@
   /// Whether this function has nocf_check attribute.
   unsigned NoCfCheck : 1;
 
+  /// Log 2 of the maximum vector width.
+  unsigned MaxVectorWidth : 4;
+
   RequiredArgs Required;
 
   /// The struct representing all arguments passed in memory.  Only used when
@@ -731,6 +734,16 @@
     ArgStructAlign = Align.getQuantity();
   }
 
+  /// Return the maximum vector width in the arguments.
+  unsigned getMaxVectorWidth() const {
+    return MaxVectorWidth ? 1U << (MaxVectorWidth - 1) : 0;
+  }
+
+  /// Set the maximum vector width in the arguments.
+  void setMaxVectorWidth(unsigned Width) {
+    MaxVectorWidth = llvm::Log2_32(Width) + 1;
+  }
+
   void Profile(llvm::FoldingSetNodeID &ID) {
     ID.AddInteger(getASTCallingConvention());
     ID.AddBoolean(InstanceMethod);
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to