teemperor updated this revision to Diff 196955.
teemperor retitled this revision from "[ASTImporter] Add an ImportInternal 
method to allow customizing Import behavior." to "[ASTImporter] Add an 
ImportImpl method to allow customizing Import behavior.".
teemperor edited the summary of this revision.
teemperor added a comment.

- Added `RegisterImportedDecl` call as suggested in D59537 
<https://reviews.llvm.org/D59537>
- Added an assert that checks that we register any decl that an overwritten 
ImportInternal would create.
- Fixed the tests that they no longer trigger that new assert.


CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D59485/new/

https://reviews.llvm.org/D59485

Files:
  clang/include/clang/AST/ASTImporter.h
  clang/lib/AST/ASTImporter.cpp
  clang/unittests/AST/ASTImporterTest.cpp

Index: clang/unittests/AST/ASTImporterTest.cpp
===================================================================
--- clang/unittests/AST/ASTImporterTest.cpp
+++ clang/unittests/AST/ASTImporterTest.cpp
@@ -313,6 +313,14 @@
   const char *const InputFileName = "input.cc";
   const char *const OutputFileName = "output.cc";
 
+public:
+  /// Allocates an ASTImporter (or one of its subclasses).
+  typedef std::function<ASTImporter *(ASTContext &, FileManager &, ASTContext &,
+                                      FileManager &, bool,
+                                      ASTImporterLookupTable *)>
+      ImporterConstructor;
+
+private:
   // Buffer for the To context, must live in the test scope.
   std::string ToCode;
 
@@ -325,22 +333,33 @@
     std::unique_ptr<ASTUnit> Unit;
     TranslationUnitDecl *TUDecl = nullptr;
     std::unique_ptr<ASTImporter> Importer;
-    TU(StringRef Code, StringRef FileName, ArgVector Args)
+    // The lambda that constructs the ASTImporter we use in this test.
+    ImporterConstructor Creator;
+    TU(StringRef Code, StringRef FileName, ArgVector Args,
+       ImporterConstructor C = ImporterConstructor())
         : Code(Code), FileName(FileName),
           Unit(tooling::buildASTFromCodeWithArgs(this->Code, Args,
                                                  this->FileName)),
