This is an automated email from the ASF dual-hosted git repository.
chaokunyang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/fory.git
The following commit(s) were added to refs/heads/main by this push:
new 6d0faffdc feat(python): fix python struct xlang ref track error (#3574)
6d0faffdc is described below
commit 6d0faffdcb6d3f82ace16ada1b0f9804087c5ab9
Author: Shawn Yang <[email protected]>
AuthorDate: Thu Apr 16 12:08:51 2026 +0800
feat(python): fix python struct xlang ref track error (#3574)
## Why?
## What does this PR do?
- fix python struct xlang ref track error
- use type_extensions for py3.8
## Related issues
Closes #3506
## AI Contribution Checklist
- [ ] Substantial AI assistance was used in this PR: `yes` / `no`
- [ ] If `yes`, I included a completed [AI Contribution
Checklist](https://github.com/apache/fory/blob/main/AI_POLICY.md#9-contributor-checklist-for-ai-assisted-prs)
in this PR description and the required `AI Usage Disclosure`.
- [ ] If `yes`, my PR description includes the required `ai_review`
summary and screenshot evidence of the final clean AI review results
from both fresh reviewers on the current PR diff or current HEAD after
the latest code changes.
## Does this PR introduce any user-facing change?
- [ ] Does this PR introduce any public API change?
- [ ] Does this PR introduce any binary protocol compatibility change?
## Benchmark
---
.agents/languages/python.md | 1 +
python/pyfory/collection.pxi | 10 +++---
python/pyfory/collection.py | 10 +++---
python/pyfory/tests/test_ref_tracking.py | 55 +++++++++++++++++++++++++++++++
python/pyfory/type_util.py | 56 +++++++++++++++++++++++++-------
python/pyfory/types.py | 13 ++++++--
python/pyproject.toml | 1 +
7 files changed, 124 insertions(+), 22 deletions(-)
diff --git a/.agents/languages/python.md b/.agents/languages/python.md
index 9548be2d7..9a470e474 100644
--- a/.agents/languages/python.md
+++ b/.agents/languages/python.md
@@ -10,6 +10,7 @@ Load this file when changing `python/`, Cython serialization,
or Python xlang be
- Use `ENABLE_FORY_CYTHON_SERIALIZATION=0` first when debugging protocol
behavior.
- Python mode is the pure-Python xlang implementation and is mainly for
debugging and testing.
- Cython mode is the default high-performance implementation.
+- Keep new Python test names compact and behavior-focused; avoid
sentence-length names that restate setup details already obvious from the test
body.
- `ENABLE_FORY_DEBUG_OUTPUT=1` enables detailed struct serialization and
deserialization logs.
## Key Paths
diff --git a/python/pyfory/collection.pxi b/python/pyfory/collection.pxi
index 3b11a5342..18d7166ff 100644
--- a/python/pyfory/collection.pxi
+++ b/python/pyfory/collection.pxi
@@ -56,16 +56,18 @@ cdef class CollectionSerializer(Serializer):
def __init__(self, type_resolver, type_, elem_serializer=None,
elem_tracking_ref=None):
super().__init__(type_resolver, type_)
self.elem_serializer = elem_serializer
+ if elem_tracking_ref is not None:
+ self.elem_tracking_ref = <int8_t>(1 if elem_tracking_ref else 0)
+ else:
+ self.elem_tracking_ref = -1
if elem_serializer is None:
self.elem_type = None
self.elem_type_info = self.type_resolver.get_type_info(None)
- self.elem_tracking_ref = -1
else:
self.elem_type = elem_serializer.type_
self.elem_type_info =
self.type_resolver.get_type_info(self.elem_type)
- self.elem_tracking_ref = <int8_t>elem_serializer.need_to_write_ref
- if elem_tracking_ref is not None:
- self.elem_tracking_ref = <int8_t>(1 if elem_tracking_ref else
0)
+ if elem_tracking_ref is None:
+ self.elem_tracking_ref =
<int8_t>elem_serializer.need_to_write_ref
cdef inline TypeInfo write_header(self, WriteContext write_context, value,
int8_t *collect_flag_ptr):
cdef int8_t collect_flag = COLL_DEFAULT_FLAG
diff --git a/python/pyfory/collection.py b/python/pyfory/collection.py
index 5fcc877a0..6010480f9 100644
--- a/python/pyfory/collection.py
+++ b/python/pyfory/collection.py
@@ -45,16 +45,18 @@ class CollectionSerializer(Serializer):
def __init__(self, type_resolver, type_, elem_serializer=None,
elem_tracking_ref=None):
super().__init__(type_resolver, type_)
self.elem_serializer = elem_serializer
+ if elem_tracking_ref is not None:
+ self.elem_tracking_ref = 1 if elem_tracking_ref else 0
+ else:
+ self.elem_tracking_ref = -1
if elem_serializer is None:
self.elem_type = None
self.elem_type_info = self.type_resolver.get_type_info(None)
- self.elem_tracking_ref = -1
else:
self.elem_type = elem_serializer.type_
self.elem_type_info =
self.type_resolver.get_type_info(self.elem_type)
- self.elem_tracking_ref = int(elem_serializer.need_to_write_ref)
- if elem_tracking_ref is not None:
- self.elem_tracking_ref = 1 if elem_tracking_ref else 0
+ if elem_tracking_ref is None:
+ self.elem_tracking_ref = int(elem_serializer.need_to_write_ref)
def write_header(self, write_context, value):
collect_flag = COLL_DEFAULT_FLAG
diff --git a/python/pyfory/tests/test_ref_tracking.py
b/python/pyfory/tests/test_ref_tracking.py
index e4641db95..97b80b350 100644
--- a/python/pyfory/tests/test_ref_tracking.py
+++ b/python/pyfory/tests/test_ref_tracking.py
@@ -16,13 +16,17 @@
# under the License.
from dataclasses import dataclass
+import typing
from typing import Any, List
import pytest
import pyfory
+from pyfory import Ref
from pyfory import _fory as fmod
from pyfory.resolver import REF_FLAG, REF_VALUE_FLAG
+from pyfory.serializer import ListSerializer
+from pyfory.type_util import get_type_hints, unwrap_ref
def _roundtrip(fory, value):
@@ -73,6 +77,16 @@ class Holder:
values: List[pyfory.int64]
+@dataclass
+class CollectionRefOverrideItem:
+ value: pyfory.int32
+
+
+@dataclass
+class CollectionRefOverrideContainer:
+ items: List[Ref[CollectionRefOverrideItem, False]]
+
+
class EvilIndex:
def __init__(self):
self.owner = None
@@ -190,6 +204,47 @@ def
test_struct_field_ref_override_controls_alias_preservation(xlang):
assert enabled.left is enabled.right
+def test_collection_ref_override_unsets_tracking_bit():
+ fory = pyfory.Fory(xlang=True, ref=True, compatible=True)
+ fory.register_type(CollectionRefOverrideItem,
typename="example.CollectionRefOverrideItem")
+
+ serializer = ListSerializer(fory.type_resolver, list,
elem_tracking_ref=False)
+ shared = CollectionRefOverrideItem(7)
+ buffer = pyfory.Buffer.allocate(64)
+ write_context = fory.write_context
+ write_context.prepare(buffer)
+
+ serializer.write(write_context, [shared, shared])
+
+ payload = buffer.to_bytes(0, buffer.get_writer_index())
+ reader = pyfory.Buffer(payload)
+ assert reader.read_var_uint32() == 2
+ assert (reader.read_int8() & 0b1) == 0
+
+
+def test_ref_annotation_preserves_override():
+ type_hints = get_type_hints(CollectionRefOverrideContainer)
+ item_type = typing.get_args(type_hints["items"])[0]
+ _, elem_ref_override = unwrap_ref(item_type)
+ assert elem_ref_override is False
+
+
+def test_collection_ref_override_disables_alias_preservation():
+ fory = pyfory.Fory(xlang=True, ref=True, compatible=True)
+ fory.register_type(CollectionRefOverrideItem,
typename="example.CollectionRefOverrideItem")
+ fory.register_type(CollectionRefOverrideContainer,
typename="example.CollectionRefOverrideContainer")
+
+ shared = CollectionRefOverrideItem(11)
+ restored = _roundtrip(
+ fory,
+ CollectionRefOverrideContainer(items=[shared, shared]),
+ )
+
+ assert restored.items[0] == shared
+ assert restored.items[1] == shared
+ assert restored.items[0] is not restored.items[1]
+
+
def test_struct_self_cycle_and_nested_alias_python_mode():
fory = pyfory.Fory(xlang=False, ref=True, strict=False)
fory.register(RefNode)
diff --git a/python/pyfory/type_util.py b/python/pyfory/type_util.py
index 7b37e9b7e..60dc72007 100644
--- a/python/pyfory/type_util.py
+++ b/python/pyfory/type_util.py
@@ -25,18 +25,52 @@ from abc import ABC, abstractmethod
from pyfory.types import RefMeta
+try:
+ from typing import Annotated as _Annotated
+except ImportError:
+ try:
+ from typing_extensions import Annotated as _Annotated
+ except ImportError:
+ _Annotated = None
+
+try:
+ from typing_extensions import get_type_hints as
_typing_extensions_get_type_hints
+except ImportError:
+ _typing_extensions_get_type_hints = None
+
+try:
+ from typing_extensions import get_origin as _typing_extensions_get_origin
+ from typing_extensions import get_args as _typing_extensions_get_args
+except ImportError:
+ _typing_extensions_get_origin = None
+ _typing_extensions_get_args = None
+
+
+def _get_origin(type_):
+ if _typing_extensions_get_origin is not None:
+ return _typing_extensions_get_origin(type_)
+ return typing.get_origin(type_) if hasattr(typing, "get_origin") else
getattr(type_, "__origin__", None)
+
+
+def _get_args(type_):
+ if _typing_extensions_get_args is not None:
+ return _typing_extensions_get_args(type_)
+ return typing.get_args(type_) if hasattr(typing, "get_args") else
getattr(type_, "__args__", ())
+
def get_type_hints(type_):
try:
return typing.get_type_hints(type_, include_extras=True)
except TypeError:
+ if _typing_extensions_get_type_hints is not None:
+ return _typing_extensions_get_type_hints(type_,
include_extras=True)
return typing.get_type_hints(type_)
def unwrap_ref(type_):
- origin = typing.get_origin(type_) if hasattr(typing, "get_origin") else
getattr(type_, "__origin__", None)
- if origin is getattr(typing, "Annotated", None):
- args = typing.get_args(type_) if hasattr(typing, "get_args") else
getattr(type_, "__args__", ())
+ origin = _get_origin(type_)
+ if _Annotated is not None and origin is _Annotated:
+ args = _get_args(type_)
if args:
base = args[0]
for meta in args[1:]:
@@ -44,7 +78,7 @@ def unwrap_ref(type_):
return base, meta.enable
return base, None
if origin is typing.Union:
- args = typing.get_args(type_) if hasattr(typing, "get_args") else
getattr(type_, "__args__", ())
+ args = _get_args(type_)
new_args = list(args)
ref_override = None
for i, arg in enumerate(args):
@@ -206,9 +240,9 @@ def infer_field_types(type_, field_nullable=False):
def is_optional_type(type_):
- origin = typing.get_origin(type_) if hasattr(typing, "get_origin") else
getattr(type_, "__origin__", None)
+ origin = _get_origin(type_)
if origin is typing.Union:
- args = typing.get_args(type_) if hasattr(typing, "get_args") else
getattr(type_, "__args__", ())
+ args = _get_args(type_)
return type(None) in args
return False
@@ -216,7 +250,7 @@ def is_optional_type(type_):
def unwrap_optional(type_, field_nullable=False):
if not is_optional_type(type_):
return type_, False or field_nullable
- args = typing.get_args(type_) if hasattr(typing, "get_args") else
getattr(type_, "__args__", ())
+ args = _get_args(type_)
non_none_types = [arg for arg in args if arg is not type(None)]
if len(non_none_types) == 1:
return non_none_types[0], True
@@ -227,10 +261,10 @@ def get_homogeneous_tuple_elem_type(type_or_args):
if isinstance(type_or_args, tuple):
args = type_or_args
else:
- origin = typing.get_origin(type_or_args) if hasattr(typing,
"get_origin") else getattr(type_or_args, "__origin__", None)
+ origin = _get_origin(type_or_args)
if origin not in (tuple, typing.Tuple):
return None
- args = typing.get_args(type_or_args) if hasattr(typing, "get_args")
else getattr(type_or_args, "__args__", ())
+ args = _get_args(type_or_args)
if not args or args == ((),):
return None
if len(args) == 2 and args[1] is Ellipsis:
@@ -245,9 +279,9 @@ def infer_field(field_name, type_, visitor: TypeVisitor,
types_path=None):
types_path = list(types_path or [])
type_, _ = unwrap_ref(type_)
types_path.append(type_)
- origin = typing.get_origin(type_) if hasattr(typing, "get_origin") else
getattr(type_, "__origin__", type_)
+ origin = _get_origin(type_) or getattr(type_, "__origin__", type_)
origin = origin or type_
- args = typing.get_args(type_) if hasattr(typing, "get_args") else
getattr(type_, "__args__", ())
+ args = _get_args(type_)
if args:
if origin is list or origin == typing.List:
elem_type = args[0]
diff --git a/python/pyfory/types.py b/python/pyfory/types.py
index 7f8f871dd..b680d393c 100644
--- a/python/pyfory/types.py
+++ b/python/pyfory/types.py
@@ -20,6 +20,14 @@ import array
import typing
from typing import TypeVar
+try:
+ from typing import Annotated as _Annotated
+except ImportError:
+ try:
+ from typing_extensions import Annotated as _Annotated
+ except ImportError:
+ _Annotated = None
+
try:
import numpy as np
@@ -219,10 +227,9 @@ class Ref:
enable = params[1]
if not isinstance(enable, bool):
raise TypeError("Ref enable must be a bool")
- annotated = getattr(typing, "Annotated", None)
- if annotated is None:
+ if _Annotated is None:
return target
- return annotated[target, RefMeta(enable)]
+ return _Annotated[target, RefMeta(enable)]
_primitive_types = {
diff --git a/python/pyproject.toml b/python/pyproject.toml
index 7d8ec8e6d..35779ce8f 100644
--- a/python/pyproject.toml
+++ b/python/pyproject.toml
@@ -50,6 +50,7 @@ classifiers = [
]
keywords = ["fory", "serialization", "multi-language", "fast", "row-format",
"jit", "codegen", "polymorphic", "zero-copy"]
dependencies = [
+ "typing_extensions>=4.0; python_version < '3.9'",
]
[project.optional-dependencies]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]