Author: Chuanqi Xu Date: 2025-02-28T16:03:50+08:00 New Revision: 15c49b9db3f60bdbd320271d5e97f118c00b95dd
URL: https://github.com/llvm/llvm-project/commit/15c49b9db3f60bdbd320271d5e97f118c00b95dd DIFF: https://github.com/llvm/llvm-project/commit/15c49b9db3f60bdbd320271d5e97f118c00b95dd.diff LOG: [Coroutines] [CodeGen] Don't change AST in CodeGen/Coroutines The root source of other odd bugs. We performed a hack in CodeGen/Coroutines. But we didn't recognize that the CodeGen is a consumer of AST. The CodeGen shouldn't change AST in any ways. It'll break the assumption about the ASTConsumer in Clang's framework, which may break any other clang-based tools which depends on multiple consumers to work together. The fix here is simple. But I am not super happy about the test. It is too specific and verbose. We can remove this if we can get the signature of the AST in ASTContext. Added: clang/unittests/Frontend/NoAlterCodeGenActionTest.cpp Modified: clang/lib/CodeGen/CGCoroutine.cpp clang/unittests/Frontend/CMakeLists.txt Removed: ################################################################################ diff --git a/clang/lib/CodeGen/CGCoroutine.cpp b/clang/lib/CodeGen/CGCoroutine.cpp index 9abf2e8c9190d..058ec01f8ce0e 100644 --- a/clang/lib/CodeGen/CGCoroutine.cpp +++ b/clang/lib/CodeGen/CGCoroutine.cpp @@ -942,9 +942,16 @@ void CodeGenFunction::EmitCoroutineBody(const CoroutineBodyStmt &S) { if (Stmt *Ret = S.getReturnStmt()) { // Since we already emitted the return value above, so we shouldn't // emit it again here. - if (GroManager.DirectEmit) + Expr *PreviousRetValue = nullptr; + if (GroManager.DirectEmit) { + PreviousRetValue = cast<ReturnStmt>(Ret)->getRetValue(); cast<ReturnStmt>(Ret)->setRetValue(nullptr); + } EmitStmt(Ret); + // Set the return value back. The code generator, as the AST **Consumer**, + // shouldn't change the AST. + if (PreviousRetValue) + cast<ReturnStmt>(Ret)->setRetValue(PreviousRetValue); } // LLVM require the frontend to mark the coroutine. diff --git a/clang/unittests/Frontend/CMakeLists.txt b/clang/unittests/Frontend/CMakeLists.txt index 0f05813338f2a..3c94846243870 100644 --- a/clang/unittests/Frontend/CMakeLists.txt +++ b/clang/unittests/Frontend/CMakeLists.txt @@ -10,6 +10,7 @@ add_clang_unittest(FrontendTests FixedPointString.cpp FrontendActionTest.cpp CodeGenActionTest.cpp + NoAlterCodeGenActionTest.cpp ParsedSourceLocationTest.cpp PCHPreambleTest.cpp ReparseWorkingDirTest.cpp @@ -27,4 +28,5 @@ clang_target_link_libraries(FrontendTests clangCodeGen clangFrontendTool clangSerialization + clangTooling ) diff --git a/clang/unittests/Frontend/NoAlterCodeGenActionTest.cpp b/clang/unittests/Frontend/NoAlterCodeGenActionTest.cpp new file mode 100644 index 0000000000000..e7a3bf5a7f87a --- /dev/null +++ b/clang/unittests/Frontend/NoAlterCodeGenActionTest.cpp @@ -0,0 +1,198 @@ +//===- unittests/Frontend/NoAlterCodeGenActionTest.cpp --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Unit tests for CodeGenAction may not alter the AST. +// +//===----------------------------------------------------------------------===// + +#include "clang/AST/ASTConsumer.h" +#include "clang/AST/RecursiveASTVisitor.h" +#include "clang/Basic/LangStandard.h" +#include "clang/CodeGen/BackendUtil.h" +#include "clang/CodeGen/CodeGenAction.h" +#include "clang/Frontend/CompilerInstance.h" +#include "clang/Frontend/MultiplexConsumer.h" +#include "clang/Lex/PreprocessorOptions.h" +#include "clang/Tooling/Tooling.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/VirtualFileSystem.h" +#include "gtest/gtest.h" + +using namespace llvm; +using namespace clang; +using namespace clang::frontend; +using namespace clang::tooling; + +namespace { + +class ASTChecker : public RecursiveASTVisitor<ASTChecker> { +public: + ASTContext &Ctx; + ASTChecker(ASTContext &Ctx) : Ctx(Ctx) {} + bool VisitReturnStmt(ReturnStmt *RS) { + EXPECT_TRUE(RS->getRetValue()); + return true; + } + + bool VisitCoroutineBodyStmt(CoroutineBodyStmt *CS) { + return VisitReturnStmt(cast<ReturnStmt>(CS->getReturnStmt())); + } +}; + +class ASTCheckerConsumer : public ASTConsumer { +public: + void HandleTranslationUnit(ASTContext &Ctx) override { + ASTChecker Checker(Ctx); + Checker.TraverseAST(Ctx); + } +}; + +class TestCodeGenAction : public EmitLLVMAction { +public: + using Base = EmitLLVMAction; + TestCodeGenAction(llvm::LLVMContext *_VMContext = nullptr) + : EmitLLVMAction(_VMContext) {} + + std::unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &CI, + StringRef InFile) override { + std::vector<std::unique_ptr<ASTConsumer>> Consumers; + Consumers.push_back(std::make_unique<ASTCheckerConsumer>()); + Consumers.push_back(Base::CreateASTConsumer(CI, InFile)); + return std::make_unique<MultiplexConsumer>(std::move(Consumers)); + } +}; + +const char *test_contents = R"cpp( + +namespace std { + +template <typename R, typename...> struct coroutine_traits { + using promise_type = typename R::promise_type; +}; + +template <typename Promise = void> struct coroutine_handle; + +template <> struct coroutine_handle<void> { + static coroutine_handle from_address(void *addr) noexcept; + void operator()() { resume(); } + void *address() const noexcept; + void resume() const { __builtin_coro_resume(ptr); } + void destroy() const { __builtin_coro_destroy(ptr); } + bool done() const; + coroutine_handle &operator=(decltype(nullptr)); + coroutine_handle(decltype(nullptr)) : ptr(nullptr) {} + coroutine_handle() : ptr(nullptr) {} +// void reset() { ptr = nullptr; } // add to P0057? + explicit operator bool() const; + +protected: + void *ptr; +}; + +template <typename Promise> struct coroutine_handle : coroutine_handle<> { + using coroutine_handle<>::operator=; + + static coroutine_handle from_address(void *addr) noexcept; + + Promise &promise() const; + static coroutine_handle from_promise(Promise &promise); +}; + +template <typename _PromiseT> +bool operator==(coroutine_handle<_PromiseT> const &_Left, + coroutine_handle<_PromiseT> const &_Right) noexcept { + return _Left.address() == _Right.address(); +} + +template <typename _PromiseT> +bool operator!=(coroutine_handle<_PromiseT> const &_Left, + coroutine_handle<_PromiseT> const &_Right) noexcept { + return !(_Left == _Right); +} + +struct noop_coroutine_promise {}; + +template <> +struct coroutine_handle<noop_coroutine_promise> { + operator coroutine_handle<>() const noexcept; + + constexpr explicit operator bool() const noexcept { return true; } + constexpr bool done() const noexcept { return false; } + + constexpr void operator()() const noexcept {} + constexpr void resume() const noexcept {} + constexpr void destroy() const noexcept {} + + noop_coroutine_promise &promise() const noexcept { + return *static_cast<noop_coroutine_promise *>( + __builtin_coro_promise(this->__handle_, alignof(noop_coroutine_promise), false)); + } + + constexpr void *address() const noexcept { return __handle_; } + +private: + friend coroutine_handle<noop_coroutine_promise> noop_coroutine() noexcept; + + coroutine_handle() noexcept { + this->__handle_ = __builtin_coro_noop(); + } + + void *__handle_ = nullptr; +}; + +using noop_coroutine_handle = coroutine_handle<noop_coroutine_promise>; + +inline noop_coroutine_handle noop_coroutine() noexcept { return noop_coroutine_handle(); } + +struct suspend_always { + bool await_ready() noexcept { return false; } + void await_suspend(coroutine_handle<>) noexcept {} + void await_resume() noexcept {} +}; +struct suspend_never { + bool await_ready() noexcept { return true; } + void await_suspend(coroutine_handle<>) noexcept {} + void await_resume() noexcept {} +}; + +} // namespace std + +using namespace std; + +class invoker { +public: + class invoker_promise { + public: + invoker get_return_object() { return invoker{}; } + auto initial_suspend() { return suspend_always{}; } + auto final_suspend() noexcept { return suspend_always{}; } + void return_void() {} + void unhandled_exception() {} + }; + using promise_type = invoker_promise; + invoker() {} + invoker(const invoker &) = delete; + invoker &operator=(const invoker &) = delete; + invoker(invoker &&) = delete; + invoker &operator=(invoker &&) = delete; +}; + +invoker g() { + co_return; +} + +)cpp"; + +TEST(CodeGenTest, TestNonAlterTest) { + EXPECT_TRUE(runToolOnCodeWithArgs(std::make_unique<TestCodeGenAction>(), + test_contents, + { + "-std=c++20", + })); +} +} // namespace _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits