https://gcc.gnu.org/g:ee743907d9bdaaa4e6717ab4c62e322d0987e801

commit r15-8816-gee743907d9bdaaa4e6717ab4c62e322d0987e801
Author: Arthur Cohen <arthur.co...@embecosm.com>
Date:   Tue Feb 4 16:12:25 2025 +0100

    gccrs: derive(Eq): Also derive StructuralEq
    
    gcc/rust/ChangeLog:
    
            * expand/rust-derive-eq.cc: Adapt functions to return two generated 
impls.
            * expand/rust-derive-eq.h: Likewise.
            * expand/rust-derive.cc (DeriveVisitor::derive): Likewise.

Diff:
---
 gcc/rust/expand/rust-derive-eq.cc | 68 +++++++++++++++++++++------------------
 gcc/rust/expand/rust-derive-eq.h  | 10 +++---
 gcc/rust/expand/rust-derive.cc    |  2 +-
 3 files changed, 43 insertions(+), 37 deletions(-)

diff --git a/gcc/rust/expand/rust-derive-eq.cc 
b/gcc/rust/expand/rust-derive-eq.cc
index a2a7a769065c..47a8350d2ffc 100644
--- a/gcc/rust/expand/rust-derive-eq.cc
+++ b/gcc/rust/expand/rust-derive-eq.cc
@@ -27,6 +27,16 @@
 namespace Rust {
 namespace AST {
 
+DeriveEq::DeriveEq (location_t loc) : DeriveVisitor (loc) {}
+
+std::vector<std::unique_ptr<AST::Item>>
+DeriveEq::go (Item &item)
+{
+  item.accept_vis (*this);
+
+  return std::move (expanded);
+}
+
 std::unique_ptr<AssociatedItem>
 DeriveEq::assert_receiver_is_total_eq_fn (
   std::vector<std::unique_ptr<Type>> &&types)
@@ -98,33 +108,29 @@ DeriveEq::assert_type_is_eq (std::unique_ptr<Type> &&type)
   return builder.let (builder.wildcard (), std::move (full_path));
 }
 
-std::unique_ptr<Item>
-DeriveEq::eq_impl (
+std::vector<std::unique_ptr<Item>>
+DeriveEq::eq_impls (
   std::unique_ptr<AssociatedItem> &&fn, std::string name,
   const std::vector<std::unique_ptr<GenericParam>> &type_generics)
 {
   auto eq = builder.type_path ({"core", "cmp", "Eq"}, true);
+  auto steq = builder.type_path (LangItem::Kind::STRUCTURAL_TEQ);
 
   auto trait_items = vec (std::move (fn));
 
-  auto generics
+  auto eq_generics
     = setup_impl_generics (name, type_generics, builder.trait_bound (eq));
+  auto steq_generics = setup_impl_generics (name, type_generics);
 
-  return builder.trait_impl (eq, std::move (generics.self_type),
-                            std::move (trait_items),
-                            std::move (generics.impl));
-}
-
-DeriveEq::DeriveEq (location_t loc) : DeriveVisitor (loc), expanded (nullptr) 
{}
-
-std::unique_ptr<AST::Item>
-DeriveEq::go (Item &item)
-{
-  item.accept_vis (*this);
+  auto eq_impl = builder.trait_impl (eq, std::move (eq_generics.self_type),
+                                    std::move (trait_items),
+                                    std::move (eq_generics.impl));
+  auto steq_impl
+    = builder.trait_impl (steq, std::move (steq_generics.self_type),
+                         std::move (trait_items),
+                         std::move (steq_generics.impl));
 
-  rust_assert (expanded);
-
-  return std::move (expanded);
+  return vec (std::move (eq_impl), std::move (steq_impl));
 }
 
 void
@@ -135,9 +141,9 @@ DeriveEq::visit_tuple (TupleStruct &item)
   for (auto &field : item.get_fields ())
     types.emplace_back (field.get_field_type ().clone_type ());
 
-  expanded
-    = eq_impl (assert_receiver_is_total_eq_fn (std::move (types)),
-              item.get_identifier ().as_string (), item.get_generic_params ());
+  expanded = eq_impls (assert_receiver_is_total_eq_fn (std::move (types)),
+                      item.get_identifier ().as_string (),
+                      item.get_generic_params ());
 }
 
 void
@@ -148,9 +154,9 @@ DeriveEq::visit_struct (StructStruct &item)
   for (auto &field : item.get_fields ())
     types.emplace_back (field.get_field_type ().clone_type ());
 
-  expanded
-    = eq_impl (assert_receiver_is_total_eq_fn (std::move (types)),
-              item.get_identifier ().as_string (), item.get_generic_params ());
+  expanded = eq_impls (assert_receiver_is_total_eq_fn (std::move (types)),
+                      item.get_identifier ().as_string (),
+                      item.get_generic_params ());
 }
 
 void
@@ -167,7 +173,7 @@ DeriveEq::visit_enum (Enum &item)
          // nothing to do as they contain no inner types
          continue;
          case EnumItem::Kind::Tuple: {
-           auto tuple = static_cast<EnumItemTuple &> (*variant);
+           auto &tuple = static_cast<EnumItemTuple &> (*variant);
 
            for (auto &field : tuple.get_tuple_fields ())
              types.emplace_back (field.get_field_type ().clone_type ());
@@ -175,7 +181,7 @@ DeriveEq::visit_enum (Enum &item)
            break;
          }
          case EnumItem::Kind::Struct: {
-           auto tuple = static_cast<EnumItemStruct &> (*variant);
+           auto &tuple = static_cast<EnumItemStruct &> (*variant);
 
            for (auto &field : tuple.get_struct_fields ())
              types.emplace_back (field.get_field_type ().clone_type ());
@@ -185,9 +191,9 @@ DeriveEq::visit_enum (Enum &item)
        }
     }
 
-  expanded
-    = eq_impl (assert_receiver_is_total_eq_fn (std::move (types)),
-              item.get_identifier ().as_string (), item.get_generic_params ());
+  expanded = eq_impls (assert_receiver_is_total_eq_fn (std::move (types)),
+                      item.get_identifier ().as_string (),
+                      item.get_generic_params ());
 }
 
 void
@@ -198,9 +204,9 @@ DeriveEq::visit_union (Union &item)
   for (auto &field : item.get_variants ())
     types.emplace_back (field.get_field_type ().clone_type ());
 
-  expanded
-    = eq_impl (assert_receiver_is_total_eq_fn (std::move (types)),
-              item.get_identifier ().as_string (), item.get_generic_params ());
+  expanded = eq_impls (assert_receiver_is_total_eq_fn (std::move (types)),
+                      item.get_identifier ().as_string (),
+                      item.get_generic_params ());
 }
 
 } // namespace AST
diff --git a/gcc/rust/expand/rust-derive-eq.h b/gcc/rust/expand/rust-derive-eq.h
index 655f1e82e02a..17af52653dea 100644
--- a/gcc/rust/expand/rust-derive-eq.h
+++ b/gcc/rust/expand/rust-derive-eq.h
@@ -31,10 +31,10 @@ class DeriveEq : DeriveVisitor
 public:
   DeriveEq (location_t loc);
 
-  std::unique_ptr<AST::Item> go (Item &item);
+  std::vector<std::unique_ptr<AST::Item>> go (Item &item);
 
 private:
-  std::unique_ptr<Item> expanded;
+  std::vector<std::unique_ptr<Item>> expanded;
 
   /**
    * Create the actual `assert_receiver_is_total_eq` function of the
@@ -52,9 +52,9 @@ private:
    * }
    *
    */
-  std::unique_ptr<Item>
-  eq_impl (std::unique_ptr<AssociatedItem> &&fn, std::string name,
-          const std::vector<std::unique_ptr<GenericParam>> &type_generics);
+  std::vector<std::unique_ptr<Item>>
+  eq_impls (std::unique_ptr<AssociatedItem> &&fn, std::string name,
+           const std::vector<std::unique_ptr<GenericParam>> &type_generics);
 
   /**
    * Generate the following structure definition
diff --git a/gcc/rust/expand/rust-derive.cc b/gcc/rust/expand/rust-derive.cc
index 250ef726f43d..39e03a67cd4a 100644
--- a/gcc/rust/expand/rust-derive.cc
+++ b/gcc/rust/expand/rust-derive.cc
@@ -50,7 +50,7 @@ DeriveVisitor::derive (Item &item, const Attribute &attr,
     case BuiltinMacro::Default:
       return vec (DeriveDefault (attr.get_locus ()).go (item));
     case BuiltinMacro::Eq:
-      return vec (DeriveEq (attr.get_locus ()).go (item));
+      return DeriveEq (attr.get_locus ()).go (item);
     case BuiltinMacro::PartialEq:
       return DerivePartialEq (attr.get_locus ()).go (item);
     case BuiltinMacro::Ord:

Reply via email to