-          TUDecl(Unit->getASTContext().getTranslationUnitDecl()) {
+          TUDecl(Unit->getASTContext().getTranslationUnitDecl()), Creator(C) {
       Unit->enableSourceFileDiagnostics();
+
+      // If the test doesn't need a specific ASTImporter, we just create a
+      // normal ASTImporter with it.
+      if (!Creator)
+        Creator = [](ASTContext &ToContext, FileManager &ToFileManager,
+                     ASTContext &FromContext, FileManager &FromFileManager,
+                     bool MinimalImport, ASTImporterLookupTable *LookupTable) {
+          return new ASTImporter(ToContext, ToFileManager, FromContext,
+                                 FromFileManager, MinimalImport, LookupTable);
+        };
     }
 
     void lazyInitImporter(ASTImporterLookupTable &LookupTable, ASTUnit *ToAST) {
       assert(ToAST);
-      if (!Importer) {
-        Importer.reset(
-            new ASTImporter(ToAST->getASTContext(), ToAST->getFileManager(),
-                            Unit->getASTContext(), Unit->getFileManager(),
-                            false, &LookupTable));
-      }
+      if (!Importer)
+        Importer.reset(Creator(ToAST->getASTContext(), ToAST->getFileManager(),
+                               Unit->getASTContext(), Unit->getFileManager(),
+                               false, &LookupTable));
       assert(&ToAST->getASTContext() == &Importer->getToContext());
       createVirtualFileIfNeeded(ToAST, FileName, Code);
     }
@@ -420,11 +439,12 @@
   // Must not be called more than once within the same test.
   std::tuple<Decl *, Decl *>
   getImportedDecl(StringRef FromSrcCode, Language FromLang, StringRef ToSrcCode,
-                  Language ToLang, StringRef Identifier = DeclToImportID) {
+                  Language ToLang, StringRef Identifier = DeclToImportID,
+                  ImporterConstructor Creator = ImporterConstructor()) {
     ArgVector FromArgs = getArgVectorForLanguage(FromLang),
               ToArgs = getArgVectorForLanguage(ToLang);
 
-    FromTUs.emplace_back(FromSrcCode, InputFileName, FromArgs);
+    FromTUs.emplace_back(FromSrcCode, InputFileName, FromArgs, Creator);
     TU &FromTU = FromTUs.back();
 
     assert(!ToAST);
@@ -562,6 +582,75 @@
   EXPECT_THAT(RedeclsD1, ::testing::ContainerEq(RedeclsD2));
 }
 
+struct CustomImporter : ASTImporterOptionSpecificTestBase {};
+
+namespace {
+struct RedirectingImporter : public ASTImporter {
+  using ASTImporter::ASTImporter;
+  // ImporterConstructor that constructs this class.
+  static ASTImporterOptionSpecificTestBase::ImporterConstructor Constructor;
+
+protected:
+  llvm::Expected<Decl *> ImportImpl(Decl *FromD) override {
+    auto *ND = dyn_cast<NamedDecl>(FromD);
+    if (!ND || ND->getName() != "shouldNotBeImported")
+      return ASTImporter::ImportImpl(FromD);
+    for (Decl *D : getToContext().getTranslationUnitDecl()->decls()) {
+      if (auto *ND = dyn_cast<NamedDecl>(D))
+        if (ND->getName() == "realDecl") {
+          RegisterImportedDecl(FromD, ND);
+          return ND;
+        }
+    }
+    return ASTImporter::ImportImpl(FromD);
+  }
+};
+
+ASTImporterOptionSpecificTestBase::ImporterConstructor
+    RedirectingImporter::Constructor =
+        [](ASTContext &ToContext, FileManager &ToFileManager,
+           ASTContext &FromContext, FileManager &FromFileManager,
+           bool MinimalImport, ASTImporterLookupTable *LookupTable) {
+          return new RedirectingImporter(ToContext, ToFileManager, FromContext,
+                                         FromFileManager, MinimalImport,
+                                         LookupTable);
+        };
+} // namespace
+
+// Test that an ASTImporter subclass can intercept an import call.
+TEST_P(CustomImporter, InterceptImport) {
+  Decl *From, *To;
+  std::tie(From, To) = getImportedDecl(
+      "class shouldNotBeImported {};", Lang_CXX, "class realDecl {};", Lang_CXX,
+      "shouldNotBeImported", RedirectingImporter::Constructor);
+  auto *Imported = cast<CXXRecordDecl>(To);
+  EXPECT_EQ(Imported->getQualifiedNameAsString(), "realDecl");
+
+  // Make sure our importer prevented the importing of the decl.
+  auto *ToTU = Imported->getTranslationUnitDecl();
+  auto Pattern = functionDecl(hasName("shouldNotBeImported"));
+  unsigned count =
+      DeclCounterWithPredicate<CXXRecordDecl>().match(ToTU, Pattern);
+  EXPECT_EQ(0U, count);
+}
+
+// Test that when we indirectly import a declaration the custom ASTImporter
+// is still intercepting the import.
+TEST_P(CustomImporter, InterceptIndirectImport) {
+  Decl *From, *To;
+  std::tie(From, To) = getImportedDecl("class shouldNotBeImported {};"
+                                       "class F { shouldNotBeImported f; };",
+                                       Lang_CXX, "class realDecl {};", Lang_CXX,
+                                       "F", RedirectingImporter::Constructor);
+
+  // Make sure our ASTImporter prevented the importing of the decl.
+  auto *ToTU = To->getTranslationUnitDecl();
+  auto Pattern = functionDecl(hasName("shouldNotBeImported"));
+  unsigned count =
+      DeclCounterWithPredicate<CXXRecordDecl>().match(ToTU, Pattern);
+  EXPECT_EQ(0U, count);
+}
+
 TEST_P(ImportExpr, ImportStringLiteral) {
   MatchVerifier<Decl> Verifier;
   testImport(
@@ -5549,6 +5638,9 @@
 INSTANTIATE_TEST_CASE_P(ParameterizedTests, ASTImporterOptionSpecificTestBase,
                         DefaultTestValuesForRunOptions, );
 
+INSTANTIATE_TEST_CASE_P(ParameterizedTests, CustomImporter,
+                        DefaultTestValuesForRunOptions, );
+
 INSTANTIATE_TEST_CASE_P(ParameterizedTests, ImportFunctions,
                         DefaultTestValuesForRunOptions, );
 
Index: clang/lib/AST/ASTImporter.cpp
===================================================================
--- clang/lib/AST/ASTImporter.cpp
+++ clang/lib/AST/ASTImporter.cpp
@@ -255,8 +255,7 @@
         return true; // Already imported.
       ToD = CreateFun(std::forward<Args>(args)...);
       // Keep track of imported Decls.
-      Importer.MapImported(FromD, ToD);
-      Importer.AddToLookupTable(ToD);
+      Importer.RegisterImportedDecl(FromD, ToD);
       InitializeImportedDecl(FromD, ToD);
       return false; // A new Decl is created.
     }
@@ -7656,6 +7655,17 @@
       LookupTable->add(ToND);
 }
 
+Expected<Decl *> ASTImporter::ImportImpl(Decl *FromD) {
+  // Import the decl using ASTNodeImporter.
+  ASTNodeImporter Importer(*this);
+  return Importer.Visit(FromD);
+}
+
+void ASTImporter::RegisterImportedDecl(Decl *FromD, Decl *ToD) {
+  MapImported(FromD, ToD);
+  AddToLookupTable(ToD);
+}
+
 Expected<QualType> ASTImporter::Import_New(QualType FromT) {
   if (FromT.isNull())
     return QualType{};
@@ -7749,7 +7759,6 @@
   if (!FromD)
     return nullptr;
 
-  ASTNodeImporter Importer(*this);
 
   // Check whether we've already imported this declaration.
   Decl *ToD = GetAlreadyImportedOrNull(FromD);
@@ -7760,7 +7769,7 @@
   }
 
   // Import the declaration.
-  ExpectedDecl ToDOrErr = Importer.Visit(FromD);
+  ExpectedDecl ToDOrErr = ImportImpl(FromD);
   if (!ToDOrErr)
     return ToDOrErr;
   ToD = *ToDOrErr;
@@ -7771,6 +7780,9 @@
     return nullptr;
   }
 
