This is an automated email from the ASF dual-hosted git repository.

chaokunyang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/fory.git


The following commit(s) were added to refs/heads/main by this push:
     new 23aa3f98e fix(compiler): never generate C++ equality methods for 
message and union containing any (#3810)
23aa3f98e is described below

commit 23aa3f98ee56687c555c681a488a69d8cfbb5832
Author: Peiyang He <[email protected]>
AuthorDate: Fri Jul 3 13:30:12 2026 +0800

    fix(compiler): never generate C++ equality methods for message and union 
containing any (#3810)
    
    ## Why?
    
    Fory IDL `any` maps to `std::any` in generated C++ code.
    `std::any` does not define `operator==`, so generated equality code MUST
    NOT ask the C++ standard library to compare two `std::any` directly.
    
    But the current C++ compiler still does this in some generated
    `operator==` implementations, especially when `any` appeared inside
    standard-comparable containers or union alternatives.
    
    For example, generated code such as `values_ == other.values_` for
    `std::vector<std::any>`,
    `by_name_ == other.by_name_` for `std::unordered_map<std::string,
    std::any>`, or
    `value_ == other.value_` for `std::variant<std::any, ...>` causes the
    standard
    container or variant equality operator to instantiate `std::any ==
    std::any`.
    
    This causes compilation error like:
    
    ```text
    error: no match for 'operator==' (operand types are 'const std::any' and 
'const std::any')
    ```
    
    The Rust compiler avoids the equivalent derived-equality problem by not
    deriving comparison traits for generated types that contain `any`:
    
    
    
https://github.com/apache/fory/blob/9b2bcec9618702d6aa5ba4b166f86619fce51baf/compiler/fory_compiler/generators/rust.py#L1149-L1158
    
    ## What does this PR do?
    
    - Follow the Rust handling, i.e. don't generate C++ equality methods for
    message and union containing `any`.
    - Add one testcase.
    
    ## Related issues
    
    N/A.
    
    ## AI Contribution Checklist
    
    - [X] Substantial AI assistance was used in this PR: `no`
    
    ## Does this PR introduce any user-facing change?
    
    N/A.
    
    ## Benchmark
    
    N/A.
---
 compiler/fory_compiler/generators/cpp.py           | 148 +++++++++++++++++----
 .../fory_compiler/tests/test_generated_code.py     |  88 ++++++++++++
 integration_tests/idl_tests/cpp/main.cc            |  39 +++++-
 3 files changed, 246 insertions(+), 29 deletions(-)

diff --git a/compiler/fory_compiler/generators/cpp.py 
b/compiler/fory_compiler/generators/cpp.py
index 91e6b49fd..f7302dd15 100644
--- a/compiler/fory_compiler/generators/cpp.py
+++ b/compiler/fory_compiler/generators/cpp.py
@@ -691,13 +691,6 @@ class CppGenerator(BaseGenerator):
     ) -> str:
         member_name = self.get_field_member_name(field)
         other_member = f"other.{member_name}"
-        if isinstance(field.field_type, PrimitiveType) and (
-            field.field_type.kind == PrimitiveKind.ANY
-        ):
-            return (
-                f"((!{member_name}.has_value() && !{other_member}.has_value()) 
|| "
-                f"({member_name}.type() == {other_member}.type()))"
-            )
         if self.is_message_type(
             field.field_type, parent_stack
         ) and self.get_field_weak_ref(field):
@@ -716,6 +709,104 @@ class CppGenerator(BaseGenerator):
             )
         return f"{member_name} == {other_member}"
 
