llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang

Author: Timm Baeder (tbaederr)

<details>
<summary>Changes</summary>

Union members get the same address, so we can't just use 
`Pointer::getByteOffset()`.

---
Full diff: https://github.com/llvm/llvm-project/pull/133852.diff


4 Files Affected:

- (modified) clang/lib/AST/ByteCode/Interp.h (+10-1) 
- (modified) clang/lib/AST/ByteCode/Pointer.cpp (+31) 
- (modified) clang/lib/AST/ByteCode/Pointer.h (+3-1) 
- (modified) clang/test/AST/ByteCode/unions.cpp (+38) 


``````````diff
diff --git a/clang/lib/AST/ByteCode/Interp.h b/clang/lib/AST/ByteCode/Interp.h
index 938077a9f10ae..6fe1d4b1f95ae 100644
--- a/clang/lib/AST/ByteCode/Interp.h
+++ b/clang/lib/AST/ByteCode/Interp.h
@@ -1070,9 +1070,18 @@ inline bool CmpHelperEQ<Pointer>(InterpState &S, CodePtr 
OpPC, CompareFn Fn) {
   }
 
   if (Pointer::hasSameBase(LHS, RHS)) {
+    if (LHS.inUnion() && RHS.inUnion()) {
+      // If the pointers point into a union, things are a little more
+      // complicated since the offset we save in interp::Pointer can't be used
+      // to compare the pointers directly.
+      size_t A = LHS.computeOffsetForComparison();
+      size_t B = RHS.computeOffsetForComparison();
+      S.Stk.push<BoolT>(BoolT::from(Fn(Compare(A, B))));
+      return true;
+    }
+
     unsigned VL = LHS.getByteOffset();
     unsigned VR = RHS.getByteOffset();
-
     // In our Pointer class, a pointer to an array and a pointer to the first
     // element in the same array are NOT equal. They have the same Base value,
     // but a different Offset. This is a pretty rare case, so we fix this here
diff --git a/clang/lib/AST/ByteCode/Pointer.cpp 
b/clang/lib/AST/ByteCode/Pointer.cpp
index 79b47c26992ae..7ccefd6aeb9e1 100644
--- a/clang/lib/AST/ByteCode/Pointer.cpp
+++ b/clang/lib/AST/ByteCode/Pointer.cpp
@@ -339,6 +339,37 @@ void Pointer::print(llvm::raw_ostream &OS) const {
   }
 }
 
+/// Compute an integer that can be used to compare this pointer to
+/// another one.
+size_t Pointer::computeOffsetForComparison() const {
+  if (!isBlockPointer())
+    return Offset;
+
+  size_t Result = 0;
+  Pointer P = *this;
+  while (!P.isRoot()) {
+    if (P.isArrayRoot()) {
+      P = P.getBase();
+      continue;
+    }
+    if (P.isArrayElement()) {
+      Result += (P.getIndex() * P.elemSize());
+      P = P.getArray();
+    }
+
+    if (const Record *R = P.getBase().getRecord(); R && R->isUnion()) {
+      // Direct child of a union - all have offset 0.
+      P = P.getBase();
+      continue;
+    }
+
+    Result += P.getInlineDesc()->Offset;
+    P = P.getBase();
+  }
+
+  return Result;
+}
+
 std::string Pointer::toDiagnosticString(const ASTContext &Ctx) const {
   if (isZero())
     return "nullptr";
diff --git a/clang/lib/AST/ByteCode/Pointer.h b/clang/lib/AST/ByteCode/Pointer.h
index fd33ee9955f55..988237d39fff4 100644
--- a/clang/lib/AST/ByteCode/Pointer.h
+++ b/clang/lib/AST/ByteCode/Pointer.h
@@ -417,7 +417,7 @@ class Pointer {
     return false;
   }
   bool inUnion() const {
-    if (isBlockPointer())
+    if (isBlockPointer() && asBlockPointer().Base >= sizeof(InlineDescriptor))
       return getInlineDesc()->InUnion;
     return false;
   };
@@ -727,6 +727,8 @@ class Pointer {
   /// Prints the pointer.
   void print(llvm::raw_ostream &OS) const;
 
+  size_t computeOffsetForComparison() const;
+
 private:
   friend class Block;
   friend class DeadBlock;
diff --git a/clang/test/AST/ByteCode/unions.cpp 
b/clang/test/AST/ByteCode/unions.cpp
index 66b8389606b85..3911a2b2f7dde 100644
--- a/clang/test/AST/ByteCode/unions.cpp
+++ b/clang/test/AST/ByteCode/unions.cpp
@@ -600,3 +600,41 @@ namespace MoveOrAssignOp {
   static_assert(foo());
 }
 #endif
+
+namespace AddressComparison {
+  union {
+    int a;
+    int c;
+  } U;
+  static_assert(__builtin_addressof(U.a) == (void*)__builtin_addressof(U.c));
+  static_assert(&U.a == &U.c);
+
+
+  struct {
+    union {
+      struct {
+        int a;
+        int b;
+      } a;
+      struct {
+        int b;
+        int a;
+      }b;
+    } u;
+    int b;
+  } S;
+
+  static_assert(&S.u.a.a == &S.u.b.b);
+  static_assert(&S.u.a.b != &S.u.b.b);
+  static_assert(&S.u.a.b == &S.u.b.b); // both-error {{failed}}
+
+
+  union {
+    int a[2];
+    int b[2];
+  } U2;
+
+  static_assert(&U2.a[0] == &U2.b[0]);
+  static_assert(&U2.a[0] != &U2.b[1]);
+  static_assert(&U2.a[0] == &U2.b[1]); // both-error {{failed}}
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/133852
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to