tbaeder created this revision.
tbaeder added reviewers: aaron.ballman, erichkeane, tahonermann, shafik.
Herald added a project: All.
tbaeder requested review of this revision.
Herald added a project: clang.
Herald added a subscriber: cfe-commits.

  We can implement these similarly to DerivedToBase casts. We just have to
  walk the class hierarchy, sum the base offsets and subtract it from the
  current base offset of the pointer.

As a side-effect, this also changes the `BaseToDerived` casts to only emit 
//one// opcode instead of one per base cast.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D149133

Files:
  clang/lib/AST/Interp/ByteCodeExprGen.cpp
  clang/lib/AST/Interp/ByteCodeExprGen.h
  clang/lib/AST/Interp/Interp.cpp
  clang/lib/AST/Interp/Interp.h
  clang/lib/AST/Interp/Opcodes.td
  clang/lib/AST/Interp/Pointer.h
  clang/test/AST/Interp/records.cpp

Index: clang/test/AST/Interp/records.cpp
===================================================================
--- clang/test/AST/Interp/records.cpp
+++ clang/test/AST/Interp/records.cpp
@@ -586,6 +586,58 @@
   static_assert(test() == 1);
 }
 
+namespace BaseToDerived {
+namespace A {
+  struct A {};
+  struct B : A { int n; };
+  struct C : B {};
+  C c = {};
+  constexpr C *pb = (C*)((A*)&c + 1); // expected-error {{must be initialized by a constant expression}} \
+                                      // expected-note {{cannot access derived class of pointer past the end of object}} \
+                                      // ref-error {{must be initialized by a constant expression}} \
+                                      // ref-note {{cannot access derived class of pointer past the end of object}}
+}
+namespace B {
+  struct A {};
+  struct Z {};
+  struct B : Z, A {
+    int n;
+   constexpr B() : n(10) {}
+  };
+  struct C : B {
+   constexpr C() : B() {}
+  };
+
+  constexpr C c = {};
+  constexpr const A *pa = &c;
+  constexpr const C *cp = (C*)pa;
+  constexpr const B *cb = (B*)cp;
+
+  static_assert(cb->n == 10);
+  static_assert(cp->n == 10);
+}
+
+namespace C {
+  struct Base { int *a; };
+  struct Base2 : Base { int f[12]; };
+
+  struct Middle1 { int b[3]; };
+  struct Middle2 : Base2 { char c; };
+  struct Middle3 : Middle2 { char g[3]; };
+  struct Middle4 { int f[3]; };
+  struct Middle5 : Middle4, Middle3 { char g2[3]; };
+
+  struct NotQuiteDerived : Middle1, Middle5 { bool d; };
+  struct Derived : NotQuiteDerived { int e; };
+
+  constexpr NotQuiteDerived NQD1 = {};
+
+  constexpr Middle5 *M4 = (Middle5*)((Base2*)&NQD1);
+  static_assert(M4->a == nullptr);
+  static_assert(M4->g2[0] == 0);
+}
+}
+
 
 namespace VirtualDtors {
   class A {
Index: clang/lib/AST/Interp/Pointer.h
===================================================================
--- clang/lib/AST/Interp/Pointer.h
+++ clang/lib/AST/Interp/Pointer.h
@@ -100,6 +100,14 @@
     return Pointer(Pointee, Field, Field);
   }
 
+  /// Subtract the given offset from the current Base and Offset
+  /// of the pointer.
+  Pointer atFieldSub(unsigned Off) const {
+    assert(Offset >= Off);
+    unsigned O = Offset - Off;
+    return Pointer(Pointee, O, O);
+  }
+
   /// Restricts the scope of an array element pointer.
   Pointer narrow() const {
     // Null pointers cannot be narrowed.
Index: clang/lib/AST/Interp/Opcodes.td
===================================================================
--- clang/lib/AST/Interp/Opcodes.td
+++ clang/lib/AST/Interp/Opcodes.td
@@ -291,6 +291,10 @@
   let Args = [ArgUint32];
 }
 
+def GetPtrDerivedPop : Opcode {
+  let Args = [ArgUint32];
+}
+
 // [Pointer] -> [Pointer]
 def GetPtrVirtBase : Opcode {
   // RecordDecl of base class.
Index: clang/lib/AST/Interp/Interp.h
===================================================================
--- clang/lib/AST/Interp/Interp.h
+++ clang/lib/AST/Interp/Interp.h
@@ -67,8 +67,9 @@
 bool CheckRange(InterpState &S, CodePtr OpPC, const Pointer &Ptr,
                 CheckSubobjectKind CSK);
 
-/// Checks if accessing a base of the given pointer is valid.
-bool CheckBase(InterpState &S, CodePtr OpPC, const Pointer &Ptr);
+/// Checks if accessing a base or derived record of the given pointer is valid.
+bool CheckBaseDerived(InterpState &S, CodePtr OpPC, const Pointer &Ptr,
+                      CheckSubobjectKind CSK);
 
 /// Checks if a pointer points to const storage.
 bool CheckConst(InterpState &S, CodePtr OpPC, const Pointer &Ptr);
@@ -1078,11 +1079,25 @@
   return true;
 }
 
+inline bool GetPtrDerivedPop(InterpState &S, CodePtr OpPC, uint32_t Off) {
+  const Pointer &Ptr = S.Stk.pop<Pointer>();
+
+  if (!CheckNull(S, OpPC, Ptr, CSK_Derived))
+    return false;
+  if (!CheckBaseDerived(S, OpPC, Ptr, CSK_Derived))
+    return false;
+
+  llvm::errs() << "Derived Offset: " << Off << ". Ptr: " << Ptr << "\n";
+
+  S.Stk.push<Pointer>(Ptr.atFieldSub(Off));
+  return true;
+}
+
 inline bool GetPtrBase(InterpState &S, CodePtr OpPC, uint32_t Off) {
   const Pointer &Ptr = S.Stk.peek<Pointer>();
   if (!CheckNull(S, OpPC, Ptr, CSK_Base))
     return false;
-  if (!CheckBase(S, OpPC, Ptr))
+  if (!CheckBaseDerived(S, OpPC, Ptr, CSK_Base))
     return false;
   S.Stk.push<Pointer>(Ptr.atField(Off));
   return true;
@@ -1092,7 +1107,7 @@
   const Pointer &Ptr = S.Stk.pop<Pointer>();
   if (!CheckNull(S, OpPC, Ptr, CSK_Base))
     return false;
-  if (!CheckBase(S, OpPC, Ptr))
+  if (!CheckBaseDerived(S, OpPC, Ptr, CSK_Base))
     return false;
   S.Stk.push<Pointer>(Ptr.atField(Off));
   return true;
Index: clang/lib/AST/Interp/Interp.cpp
===================================================================
--- clang/lib/AST/Interp/Interp.cpp
+++ clang/lib/AST/Interp/Interp.cpp
@@ -211,12 +211,13 @@
   return false;
 }
 
-bool CheckBase(InterpState &S, CodePtr OpPC, const Pointer &Ptr) {
+bool CheckBaseDerived(InterpState &S, CodePtr OpPC, const Pointer &Ptr,
+                      CheckSubobjectKind CSK) {
   if (!Ptr.isOnePastEnd())
     return true;
 
   const SourceInfo &Loc = S.Current->getSource(OpPC);
-  S.FFDiag(Loc, diag::note_constexpr_past_end_subobject) << CSK_Base;
+  S.FFDiag(Loc, diag::note_constexpr_past_end_subobject) << CSK;
   return false;
 }
 
Index: clang/lib/AST/Interp/ByteCodeExprGen.h
===================================================================
--- clang/lib/AST/Interp/ByteCodeExprGen.h
+++ clang/lib/AST/Interp/ByteCodeExprGen.h
@@ -269,8 +269,8 @@
   }
 
   bool emitRecordDestruction(const Descriptor *Desc);
-  bool emitDerivedToBaseCasts(const RecordType *DerivedType,
-                              const RecordType *BaseType, const Expr *E);
+  unsigned collectBaseOffset(const RecordType *BaseType,
+                             const RecordType *DerivedType);
   bool emitBuiltinBitCast(const CastExpr *E);
 
 protected:
Index: clang/lib/AST/Interp/ByteCodeExprGen.cpp
===================================================================
--- clang/lib/AST/Interp/ByteCodeExprGen.cpp
+++ clang/lib/AST/Interp/ByteCodeExprGen.cpp
@@ -138,8 +138,20 @@
     if (!this->visit(SubExpr))
       return false;
 
-    return this->emitDerivedToBaseCasts(getRecordTy(SubExpr->getType()),
-                                        getRecordTy(CE->getType()), CE);
+    unsigned DerivedOffset = collectBaseOffset(getRecordTy(CE->getType()),
+                                               getRecordTy(SubExpr->getType()));
+
+    return this->emitGetPtrBasePop(DerivedOffset, CE);
+  }
+
+  case CK_BaseToDerived: {
+    if (!this->visit(SubExpr))
+      return false;
+
+    unsigned DerivedOffset = collectBaseOffset(getRecordTy(SubExpr->getType()),
+                                               getRecordTy(CE->getType()));
+
+    return this->emitGetPtrDerivedPop(DerivedOffset, CE);
   }
 
   case CK_FloatingCast: {
@@ -2124,13 +2136,15 @@
 }
 
 template <class Emitter>
