danix800 updated this revision to Diff 549059.
danix800 added a comment.

`CXXRecordDecl::friend_iterator` is actually a reversed iterator. Deduplication 
with
different iterator direction produces different result. ASTImporter uses 
forward iterator
so structural equivalence checking should be in consistent with that.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D157114/new/

https://reviews.llvm.org/D157114

Files:
  clang/lib/AST/ASTImporter.cpp
  clang/lib/AST/ASTStructuralEquivalence.cpp
  clang/unittests/AST/ASTImporterTest.cpp
  clang/unittests/AST/StructuralEquivalenceTest.cpp

Index: clang/unittests/AST/StructuralEquivalenceTest.cpp
===================================================================
--- clang/unittests/AST/StructuralEquivalenceTest.cpp
+++ clang/unittests/AST/StructuralEquivalenceTest.cpp
@@ -833,7 +833,18 @@
   auto t = makeNamedDecls("struct foo { friend class X; };",
                           "struct foo { friend class X; friend class X; };",
                           Lang_CXX11);
-  EXPECT_FALSE(testStructuralMatch(t));
+  EXPECT_TRUE(testStructuralMatch(t));
+}
+
+TEST_F(StructuralEquivalenceRecordTest,
+       SameFriendMultipleTimesForwardIteratorDirection) {
+  // Deduplication with forward iterator produces the same 'foo', but reverse
+  // iterator doesn't.
+  auto t = makeNamedDecls(
+      "struct foo { friend class X; friend class Y;};",
+      "struct foo { friend class X; friend class Y; friend class X; };",
+      Lang_CXX11);
+  EXPECT_TRUE(testStructuralMatch(t));
 }
 
 TEST_F(StructuralEquivalenceRecordTest, SameFriendsDifferentOrder) {
Index: clang/unittests/AST/ASTImporterTest.cpp
===================================================================
--- clang/unittests/AST/ASTImporterTest.cpp
+++ clang/unittests/AST/ASTImporterTest.cpp
@@ -10,6 +10,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "clang/AST/ASTStructuralEquivalence.h"
 #include "clang/AST/RecordLayout.h"
 #include "clang/ASTMatchers/ASTMatchers.h"
 #include "llvm/ADT/StringMap.h"
@@ -4385,6 +4386,44 @@
   EXPECT_EQ(ToFriend2, ToImportedFriend2);
 }
 
+TEST_P(ASTImporterOptionSpecificTestBase, ImportRepeatedFriendDeclIntoEmptyDC) {
+  Decl *From, *To;
+  std::tie(From, To) = getImportedDecl(R"(
+      template <class T>
+      class A {
+      public:
+        template <class U> friend A<U> &f();
+        template <class U> friend A<U> &f();
+      };
+  )",
+                                       Lang_CXX17, "", Lang_CXX17, "A");
+
+  auto *FromFriend1 = FirstDeclMatcher<FriendDecl>().match(From, friendDecl());
+  auto *FromFriend2 = LastDeclMatcher<FriendDecl>().match(From, friendDecl());
+  auto *ToFriend1 = FirstDeclMatcher<FriendDecl>().match(To, friendDecl());
+  auto *ToFriend2 = LastDeclMatcher<FriendDecl>().match(To, friendDecl());
+
+  // Two different FriendDecls in From context.
+  EXPECT_TRUE(FromFriend1 != FromFriend2);
+  // Only one is imported into empty DC.
+  EXPECT_TRUE(ToFriend1 == ToFriend2);
+
+  // 'A' is imported into empty DC, keeping structure equivalence.
+  llvm::DenseSet<std::pair<Decl *, Decl *>> NonEquivalentDecls01;
+  llvm::DenseSet<std::pair<Decl *, Decl *>> NonEquivalentDecls10;
+  StructuralEquivalenceContext Ctx01(
+      From->getASTContext(), To->getASTContext(), NonEquivalentDecls01,
+      StructuralEquivalenceKind::Default, false, false);
+  StructuralEquivalenceContext Ctx10(
+      To->getASTContext(), From->getASTContext(), NonEquivalentDecls10,
+      StructuralEquivalenceKind::Default, false, false);
+
+  bool Eq01 = Ctx01.IsEquivalent(From, To);
+  bool Eq10 = Ctx10.IsEquivalent(To, From);
+  EXPECT_EQ(Eq01, Eq10);
+  EXPECT_TRUE(Eq01);
+}
+
 TEST_P(ASTImporterOptionSpecificTestBase, FriendFunInClassTemplate) {
   auto *Code = R"(
   template <class T>
Index: clang/lib/AST/ASTStructuralEquivalence.cpp
===================================================================
--- clang/lib/AST/ASTStructuralEquivalence.cpp
+++ clang/lib/AST/ASTStructuralEquivalence.cpp
@@ -1464,6 +1464,165 @@
   return IsStructurallyEquivalent(GetName(D1), GetName(D2));
 }
 