+  // Make sure that ImportImpl registered the imported decl.
+  assert(ImportedDecls.count(FromD) != 0 && "Missing call to MapImported?");
+
   // Once the decl is connected to the existing declarations, i.e. when the
   // redecl chain is properly set then we populate the lookup again.
   // This way the primary context will be able to find all decls.
Index: clang/include/clang/AST/ASTImporter.h
===================================================================
--- clang/include/clang/AST/ASTImporter.h
+++ clang/include/clang/AST/ASTImporter.h
@@ -142,6 +142,12 @@
 
     void AddToLookupTable(Decl *ToD);
 
+  protected:
+    /// Can be overwritten by subclasses to implement their own import logic.
+    /// The overwritten method should call this method if it didn't import the
+    /// decl on its own.
+    virtual Expected<Decl *> ImportImpl(Decl *From);
+
   public:
 
     /// \param ToContext The context we'll be importing into.
@@ -427,6 +433,8 @@
     /// \c To declaration mappings as they are imported.
     virtual void Imported(Decl *From, Decl *To) {}
 
+    void RegisterImportedDecl(Decl *FromD, Decl *ToD);
+
     /// Store and assign the imported declaration to its counterpart.
     Decl *MapImported(Decl *From, Decl *To);
 
_______________________________________________
lldb-commits mailing list
lldb-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/lldb-commits

Reply via email to