-bool ByteCodeExprGen<Emitter>::emitDerivedToBaseCasts(
-    const RecordType *DerivedType, const RecordType *BaseType, const Expr *E) {
-  // Pointer of derived type is already on the stack.
+unsigned
+ByteCodeExprGen<Emitter>::collectBaseOffset(const RecordType *BaseType,
+                                            const RecordType *DerivedType) {
   const auto *FinalDecl = cast<CXXRecordDecl>(BaseType->getDecl());
   const RecordDecl *CurDecl = DerivedType->getDecl();
   const Record *CurRecord = getRecord(CurDecl);
   assert(CurDecl && FinalDecl);
+
+  unsigned OffsetSum = 0;
   for (;;) {
     assert(CurRecord->getNumBases() > 0);
     // One level up
@@ -2138,21 +2152,18 @@
       const auto *BaseDecl = cast<CXXRecordDecl>(B.Decl);
 
       if (BaseDecl == FinalDecl || BaseDecl->isDerivedFrom(FinalDecl)) {
-        // This decl will lead us to the final decl, so emit a base cast.
-        if (!this->emitGetPtrBasePop(B.Offset, E))
-          return false;
-
+        OffsetSum += B.Offset;
         CurRecord = B.R;
         CurDecl = BaseDecl;
         break;
       }
     }
     if (CurDecl == FinalDecl)
-      return true;
+      break;
   }
 
-  llvm_unreachable("Couldn't find the base class?");
-  return false;
+  assert(OffsetSum > 0);
+  return OffsetSum;
 }
 
 /// When calling this, we have a pointer of the local-to-destroy
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to