+static bool
+IsCXXRecordBaseStructurallyEquivalent(StructuralEquivalenceContext &Context,
+                                      RecordDecl *D1, RecordDecl *D2) {
+  auto *D1CXX = cast<CXXRecordDecl>(D1);
+  auto *D2CXX = cast<CXXRecordDecl>(D2);
+
+  if (D1CXX->getNumBases() != D2CXX->getNumBases()) {
+    if (Context.Complain) {
+      Context.Diag2(D2->getLocation(), Context.getApplicableDiagnostic(
+                                           diag::err_odr_tag_type_inconsistent))
+          << Context.ToCtx.getTypeDeclType(D2);
+      Context.Diag2(D2->getLocation(), diag::note_odr_number_of_bases)
+          << D2CXX->getNumBases();
+      Context.Diag1(D1->getLocation(), diag::note_odr_number_of_bases)
+          << D1CXX->getNumBases();
+    }
+    return false;
+  }
+
+  for (CXXRecordDecl::base_class_iterator Base1 = D1CXX->bases_begin(),
+                                          BaseEnd1 = D1CXX->bases_end(),
+                                          Base2 = D2CXX->bases_begin();
+       Base1 != BaseEnd1; ++Base1, ++Base2) {
+    if (!IsStructurallyEquivalent(Context, Base1->getType(),
+                                  Base2->getType())) {
+      if (Context.Complain) {
+        Context.Diag2(D2->getLocation(),
+                      Context.getApplicableDiagnostic(
+                          diag::err_odr_tag_type_inconsistent))
+            << Context.ToCtx.getTypeDeclType(D2);
+        Context.Diag2(Base2->getBeginLoc(), diag::note_odr_base)
+            << Base2->getType() << Base2->getSourceRange();
+        Context.Diag1(Base1->getBeginLoc(), diag::note_odr_base)
+            << Base1->getType() << Base1->getSourceRange();
+      }
+      return false;
+    }
+
+    // Check virtual vs. non-virtual inheritance mismatch.
+    if (Base1->isVirtual() != Base2->isVirtual()) {
+      if (Context.Complain) {
+        Context.Diag2(D2->getLocation(),
+                      Context.getApplicableDiagnostic(
+                          diag::err_odr_tag_type_inconsistent))
+            << Context.ToCtx.getTypeDeclType(D2);
+        Context.Diag2(Base2->getBeginLoc(), diag::note_odr_virtual_base)
+            << Base2->isVirtual() << Base2->getSourceRange();
+        Context.Diag1(Base1->getBeginLoc(), diag::note_odr_base)
+            << Base1->isVirtual() << Base1->getSourceRange();
+      }
+      return false;
+    }
+  }
+
+  return true;
+}
+
+using NonEquivalentDeclSet = llvm::DenseSet<std::pair<Decl *, Decl *>>;
+
+static bool IsEquivalentFriend(FriendDecl *F1, FriendDecl *F2,
+                               NonEquivalentDeclSet &NonEquivalentDecls) {
+  StructuralEquivalenceContext Ctx(
+      F1->getASTContext(), F2->getASTContext(), NonEquivalentDecls,
+      StructuralEquivalenceKind::Minimal, false, false);
+  if (F1->getFriendDecl() && F2->getFriendDecl())
+    return Ctx.IsEquivalent(F1->getFriendDecl(), F2->getFriendDecl());
+  if (F1->getFriendType() && F2->getFriendType())
+    return Ctx.IsEquivalent(F1->getFriendType()->getType(),
+                            F2->getFriendType()->getType());
+
+  return false;
+}
+
+static bool
+IsEquivalentToAnyExistingFriends(FriendDecl *F, ArrayRef<FriendDecl *> Friends,
+                                 NonEquivalentDeclSet &NonEquivalentDecls) {
+  for (FriendDecl *Other : Friends)
+    if (IsEquivalentFriend(F, Other, NonEquivalentDecls))
+      return true;
+
+  return false;
+}
+
+static SmallVector<FriendDecl *, 2> getDeduplicatedFriends(CXXRecordDecl *RD) {
+  NonEquivalentDeclSet NonEquivalentDecls;
+  SmallVector<FriendDecl *, 2> EquivalentFriends;
+
+  // CXXRecordDecl::friend_iterator is a reversed iterator.
+  SmallVector<FriendDecl *, 2> Friends;
+  for (auto *Friend : RD->friends())
+    Friends.push_back(Friend);
+
+  auto Friend = Friends.rbegin(), FriendEnd = Friends.rend();
+  if (Friend == FriendEnd)
+    return EquivalentFriends;
+
+  EquivalentFriends.push_back(*Friend);
+  Friend = ++Friend;
+  while (Friend != FriendEnd) {
+    if (!IsEquivalentToAnyExistingFriends(*Friend, EquivalentFriends,
+                                          NonEquivalentDecls))
+      EquivalentFriends.push_back(*Friend);
+    Friend = ++Friend;
+  }
+
+  return EquivalentFriends;
+}
+
+static bool
+IsFriendInCXXRecordStructurallyEquivalent(StructuralEquivalenceContext &Context,
+                                          RecordDecl *D1, RecordDecl *D2) {
+  auto *D1CXX = cast<CXXRecordDecl>(D1);
+  auto *D2CXX = cast<CXXRecordDecl>(D2);
+
+  const auto &Friends1 = getDeduplicatedFriends(D1CXX);
+  const auto &Friends2 = getDeduplicatedFriends(D2CXX);
+
+  auto Friend2 = Friends2.begin(), Friend2End = Friends2.end();
+  for (auto Friend1 = Friends1.begin(), Friend1End = Friends1.end();
+       Friend1 != Friend1End; ++Friend1, ++Friend2) {
+    if (Friend2 == Friend2End) {
+      if (Context.Complain) {
+        Context.Diag2(D2->getLocation(),
+                      Context.getApplicableDiagnostic(
+                          diag::err_odr_tag_type_inconsistent))
+            << Context.ToCtx.getTypeDeclType(D2CXX);
+        Context.Diag1((*Friend1)->getFriendLoc(), diag::note_odr_friend);
+        Context.Diag2(D2->getLocation(), diag::note_odr_missing_friend);
+      }
+      return false;
+    }
+
+    if (!IsStructurallyEquivalent(Context, *Friend1, *Friend2)) {
+      if (Context.Complain) {
+        Context.Diag2(D2->getLocation(),
+                      Context.getApplicableDiagnostic(
+                          diag::err_odr_tag_type_inconsistent))
+            << Context.ToCtx.getTypeDeclType(D2CXX);
+        Context.Diag1((*Friend1)->getFriendLoc(), diag::note_odr_friend);
+        Context.Diag2((*Friend2)->getFriendLoc(), diag::note_odr_friend);
+      }
+      return false;
+    }
+  }
+
+  if (Friend2 != Friend2End) {
+    if (Context.Complain) {
+      Context.Diag2(D2->getLocation(), Context.getApplicableDiagnostic(
+                                           diag::err_odr_tag_type_inconsistent))
+          << Context.ToCtx.getTypeDeclType(D2);
+      Context.Diag2((*Friend2)->getFriendLoc(), diag::note_odr_friend);
+      Context.Diag1(D1->getLocation(), diag::note_odr_missing_friend);
+    }
+    return false;
+  }
+
+  return true;
+}
+
 /// Determine structural equivalence of two records.
 static bool IsStructurallyEquivalent(StructuralEquivalenceContext &Context,
                                      RecordDecl *D1, RecordDecl *D2) {
@@ -1562,98 +1721,11 @@
           return false;
       }
 
-      if (D1CXX->getNumBases() != D2CXX->getNumBases()) {
-        if (Context.Complain) {
-          Context.Diag2(D2->getLocation(),
-                        Context.getApplicableDiagnostic(
-                            diag::err_odr_tag_type_inconsistent))
-              << Context.ToCtx.getTypeDeclType(D2);
-          Context.Diag2(D2->getLocation(), diag::note_odr_number_of_bases)
-              << D2CXX->getNumBases();
-          Context.Diag1(D1->getLocation(), diag::note_odr_number_of_bases)
-              << D1CXX->getNumBases();
-        }
+      if (!IsCXXRecordBaseStructurallyEquivalent(Context, D1, D2))
         return false;
-      }
-
-      // Check the base classes.
-      for (CXXRecordDecl::base_class_iterator Base1 = D1CXX->bases_begin(),
-                                              BaseEnd1 = D1CXX->bases_end(),
-                                              Base2 = D2CXX->bases_begin();
-           Base1 != BaseEnd1; ++Base1, ++Base2) {
-        if (!IsStructurallyEquivalent(Context, Base1->getType(),
-                                      Base2->getType())) {
-          if (Context.Complain) {
-            Context.Diag2(D2->getLocation(),
-                          Context.getApplicableDiagnostic(
-                              diag::err_odr_tag_type_inconsistent))
-                << Context.ToCtx.getTypeDeclType(D2);
-            Context.Diag2(Base2->getBeginLoc(), diag::note_odr_base)
-                << Base2->getType() << Base2->getSourceRange();
-            Context.Diag1(Base1->getBeginLoc(), diag::note_odr_base)
-                << Base1->getType() << Base1->getSourceRange();
-          }
-          return false;
-        }
-
-        // Check virtual vs. non-virtual inheritance mismatch.
-        if (Base1->isVirtual() != Base2->isVirtual()) {
-          if (Context.Complain) {
-            Context.Diag2(D2->getLocation(),
-                          Context.getApplicableDiagnostic(
-                              diag::err_odr_tag_type_inconsistent))
-                << Context.ToCtx.getTypeDeclType(D2);
-            Context.Diag2(Base2->getBeginLoc(), diag::note_odr_virtual_base)
-                << Base2->isVirtual() << Base2->getSourceRange();
-            Context.Diag1(Base1->getBeginLoc(), diag::note_odr_base)
-                << Base1->isVirtual() << Base1->getSourceRange();
-          }
-          return false;
-        }
-      }
 
