https://github.com/tbaederr created 
https://github.com/llvm/llvm-project/pull/102723

None

>From 5d00cf049b64bf31a6039c1c4761dca64483a4f3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Timm=20B=C3=A4der?= <tbae...@redhat.com>
Date: Fri, 9 Aug 2024 15:30:30 +0200
Subject: [PATCH] [clang][Interp] Start implementing unions and changing the
 active member

---
 clang/lib/AST/Interp/Compiler.cpp   | 16 ++++--
 clang/lib/AST/Interp/Descriptor.cpp | 49 +++++++++-------
 clang/lib/AST/Interp/Descriptor.h   |  8 ++-
 clang/lib/AST/Interp/Disasm.cpp     |  3 +
 clang/lib/AST/Interp/Interp.cpp     | 21 ++++++-
 clang/lib/AST/Interp/Interp.h       |  8 ++-
 clang/lib/AST/Interp/InterpBlock.h  |  5 +-
 clang/lib/AST/Interp/Pointer.cpp    | 27 ++++++++-
 clang/lib/AST/Interp/Pointer.h      |  6 ++
 clang/test/AST/Interp/unions.cpp    | 88 +++++++++++++++++++++++++++++
 10 files changed, 197 insertions(+), 34 deletions(-)

diff --git a/clang/lib/AST/Interp/Compiler.cpp 
b/clang/lib/AST/Interp/Compiler.cpp
index 11fe2acf2d7b95..0d72e33c1c7d25 100644
--- a/clang/lib/AST/Interp/Compiler.cpp
+++ b/clang/lib/AST/Interp/Compiler.cpp
@@ -4739,7 +4739,8 @@ bool Compiler<Emitter>::visitFunc(const FunctionDecl *F) {
   // Classify the return type.
   ReturnType = this->classify(F->getReturnType());
 
-  auto emitFieldInitializer = [&](const Record::Field *F, unsigned FieldOffset,
+  auto emitFieldInitializer = [&](const Record *R, const Record::Field *F,
+                                  unsigned FieldOffset,
                                   const Expr *InitExpr) -> bool {
     // We don't know what to do with these, so just return false.
     if (InitExpr->getType().isNull())
@@ -4751,6 +4752,8 @@ bool Compiler<Emitter>::visitFunc(const FunctionDecl *F) {
 
       if (F->isBitField())
         return this->emitInitThisBitField(*T, F, FieldOffset, InitExpr);
+      if (R->isUnion())
+        return this->emitInitThisFieldActive(*T, FieldOffset, InitExpr);
       return this->emitInitThisField(*T, FieldOffset, InitExpr);
     }
     // Non-primitive case. Get a pointer to the field-to-initialize
@@ -4762,7 +4765,7 @@ bool Compiler<Emitter>::visitFunc(const FunctionDecl *F) {
     if (!this->visitInitializer(InitExpr))
       return false;
 
-    return this->emitPopPtr(InitExpr);
+    return this->emitFinishInitPop(InitExpr);
   };
 
   // Emit custom code if this is a lambda static invoker.
@@ -4786,7 +4789,7 @@ bool Compiler<Emitter>::visitFunc(const FunctionDecl *F) {
       if (const FieldDecl *Member = Init->getMember()) {
         const Record::Field *F = R->getField(Member);
 
-        if (!emitFieldInitializer(F, F->Offset, InitExpr))
+        if (!emitFieldInitializer(R, F, F->Offset, InitExpr))
           return false;
       } else if (const Type *Base = Init->getBaseClass()) {
         const auto *BaseDecl = Base->getAsCXXRecordDecl();
@@ -4814,11 +4817,11 @@ bool Compiler<Emitter>::visitFunc(const FunctionDecl 
*F) {
         assert(IFD->getChainingSize() >= 2);
 
         unsigned NestedFieldOffset = 0;
+        const Record *FieldRecord = nullptr;
         const Record::Field *NestedField = nullptr;
         for (const NamedDecl *ND : IFD->chain()) {
           const auto *FD = cast<FieldDecl>(ND);
-          const Record *FieldRecord =
-              this->P.getOrCreateRecord(FD->getParent());
+          FieldRecord = this->P.getOrCreateRecord(FD->getParent());
           assert(FieldRecord);
 
           NestedField = FieldRecord->getField(FD);
@@ -4828,7 +4831,8 @@ bool Compiler<Emitter>::visitFunc(const FunctionDecl *F) {
         }
         assert(NestedField);
 
-        if (!emitFieldInitializer(NestedField, NestedFieldOffset, InitExpr))
+        if (!emitFieldInitializer(FieldRecord, NestedField, NestedFieldOffset,
+                                  InitExpr))
           return false;
       } else {
         assert(Init->isDelegatingInitializer());
diff --git a/clang/lib/AST/Interp/Descriptor.cpp 
b/clang/lib/AST/Interp/Descriptor.cpp
index 23dd08ca486275..634413b2f2f52a 100644
--- a/clang/lib/AST/Interp/Descriptor.cpp
+++ b/clang/lib/AST/Interp/Descriptor.cpp
@@ -20,7 +20,7 @@ using namespace clang;
 using namespace clang::interp;
 
 template <typename T>
-static void ctorTy(Block *, std::byte *Ptr, bool, bool, bool,
+static void ctorTy(Block *, std::byte *Ptr, bool, bool, bool, bool,
                    const Descriptor *) {
   new (Ptr) T();
 }
@@ -40,7 +40,7 @@ static void moveTy(Block *, const std::byte *Src, std::byte 
*Dst,
 }
 
 template <typename T>
-static void ctorArrayTy(Block *, std::byte *Ptr, bool, bool, bool,
+static void ctorArrayTy(Block *, std::byte *Ptr, bool, bool, bool, bool,
                         const Descriptor *D) {
   new (Ptr) InitMapPtr(std::nullopt);
 
@@ -83,7 +83,8 @@ static void moveArrayTy(Block *, const std::byte *Src, 
std::byte *Dst,
 }
 
 static void ctorArrayDesc(Block *B, std::byte *Ptr, bool IsConst,
-                          bool IsMutable, bool IsActive, const Descriptor *D) {
+                          bool IsMutable, bool IsActive, bool InUnion,
+                          const Descriptor *D) {
   const unsigned NumElems = D->getNumElems();
   const unsigned ElemSize =
       D->ElemDesc->getAllocSize() + sizeof(InlineDescriptor);
@@ -102,9 +103,11 @@ static void ctorArrayDesc(Block *B, std::byte *Ptr, bool 
IsConst,
     Desc->IsActive = IsActive;
     Desc->IsConst = IsConst || D->IsConst;
     Desc->IsFieldMutable = IsMutable || D->IsMutable;
+    Desc->InUnion = InUnion;
+
     if (auto Fn = D->ElemDesc->CtorFn)
       Fn(B, ElemLoc, Desc->IsConst, Desc->IsFieldMutable, IsActive,
-         D->ElemDesc);
+         Desc->InUnion || SD->isUnion(), D->ElemDesc);
   }
 }
 
@@ -146,25 +149,26 @@ static void moveArrayDesc(Block *B, const std::byte *Src, 
std::byte *Dst,
 }
 
 static void initField(Block *B, std::byte *Ptr, bool IsConst, bool IsMutable,
-                      bool IsActive, bool IsUnion, const Descriptor *D,
-                      unsigned FieldOffset) {
+                      bool IsActive, bool IsUnionField, bool InUnion,
+                      const Descriptor *D, unsigned FieldOffset) {
   auto *Desc = reinterpret_cast<InlineDescriptor *>(Ptr + FieldOffset) - 1;
   Desc->Offset = FieldOffset;
   Desc->Desc = D;
   Desc->IsInitialized = D->IsArray;
   Desc->IsBase = false;
-  Desc->IsActive = IsActive && !IsUnion;
+  Desc->IsActive = IsActive && !IsUnionField;
+  Desc->InUnion = InUnion;
   Desc->IsConst = IsConst || D->IsConst;
   Desc->IsFieldMutable = IsMutable || D->IsMutable;
 
   if (auto Fn = D->CtorFn)
     Fn(B, Ptr + FieldOffset, Desc->IsConst, Desc->IsFieldMutable,
-       Desc->IsActive, D);
+       Desc->IsActive, InUnion || D->isUnion(), D);
 }
 
 static void initBase(Block *B, std::byte *Ptr, bool IsConst, bool IsMutable,
-                     bool IsActive, const Descriptor *D, unsigned FieldOffset,
-                     bool IsVirtualBase) {
+                     bool IsActive, bool InUnion, const Descriptor *D,
+                     unsigned FieldOffset, bool IsVirtualBase) {
   assert(D);
   assert(D->ElemRecord);
 
@@ -180,21 +184,26 @@ static void initBase(Block *B, std::byte *Ptr, bool 
IsConst, bool IsMutable,
   Desc->IsFieldMutable = IsMutable || D->IsMutable;
 
   for (const auto &V : D->ElemRecord->bases())
-    initBase(B, Ptr + FieldOffset, IsConst, IsMutable, IsActive, V.Desc,
-             V.Offset, false);
+    initBase(B, Ptr + FieldOffset, IsConst, IsMutable, IsActive, InUnion,
+             V.Desc, V.Offset, false);
   for (const auto &F : D->ElemRecord->fields())
-    initField(B, Ptr + FieldOffset, IsConst, IsMutable, IsActive, IsUnion,
-              F.Desc, F.Offset);
+    initField(B, Ptr + FieldOffset, IsConst, IsMutable, IsActive, InUnion,
+              IsUnion, F.Desc, F.Offset);
 }
 
 static void ctorRecord(Block *B, std::byte *Ptr, bool IsConst, bool IsMutable,
-                       bool IsActive, const Descriptor *D) {
+                       bool IsActive, bool InUnion, const Descriptor *D) {
   for (const auto &V : D->ElemRecord->bases())
-    initBase(B, Ptr, IsConst, IsMutable, IsActive, V.Desc, V.Offset, false);
-  for (const auto &F : D->ElemRecord->fields())
-    initField(B, Ptr, IsConst, IsMutable, IsActive, D->ElemRecord->isUnion(), 
F.Desc, F.Offset);
+    initBase(B, Ptr, IsConst, IsMutable, IsActive, false, V.Desc, V.Offset,
+             false);
+  for (const auto &F : D->ElemRecord->fields()) {
+    bool IsUnionField = D->isUnion();
+    initField(B, Ptr, IsConst, IsMutable, IsActive, IsUnionField,
+              InUnion || IsUnionField, F.Desc, F.Offset);
+  }
   for (const auto &V : D->ElemRecord->virtual_bases())
-    initBase(B, Ptr, IsConst, IsMutable, IsActive, V.Desc, V.Offset, true);
+    initBase(B, Ptr, IsConst, IsMutable, IsActive, false, V.Desc, V.Offset,
+             true);
 }
 
 static void destroyField(Block *B, std::byte *Ptr, const Descriptor *D,
@@ -403,6 +412,8 @@ SourceLocation Descriptor::getLocation() const {
   llvm_unreachable("Invalid descriptor type");
 }
 
+bool Descriptor::isUnion() const { return isRecord() && ElemRecord->isUnion(); 
}
+
 InitMap::InitMap(unsigned N)
     : UninitFields(N), Data(std::make_unique<T[]>(numFields(N))) {
   std::fill_n(data(), numFields(N), 0);
diff --git a/clang/lib/AST/Interp/Descriptor.h 
b/clang/lib/AST/Interp/Descriptor.h
index 0cc5d77c407e34..6f1adeb898c430 100644
--- a/clang/lib/AST/Interp/Descriptor.h
+++ b/clang/lib/AST/Interp/Descriptor.h
@@ -32,7 +32,7 @@ using InitMapPtr = std::optional<std::pair<bool, 
std::shared_ptr<InitMap>>>;
 /// inline descriptors of all fields and array elements. It also initializes
 /// all the fields which contain non-trivial types.
 using BlockCtorFn = void (*)(Block *Storage, std::byte *FieldPtr, bool IsConst,
-                             bool IsMutable, bool IsActive,
+                             bool IsMutable, bool IsActive, bool InUnion,
                              const Descriptor *FieldDesc);
 
 /// Invoked when a block is destroyed. Invokes the destructors of all
@@ -83,11 +83,15 @@ struct InlineDescriptor {
   /// Flag indicating if the field is an embedded base class.
   LLVM_PREFERRED_TYPE(bool)
   unsigned IsBase : 1;
+  /// Flag inidcating if the field is a virtual base class.
   LLVM_PREFERRED_TYPE(bool)
   unsigned IsVirtualBase : 1;
   /// Flag indicating if the field is the active member of a union.
   LLVM_PREFERRED_TYPE(bool)
   unsigned IsActive : 1;
+  /// Flat indicating if this field is in a union (even if nested).
+  unsigned InUnion : 1;
+  LLVM_PREFERRED_TYPE(bool)
   /// Flag indicating if the field is mutable (if in a record).
   LLVM_PREFERRED_TYPE(bool)
   unsigned IsFieldMutable : 1;
@@ -250,6 +254,8 @@ struct Descriptor final {
   bool isArray() const { return IsArray; }
   /// Checks if the descriptor is of a record.
   bool isRecord() const { return !IsArray && ElemRecord; }
+  /// Checks if the descriptor is of a union.
+  bool isUnion() const;
   /// Checks if this is a dummy descriptor.
   bool isDummy() const { return IsDummy; }
 
diff --git a/clang/lib/AST/Interp/Disasm.cpp b/clang/lib/AST/Interp/Disasm.cpp
index 5e3a5b9515b528..e1051e5c2bbf62 100644
--- a/clang/lib/AST/Interp/Disasm.cpp
+++ b/clang/lib/AST/Interp/Disasm.cpp
@@ -226,6 +226,8 @@ LLVM_DUMP_METHOD void Descriptor::dump(llvm::raw_ostream 
&OS) const {
     OS << " primitive-array";
   else if (isCompositeArray())
     OS << " composite-array";
+  else if (isUnion())
+    OS << " union";
   else if (isRecord())
     OS << " record";
   else if (isPrimitive())
@@ -250,6 +252,7 @@ LLVM_DUMP_METHOD void 
InlineDescriptor::dump(llvm::raw_ostream &OS) const {
   OS << "IsInitialized: " << IsInitialized << "\n";
   OS << "IsBase: " << IsBase << "\n";
   OS << "IsActive: " << IsActive << "\n";
+  OS << "InUnion: " << InUnion << "\n";
   OS << "IsFieldMutable: " << IsFieldMutable << "\n";
   OS << "Desc: ";
   if (Desc)
diff --git a/clang/lib/AST/Interp/Interp.cpp b/clang/lib/AST/Interp/Interp.cpp
index 85cb8ff2db974e..588680adb4616f 100644
--- a/clang/lib/AST/Interp/Interp.cpp
+++ b/clang/lib/AST/Interp/Interp.cpp
@@ -127,16 +127,31 @@ static bool CheckActive(InterpState &S, CodePtr OpPC, 
const Pointer &Ptr,
 
   // Get the inactive field descriptor.
   const FieldDecl *InactiveField = Ptr.getField();
+  assert(InactiveField);
 
-  // Walk up the pointer chain to find the union which is not active.
+  // Walk up the pointer chain to find the closest union.
   Pointer U = Ptr.getBase();
-  while (!U.isActive()) {
+  while (!U.getFieldDesc()->isUnion())
     U = U.getBase();
-  }
 
   // Find the active field of the union.
   const Record *R = U.getRecord();
   assert(R && R->isUnion() && "Not a union");
+
+  // Consider:
+  // union U {
+  //   struct {
+  //     int x;
+  //     int y;
+  //   } a;
+  // }
+  //
+  // When activating x, we will also activate a. If we now try to read
+  // from y, we will get to CheckActive, because y is not active. In that
+  // case we return here and let later code handle this.
+  if (!llvm::is_contained(R->getDecl()->fields(), InactiveField))
+    return true;
+
   const FieldDecl *ActiveField = nullptr;
   for (unsigned I = 0, N = R->getNumFields(); I < N; ++I) {
     const Pointer &Field = U.atField(R->getField(I)->Offset);
diff --git a/clang/lib/AST/Interp/Interp.h b/clang/lib/AST/Interp/Interp.h
index 832fc028ad6696..196fc15a77519b 100644
--- a/clang/lib/AST/Interp/Interp.h
+++ b/clang/lib/AST/Interp/Interp.h
@@ -1719,8 +1719,10 @@ bool Store(InterpState &S, CodePtr OpPC) {
   const Pointer &Ptr = S.Stk.peek<Pointer>();
   if (!CheckStore(S, OpPC, Ptr))
     return false;
-  if (Ptr.canBeInitialized())
+  if (Ptr.canBeInitialized()) {
     Ptr.initialize();
+    Ptr.activate();
+  }
   Ptr.deref<T>() = Value;
   return true;
 }
@@ -1731,8 +1733,10 @@ bool StorePop(InterpState &S, CodePtr OpPC) {
   const Pointer &Ptr = S.Stk.pop<Pointer>();
   if (!CheckStore(S, OpPC, Ptr))
     return false;
-  if (Ptr.canBeInitialized())
+  if (Ptr.canBeInitialized()) {
     Ptr.initialize();
+    Ptr.activate();
+  }
   Ptr.deref<T>() = Value;
   return true;
 }
diff --git a/clang/lib/AST/Interp/InterpBlock.h 
b/clang/lib/AST/Interp/InterpBlock.h
index 3760ded7b13fef..20cf8cb9d3810e 100644
--- a/clang/lib/AST/Interp/InterpBlock.h
+++ b/clang/lib/AST/Interp/InterpBlock.h
@@ -110,9 +110,10 @@ class Block final {
   void invokeCtor() {
     assert(!IsInitialized);
     std::memset(rawData(), 0, Desc->getAllocSize());
-    if (Desc->CtorFn)
+    if (Desc->CtorFn) {
       Desc->CtorFn(this, data(), Desc->IsConst, Desc->IsMutable,
-                   /*isActive=*/true, Desc);
+                   /*isActive=*/true, /*InUnion=*/false, Desc);
+    }
     IsInitialized = true;
   }
 
diff --git a/clang/lib/AST/Interp/Pointer.cpp b/clang/lib/AST/Interp/Pointer.cpp
index ba9683a059e18b..f1f7a27c1400dd 100644
--- a/clang/lib/AST/Interp/Pointer.cpp
+++ b/clang/lib/AST/Interp/Pointer.cpp
@@ -388,12 +388,37 @@ void Pointer::initialize() const {
 void Pointer::activate() const {
   // Field has its bit in an inline descriptor.
   assert(PointeeStorage.BS.Base != 0 &&
-         "Only composite fields can be initialised");
+         "Only composite fields can be activated");
 
   if (isRoot() && PointeeStorage.BS.Base == sizeof(GlobalInlineDescriptor))
     return;
+  if (!getInlineDesc()->InUnion)
+    return;
 
   getInlineDesc()->IsActive = true;
+
+  // Get the union, iterate over its fields and DEactivate all others.
+  Pointer UnionPtr = getBase();
+  while (!UnionPtr.getFieldDesc()->isUnion())
+    UnionPtr = UnionPtr.getBase();
+
+  const Record *UnionRecord = UnionPtr.getRecord();
+  for (const Record::Field &F : UnionRecord->fields()) {
+    Pointer FieldPtr = UnionPtr.atField(F.Offset);
+    if (FieldPtr == *this) {
+    } else {
+      FieldPtr.getInlineDesc()->IsActive = false;
+      // FIXME: Recurse.
+    }
+  }
+
+  Pointer B = getBase();
+  while (!B.getFieldDesc()->isUnion()) {
+    // FIXME: Need to de-activate other fields of parent records.
+    B.getInlineDesc()->IsActive = true;
+    assert(B.isActive());
+    B = B.getBase();
+  }
 }
 
 void Pointer::deactivate() const {
diff --git a/clang/lib/AST/Interp/Pointer.h b/clang/lib/AST/Interp/Pointer.h
index b7b4f82f16f66b..07ff8025ba9541 100644
--- a/clang/lib/AST/Interp/Pointer.h
+++ b/clang/lib/AST/Interp/Pointer.h
@@ -400,6 +400,12 @@ class Pointer {
       return getFieldDesc()->IsArray;
     return false;
   }
+  bool inUnion() const {
+    if (isBlockPointer())
+      return getInlineDesc()->InUnion;
+    return false;
+  };
+
   /// Checks if the structure is a primitive array.
   bool inPrimitiveArray() const {
     if (isBlockPointer())
diff --git a/clang/test/AST/Interp/unions.cpp b/clang/test/AST/Interp/unions.cpp
index 293a1981a52f07..d615b3584b30b7 100644
--- a/clang/test/AST/Interp/unions.cpp
+++ b/clang/test/AST/Interp/unions.cpp
@@ -1,5 +1,7 @@
 // RUN: %clang_cc1 -fexperimental-new-constant-interpreter 
-verify=expected,both %s
+// RUN: %clang_cc1 -std=c++20 -fexperimental-new-constant-interpreter 
-verify=expected,both %s
 // RUN: %clang_cc1 -verify=ref,both %s
+// RUN: %clang_cc1 -std=c++20 -verify=ref,both %s
 
 union U {
   int a;
@@ -65,3 +67,89 @@ namespace ZeroInit {
   constexpr Z z{};
   static_assert(z.f == 0.0, "");
 }
+
+namespace DefaultInit {
+  union U1 {
+    constexpr U1() {}
+    int a, b = 42;
+  };
+
+  constexpr U1 u1; /// OK.
+
+  constexpr int foo() {
+    U1 u;
+    return u.a; // both-note {{read of member 'a' of union with active member 
'b'}}
+  }
+  static_assert(foo() == 42); // both-error {{not an integral constant 
expression}} \
+                              // both-note {{in call to}}
+}
+
+#if __cplusplus >= 202002L
+namespace SimpleActivate {
+  constexpr int foo() { // ref-error {{never produces a constant expression}}
+    union {
+      int a;
+      int b;
+    } Z;
+
+    Z.a = 10;
+    Z.b = 20;
+    return Z.a; // both-note {{read of member 'a' of union with active member 
'b'}} \
+                // ref-note {{read of member 'a' of union with active member 
'b}}
+  }
+  static_assert(foo() == 20); // both-error {{not an integral constant 
expression}} \
+                              // both-note {{in call to}}
+
+  constexpr int foo2() {
+    union {
+      int a;
+      int b;
+    } Z;
+
+    Z.a = 10;
+    Z.b = 20;
+    return Z.b;
+  }
+  static_assert(foo2() == 20);
+
+
+  constexpr int foo3() {
+    union {
+      struct {
+        float x,y;
+      } a;
+      int b;
+    } Z;
+
+    Z.a.y = 10;
+
+    return Z.a.x; // both-note {{read of uninitialized object}}
+  }
+  static_assert(foo3() == 10); // both-error {{not an integral constant 
expression}} \
+                               // both-note {{in call to}}
+
+  constexpr int foo4() {
+    union {
+      struct {
+        float x,y;
+      } a;
+      int b;
+    } Z;
+
+    Z.a.x = 100;
+    Z.a.y = 10;
+
+    return Z.a.x;
+  }
+  static_assert(foo4() == 100);
+}
+
+namespace IndirectFieldDecl {
+  struct C {
+    union { int a, b = 2, c; };
+    union { int d, e = 5, f; };
+    constexpr C() : a(1) {}
+  };
+  static_assert(C().a == 1, "");
+}
+#endif

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to