This is an automated email from the ASF dual-hosted git repository.

chaokunyang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/fory.git


The following commit(s) were added to refs/heads/main by this push:
     new 6d0faffdc feat(python): fix python struct xlang ref track error (#3574)
6d0faffdc is described below

commit 6d0faffdcb6d3f82ace16ada1b0f9804087c5ab9
Author: Shawn Yang <[email protected]>
AuthorDate: Thu Apr 16 12:08:51 2026 +0800

    feat(python): fix python struct xlang ref track error (#3574)
    
    ## Why?
    
    
    
    ## What does this PR do?
    
    -  fix python struct xlang ref track error
    - use type_extensions for py3.8
    
    ## Related issues
    
    Closes #3506
    
    ## AI Contribution Checklist
    
    
    
    - [ ] Substantial AI assistance was used in this PR: `yes` / `no`
    - [ ] If `yes`, I included a completed [AI Contribution
    
Checklist](https://github.com/apache/fory/blob/main/AI_POLICY.md#9-contributor-checklist-for-ai-assisted-prs)
    in this PR description and the required `AI Usage Disclosure`.
    - [ ] If `yes`, my PR description includes the required `ai_review`
    summary and screenshot evidence of the final clean AI review results
    from both fresh reviewers on the current PR diff or current HEAD after
    the latest code changes.
    
    
    
    ## Does this PR introduce any user-facing change?
    
    
    
    - [ ] Does this PR introduce any public API change?
    - [ ] Does this PR introduce any binary protocol compatibility change?
    
    ## Benchmark
---
 .agents/languages/python.md              |  1 +
 python/pyfory/collection.pxi             | 10 +++---
 python/pyfory/collection.py              | 10 +++---
 python/pyfory/tests/test_ref_tracking.py | 55 +++++++++++++++++++++++++++++++
 python/pyfory/type_util.py               | 56 +++++++++++++++++++++++++-------
 python/pyfory/types.py                   | 13 ++++++--
 python/pyproject.toml                    |  1 +
 7 files changed, 124 insertions(+), 22 deletions(-)

diff --git a/.agents/languages/python.md b/.agents/languages/python.md
index 9548be2d7..9a470e474 100644
--- a/.agents/languages/python.md
+++ b/.agents/languages/python.md
@@ -10,6 +10,7 @@ Load this file when changing `python/`, Cython serialization, 
or Python xlang be
 - Use `ENABLE_FORY_CYTHON_SERIALIZATION=0` first when debugging protocol 
behavior.
 - Python mode is the pure-Python xlang implementation and is mainly for 
debugging and testing.
 - Cython mode is the default high-performance implementation.
+- Keep new Python test names compact and behavior-focused; avoid 
sentence-length names that restate setup details already obvious from the test 
body.
 - `ENABLE_FORY_DEBUG_OUTPUT=1` enables detailed struct serialization and 
deserialization logs.
 
 ## Key Paths
diff --git a/python/pyfory/collection.pxi b/python/pyfory/collection.pxi
index 3b11a5342..18d7166ff 100644
--- a/python/pyfory/collection.pxi
+++ b/python/pyfory/collection.pxi
@@ -56,16 +56,18 @@ cdef class CollectionSerializer(Serializer):
     def __init__(self, type_resolver, type_, elem_serializer=None, 
elem_tracking_ref=None):
         super().__init__(type_resolver, type_)
         self.elem_serializer = elem_serializer
+        if elem_tracking_ref is not None:
+            self.elem_tracking_ref = <int8_t>(1 if elem_tracking_ref else 0)
+        else:
+            self.elem_tracking_ref = -1
         if elem_serializer is None:
             self.elem_type = None
             self.elem_type_info = self.type_resolver.get_type_info(None)
-            self.elem_tracking_ref = -1
         else:
             self.elem_type = elem_serializer.type_
             self.elem_type_info = 
self.type_resolver.get_type_info(self.elem_type)
-            self.elem_tracking_ref = <int8_t>elem_serializer.need_to_write_ref
-            if elem_tracking_ref is not None:
-                self.elem_tracking_ref = <int8_t>(1 if elem_tracking_ref else 
0)
+            if elem_tracking_ref is None:
+                self.elem_tracking_ref = 
<int8_t>elem_serializer.need_to_write_ref
 
     cdef inline TypeInfo write_header(self, WriteContext write_context, value, 
int8_t *collect_flag_ptr):
         cdef int8_t collect_flag = COLL_DEFAULT_FLAG
diff --git a/python/pyfory/collection.py b/python/pyfory/collection.py
index 5fcc877a0..6010480f9 100644
--- a/python/pyfory/collection.py
+++ b/python/pyfory/collection.py
@@ -45,16 +45,18 @@ class CollectionSerializer(Serializer):
     def __init__(self, type_resolver, type_, elem_serializer=None, 
elem_tracking_ref=None):
         super().__init__(type_resolver, type_)
         self.elem_serializer = elem_serializer
+        if elem_tracking_ref is not None:
+            self.elem_tracking_ref = 1 if elem_tracking_ref else 0
+        else:
+            self.elem_tracking_ref = -1
         if elem_serializer is None:
             self.elem_type = None
             self.elem_type_info = self.type_resolver.get_type_info(None)
-            self.elem_tracking_ref = -1
         else:
             self.elem_type = elem_serializer.type_
             self.elem_type_info = 
self.type_resolver.get_type_info(self.elem_type)
-            self.elem_tracking_ref = int(elem_serializer.need_to_write_ref)
-            if elem_tracking_ref is not None:
-                self.elem_tracking_ref = 1 if elem_tracking_ref else 0
+            if elem_tracking_ref is None:
+                self.elem_tracking_ref = int(elem_serializer.need_to_write_ref)
 
     def write_header(self, write_context, value):
         collect_flag = COLL_DEFAULT_FLAG
diff --git a/python/pyfory/tests/test_ref_tracking.py 
b/python/pyfory/tests/test_ref_tracking.py
index e4641db95..97b80b350 100644
--- a/python/pyfory/tests/test_ref_tracking.py
+++ b/python/pyfory/tests/test_ref_tracking.py
@@ -16,13 +16,17 @@
 # under the License.
 
 from dataclasses import dataclass
+import typing
 from typing import Any, List
 
 import pytest
 
 import pyfory
+from pyfory import Ref
 from pyfory import _fory as fmod
 from pyfory.resolver import REF_FLAG, REF_VALUE_FLAG
+from pyfory.serializer import ListSerializer
+from pyfory.type_util import get_type_hints, unwrap_ref
 
 
 def _roundtrip(fory, value):
@@ -73,6 +77,16 @@ class Holder:
     values: List[pyfory.int64]
 
 
+@dataclass
+class CollectionRefOverrideItem:
+    value: pyfory.int32
+
+
+@dataclass
+class CollectionRefOverrideContainer:
+    items: List[Ref[CollectionRefOverrideItem, False]]
+
+
 class EvilIndex:
     def __init__(self):
         self.owner = None
@@ -190,6 +204,47 @@ def 
test_struct_field_ref_override_controls_alias_preservation(xlang):
     assert enabled.left is enabled.right
 
 
+def test_collection_ref_override_unsets_tracking_bit():
+    fory = pyfory.Fory(xlang=True, ref=True, compatible=True)
+    fory.register_type(CollectionRefOverrideItem, 
typename="example.CollectionRefOverrideItem")
+
+    serializer = ListSerializer(fory.type_resolver, list, 
elem_tracking_ref=False)
+    shared = CollectionRefOverrideItem(7)
+    buffer = pyfory.Buffer.allocate(64)
+    write_context = fory.write_context
+    write_context.prepare(buffer)
+
+    serializer.write(write_context, [shared, shared])
+
+    payload = buffer.to_bytes(0, buffer.get_writer_index())
+    reader = pyfory.Buffer(payload)
+    assert reader.read_var_uint32() == 2
+    assert (reader.read_int8() & 0b1) == 0
+
+
+def test_ref_annotation_preserves_override():
+    type_hints = get_type_hints(CollectionRefOverrideContainer)
+    item_type = typing.get_args(type_hints["items"])[0]
+    _, elem_ref_override = unwrap_ref(item_type)
+    assert elem_ref_override is False
+
+
+def test_collection_ref_override_disables_alias_preservation():
+    fory = pyfory.Fory(xlang=True, ref=True, compatible=True)
+    fory.register_type(CollectionRefOverrideItem, 
typename="example.CollectionRefOverrideItem")
+    fory.register_type(CollectionRefOverrideContainer, 
typename="example.CollectionRefOverrideContainer")
+
+    shared = CollectionRefOverrideItem(11)
+    restored = _roundtrip(
+        fory,
+        CollectionRefOverrideContainer(items=[shared, shared]),
+    )
+
+    assert restored.items[0] == shared
+    assert restored.items[1] == shared
+    assert restored.items[0] is not restored.items[1]
+
+
 def test_struct_self_cycle_and_nested_alias_python_mode():
     fory = pyfory.Fory(xlang=False, ref=True, strict=False)
     fory.register(RefNode)
diff --git a/python/pyfory/type_util.py b/python/pyfory/type_util.py
index 7b37e9b7e..60dc72007 100644
--- a/python/pyfory/type_util.py
+++ b/python/pyfory/type_util.py
@@ -25,18 +25,52 @@ from abc import ABC, abstractmethod
 
 from pyfory.types import RefMeta
 
+try:
+    from typing import Annotated as _Annotated
+except ImportError:
+    try:
+        from typing_extensions import Annotated as _Annotated
+    except ImportError:
+        _Annotated = None
+
+try:
+    from typing_extensions import get_type_hints as 
_typing_extensions_get_type_hints
+except ImportError:
+    _typing_extensions_get_type_hints = None
+
+try:
+    from typing_extensions import get_origin as _typing_extensions_get_origin
+    from typing_extensions import get_args as _typing_extensions_get_args
+except ImportError:
+    _typing_extensions_get_origin = None
+    _typing_extensions_get_args = None
+
+
+def _get_origin(type_):
+    if _typing_extensions_get_origin is not None:
+        return _typing_extensions_get_origin(type_)
+    return typing.get_origin(type_) if hasattr(typing, "get_origin") else 
getattr(type_, "__origin__", None)
+
+
+def _get_args(type_):
+    if _typing_extensions_get_args is not None:
+        return _typing_extensions_get_args(type_)
+    return typing.get_args(type_) if hasattr(typing, "get_args") else 
getattr(type_, "__args__", ())
+
 
 def get_type_hints(type_):
     try:
         return typing.get_type_hints(type_, include_extras=True)
     except TypeError:
+        if _typing_extensions_get_type_hints is not None:
+            return _typing_extensions_get_type_hints(type_, 
include_extras=True)
         return typing.get_type_hints(type_)
 
 
 def unwrap_ref(type_):
-    origin = typing.get_origin(type_) if hasattr(typing, "get_origin") else 
getattr(type_, "__origin__", None)
-    if origin is getattr(typing, "Annotated", None):
-        args = typing.get_args(type_) if hasattr(typing, "get_args") else 
getattr(type_, "__args__", ())
+    origin = _get_origin(type_)
+    if _Annotated is not None and origin is _Annotated:
+        args = _get_args(type_)
         if args:
             base = args[0]
             for meta in args[1:]:
@@ -44,7 +78,7 @@ def unwrap_ref(type_):
                     return base, meta.enable
             return base, None
     if origin is typing.Union:
-        args = typing.get_args(type_) if hasattr(typing, "get_args") else 
getattr(type_, "__args__", ())
+        args = _get_args(type_)
         new_args = list(args)
         ref_override = None
         for i, arg in enumerate(args):
@@ -206,9 +240,9 @@ def infer_field_types(type_, field_nullable=False):
 
 
 def is_optional_type(type_):
-    origin = typing.get_origin(type_) if hasattr(typing, "get_origin") else 
getattr(type_, "__origin__", None)
+    origin = _get_origin(type_)
     if origin is typing.Union:
-        args = typing.get_args(type_) if hasattr(typing, "get_args") else 
getattr(type_, "__args__", ())
+        args = _get_args(type_)
         return type(None) in args
     return False
 
@@ -216,7 +250,7 @@ def is_optional_type(type_):
 def unwrap_optional(type_, field_nullable=False):
     if not is_optional_type(type_):
         return type_, False or field_nullable
-    args = typing.get_args(type_) if hasattr(typing, "get_args") else 
getattr(type_, "__args__", ())
+    args = _get_args(type_)
     non_none_types = [arg for arg in args if arg is not type(None)]
     if len(non_none_types) == 1:
         return non_none_types[0], True
@@ -227,10 +261,10 @@ def get_homogeneous_tuple_elem_type(type_or_args):
     if isinstance(type_or_args, tuple):
         args = type_or_args
     else:
-        origin = typing.get_origin(type_or_args) if hasattr(typing, 
"get_origin") else getattr(type_or_args, "__origin__", None)
+        origin = _get_origin(type_or_args)
         if origin not in (tuple, typing.Tuple):
             return None
-        args = typing.get_args(type_or_args) if hasattr(typing, "get_args") 
else getattr(type_or_args, "__args__", ())
+        args = _get_args(type_or_args)
     if not args or args == ((),):
         return None
     if len(args) == 2 and args[1] is Ellipsis:
@@ -245,9 +279,9 @@ def infer_field(field_name, type_, visitor: TypeVisitor, 
types_path=None):
     types_path = list(types_path or [])
     type_, _ = unwrap_ref(type_)
     types_path.append(type_)
-    origin = typing.get_origin(type_) if hasattr(typing, "get_origin") else 
getattr(type_, "__origin__", type_)
+    origin = _get_origin(type_) or getattr(type_, "__origin__", type_)
     origin = origin or type_
-    args = typing.get_args(type_) if hasattr(typing, "get_args") else 
getattr(type_, "__args__", ())
+    args = _get_args(type_)
     if args:
         if origin is list or origin == typing.List:
             elem_type = args[0]
diff --git a/python/pyfory/types.py b/python/pyfory/types.py
index 7f8f871dd..b680d393c 100644
--- a/python/pyfory/types.py
+++ b/python/pyfory/types.py
@@ -20,6 +20,14 @@ import array
 import typing
 from typing import TypeVar
 
+try:
+    from typing import Annotated as _Annotated
+except ImportError:
+    try:
+        from typing_extensions import Annotated as _Annotated
+    except ImportError:
+        _Annotated = None
+
 try:
     import numpy as np
 
@@ -219,10 +227,9 @@ class Ref:
             enable = params[1]
         if not isinstance(enable, bool):
             raise TypeError("Ref enable must be a bool")
-        annotated = getattr(typing, "Annotated", None)
-        if annotated is None:
+        if _Annotated is None:
             return target
-        return annotated[target, RefMeta(enable)]
+        return _Annotated[target, RefMeta(enable)]
 
 
 _primitive_types = {
diff --git a/python/pyproject.toml b/python/pyproject.toml
index 7d8ec8e6d..35779ce8f 100644
--- a/python/pyproject.toml
+++ b/python/pyproject.toml
@@ -50,6 +50,7 @@ classifiers = [
 ]
 keywords = ["fory", "serialization", "multi-language", "fast", "row-format", 
"jit", "codegen", "polymorphic", "zero-copy"]
 dependencies = [
+    "typing_extensions>=4.0; python_version < '3.9'",
 ]
 
 [project.optional-dependencies]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to