From: Arthur Cohen <arthur.co...@embecosm.com> gcc/rust/ChangeLog:
* expand/rust-derive-ord.cc (DeriveOrd::cmp_call): New function. (DeriveOrd::recursive_match): Use it. (DeriveOrd::visit_enum): Likewise. * expand/rust-derive-ord.h: Declare it. --- gcc/rust/expand/rust-derive-ord.cc | 68 +++++++++++++++++++++++++----- gcc/rust/expand/rust-derive-ord.h | 16 ++++++- 2 files changed, 73 insertions(+), 11 deletions(-) diff --git a/gcc/rust/expand/rust-derive-ord.cc b/gcc/rust/expand/rust-derive-ord.cc index e39c6b44ca4..1f39c94d87b 100644 --- a/gcc/rust/expand/rust-derive-ord.cc +++ b/gcc/rust/expand/rust-derive-ord.cc @@ -39,6 +39,17 @@ DeriveOrd::go (Item &item) return std::move (expanded); } +std::unique_ptr<Expr> +DeriveOrd::cmp_call (std::unique_ptr<Expr> &&self_expr, + std::unique_ptr<Expr> &&other_expr) +{ + auto cmp_fn_path = builder.path_in_expression ( + {"core", "cmp", trait (ordering), fn (ordering)}, true); + + return builder.call (ptrify (cmp_fn_path), + vec (std::move (self_expr), std::move (other_expr))); +} + std::unique_ptr<Item> DeriveOrd::cmp_impl ( std::unique_ptr<BlockExpr> &&fn_block, Identifier type_name, @@ -132,18 +143,14 @@ DeriveOrd::recursive_match (std::vector<SelfOther> &&members) { auto &member = *it; - auto cmp_fn_path = builder.path_in_expression ( - {"core", "cmp", trait (ordering), fn (ordering)}, true); - - auto cmp_call = builder.call (ptrify (cmp_fn_path), - vec (std::move (member.self_expr), - std::move (member.other_expr))); + auto call = cmp_call (std::move (member.self_expr), + std::move (member.other_expr)); // For the last member (so the first iterator), we just create a call // expression if (it == members.rbegin ()) { - final_expr = std::move (cmp_call); + final_expr = std::move (call); continue; } @@ -157,8 +164,7 @@ DeriveOrd::recursive_match (std::vector<SelfOther> &&members) builder.match_case (std::move (match_arms.second), builder.identifier (DeriveOrd::not_equal))}; - final_expr - = builder.match (std::move (cmp_call), std::move (match_cases)); + final_expr = builder.match (std::move (call), std::move (match_cases)); } return final_expr; @@ -227,7 +233,49 @@ DeriveOrd::visit_tuple (TupleStruct &item) // } void DeriveOrd::visit_enum (Enum &item) -{} +{ + auto cases = std::vector<MatchCase> (); + auto type_name = item.get_identifier ().as_string (); + + auto let_sd = builder.discriminant_value (DeriveOrd::self_discr, "self"); + auto other_sd = builder.discriminant_value (DeriveOrd::other_discr, "other"); + + auto discr_cmp = cmp_call (builder.identifier (DeriveOrd::self_discr), + builder.identifier (DeriveOrd::other_discr)); + + for (auto &variant : item.get_variants ()) + { + auto variant_path + = builder.variant_path (type_name, + variant->get_identifier ().as_string ()); + + switch (variant->get_enum_item_kind ()) + { + case EnumItem::Kind::Tuple: + case EnumItem::Kind::Struct: + case EnumItem::Kind::Identifier: + case EnumItem::Kind::Discriminant: + // We don't need to do anything for these, as they are handled by the + // discriminant value comparison + break; + } + } + + // Add the last case which compares the discriminant values in case `self` and + // `other` are actually different variants of the enum + cases.emplace_back ( + builder.match_case (builder.wildcard (), std::move (discr_cmp))); + + auto match + = builder.match (builder.tuple (vec (builder.identifier ("self"), + builder.identifier ("other"))), + std::move (cases)); + + expanded + = cmp_impl (builder.block (vec (std::move (let_sd), std::move (other_sd)), + std::move (match)), + type_name, item.get_generic_params ()); +} void DeriveOrd::visit_union (Union &item) diff --git a/gcc/rust/expand/rust-derive-ord.h b/gcc/rust/expand/rust-derive-ord.h index 047ebfb0c01..a360dd26d97 100644 --- a/gcc/rust/expand/rust-derive-ord.h +++ b/gcc/rust/expand/rust-derive-ord.h @@ -69,7 +69,9 @@ private: Ordering ordering; /* Identifier patterns for the non-equal match arms */ - constexpr static const char *not_equal = "non_eq"; + constexpr static const char *not_equal = "#non_eq"; + constexpr static const char *self_discr = "#self_discr"; + constexpr static const char *other_discr = "#other_discr"; /** * Create the recursive matching structure used when implementing the @@ -89,6 +91,18 @@ private: */ std::pair<MatchArm, MatchArm> make_cmp_arms (); + MatchCase match_enum_tuple (PathInExpression variant_path, + const EnumItemTuple &variant); + MatchCase match_enum_struct (PathInExpression variant_path, + const EnumItemStruct &variant); + + /** + * Generate a call to the proper trait function, based on the ordering, in + * order to compare two given expressions + */ + std::unique_ptr<Expr> cmp_call (std::unique_ptr<Expr> &&self_expr, + std::unique_ptr<Expr> &&other_expr); + std::unique_ptr<Item> cmp_impl (std::unique_ptr<BlockExpr> &&fn_block, Identifier type_name, const std::vector<std::unique_ptr<GenericParam>> &type_generics); -- 2.49.0