llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-flang-runtime Author: Slava Zakharin (vzakhari) <details> <summary>Changes</summary> This is a simplified implementation of std::reference_wrapper that can be used in the offload builds for the device code. The methods are properly marked with RT_API_ATTRS so that the device compilation succedes. --- Full diff: https://github.com/llvm/llvm-project/pull/85178.diff 2 Files Affected: - (added) flang/include/flang/Common/reference-wrapper.h (+114) - (modified) flang/runtime/io-stmt.h (+34-25) ``````````diff diff --git a/flang/include/flang/Common/reference-wrapper.h b/flang/include/flang/Common/reference-wrapper.h new file mode 100644 index 00000000000000..66f924662d9612 --- /dev/null +++ b/flang/include/flang/Common/reference-wrapper.h @@ -0,0 +1,114 @@ +//===-- include/flang/Common/reference-wrapper.h ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// clang-format off +// +// Implementation of std::reference_wrapper borrowed from libcu++ +// https://github.com/NVIDIA/libcudacxx/blob/f7e6cd07ed5ba826aeac0b742feafddfedc1e400/include/cuda/std/detail/libcxx/include/__functional/reference_wrapper.h#L1 +// with modifications. +// +// The original source code is distributed under the Apache License v2.0 +// with LLVM Exceptions. +// +// TODO: using libcu++ is the best option for CUDA, but there is a couple +// of issues: +// * The include paths need to be set up such that all STD header files +// are taken from libcu++. +// * cuda:: namespace need to be forced for all std:: references. +// +// clang-format on + +#ifndef FORTRAN_COMMON_REFERENCE_WRAPPER_H +#define FORTRAN_COMMON_REFERENCE_WRAPPER_H + +#include "flang/Runtime/api-attrs.h" +#include <functional> +#include <type_traits> + +#if !defined(STD_REFERENCE_WRAPPER_UNSUPPORTED) && \ + (defined(__CUDACC__) || defined(__CUDA__)) && defined(__CUDA_ARCH__) +#define STD_REFERENCE_WRAPPER_UNSUPPORTED 1 +#endif + +namespace Fortran::common { + +template <class _Tp> +using __remove_cvref_t = std::remove_cv_t<std::remove_reference_t<_Tp>>; +template <class _Tp, class _Up> +struct __is_same_uncvref + : std::is_same<__remove_cvref_t<_Tp>, __remove_cvref_t<_Up>> {}; + +#if STD_REFERENCE_WRAPPER_UNSUPPORTED +template <class _Tp> class reference_wrapper { +public: + // types + typedef _Tp type; + +private: + type *__f_; + + static RT_API_ATTRS void __fun(_Tp &); + static void __fun(_Tp &&) = delete; + +public: + template <class _Up, + class = + std::enable_if_t<!__is_same_uncvref<_Up, reference_wrapper>::value, + decltype(__fun(std::declval<_Up>()))>> + constexpr RT_API_ATTRS reference_wrapper(_Up &&__u) { + type &__f = static_cast<_Up &&>(__u); + __f_ = std::addressof(__f); + } + + // access + constexpr RT_API_ATTRS operator type &() const { return *__f_; } + constexpr RT_API_ATTRS type &get() const { return *__f_; } + + // invoke + template <class... _ArgTypes> + constexpr RT_API_ATTRS typename std::invoke_result_t<type &, _ArgTypes...> + operator()(_ArgTypes &&...__args) const { + return std::invoke(get(), std::forward<_ArgTypes>(__args)...); + } +}; + +template <class _Tp> reference_wrapper(_Tp &) -> reference_wrapper<_Tp>; + +template <class _Tp> +inline constexpr RT_API_ATTRS reference_wrapper<_Tp> ref(_Tp &__t) { + return reference_wrapper<_Tp>(__t); +} + +template <class _Tp> +inline constexpr RT_API_ATTRS reference_wrapper<_Tp> ref( + reference_wrapper<_Tp> __t) { + return __t; +} + +template <class _Tp> +inline constexpr RT_API_ATTRS reference_wrapper<const _Tp> cref( + const _Tp &__t) { + return reference_wrapper<const _Tp>(__t); +} + +template <class _Tp> +inline constexpr RT_API_ATTRS reference_wrapper<const _Tp> cref( + reference_wrapper<_Tp> __t) { + return __t; +} + +template <class _Tp> void ref(const _Tp &&) = delete; +template <class _Tp> void cref(const _Tp &&) = delete; +#else // !STD_REFERENCE_WRAPPER_UNSUPPORTED +using std::cref; +using std::ref; +using std::reference_wrapper; +#endif // !STD_REFERENCE_WRAPPER_UNSUPPORTED + +} // namespace Fortran::common + +#endif // FORTRAN_COMMON_REFERENCE_WRAPPER_H diff --git a/flang/runtime/io-stmt.h b/flang/runtime/io-stmt.h index 0477c32b3b53ad..e00d54980aae59 100644 --- a/flang/runtime/io-stmt.h +++ b/flang/runtime/io-stmt.h @@ -17,6 +17,7 @@ #include "internal-unit.h" #include "io-error.h" #include "flang/Common/optional.h" +#include "flang/Common/reference-wrapper.h" #include "flang/Common/visit.h" #include "flang/Runtime/descriptor.h" #include "flang/Runtime/io-api.h" @@ -210,39 +211,47 @@ class IoStatementState { } private: - std::variant<std::reference_wrapper<OpenStatementState>, - std::reference_wrapper<CloseStatementState>, - std::reference_wrapper<NoopStatementState>, - std::reference_wrapper< + std::variant<Fortran::common::reference_wrapper<OpenStatementState>, + Fortran::common::reference_wrapper<CloseStatementState>, + Fortran::common::reference_wrapper<NoopStatementState>, + Fortran::common::reference_wrapper< InternalFormattedIoStatementState<Direction::Output>>, - std::reference_wrapper< + Fortran::common::reference_wrapper< InternalFormattedIoStatementState<Direction::Input>>, - std::reference_wrapper<InternalListIoStatementState<Direction::Output>>, - std::reference_wrapper<InternalListIoStatementState<Direction::Input>>, - std::reference_wrapper< + Fortran::common::reference_wrapper< + InternalListIoStatementState<Direction::Output>>, + Fortran::common::reference_wrapper< + InternalListIoStatementState<Direction::Input>>, + Fortran::common::reference_wrapper< ExternalFormattedIoStatementState<Direction::Output>>, - std::reference_wrapper< + Fortran::common::reference_wrapper< ExternalFormattedIoStatementState<Direction::Input>>, - std::reference_wrapper<ExternalListIoStatementState<Direction::Output>>, - std::reference_wrapper<ExternalListIoStatementState<Direction::Input>>, - std::reference_wrapper< + Fortran::common::reference_wrapper< + ExternalListIoStatementState<Direction::Output>>, + Fortran::common::reference_wrapper< + ExternalListIoStatementState<Direction::Input>>, + Fortran::common::reference_wrapper< ExternalUnformattedIoStatementState<Direction::Output>>, - std::reference_wrapper< + Fortran::common::reference_wrapper< ExternalUnformattedIoStatementState<Direction::Input>>, - std::reference_wrapper<ChildFormattedIoStatementState<Direction::Output>>, - std::reference_wrapper<ChildFormattedIoStatementState<Direction::Input>>, - std::reference_wrapper<ChildListIoStatementState<Direction::Output>>, - std::reference_wrapper<ChildListIoStatementState<Direction::Input>>, - std::reference_wrapper< + Fortran::common::reference_wrapper< + ChildFormattedIoStatementState<Direction::Output>>, + Fortran::common::reference_wrapper< + ChildFormattedIoStatementState<Direction::Input>>, + Fortran::common::reference_wrapper< + ChildListIoStatementState<Direction::Output>>, + Fortran::common::reference_wrapper< + ChildListIoStatementState<Direction::Input>>, + Fortran::common::reference_wrapper< ChildUnformattedIoStatementState<Direction::Output>>, - std::reference_wrapper< + Fortran::common::reference_wrapper< ChildUnformattedIoStatementState<Direction::Input>>, - std::reference_wrapper<InquireUnitState>, - std::reference_wrapper<InquireNoUnitState>, - std::reference_wrapper<InquireUnconnectedFileState>, - std::reference_wrapper<InquireIOLengthState>, - std::reference_wrapper<ExternalMiscIoStatementState>, - std::reference_wrapper<ErroneousIoStatementState>> + Fortran::common::reference_wrapper<InquireUnitState>, + Fortran::common::reference_wrapper<InquireNoUnitState>, + Fortran::common::reference_wrapper<InquireUnconnectedFileState>, + Fortran::common::reference_wrapper<InquireIOLengthState>, + Fortran::common::reference_wrapper<ExternalMiscIoStatementState>, + Fortran::common::reference_wrapper<ErroneousIoStatementState>> u_; }; `````````` </details> https://github.com/llvm/llvm-project/pull/85178 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits