From: Arthur Cohen <arthur.co...@embecosm.com>

gcc/rust/ChangeLog:

        * expand/rust-derive-ord.cc (DeriveOrd::cmp_call): New function.
        (DeriveOrd::recursive_match): Use it.
        (DeriveOrd::visit_enum): Likewise.
        * expand/rust-derive-ord.h: Declare it.
---
 gcc/rust/expand/rust-derive-ord.cc | 68 +++++++++++++++++++++++++-----
 gcc/rust/expand/rust-derive-ord.h  | 16 ++++++-
 2 files changed, 73 insertions(+), 11 deletions(-)

diff --git a/gcc/rust/expand/rust-derive-ord.cc 
b/gcc/rust/expand/rust-derive-ord.cc
index e39c6b44ca4..1f39c94d87b 100644
--- a/gcc/rust/expand/rust-derive-ord.cc
+++ b/gcc/rust/expand/rust-derive-ord.cc
@@ -39,6 +39,17 @@ DeriveOrd::go (Item &item)
   return std::move (expanded);
 }
 
+std::unique_ptr<Expr>
+DeriveOrd::cmp_call (std::unique_ptr<Expr> &&self_expr,
+                    std::unique_ptr<Expr> &&other_expr)
+{
+  auto cmp_fn_path = builder.path_in_expression (
+    {"core", "cmp", trait (ordering), fn (ordering)}, true);
+
+  return builder.call (ptrify (cmp_fn_path),
+                      vec (std::move (self_expr), std::move (other_expr)));
+}
+
 std::unique_ptr<Item>
 DeriveOrd::cmp_impl (
   std::unique_ptr<BlockExpr> &&fn_block, Identifier type_name,
@@ -132,18 +143,14 @@ DeriveOrd::recursive_match (std::vector<SelfOther> 
&&members)
     {
       auto &member = *it;
 
-      auto cmp_fn_path = builder.path_in_expression (
-       {"core", "cmp", trait (ordering), fn (ordering)}, true);
-
-      auto cmp_call = builder.call (ptrify (cmp_fn_path),
-                                   vec (std::move (member.self_expr),
-                                        std::move (member.other_expr)));
+      auto call = cmp_call (std::move (member.self_expr),
+                           std::move (member.other_expr));
 
       // For the last member (so the first iterator), we just create a call
       // expression
       if (it == members.rbegin ())
        {
-         final_expr = std::move (cmp_call);
+         final_expr = std::move (call);
          continue;
        }
 
@@ -157,8 +164,7 @@ DeriveOrd::recursive_match (std::vector<SelfOther> 
&&members)
           builder.match_case (std::move (match_arms.second),
                               builder.identifier (DeriveOrd::not_equal))};
 
-      final_expr
-       = builder.match (std::move (cmp_call), std::move (match_cases));
+      final_expr = builder.match (std::move (call), std::move (match_cases));
     }
 
   return final_expr;
@@ -227,7 +233,49 @@ DeriveOrd::visit_tuple (TupleStruct &item)
 // }
 void
 DeriveOrd::visit_enum (Enum &item)
-{}
+{
+  auto cases = std::vector<MatchCase> ();
+  auto type_name = item.get_identifier ().as_string ();
+
+  auto let_sd = builder.discriminant_value (DeriveOrd::self_discr, "self");
+  auto other_sd = builder.discriminant_value (DeriveOrd::other_discr, "other");
+
+  auto discr_cmp = cmp_call (builder.identifier (DeriveOrd::self_discr),
+                            builder.identifier (DeriveOrd::other_discr));
+
+  for (auto &variant : item.get_variants ())
+    {
+      auto variant_path
+       = builder.variant_path (type_name,
+                               variant->get_identifier ().as_string ());
+
+      switch (variant->get_enum_item_kind ())
+       {
+       case EnumItem::Kind::Tuple:
+       case EnumItem::Kind::Struct:
+       case EnumItem::Kind::Identifier:
+       case EnumItem::Kind::Discriminant:
+         // We don't need to do anything for these, as they are handled by the
+         // discriminant value comparison
+         break;
+       }
+    }
+
+  // Add the last case which compares the discriminant values in case `self` 
and
+  // `other` are actually different variants of the enum
+  cases.emplace_back (
+    builder.match_case (builder.wildcard (), std::move (discr_cmp)));
+
+  auto match
+    = builder.match (builder.tuple (vec (builder.identifier ("self"),
+                                        builder.identifier ("other"))),
+                    std::move (cases));
+
+  expanded
+    = cmp_impl (builder.block (vec (std::move (let_sd), std::move (other_sd)),
+                              std::move (match)),
+               type_name, item.get_generic_params ());
+}
 
 void
 DeriveOrd::visit_union (Union &item)
diff --git a/gcc/rust/expand/rust-derive-ord.h 
b/gcc/rust/expand/rust-derive-ord.h
index 047ebfb0c01..a360dd26d97 100644
--- a/gcc/rust/expand/rust-derive-ord.h
+++ b/gcc/rust/expand/rust-derive-ord.h
@@ -69,7 +69,9 @@ private:
   Ordering ordering;
 
   /* Identifier patterns for the non-equal match arms */
-  constexpr static const char *not_equal = "non_eq";
+  constexpr static const char *not_equal = "#non_eq";
+  constexpr static const char *self_discr = "#self_discr";
+  constexpr static const char *other_discr = "#other_discr";
 
   /**
    * Create the recursive matching structure used when implementing the
@@ -89,6 +91,18 @@ private:
    */
   std::pair<MatchArm, MatchArm> make_cmp_arms ();
 
+  MatchCase match_enum_tuple (PathInExpression variant_path,
+                             const EnumItemTuple &variant);
+  MatchCase match_enum_struct (PathInExpression variant_path,
+                              const EnumItemStruct &variant);
+
+  /**
+   * Generate a call to the proper trait function, based on the ordering, in
+   * order to compare two given expressions
+   */
+  std::unique_ptr<Expr> cmp_call (std::unique_ptr<Expr> &&self_expr,
+                                 std::unique_ptr<Expr> &&other_expr);
+
   std::unique_ptr<Item>
   cmp_impl (std::unique_ptr<BlockExpr> &&fn_block, Identifier type_name,
            const std::vector<std::unique_ptr<GenericParam>> &type_generics);
-- 
2.49.0

Reply via email to