+    def message_has_any(
+        self,
+        message: Message,
+        parent_stack: Optional[List[Message]] = None,
+        visiting: Optional[Set[Tuple[str, int]]] = None,
+    ) -> bool:
+        if visiting is None:
+            visiting = set()
+        key = ("message", id(message))
+        if key in visiting:
+            return False
+        visiting.add(key)
+        try:
+            lineage = (parent_stack or []) + [message]
+            return any(
+                self.field_type_has_any(field.field_type, lineage, visiting)
+                for field in message.fields
+            )
+        finally:
+            visiting.remove(key)
+
+    def union_has_any(
+        self,
+        union: Union,
+        parent_stack: Optional[List[Message]] = None,
+        visiting: Optional[Set[Tuple[str, int]]] = None,
+    ) -> bool:
+        if visiting is None:
+            visiting = set()
+        key = ("union", id(union))
+        if key in visiting:
+            return False
+        visiting.add(key)
+        try:
+            return any(
+                self.field_type_has_any(field.field_type, parent_stack, 
visiting)
+                for field in union.fields
+            )
+        finally:
+            visiting.remove(key)
+
+    def field_type_has_any(
+        self,
+        field_type: FieldType,
+        parent_stack: Optional[List[Message]] = None,
+        visiting: Optional[Set[Tuple[str, int]]] = None,
+    ) -> bool:
+        """Return True when a field type or its children contain `any`."""
+        if isinstance(field_type, PrimitiveType):
+            return field_type.kind == PrimitiveKind.ANY
+        if isinstance(field_type, ListType):
+            return self.field_type_has_any(
+                field_type.element_type, parent_stack, visiting
+            )
+        if isinstance(field_type, ArrayType):
+            return self.field_type_has_any(
+                field_type.element_type, parent_stack, visiting
+            )
+        if isinstance(field_type, MapType):
+            # `any` is not allowed as map key (rejected first by the 
validator),
+            # so we only check map value here.
+            return self.field_type_has_any(
+                field_type.value_type, parent_stack, visiting
+            )
+        if isinstance(field_type, NamedType):
+            named_type = self.resolve_named_type(field_type.name, parent_stack)
+            if isinstance(named_type, Message):
+                return self.message_has_any(
+                    named_type, self._parent_stack_for_type(named_type), 
visiting
+                )
+            if isinstance(named_type, Union):
+                return self.union_has_any(
+                    named_type, self._parent_stack_for_type(named_type), 
visiting
+                )
+        return False
+
+    def _parent_stack_for_type(self, type_def: object) -> List[Message]:
+        def visit(message: Message, parents: List[Message]) -> 
Optional[List[Message]]:
+            if message is type_def:
+                return parents
+            for nested_union in message.nested_unions:
+                if nested_union is type_def:
+                    return parents + [message]
+            for nested_enum in message.nested_enums:
+                if nested_enum is type_def:
+                    return parents + [message]
+            for nested_message in message.nested_messages:
+                found = visit(nested_message, parents + [message])
+                if found is not None:
+                    return found
+            return None
+
+        for top in self.schema.messages:
+            found = visit(top, [])
+            if found is not None:
+                return found
+        return []
+
     def is_numeric_field(self, field: Field) -> bool:
         if not isinstance(field.field_type, PrimitiveType):
             return False
@@ -914,19 +1005,23 @@ class CppGenerator(BaseGenerator):
                     lines.append("")
             lines.append("")
 