-      // Check the friends for consistency.
-      CXXRecordDecl::friend_iterator Friend2 = D2CXX->friend_begin(),
-                                     Friend2End = D2CXX->friend_end();
-      for (CXXRecordDecl::friend_iterator Friend1 = D1CXX->friend_begin(),
-                                          Friend1End = D1CXX->friend_end();
-           Friend1 != Friend1End; ++Friend1, ++Friend2) {
-        if (Friend2 == Friend2End) {
-          if (Context.Complain) {
-            Context.Diag2(D2->getLocation(),
-                          Context.getApplicableDiagnostic(
-                              diag::err_odr_tag_type_inconsistent))
-                << Context.ToCtx.getTypeDeclType(D2CXX);
-            Context.Diag1((*Friend1)->getFriendLoc(), diag::note_odr_friend);
-            Context.Diag2(D2->getLocation(), diag::note_odr_missing_friend);
-          }
-          return false;
-        }
-
-        if (!IsStructurallyEquivalent(Context, *Friend1, *Friend2)) {
-          if (Context.Complain) {
-            Context.Diag2(D2->getLocation(),
-                          Context.getApplicableDiagnostic(
-                              diag::err_odr_tag_type_inconsistent))
-                << Context.ToCtx.getTypeDeclType(D2CXX);
-            Context.Diag1((*Friend1)->getFriendLoc(), diag::note_odr_friend);
-            Context.Diag2((*Friend2)->getFriendLoc(), diag::note_odr_friend);
-          }
-          return false;
-        }
-      }
-
-      if (Friend2 != Friend2End) {
-        if (Context.Complain) {
-          Context.Diag2(D2->getLocation(),
-                        Context.getApplicableDiagnostic(
-                            diag::err_odr_tag_type_inconsistent))
-              << Context.ToCtx.getTypeDeclType(D2);
-          Context.Diag2((*Friend2)->getFriendLoc(), diag::note_odr_friend);
-          Context.Diag1(D1->getLocation(), diag::note_odr_missing_friend);
-        }
+      if (!IsFriendInCXXRecordStructurallyEquivalent(Context, D1, D2))
         return false;
-      }
     } else if (D1CXX->getNumBases() > 0) {
       if (Context.Complain) {
         Context.Diag2(D2->getLocation(),
@@ -2327,8 +2399,8 @@
     Decl *D1 = P.first;
     Decl *D2 = P.second;
 
-    bool Equivalent =
-        CheckCommonEquivalence(D1, D2) && CheckKindSpecificEquivalence(D1, D2);
+    bool Equivalent = (D1 == D2) || (CheckCommonEquivalence(D1, D2) &&
+                                     CheckKindSpecificEquivalence(D1, D2));
 
     if (!Equivalent) {
       // Note that these two declarations are not equivalent (and we already
Index: clang/lib/AST/ASTImporter.cpp
===================================================================
--- clang/lib/AST/ASTImporter.cpp
+++ clang/lib/AST/ASTImporter.cpp
@@ -6447,7 +6447,8 @@
 
   ToFunc->setAccess(D->getAccess());
   ToFunc->setLexicalDeclContext(LexicalDC);
-  LexicalDC->addDeclInternal(ToFunc);
+  if (D->getFriendObjectKind() == Decl::FOK_None)
+    LexicalDC->addDeclInternal(ToFunc);
 
   ASTImporterLookupTable *LT = Importer.SharedState->getLookupTable();
   if (LT && !OldParamDC.empty()) {
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to