-        lines.append(
-            f"{body_indent}bool operator==(const {class_name}& other) const {{"
-        )
-        if message.fields:
-            conditions = [
-                self.get_field_eq_expression(field, lineage) for field in 
message.fields
-            ]
-            lines.append(f"{body_indent}  return {' && '.join(conditions)};")
-        else:
-            lines.append(f"{body_indent}  return true;")
-        lines.append(f"{body_indent}}}")
+        # We don't generate equality method for message containing `any`
+        # since C++ doesn't support std::any == std::any.
+        if not self.message_has_any(message, parent_stack):
+            lines.append(
+                f"{body_indent}bool operator==(const {class_name}& other) 
const {{"
+            )
+            if message.fields:
+                conditions = [
+                    self.get_field_eq_expression(field, lineage)
+                    for field in message.fields
+                ]
+                lines.append(f"{body_indent}  return {' && 
'.join(conditions)};")
+            else:
+                lines.append(f"{body_indent}  return true;")
+            lines.append(f"{body_indent}}}")
+            lines.append("")
 
-        lines.append("")
         lines.extend(self.generate_bytes_methods(class_name, body_indent))
 
         struct_type_name = self.get_qualified_type_name(message.name, 
parent_stack)
@@ -1069,12 +1164,15 @@ class CppGenerator(BaseGenerator):
         )
         lines.append(f"{body_indent}  }}")
         lines.append("")
-        lines.append(
-            f"{body_indent}  bool operator==(const {class_name}& other) const 
{{"
-        )
-        lines.append(f"{body_indent}    return value_ == other.value_;")
-        lines.append(f"{body_indent}  }}")
-        lines.append("")
+        # We don't generate equality method for union containing `any`
+        # since C++ doesn't support std::any == std::any.
+        if not self.union_has_any(union, parent_stack):
+            lines.append(
+                f"{body_indent}  bool operator==(const {class_name}& other) 
const {{"
+            )
+            lines.append(f"{body_indent}    return value_ == other.value_;")
+            lines.append(f"{body_indent}  }}")
+            lines.append("")
 
         lines.extend(self.generate_bytes_methods(class_name, f"{body_indent}  
"))
 
diff --git a/compiler/fory_compiler/tests/test_generated_code.py 
b/compiler/fory_compiler/tests/test_generated_code.py
index d5f7b3960..16e8fca31 100644
--- a/compiler/fory_compiler/tests/test_generated_code.py
+++ b/compiler/fory_compiler/tests/test_generated_code.py
@@ -1126,6 +1126,94 @@ def 
test_cpp_generator_supports_decimal_fields_and_unions():
     assert "(amount, fory::serialization::Decimal, fory::F(1))" in cpp_output
 
 
+def test_cpp_omits_equality_for_any_types():
+    schema = parse_fdl(
+        dedent(
+            """
+            package gen;
+
+            message Inner {
+                any value = 1;
+            }
+
+            union AnyChoice {
+                Inner inner = 1;
+                string name = 2;
+            }
+
+            message DirectAny {
+                any value = 1;
+            }
+
+            message AnyList {
+                list<any> values = 1;
+            }
+
+            message AnyMap {
+                map<string, any> values = 1;
+            }
+
+            union DirectChoice {
+                any payload = 1;
+                list<any> values = 2;
+                string name = 3;
+            }
+
+            message DirectOwner {
+                Inner inner = 1;
+            }
+
+            message ListOwner {
+                list<Inner> values = 1;
+            }
+
+            message MapOwner {
+                map<string, Inner> values = 1;
+            }
+
+            message UnionOwner {
+                AnyChoice choice = 1;
+            }
+
+            message DeclaresNestedOnly {
+                message Nested {
+                    any value = 1;
+                }
+
+                string name = 1;
+            }
+
+            message Plain {
+                string name = 1;
+                list<int32> values = 2;
+                map<string, int32> counts = 3;
+            }
+
+            union PlainChoice {
+                string name = 1;
+                int32 code = 2;
+            }
+            """
+        )
+    )
+
+    cpp_output = render_files(generate_files(schema, CppGenerator))
+    assert "bool operator==(const Inner& other) const" not in cpp_output
+    assert "bool operator==(const AnyChoice& other) const" not in cpp_output
+    assert "bool operator==(const DirectAny& other) const" not in cpp_output
+    assert "bool operator==(const AnyList& other) const" not in cpp_output
+    assert "bool operator==(const AnyMap& other) const" not in cpp_output
+    assert "bool operator==(const DirectChoice& other) const" not in cpp_output
+    assert "bool operator==(const DirectOwner& other) const" not in cpp_output
+    assert "bool operator==(const ListOwner& other) const" not in cpp_output
+    assert "bool operator==(const MapOwner& other) const" not in cpp_output
+    assert "bool operator==(const UnionOwner& other) const" not in cpp_output
+    assert "bool operator==(const Nested& other) const" not in cpp_output
+    assert "bool operator==(const DeclaresNestedOnly& other) const" in 
cpp_output
+    assert "bool operator==(const Plain& other) const" in cpp_output
+    assert "bool operator==(const PlainChoice& other) const" in cpp_output
+
+
 def test_cpp_nested_container_ref_uses_correct_pointer_type():
     schema = parse_fdl(
         dedent(
diff --git a/integration_tests/idl_tests/cpp/main.cc 
b/integration_tests/idl_tests/cpp/main.cc
index cb41a987a..8d9ec6baf 100644
--- a/integration_tests/idl_tests/cpp/main.cc
+++ b/integration_tests/idl_tests/cpp/main.cc
@@ -1076,6 +1076,23 @@ fory::Result<void, fory::Error> RunEvolvingRoundTrip() {
 
 using StringMap = std::unordered_map<std::string, std::string>;
 
+template <typename T>
+fory::Result<void, fory::Error>
+ValidateAnyField(const std::any &actual_any, const std::any &expected_any,
+                 const std::string &field_name) {
+  const auto *actual = std::any_cast<T>(&actual_any);
+  const auto *expected = std::any_cast<T>(&expected_any);
+  if (actual == nullptr || expected == nullptr) {
+    return fory::Unexpected(
+        fory::Error::invalid("any holder " + field_name + " type mismatch"));
+  }
+  if (!(*actual == *expected)) {
+    return fory::Unexpected(
+        fory::Error::invalid("any holder " + field_name + " value mismatch"));
+  }
+  return fory::Result<void, fory::Error>();
+}
+
 fory::Result<void, fory::Error> RunRoundTrip(bool compatible) {
   auto fory = fory::serialization::Fory::builder()
                   .xlang(true)
@@ -1479,10 +1496,24 @@ fory::Result<void, fory::Error> RunRoundTrip(bool 
compatible) {
   FORY_TRY(any_roundtrip, fory.deserialize<any_example::AnyHolder>(
                               any_bytes.data(), any_bytes.size()));
 
-  if (!(any_roundtrip == any_holder)) {
-    return fory::Unexpected(
-        fory::Error::invalid("any holder roundtrip mismatch"));
-  }
+  FORY_RETURN_IF_ERROR(ValidateAnyField<bool>(
+      any_roundtrip.bool_value(), any_holder.bool_value(), "bool_value"));
+  FORY_RETURN_IF_ERROR(ValidateAnyField<std::string>(
+      any_roundtrip.string_value(), any_holder.string_value(), 
"string_value"));
+  FORY_RETURN_IF_ERROR(ValidateAnyField<fory::serialization::Date>(
+      any_roundtrip.date_value(), any_holder.date_value(), "date_value"));
+  FORY_RETURN_IF_ERROR(ValidateAnyField<fory::serialization::Timestamp>(
+      any_roundtrip.timestamp_value(), any_holder.timestamp_value(),
+      "timestamp_value"));
+  FORY_RETURN_IF_ERROR(ValidateAnyField<any_example::AnyInner>(
+      any_roundtrip.message_value(), any_holder.message_value(),
+      "message_value"));
+  FORY_RETURN_IF_ERROR(ValidateAnyField<any_example::AnyUnion>(
+      any_roundtrip.union_value(), any_holder.union_value(), "union_value"));
+  FORY_RETURN_IF_ERROR(ValidateAnyField<std::vector<std::string>>(
+      any_roundtrip.list_value(), any_holder.list_value(), "list_value"));
+  FORY_RETURN_IF_ERROR(ValidateAnyField<StringMap>(
+      any_roundtrip.map_value(), any_holder.map_value(), "map_value"));
 
   example_peer::ExampleMessage example_message = BuildExampleMessage();
   FORY_TRY(example_bytes, fory.serialize(example_message));


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to