This is an automated email from the ASF dual-hosted git repository.
chaokunyang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/fory.git
The following commit(s) were added to refs/heads/main by this push:
new a268134ce feat(python): make scalar wire markers typing-friendly
(#3756)
a268134ce is described below
commit a268134ced3d7a689e5d5c1cc5024cb3f9c6d458
Author: Shawn Yang <[email protected]>
AuthorDate: Sat Jun 13 00:47:14 2026 +0800
feat(python): make scalar wire markers typing-friendly (#3756)
## Why?
## What does this PR do?
## Related issues
Closes #3725
## AI Contribution Checklist
- [ ] Substantial AI assistance was used in this PR: `yes` / `no`
- [ ] If `yes`, I included a completed [AI Contribution
Checklist](https://github.com/apache/fory/blob/main/AI_POLICY.md#9-contributor-checklist-for-ai-assisted-prs)
in this PR description and the required `AI Usage Disclosure`.
- [ ] If `yes`, my PR description includes the required `ai_review`
summary and screenshot evidence of the final clean AI review results
from both fresh reviewers on the current PR diff or current HEAD after
the latest code changes.
## Does this PR introduce any user-facing change?
- [ ] Does this PR introduce any public API change?
- [ ] Does this PR introduce any binary protocol compatibility change?
## Benchmark
---
docs/guide/python/schema-metadata.md | 4 +
docs/guide/python/xlang-serialization.md | 4 +
python/pyfory/_fory.py | 15 ++-
python/pyfory/_serializer.py | 6 +-
python/pyfory/annotation.py | 84 +++++----------
python/pyfory/format/infer.py | 25 ++---
python/pyfory/format/schema.pxi | 35 +------
python/pyfory/format/schema.py | 70 +++++++------
python/pyfory/format/tests/test_infer.py | 25 +++++
python/pyfory/primitive.pxi | 2 +-
python/pyfory/registry.py | 25 +++--
python/pyfory/serialization.pyx | 39 ++-----
python/pyfory/struct.py | 139 +++++++++++--------------
python/pyfory/tests/test_struct.py | 96 ++++++++++++++++-
python/pyfory/type_id.py | 173 +++++++++++++++++++++++++++++++
python/pyfory/type_util.py | 73 ++++++++++---
python/pyfory/types.py | 158 +---------------------------
17 files changed, 536 insertions(+), 437 deletions(-)
diff --git a/docs/guide/python/schema-metadata.md
b/docs/guide/python/schema-metadata.md
index 50e4456a6..d956cf341 100644
--- a/docs/guide/python/schema-metadata.md
+++ b/docs/guide/python/schema-metadata.md
@@ -231,6 +231,10 @@ class Container:
Fory provides type annotations to control integer encoding:
+Use these markers directly in Python type annotations. Field values remain
+ordinary Python `int` or `float` values, and Fory serializes them with the
+requested xlang numeric width and encoding.
+
### Signed Integers
```python
diff --git a/docs/guide/python/xlang-serialization.md
b/docs/guide/python/xlang-serialization.md
index e9da01734..34cec4048 100644
--- a/docs/guide/python/xlang-serialization.md
+++ b/docs/guide/python/xlang-serialization.md
@@ -97,6 +97,10 @@ let person: Person = fory.deserialize(&binary_data)?;
Use pyfory type annotations for explicit xlang type mapping:
+Use these markers directly in Python type annotations. Field values remain
+ordinary Python `int` or `float` values, and Fory serializes them with the
+requested xlang numeric width and encoding.
+
```python
from dataclasses import dataclass
from typing import Dict, List
diff --git a/python/pyfory/_fory.py b/python/pyfory/_fory.py
index 477b3bfb7..9414c3c7e 100644
--- a/python/pyfory/_fory.py
+++ b/python/pyfory/_fory.py
@@ -17,7 +17,7 @@
import os
from abc import ABC, abstractmethod
-from typing import Iterable, Optional, TypeVar, Union
+from typing import Iterable, Optional, Union
_ENABLE_TYPE_REGISTRATION_FORCIBLY =
os.getenv("ENABLE_TYPE_REGISTRATION_FORCIBLY", "0") in {
"1",
@@ -217,7 +217,7 @@ class Fory:
def register(
self,
- cls: Union[type, TypeVar],
+ cls,
*,
type_id: int = None,
name: str = None,
@@ -262,10 +262,9 @@ class Fory:
serializer=serializer,
)
- # `Union[type, TypeVar]` is not supported in py3.6
def register_type(
self,
- cls: Union[type, TypeVar],
+ cls,
*,
type_id: int = None,
name: str = None,
@@ -311,7 +310,7 @@ class Fory:
def register_union(
self,
- cls: Union[type, TypeVar],
+ cls,
*,
type_id: int = None,
name: str = None,
@@ -664,7 +663,7 @@ class ThreadSafeFory:
def register(
self,
- cls: Union[type, TypeVar],
+ cls,
*,
type_id: int = None,
name: str = None,
@@ -674,7 +673,7 @@ class ThreadSafeFory:
def register_type(
self,
- cls: Union[type, TypeVar],
+ cls,
*,
type_id: int = None,
name: str = None,
@@ -684,7 +683,7 @@ class ThreadSafeFory:
def register_union(
self,
- cls: Union[type, TypeVar],
+ cls,
*,
type_id: int = None,
name: str = None,
diff --git a/python/pyfory/_serializer.py b/python/pyfory/_serializer.py
index ec268190a..0e74a536e 100644
--- a/python/pyfory/_serializer.py
+++ b/python/pyfory/_serializer.py
@@ -44,6 +44,7 @@ from pyfory.serialization import (
_float16_to_bits,
)
from pyfory.types import is_primitive_type
+from pyfory.type_util import normalize_fory_type
from pyfory.utils import is_little_endian
try:
@@ -64,9 +65,10 @@ class Serializer(ABC):
__slots__ = "type_resolver", "type_", "need_to_write_ref"
- def __init__(self, type_resolver, type_: type):
+ def __init__(self, type_resolver, type_):
self.type_resolver = type_resolver
- self.type_: type = type_
+ type_ = normalize_fory_type(type_)
+ self.type_ = type_
self.need_to_write_ref = type_resolver.track_ref and not
is_primitive_type(type_)
def write(self, write_context, value):
diff --git a/python/pyfory/annotation.py b/python/pyfory/annotation.py
index 1bea95bdb..0bb34eed8 100644
--- a/python/pyfory/annotation.py
+++ b/python/pyfory/annotation.py
@@ -17,7 +17,13 @@
import array
import typing
-from typing import TypeVar
+
+from pyfory.type_id import TypeId
+
+try:
+ from typing import Annotated
+except ImportError:
+ from typing_extensions import Annotated
if typing.TYPE_CHECKING:
from pyfory.serialization import (
@@ -36,34 +42,25 @@ if typing.TYPE_CHECKING:
UInt8Array,
)
-try:
- from typing import Annotated as _Annotated
-except ImportError:
- try:
- from typing_extensions import Annotated as _Annotated
- except ImportError:
- _Annotated = None
-
-
Bool = bool
-Int8 = TypeVar("Int8", bound=int)
-UInt8 = TypeVar("UInt8", bound=int)
-Int16 = TypeVar("Int16", bound=int)
-UInt16 = TypeVar("UInt16", bound=int)
-Int32 = TypeVar("Int32", bound=int)
-UInt32 = TypeVar("UInt32", bound=int)
-FixedInt32 = TypeVar("FixedInt32", bound=int)
-FixedUInt32 = TypeVar("FixedUInt32", bound=int)
-Int64 = TypeVar("Int64", bound=int)
-UInt64 = TypeVar("UInt64", bound=int)
-FixedInt64 = TypeVar("FixedInt64", bound=int)
-TaggedInt64 = TypeVar("TaggedInt64", bound=int)
-FixedUInt64 = TypeVar("FixedUInt64", bound=int)
-TaggedUInt64 = TypeVar("TaggedUInt64", bound=int)
-Float16 = TypeVar("Float16", bound=float)
-BFloat16 = TypeVar("BFloat16", bound=float)
-Float32 = TypeVar("Float32", bound=float)
-Float64 = TypeVar("Float64", bound=float)
+Int8 = Annotated[int, TypeId.INT8]
+UInt8 = Annotated[int, TypeId.UINT8]
+Int16 = Annotated[int, TypeId.INT16]
+UInt16 = Annotated[int, TypeId.UINT16]
+Int32 = Annotated[int, TypeId.VARINT32]
+UInt32 = Annotated[int, TypeId.VAR_UINT32]
+FixedInt32 = Annotated[int, TypeId.INT32]
+FixedUInt32 = Annotated[int, TypeId.UINT32]
+Int64 = Annotated[int, TypeId.VARINT64]
+UInt64 = Annotated[int, TypeId.VAR_UINT64]
+FixedInt64 = Annotated[int, TypeId.INT64]
+TaggedInt64 = Annotated[int, TypeId.TAGGED_INT64]
+FixedUInt64 = Annotated[int, TypeId.UINT64]
+TaggedUInt64 = Annotated[int, TypeId.TAGGED_UINT64]
+Float16 = Annotated[float, TypeId.FLOAT16]
+BFloat16 = Annotated[float, TypeId.BFLOAT16]
+Float32 = Annotated[float, TypeId.FLOAT32]
+Float64 = Annotated[float, TypeId.FLOAT64]
_ARRAY_EXPORTS = {
"BoolArray",
@@ -111,9 +108,7 @@ class Ref:
enable = params[1]
if not isinstance(enable, bool):
raise TypeError("Ref enable must be a bool")
- if _Annotated is None:
- return target
- return _Annotated[target, RefMeta(enable)]
+ return Annotated[target, RefMeta(enable)]
class ArrayMeta:
@@ -133,29 +128,6 @@ class ArrayMeta:
return f"ArrayMeta(element_type={self.element_type!r},
carrier={self.carrier!r})"
-class _ArrayTypeHint:
- __slots__ = ("__origin__", "__args__", "__fory_array_meta__")
-
- def __init__(self, origin, element_type, carrier: str):
- self.__origin__ = origin
- self.__args__ = (element_type,)
- self.__fory_array_meta__ = ArrayMeta(element_type, carrier)
-
- def __repr__(self):
- return f"{self.__origin__.__name__}[{self.__args__[0]!r}]"
-
- def __eq__(self, other):
- return (
- type(other) is _ArrayTypeHint
- and self.__origin__ is other.__origin__
- and self.__args__ == other.__args__
- and self.__fory_array_meta__ == other.__fory_array_meta__
- )
-
- def __hash__(self):
- return hash((self.__origin__, self.__args__, self.__fory_array_meta__))
-
-
class _ArrayHint:
_carrier = "array"
@@ -168,9 +140,7 @@ class _ArrayHint:
if len(element_type) != 1:
raise TypeError(f"{cls.__name__} expects exactly one element
type")
element_type = element_type[0]
- if _Annotated is None:
- return _ArrayTypeHint(cls, element_type, cls._carrier)
- return _Annotated[cls._base_type(element_type),
ArrayMeta(element_type, cls._carrier)]
+ return Annotated[cls._base_type(element_type), ArrayMeta(element_type,
cls._carrier)]
class Array(_ArrayHint):
diff --git a/python/pyfory/format/infer.py b/python/pyfory/format/infer.py
index 3865773dc..045150b4f 100644
--- a/python/pyfory/format/infer.py
+++ b/python/pyfory/format/infer.py
@@ -108,24 +108,17 @@ _supported_types_mapping = {
datetime.datetime: timestamp,
}
-# Add pyfory type annotations support
-from pyfory.annotation import (
- Int8 as int8_type,
- Int16 as int16_type,
- Int32 as int32_type,
- Int64 as int64_type,
- Float32 as float32_type,
- Float64 as float64_type,
-)
-
_supported_types_mapping.update(
{
- int8_type: int8,
- int16_type: int16,
- int32_type: int32,
- int64_type: int64,
- float32_type: float32,
- float64_type: float64,
+ TypeId.INT8: int8,
+ TypeId.INT16: int16,
+ TypeId.INT32: int32,
+ TypeId.VARINT32: int32,
+ TypeId.INT64: int64,
+ TypeId.VARINT64: int64,
+ TypeId.TAGGED_INT64: int64,
+ TypeId.FLOAT32: float32,
+ TypeId.FLOAT64: float64,
}
)
diff --git a/python/pyfory/format/schema.pxi b/python/pyfory/format/schema.pxi
index 84c859cf1..7f9b8527f 100644
--- a/python/pyfory/format/schema.pxi
+++ b/python/pyfory/format/schema.pxi
@@ -29,6 +29,7 @@ from libcpp cimport bool as c_bool
from libc.stdint cimport int64_t
from cython.operator cimport dereference as deref
+from pyfory.type_id import TypeId
from pyfory.includes.libformat cimport (
CDataType, CDataTypePtr, CField, CFieldPtr, CSchema, CSchemaPtr,
CListType, CListTypePtr, CMapType, CMapTypePtr, CStructType,
CStructTypePtr,
@@ -42,40 +43,6 @@ from pyfory.includes.libformat cimport (
)
-# Create Python-accessible TypeId enum
-# The CTypeId enum from libformat.pxd is only accessible from Cython
-class TypeId:
- """Type identifiers for Fory data types."""
- BOOL = 1
- INT8 = 2
- INT16 = 3
- INT32 = 4
- INT64 = 6
- UINT8 = 9
- UINT16 = 10
- UINT32 = 11
- UINT64 = 13
- FLOAT8 = 16
- FLOAT16 = 17
- BFLOAT16 = 18
- FLOAT32 = 19
- FLOAT64 = 20
- STRING = 21
- LIST = 22
- SET = 23
- MAP = 24
- STRUCT = 27
- UNION = 33
- TYPED_UNION = 34
- NAMED_UNION = 35
- NONE = 36
- DURATION = 37
- TIMESTAMP = 38
- DATE = 39
- DECIMAL = 40
- BINARY = 41
-
-
cdef class DataType:
"""Base class for all Fory data types."""
cdef shared_ptr[CDataType] c_type
diff --git a/python/pyfory/format/schema.py b/python/pyfory/format/schema.py
index 18baa2037..e9fffdf37 100644
--- a/python/pyfory/format/schema.py
+++ b/python/pyfory/format/schema.py
@@ -25,6 +25,8 @@ and Fory's internal schema representation for the row format.
import pyarrow as pa
from pyarrow import types as pa_types
+from pyfory.type_id import TypeId
+
def arrow_type_to_fory_type_id(arrow_type):
"""
@@ -41,51 +43,51 @@ def arrow_type_to_fory_type_id(arrow_type):
"""
# Boolean
if pa_types.is_boolean(arrow_type):
- return 1 # BOOL
+ return TypeId.BOOL
# Integer types
if pa_types.is_int8(arrow_type):
- return 2 # INT8
+ return TypeId.INT8
if pa_types.is_int16(arrow_type):
- return 3 # INT16
+ return TypeId.INT16
if pa_types.is_int32(arrow_type):
- return 4 # INT32
+ return TypeId.INT32
if pa_types.is_int64(arrow_type):
- return 6 # INT64
+ return TypeId.INT64
# Floating point types
if pa_types.is_float16(arrow_type):
- return 17 # FLOAT16
+ return TypeId.FLOAT16
if pa_types.is_float32(arrow_type):
- return 19 # FLOAT32
+ return TypeId.FLOAT32
if pa_types.is_float64(arrow_type):
- return 20 # FLOAT64
+ return TypeId.FLOAT64
# String and binary
if pa_types.is_string(arrow_type) or pa_types.is_large_string(arrow_type):
- return 21 # STRING
+ return TypeId.STRING
if pa_types.is_binary(arrow_type) or pa_types.is_large_binary(arrow_type):
- return 41 # BINARY
+ return TypeId.BINARY
# Date/time types
if pa_types.is_date32(arrow_type):
- return 39 # DATE
+ return TypeId.DATE
if pa_types.is_timestamp(arrow_type):
- return 38 # TIMESTAMP
+ return TypeId.TIMESTAMP
if pa_types.is_duration(arrow_type):
- return 37 # DURATION
+ return TypeId.DURATION
# Decimal
if pa_types.is_decimal(arrow_type):
- return 40 # DECIMAL
+ return TypeId.DECIMAL
# Complex types
if pa_types.is_list(arrow_type) or pa_types.is_large_list(arrow_type):
- return 22 # LIST
+ return TypeId.LIST
if pa_types.is_map(arrow_type):
- return 24 # MAP
+ return TypeId.MAP
if pa_types.is_struct(arrow_type):
- return 27 # STRUCT
+ return TypeId.STRUCT
raise NotImplementedError(f"Unsupported Arrow type: {arrow_type}")
@@ -110,42 +112,42 @@ def fory_type_id_to_arrow_type(type_id, precision=None,
scale=None, list_type=No
NotImplementedError: If the Fory type is not supported.
"""
type_map = {
- 1: pa.bool_(), # BOOL
- 2: pa.int8(), # INT8
- 3: pa.int16(), # INT16
- 4: pa.int32(), # INT32
- 6: pa.int64(), # INT64
- 17: pa.float16(), # FLOAT16
- 19: pa.float32(), # FLOAT32
- 20: pa.float64(), # FLOAT64
- 21: pa.utf8(), # STRING
- 37: pa.duration("ns"), # DURATION
- 38: pa.timestamp("us"), # TIMESTAMP
- 39: pa.date32(), # DATE
- 41: pa.binary(), # BINARY
+ TypeId.BOOL: pa.bool_(),
+ TypeId.INT8: pa.int8(),
+ TypeId.INT16: pa.int16(),
+ TypeId.INT32: pa.int32(),
+ TypeId.INT64: pa.int64(),
+ TypeId.FLOAT16: pa.float16(),
+ TypeId.FLOAT32: pa.float32(),
+ TypeId.FLOAT64: pa.float64(),
+ TypeId.STRING: pa.utf8(),
+ TypeId.DURATION: pa.duration("ns"),
+ TypeId.TIMESTAMP: pa.timestamp("us"),
+ TypeId.DATE: pa.date32(),
+ TypeId.BINARY: pa.binary(),
}
if type_id in type_map:
return type_map[type_id]
# Decimal
- if type_id == 40: # DECIMAL
+ if type_id == TypeId.DECIMAL:
return pa.decimal128(precision or 38, scale or 18)
# List
- if type_id == 22: # LIST
+ if type_id == TypeId.LIST:
if list_type is None:
raise ValueError("list_type must be provided for LIST type")
return pa.list_(list_type)
# Map
- if type_id == 24: # MAP
+ if type_id == TypeId.MAP:
if map_key_type is None or map_value_type is None:
raise ValueError("map_key_type and map_value_type must be provided
for MAP type")
return pa.map_(map_key_type, map_value_type)
# Struct
- if type_id == 27: # STRUCT
+ if type_id == TypeId.STRUCT:
if struct_fields is None:
raise ValueError("struct_fields must be provided for STRUCT type")
return pa.struct(struct_fields)
diff --git a/python/pyfory/format/tests/test_infer.py
b/python/pyfory/format/tests/test_infer.py
index 810e43614..131e405f0 100644
--- a/python/pyfory/format/tests/test_infer.py
+++ b/python/pyfory/format/tests/test_infer.py
@@ -16,6 +16,7 @@
# under the License.
import datetime
+import importlib
import pyfory
import pytest
@@ -34,6 +35,7 @@ from pyfory.format import (
TypeId,
)
from pyfory.tests.core import require_pyarrow
+from pyfory.types import TypeId as SerializationTypeId
from typing import Dict, List, Tuple
@@ -97,6 +99,16 @@ def test_infer_field():
assert result.type.id == TypeId.STRUCT
+def test_pyfory_scalar_markers_infer_row_value_types():
+ assert pyfory.format.TypeId is SerializationTypeId
+ assert _infer_field("", pyfory.Int8).type.id == TypeId.INT8
+ assert _infer_field("", pyfory.Int16).type.id == TypeId.INT16
+ assert _infer_field("", pyfory.Int32).type.id == TypeId.INT32
+ assert _infer_field("", pyfory.Int64).type.id == TypeId.INT64
+ assert _infer_field("", pyfory.Float32).type.id == TypeId.FLOAT32
+ assert _infer_field("", pyfory.Float64).type.id == TypeId.FLOAT64
+
+
def test_infer_class_schema():
schema = infer_schema(Foo)
assert schema.num_fields == 7
@@ -156,6 +168,19 @@ def
test_pyarrow_schema_fields_roundtrip_through_row_format_schema():
assert roundtrip_arrow_schema == arrow_schema
+@require_pyarrow
+def test_pyarrow_type_id_helpers_use_serialization_type_id():
+ import pyarrow as pa
+
+ schema_util = importlib.import_module("pyfory.format.schema")
+
+ assert schema_util.TypeId is SerializationTypeId
+ assert schema_util.arrow_type_to_fory_type_id(pa.int32()) ==
SerializationTypeId.INT32
+ assert schema_util.arrow_type_to_fory_type_id(pa.float64()) ==
SerializationTypeId.FLOAT64
+ assert schema_util.fory_type_id_to_arrow_type(SerializationTypeId.INT16)
== pa.int16()
+ assert schema_util.fory_type_id_to_arrow_type(SerializationTypeId.STRING)
== pa.utf8()
+
+
def test_row_format_rejects_xlang_array_carrier_annotations():
with pytest.raises(TypeError, match="Row format does not support
pyfory.array array annotations"):
_infer_field("values", pyfory.Array[pyfory.Int32])
diff --git a/python/pyfory/primitive.pxi b/python/pyfory/primitive.pxi
index a49515e09..03de47ff9 100644
--- a/python/pyfory/primitive.pxi
+++ b/python/pyfory/primitive.pxi
@@ -254,7 +254,7 @@ cdef class DateSerializer(Serializer):
cdef class TimestampSerializer(Serializer):
cdef bint win_platform
- def __init__(self, type_resolver, type_: Union[type, TypeVar]):
+ def __init__(self, type_resolver, type_):
super().__init__(type_resolver, type_)
self.win_platform = platform.system() == "Windows"
diff --git a/python/pyfory/registry.py b/python/pyfory/registry.py
index 15231b9d9..a46700241 100644
--- a/python/pyfory/registry.py
+++ b/python/pyfory/registry.py
@@ -26,7 +26,6 @@ import logging
import pickle
import types
import typing
-from typing import TypeVar, Union
from enum import Enum
from pyfory import ENABLE_FORY_CYTHON_SERIALIZATION
@@ -136,6 +135,7 @@ from pyfory.types import (
)
from pyfory.type_util import (
load_class,
+ normalize_fory_type,
record_class_factory,
)
from pyfory._fory import (
@@ -547,7 +547,7 @@ class TypeResolver:
def register_type(
self,
- cls: Union[type, TypeVar],
+ cls,
*,
type_id: int = None,
name: str = None,
@@ -564,7 +564,7 @@ class TypeResolver:
def register_union(
self,
- cls: Union[type, TypeVar],
+ cls,
*,
type_id: int = None,
name: str = None,
@@ -601,7 +601,7 @@ class TypeResolver:
def _register_type(
self,
- cls: Union[type, TypeVar],
+ cls,
*,
type_id: int = None,
user_type_id: int = NO_USER_TYPE_ID,
@@ -612,6 +612,7 @@ class TypeResolver:
):
"""Register type with given type id or typename. If typename is not
None, it will be used for
cross-language serialization."""
+ cls = normalize_fory_type(cls)
if internal:
if type_id is not None and type_id >= 0 and type_id > 0xFF:
raise ValueError(f"Internal type id overflow: {type_id}")
@@ -652,7 +653,7 @@ class TypeResolver:
def _register_xtype(
self,
- cls: Union[type, TypeVar],
+ cls,
*,
type_id: int = None,
user_type_id: int = NO_USER_TYPE_ID,
@@ -710,7 +711,7 @@ class TypeResolver:
def __register_type(
self,
- cls: Union[type, TypeVar],
+ cls,
*,
type_id: int = None,
user_type_id: int = NO_USER_TYPE_ID,
@@ -780,8 +781,9 @@ class TypeResolver:
type_id = self._type_id_counter = self._type_id_counter + 1
return type_id
- def register_serializer(self, cls: Union[type, TypeVar], serializer):
- assert isinstance(cls, (type, TypeVar)), cls
+ def register_serializer(self, cls, serializer):
+ cls = normalize_fory_type(cls)
+ assert isinstance(cls, type) or type(cls) is int, cls
if cls not in self._types_info:
raise TypeUnregisteredError(f"{cls} not registered")
typeinfo = self._types_info[cls]
@@ -811,6 +813,7 @@ class TypeResolver:
return self.get_type_info(cls).serializer
def get_type_info(self, cls, create=True):
+ cls = normalize_fory_type(cls)
if cls is tuple and self.xlang:
return self.get_type_info(list, create=create)
type_info = self._types_info.get(cls)
@@ -950,6 +953,7 @@ class TypeResolver:
return serializer
def is_registered_by_name(self, cls):
+ cls = normalize_fory_type(cls)
typeinfo = self._types_info.get(cls)
if typeinfo is None:
return False
@@ -957,6 +961,7 @@ class TypeResolver:
def is_registered_by_id(self, cls=None, type_id=None,
user_type_id=NO_USER_TYPE_ID):
if cls is not None:
+ cls = normalize_fory_type(cls)
typeinfo = self._types_info.get(cls)
if typeinfo is None:
return False
@@ -971,21 +976,25 @@ class TypeResolver:
return type_id in self._type_id_to_type_info
def get_registered_name(self, cls):
+ cls = normalize_fory_type(cls)
typeinfo = self._types_info.get(cls)
assert typeinfo is not None, f"{cls} not registered"
return typeinfo.decode_namespace(), typeinfo.decode_typename()
def get_registered_id(self, cls):
+ cls = normalize_fory_type(cls)
typeinfo = self._types_info.get(cls)
assert typeinfo is not None, f"{cls} not registered"
return typeinfo.type_id
def get_registered_user_type_id(self, cls):
+ cls = normalize_fory_type(cls)
typeinfo = self._types_info.get(cls)
assert typeinfo is not None, f"{cls} not registered"
return typeinfo.user_type_id
def get_registered_type_ids(self, cls):
+ cls = normalize_fory_type(cls)
typeinfo = self._types_info.get(cls)
assert typeinfo is not None, f"{cls} not registered"
return typeinfo.type_id, typeinfo.user_type_id
diff --git a/python/pyfory/serialization.pyx b/python/pyfory/serialization.pyx
index 0270ef829..715218a30 100644
--- a/python/pyfory/serialization.pyx
+++ b/python/pyfory/serialization.pyx
@@ -19,7 +19,6 @@ import datetime
import os
import platform
import time
-from typing import TypeVar, Union
import cython
from libc.stdint cimport int32_t, int64_t, uint8_t, uint64_t
@@ -46,6 +45,7 @@ from pyfory.meta.typedef_decoder import decode_typedef
from pyfory.meta.metastring import MetaStringDecoder
from pyfory.policy import DEFAULT_POLICY
from pyfory.resolver import NULL_FLAG, NOT_NULL_VALUE_FLAG
+from pyfory.type_util import normalize_fory_type
from pyfory.includes.libserialization cimport (
TypeId,
TypeRegistrationKind,
@@ -83,35 +83,13 @@ cdef int32_t NOT_NULL_STRING_FLAG = (NOT_NULL_VALUE_FLAG &
0xFF) | (<int32_t>Typ
cdef int32_t NOT_NULL_FLOAT64_FLAG = (NOT_NULL_VALUE_FLAG & 0xFF) |
(<int32_t>TypeId.FLOAT64 << 8)
cdef int32_t MAX_CACHED_TYPE_DEFS = 8192
-_PRIMITIVE_TYPEVAR_NAMES = frozenset(
- {
- "Int8",
- "UInt8",
- "Int16",
- "UInt16",
- "Int32",
- "UInt32",
- "FixedInt32",
- "FixedUInt32",
- "Int64",
- "UInt64",
- "FixedInt64",
- "TaggedInt64",
- "FixedUInt64",
- "TaggedUInt64",
- "Float16",
- "BFloat16",
- "Float32",
- "Float64",
- }
-)
_PRIMITIVE_TYPE_IDS = frozenset(range(1, 21)) - {16}
def _is_primitive_type(type_):
if type(type_) is int:
return type_ in _PRIMITIVE_TYPE_IDS
- return type_ in (bool, int, float) or getattr(type_, "__name__", None) in
_PRIMITIVE_TYPEVAR_NAMES
+ return type_ in (bool, int, float)
@cython.final
@@ -273,7 +251,7 @@ cdef class TypeResolver:
def register_type(
self,
- cls: Union[type, TypeVar],
+ cls,
*,
type_id: int = None,
name: str = None,
@@ -290,7 +268,7 @@ cdef class TypeResolver:
def register_union(
self,
- cls: Union[type, TypeVar],
+ cls,
*,
type_id: int = None,
name: str = None,
@@ -609,7 +587,7 @@ cdef class Serializer:
cdef readonly object type_
cdef public bint need_to_write_ref
- def __init__(self, TypeResolver type_resolver, type_: Union[type,
TypeVar]):
+ def __init__(self, TypeResolver type_resolver, type_):
"""
Initialize a serializer for one declared Python type.
@@ -618,6 +596,7 @@ cdef class Serializer:
type_: Declared Python type handled by this serializer.
"""
self.type_resolver = type_resolver
+ type_ = normalize_fory_type(type_)
self.type_ = type_
self.need_to_write_ref = self.type_resolver.track_ref and not
_is_primitive_type(type_)
@@ -751,7 +730,7 @@ cdef class TypeInfo:
def __init__(
self,
- cls: Union[type, TypeVar] = None,
+ cls=None,
type_id: int = 0,
user_type_id: int = 0xFFFFFFFF,
serializer=None,
@@ -891,7 +870,7 @@ cdef class Fory:
def register_type(
self,
- cls: Union[type, TypeVar],
+ cls,
*,
type_id: int = None,
name: str = None,
@@ -906,7 +885,7 @@ cdef class Fory:
def register_union(
self,
- cls: Union[type, TypeVar],
+ cls,
*,
type_id: int = None,
name: str = None,
diff --git a/python/pyfory/struct.py b/python/pyfory/struct.py
index 43984d140..a8daf4668 100644
--- a/python/pyfory/struct.py
+++ b/python/pyfory/struct.py
@@ -30,24 +30,6 @@ from typing import List, Dict
from pyfory.annotation import (
ArrayMeta,
- BFloat16,
- Float16,
- Float32,
- Float64,
- FixedInt32,
- FixedInt64,
- FixedUInt32,
- FixedUInt64,
- Int8,
- Int16,
- Int32,
- Int64,
- TaggedInt64,
- TaggedUInt64,
- UInt8,
- UInt16,
- UInt32,
- UInt64,
)
from pyfory.lib.mmh3 import hash_buffer
from pyfory.policy import DEFAULT_POLICY
@@ -68,6 +50,8 @@ from pyfory.type_util import (
get_homogeneous_tuple_elem_type,
is_subclass,
get_type_hints,
+ normalize_fory_type,
+ scalar_type_id,
unwrap_array,
unwrap_optional,
unwrap_ref,
@@ -98,26 +82,28 @@ logger = logging.getLogger(__name__)
_MISSING_DEFAULT_INT_TYPES = {
int,
- Int8,
- Int16,
- Int32,
- FixedInt32,
- Int64,
- FixedInt64,
- TaggedInt64,
- UInt8,
- UInt16,
- UInt32,
- FixedUInt32,
- UInt64,
- FixedUInt64,
- TaggedUInt64,
+ TypeId.INT8,
+ TypeId.INT16,
+ TypeId.VARINT32,
+ TypeId.INT32,
+ TypeId.VARINT64,
+ TypeId.INT64,
+ TypeId.TAGGED_INT64,
+ TypeId.UINT8,
+ TypeId.UINT16,
+ TypeId.VAR_UINT32,
+ TypeId.UINT32,
+ TypeId.VAR_UINT64,
+ TypeId.UINT64,
+ TypeId.TAGGED_UINT64,
}
_MISSING_DEFAULT_FLOAT_TYPES = {
float,
- Float32,
- Float64,
+ TypeId.FLOAT16,
+ TypeId.BFLOAT16,
+ TypeId.FLOAT32,
+ TypeId.FLOAT64,
}
@@ -353,6 +339,7 @@ def resolve_missing_field_default(
return dc_field.default_factory
if not effective_nullable:
+ unwrapped_type = normalize_fory_type(unwrapped_type)
origin = typing.get_origin(unwrapped_type) if hasattr(typing,
"get_origin") else getattr(unwrapped_type, "__origin__", None)
origin = origin or unwrapped_type
if is_subclass(unwrapped_type, enum.Enum):
@@ -808,28 +795,6 @@ class DataClassStubSerializer(DataClassSerializer):
basic_types = {
bool,
- # Signed integers
- Int8,
- Int16,
- Int32,
- FixedInt32,
- Int64,
- FixedInt64,
- TaggedInt64,
- # Unsigned integers
- UInt8,
- UInt16,
- UInt32,
- FixedUInt32,
- UInt64,
- FixedUInt64,
- TaggedUInt64,
- # Floats
- Float16,
- BFloat16,
- Float32,
- Float64,
- # Python native types
int,
float,
str,
@@ -839,46 +804,65 @@ basic_types = {
datetime.time,
datetime.timedelta,
decimal.Decimal,
+ TypeId.INT8,
+ TypeId.INT16,
+ TypeId.VARINT32,
+ TypeId.INT32,
+ TypeId.VARINT64,
+ TypeId.INT64,
+ TypeId.TAGGED_INT64,
+ TypeId.UINT8,
+ TypeId.UINT16,
+ TypeId.VAR_UINT32,
+ TypeId.UINT32,
+ TypeId.VAR_UINT64,
+ TypeId.UINT64,
+ TypeId.TAGGED_UINT64,
+ TypeId.FLOAT16,
+ TypeId.BFLOAT16,
+ TypeId.FLOAT32,
+ TypeId.FLOAT64,
}
_ARRAY_ELEMENT_TYPE_IDS = {
bool: TypeId.BOOL_ARRAY,
- Int8: TypeId.INT8_ARRAY,
- Int16: TypeId.INT16_ARRAY,
- Int32: TypeId.INT32_ARRAY,
- Int64: TypeId.INT64_ARRAY,
- UInt8: TypeId.UINT8_ARRAY,
- UInt16: TypeId.UINT16_ARRAY,
- UInt32: TypeId.UINT32_ARRAY,
- UInt64: TypeId.UINT64_ARRAY,
- Float16: TypeId.FLOAT16_ARRAY,
- BFloat16: TypeId.BFLOAT16_ARRAY,
- Float32: TypeId.FLOAT32_ARRAY,
- Float64: TypeId.FLOAT64_ARRAY,
+ TypeId.INT8: TypeId.INT8_ARRAY,
+ TypeId.INT16: TypeId.INT16_ARRAY,
+ TypeId.VARINT32: TypeId.INT32_ARRAY,
+ TypeId.VARINT64: TypeId.INT64_ARRAY,
+ TypeId.UINT8: TypeId.UINT8_ARRAY,
+ TypeId.UINT16: TypeId.UINT16_ARRAY,
+ TypeId.VAR_UINT32: TypeId.UINT32_ARRAY,
+ TypeId.VAR_UINT64: TypeId.UINT64_ARRAY,
+ TypeId.FLOAT16: TypeId.FLOAT16_ARRAY,
+ TypeId.BFLOAT16: TypeId.BFLOAT16_ARRAY,
+ TypeId.FLOAT32: TypeId.FLOAT32_ARRAY,
+ TypeId.FLOAT64: TypeId.FLOAT64_ARRAY,
}
_ARRAY_INVALID_SCALAR_MODIFIERS = {
- FixedInt32,
- FixedInt64,
- FixedUInt32,
- FixedUInt64,
- TaggedInt64,
- TaggedUInt64,
+ TypeId.INT32,
+ TypeId.INT64,
+ TypeId.UINT32,
+ TypeId.UINT64,
+ TypeId.TAGGED_INT64,
+ TypeId.TAGGED_UINT64,
}
def _array_type_id(elem_type, carrier):
elem_type, ref_override = unwrap_ref(elem_type)
elem_type, elem_nullable = unwrap_optional(elem_type)
+ elem_type = normalize_fory_type(elem_type)
if elem_nullable:
raise TypeError("array<T> does not allow optional elements")
if ref_override is not None:
raise TypeError("array<T> does not allow ref-tracked elements")
if elem_type in _ARRAY_INVALID_SCALAR_MODIFIERS:
raise TypeError(f"array<T> does not allow scalar encoding modifier
{elem_type}")
- if carrier == "ndarray" and elem_type is BFloat16:
+ if carrier == "ndarray" and elem_type == TypeId.BFLOAT16:
raise TypeError("pyfory.NDArray does not support BFloat16 arrays")
- if carrier == "pyarray" and elem_type in (bool, Float16, BFloat16):
+ if carrier == "pyarray" and elem_type in (bool, TypeId.FLOAT16,
TypeId.BFLOAT16):
raise TypeError("pyfory.PyArray supports Python array.array numeric
typecodes only")
type_id = _ARRAY_ELEMENT_TYPE_IDS.get(elem_type)
if type_id is None:
@@ -1175,6 +1159,9 @@ def _normalize_schema_fingerprint_type_id(type_id):
def _leaf_schema_fingerprint_type_id(type_resolver, type_hint):
+ scalar_id = scalar_type_id(type_hint)
+ if scalar_id is not None:
+ return scalar_id
if type_hint is typing.Any or type_hint is object:
return TypeId.UNKNOWN
if is_primitive_array_type(type_hint):
diff --git a/python/pyfory/tests/test_struct.py
b/python/pyfory/tests/test_struct.py
index 55d2b2c80..b42af4f49 100644
--- a/python/pyfory/tests/test_struct.py
+++ b/python/pyfory/tests/test_struct.py
@@ -26,11 +26,18 @@ from typing import Dict, Any, List, Set, Optional, Tuple
import pytest
import typing
+try:
+ from typing_extensions import get_args
+except ImportError:
+ from typing import get_args
+
import pyfory
from pyfory import Fory
from pyfory.error import ForyInvalidDataError, TypeNotCompatibleError,
TypeUnregisteredError
from pyfory.resolver import NOT_NULL_VALUE_FLAG, REF_VALUE_FLAG
-from pyfory.struct import DataClassSerializer, build_default_values_factory
+from pyfory.serializer import FixedInt32Serializer
+from pyfory.struct import DataClassSerializer, build_default_values_factory,
compute_struct_fingerprint
+from pyfory.type_util import get_type_hints
from pyfory.types import TypeId
@@ -74,6 +81,93 @@ class ComplexObject:
f10: Optional[Dict[pyfory.Int32, pyfory.Float64]] = None
+@dataclass
+class TypingFriendlyScalarObject:
+ byte_value: pyfory.Int8 = 0
+ int_value: pyfory.Int32 = 0
+ fixed_int_value: pyfory.FixedInt32 = 0
+ float_value: pyfory.Float32 = 0.0
+ double_value: pyfory.Float64 = 0.0
+ values: List[pyfory.Int16] = dataclasses.field(default_factory=list)
+ dense_values: pyfory.Array[pyfory.Int32] =
dataclasses.field(default_factory=pyfory.Int32Array)
+
+
+def _plain_numeric_hint_base(type_hint):
+ args = get_args(type_hint)
+ if args and args[0] in (int, float):
+ return args[0]
+ return type_hint
+
+
+def test_scalar_markers_are_typing_friendly_aliases():
+ plain_hints = typing.get_type_hints(TypingFriendlyScalarObject)
+ assert _plain_numeric_hint_base(plain_hints["byte_value"]) is int
+ assert _plain_numeric_hint_base(plain_hints["int_value"]) is int
+ assert _plain_numeric_hint_base(plain_hints["float_value"]) is float
+ assert _plain_numeric_hint_base(plain_hints["double_value"]) is float
+
+ fory_hints = get_type_hints(TypingFriendlyScalarObject)
+ assert get_args(fory_hints["byte_value"]) == (int, TypeId.INT8)
+ assert get_args(fory_hints["int_value"]) == (int, TypeId.VARINT32)
+ assert get_args(fory_hints["fixed_int_value"]) == (int, TypeId.INT32)
+ assert get_args(fory_hints["float_value"]) == (float, TypeId.FLOAT32)
+ assert get_args(fory_hints["double_value"]) == (float, TypeId.FLOAT64)
+
+
+def test_scalar_marker_typeids_drive_struct_fields():
+ fory = Fory(xlang=True, compatible=True, ref=False)
+ fory.register_type(TypingFriendlyScalarObject, type_id=702)
+ serializer = fory.type_resolver.get_serializer(TypingFriendlyScalarObject)
+ field_infos = {field_info.name: field_info for field_info in
serializer._field_infos}
+
+ assert field_infos["byte_value"].field_type.type_id == TypeId.INT8
+ assert field_infos["int_value"].field_type.type_id == TypeId.VARINT32
+ assert field_infos["fixed_int_value"].field_type.type_id == TypeId.INT32
+ assert field_infos["float_value"].field_type.type_id == TypeId.FLOAT32
+ assert field_infos["double_value"].field_type.type_id == TypeId.FLOAT64
+ assert field_infos["values"].field_type.element_type.type_id ==
TypeId.INT16
+ assert field_infos["dense_values"].field_type.type_id == TypeId.INT32_ARRAY
+ fingerprint = compute_struct_fingerprint(
+ fory.type_resolver,
+ serializer._field_names,
+ serializer._serializers,
+ serializer._nullable_fields,
+ serializer._field_infos,
+ )
+ assert f"byte_value,{TypeId.INT8},0,0;" in fingerprint
+ assert f"fixed_int_value,{TypeId.INT32},0,0;" in fingerprint
+ assert f"float_value,{TypeId.FLOAT32},0,0;" in fingerprint
+ assert f"values,{TypeId.LIST},0,0[{TypeId.INT16},0,0];" in fingerprint
+
+ value = TypingFriendlyScalarObject(
+ byte_value=127,
+ int_value=2**31 - 1,
+ fixed_int_value=-(2**31),
+ float_value=1.5,
+ double_value=2.5,
+ values=[1, 2, 3],
+ dense_values=pyfory.Int32Array([4, 5, 6]),
+ )
+ assert ser_de(fory, value) == value
+
+
+def test_scalar_marker_resolves_registered_serializer():
+ fory = Fory(xlang=True, compatible=False, ref=True)
+ cases = [
+ (pyfory.Int8, TypeId.INT8, pyfory.ByteSerializer),
+ (pyfory.Int32, TypeId.VARINT32, pyfory.Int32Serializer),
+ (pyfory.FixedInt32, TypeId.INT32, FixedInt32Serializer),
+ (pyfory.Float32, TypeId.FLOAT32, pyfory.Float32Serializer),
+ (pyfory.Float64, TypeId.FLOAT64, pyfory.Float64Serializer),
+ ]
+ for marker, type_id, serializer_type in cases:
+ typeinfo = fory.type_resolver.get_type_info(marker)
+ assert typeinfo.cls == type_id
+ assert typeinfo.type_id == type_id
+ assert type(typeinfo.serializer) is serializer_type
+ assert typeinfo.serializer.need_to_write_ref is False
+
+
def test_struct():
fory = Fory(xlang=True, compatible=False, ref=True)
fory.register_type(SimpleObject, name="SimpleObject")
diff --git a/python/pyfory/type_id.py b/python/pyfory/type_id.py
new file mode 100644
index 000000000..02d172dae
--- /dev/null
+++ b/python/pyfory/type_id.py
@@ -0,0 +1,173 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+class TypeId:
+ """
+ Fory type for cross-language serialization.
+ See `org.apache.fory.types.Type`
+ """
+
+ # Unknown/polymorphic type marker.
+ UNKNOWN = 0
+ # a boolean value (true or false).
+ BOOL = 1
+ # a 8-bit signed integer.
+ INT8 = 2
+ # a 16-bit signed integer.
+ INT16 = 3
+ # a 32-bit signed integer.
+ INT32 = 4
+ # a 32-bit signed integer using variable-length encoding.
+ VARINT32 = 5
+ # a 64-bit signed integer.
+ INT64 = 6
+ # a 64-bit signed integer using variable-length encoding.
+ VARINT64 = 7
+ # a 64-bit signed integer using tagged encoding.
+ TAGGED_INT64 = 8
+ # an 8-bit unsigned integer.
+ UINT8 = 9
+ # a 16-bit unsigned integer.
+ UINT16 = 10
+ # a 32-bit unsigned integer.
+ UINT32 = 11
+ # a 32-bit unsigned integer using variable-length encoding.
+ VAR_UINT32 = 12
+ # a 64-bit unsigned integer.
+ UINT64 = 13
+ # a 64-bit unsigned integer using variable-length encoding.
+ VAR_UINT64 = 14
+ # a 64-bit unsigned integer using tagged encoding.
+ TAGGED_UINT64 = 15
+ # an 8-bit floating point number.
+ FLOAT8 = 16
+ # a 16-bit floating point number.
+ FLOAT16 = 17
+ # a 16-bit brain floating point number.
+ BFLOAT16 = 18
+ # a 32-bit floating point number.
+ FLOAT32 = 19
+ # a 64-bit floating point number including NaN and Infinity.
+ FLOAT64 = 20
+ # a text string encoded using Latin1/UTF16/UTF-8 encoding.
+ STRING = 21
+ # a sequence of objects.
+ LIST = 22
+ # an unordered set of unique elements.
+ SET = 23
+ # a map of key-value pairs. Mutable types such as
`list/map/set/array/tensor/arrow` are not allowed as key of map.
+ MAP = 24
+ # a data type consisting of a set of named values. Rust enum with
non-predefined field values are not supported as
+ # an enum.
+ ENUM = 25
+ # an enum whose value will be serialized as the registered name.
+ NAMED_ENUM = 26
+ # a morphic(final) type serialized by Fory Struct serializer. i.e., it
doesn't have subclasses. Suppose we're
+ # deserializing `List[SomeClass]`, we can save dynamic serializer dispatch
since `SomeClass` is morphic(final).
+ STRUCT = 27
+ # a morphic(final) type serialized by Fory compatible Struct serializer.
+ COMPATIBLE_STRUCT = 28
+ # a `struct` whose type mapping will be encoded as a name.
+ NAMED_STRUCT = 29
+ # a `compatible_struct` whose type mapping will be encoded as a name.
+ NAMED_COMPATIBLE_STRUCT = 30
+ # a type which will be serialized by a customized serializer.
+ EXT = 31
+ # an `ext` type whose type mapping will be encoded as a name.
+ NAMED_EXT = 32
+ # a union value whose schema identity is not embedded.
+ UNION = 33
+ # a union value with embedded numeric union type ID.
+ TYPED_UNION = 34
+ # a union value with embedded union type name/TypeDef.
+ NAMED_UNION = 35
+ # represents an empty/unit value with no data (e.g., for empty union
alternatives).
+ NONE = 36
+ # an absolute length of time, independent of any calendar/timezone, as a
count of nanoseconds.
+ DURATION = 37
+ # a point in time, independent of any calendar/timezone, as a count of
nanoseconds. The count is relative
+ # to an epoch at UTC midnight on January 1, 1970.
+ TIMESTAMP = 38
+ # a naive date without timezone. The count is days relative to an epoch at
UTC midnight on Jan 1, 1970.
+ DATE = 39
+ # exact decimal value represented as an integer value in two's complement.
+ DECIMAL = 40
+ # a variable-length array of bytes.
+ BINARY = 41
+ # generic dense array descriptor, reserved for future shaped-array
metadata.
+ ARRAY = 42
+ # one dimensional bool array.
+ BOOL_ARRAY = 43
+ # one dimensional Int8 array.
+ INT8_ARRAY = 44
+ # one dimensional Int16 array.
+ INT16_ARRAY = 45
+ # one dimensional Int32 array.
+ INT32_ARRAY = 46
+ # one dimensional Int64 array.
+ INT64_ARRAY = 47
+ # one dimensional UInt8 array.
+ UINT8_ARRAY = 48
+ # one dimensional UInt16 array.
+ UINT16_ARRAY = 49
+ # one dimensional UInt32 array.
+ UINT32_ARRAY = 50
+ # one dimensional UInt64 array.
+ UINT64_ARRAY = 51
+ # one dimensional float8 array.
+ FLOAT8_ARRAY = 52
+ # one dimensional Float16 array.
+ FLOAT16_ARRAY = 53
+ # one dimensional BFloat16 array.
+ BFLOAT16_ARRAY = 54
+ # one dimensional Float32 array.
+ FLOAT32_ARRAY = 55
+ # one dimensional Float64 array.
+ FLOAT64_ARRAY = 56
+
+ # Bound value for range checks (types with id >= BOUND are not internal
types).
+ BOUND = 64
+
+ @staticmethod
+ def is_namespaced_type(type_id: int) -> bool:
+ return type_id in __NAMESPACED_TYPES__
+
+ @staticmethod
+ def is_type_share_meta(type_id: int) -> bool:
+ return type_id in __TYPE_SHARE_META__
+
+
+__NAMESPACED_TYPES__ = {
+ TypeId.NAMED_EXT,
+ TypeId.NAMED_ENUM,
+ TypeId.NAMED_STRUCT,
+ TypeId.NAMED_COMPATIBLE_STRUCT,
+ TypeId.NAMED_UNION,
+}
+
+__TYPE_SHARE_META__ = {
+ TypeId.NAMED_ENUM,
+ TypeId.NAMED_STRUCT,
+ TypeId.NAMED_EXT,
+ TypeId.COMPATIBLE_STRUCT,
+ TypeId.NAMED_COMPATIBLE_STRUCT,
+ TypeId.NAMED_UNION,
+}
+
+
+__all__ = ["TypeId"]
diff --git a/python/pyfory/type_util.py b/python/pyfory/type_util.py
index ca1a750dc..0e52681ee 100644
--- a/python/pyfory/type_util.py
+++ b/python/pyfory/type_util.py
@@ -20,18 +20,15 @@ import importlib
import inspect
import typing
-from typing import TypeVar
from abc import ABC, abstractmethod
from pyfory.annotation import ArrayMeta, RefMeta
+from pyfory.type_id import TypeId
try:
- from typing import Annotated as _Annotated
+ from typing import Annotated
except ImportError:
- try:
- from typing_extensions import Annotated as _Annotated
- except ImportError:
- _Annotated = None
+ from typing_extensions import Annotated
try:
from typing_extensions import get_type_hints as
_typing_extensions_get_type_hints
@@ -75,7 +72,7 @@ def get_type_hints(type_):
def unwrap_ref(type_):
origin = _get_origin(type_)
- if _Annotated is not None and origin is _Annotated:
+ if origin is Annotated:
args = _get_args(type_)
if args:
base = args[0]
@@ -83,7 +80,7 @@ def unwrap_ref(type_):
for meta in args[1:]:
if isinstance(meta, RefMeta):
if other_metadata:
- return _Annotated[(base, *other_metadata)], meta.enable
+ return Annotated[(base, *other_metadata)], meta.enable
return base, meta.enable
other_metadata.append(meta)
return type_, None
@@ -103,7 +100,7 @@ def unwrap_ref(type_):
def unwrap_array(type_):
origin = _get_origin(type_)
- if _Annotated is not None and origin is _Annotated:
+ if origin is Annotated:
args = _get_args(type_)
for meta in args[1:]:
if isinstance(meta, ArrayMeta):
@@ -111,6 +108,56 @@ def unwrap_array(type_):
return getattr(type_, "__fory_array_meta__", None)
+_INT_SCALAR_TYPE_IDS = frozenset(
+ {
+ TypeId.INT8,
+ TypeId.INT16,
+ TypeId.INT32,
+ TypeId.VARINT32,
+ TypeId.INT64,
+ TypeId.VARINT64,
+ TypeId.TAGGED_INT64,
+ TypeId.UINT8,
+ TypeId.UINT16,
+ TypeId.UINT32,
+ TypeId.VAR_UINT32,
+ TypeId.UINT64,
+ TypeId.VAR_UINT64,
+ TypeId.TAGGED_UINT64,
+ }
+)
+_FLOAT_SCALAR_TYPE_IDS = frozenset({TypeId.FLOAT16, TypeId.BFLOAT16,
TypeId.FLOAT32, TypeId.FLOAT64})
+_SCALAR_TYPE_IDS = _INT_SCALAR_TYPE_IDS | _FLOAT_SCALAR_TYPE_IDS
+
+
+def scalar_type_id(type_):
+ if type(type_) is int and type_ in _SCALAR_TYPE_IDS:
+ return type_
+ origin = _get_origin(type_)
+ if origin is not Annotated:
+ return None
+ args = _get_args(type_)
+ if not args:
+ return None
+ base = args[0]
+ for meta in args[1:]:
+ if type(meta) is not int or meta not in _SCALAR_TYPE_IDS:
+ continue
+ if base is int and meta in _INT_SCALAR_TYPE_IDS:
+ return meta
+ if base is float and meta in _FLOAT_SCALAR_TYPE_IDS:
+ return meta
+ raise TypeError(f"Fory scalar TypeId {meta} does not match Annotated
base type {base}")
+ return None
+
+
+def normalize_fory_type(type_):
+ type_id = scalar_type_id(type_)
+ if type_id is None:
+ return type_
+ return type_id
+
+
# modified from `fluent python`
def record_class_factory(cls_name, field_names, *, publish=True):
"""
@@ -306,6 +353,9 @@ def infer_field(field_name, type_, visitor: TypeVisitor,
types_path=None):
array_meta = unwrap_array(type_)
if array_meta is not None:
return visitor.visit_array(field_name, array_meta.element_type,
array_meta.carrier, types_path=types_path)
+ normalized_type = normalize_fory_type(type_)
+ if normalized_type is not type_:
+ return visitor.visit_other(field_name, normalized_type,
types_path=types_path)
origin = _get_origin(type_) or getattr(type_, "__origin__", type_)
origin = origin or type_
args = _get_args(type_)
@@ -358,10 +408,7 @@ def compute_string_hash(string):
def qualified_class_name(cls):
- if isinstance(cls, TypeVar):
- return cls.__module__ + "#" + cls.__name__
- else:
- return cls.__module__ + "#" + cls.__qualname__
+ return cls.__module__ + "#" + cls.__qualname__
def load_class(classname: str, policy=None):
diff --git a/python/pyfory/types.py b/python/pyfory/types.py
index 47e8822b5..74979720a 100644
--- a/python/pyfory/types.py
+++ b/python/pyfory/types.py
@@ -17,6 +17,7 @@
import typing
+from pyfory.type_id import TypeId
from pyfory.annotation import (
BFloat16 as _BFloat16,
Float16 as _Float16,
@@ -43,160 +44,6 @@ from pyfory.annotation import (
__all__ = ["TypeId"]
-class TypeId:
- """
- Fory type for cross-language serialization.
- See `org.apache.fory.types.Type`
- """
-
- # Unknown/polymorphic type marker.
- UNKNOWN = 0
- # a boolean value (true or false).
- BOOL = 1
- # a 8-bit signed integer.
- INT8 = 2
- # a 16-bit signed integer.
- INT16 = 3
- # a 32-bit signed integer.
- INT32 = 4
- # a 32-bit signed integer using variable-length encoding.
- VARINT32 = 5
- # a 64-bit signed integer.
- INT64 = 6
- # a 64-bit signed integer using variable-length encoding.
- VARINT64 = 7
- # a 64-bit signed integer using tagged encoding.
- TAGGED_INT64 = 8
- # an 8-bit unsigned integer.
- UINT8 = 9
- # a 16-bit unsigned integer.
- UINT16 = 10
- # a 32-bit unsigned integer.
- UINT32 = 11
- # a 32-bit unsigned integer using variable-length encoding.
- VAR_UINT32 = 12
- # a 64-bit unsigned integer.
- UINT64 = 13
- # a 64-bit unsigned integer using variable-length encoding.
- VAR_UINT64 = 14
- # a 64-bit unsigned integer using tagged encoding.
- TAGGED_UINT64 = 15
- # an 8-bit floating point number.
- FLOAT8 = 16
- # a 16-bit floating point number.
- FLOAT16 = 17
- # a 16-bit brain floating point number.
- BFLOAT16 = 18
- # a 32-bit floating point number.
- FLOAT32 = 19
- # a 64-bit floating point number including NaN and Infinity.
- FLOAT64 = 20
- # a text string encoded using Latin1/UTF16/UTF-8 encoding.
- STRING = 21
- # a sequence of objects.
- LIST = 22
- # an unordered set of unique elements.
- SET = 23
- # a map of key-value pairs. Mutable types such as
`list/map/set/array/tensor/arrow` are not allowed as key of map.
- MAP = 24
- # a data type consisting of a set of named values. Rust enum with
non-predefined field values are not supported as
- # an enum.
- ENUM = 25
- # an enum whose value will be serialized as the registered name.
- NAMED_ENUM = 26
- # a morphic(final) type serialized by Fory Struct serializer. i.e., it
doesn't have subclasses. Suppose we're
- # deserializing `List[SomeClass]`, we can save dynamic serializer dispatch
since `SomeClass` is morphic(final).
- STRUCT = 27
- # a morphic(final) type serialized by Fory compatible Struct serializer.
- COMPATIBLE_STRUCT = 28
- # a `struct` whose type mapping will be encoded as a name.
- NAMED_STRUCT = 29
- # a `compatible_struct` whose type mapping will be encoded as a name.
- NAMED_COMPATIBLE_STRUCT = 30
- # a type which will be serialized by a customized serializer.
- EXT = 31
- # an `ext` type whose type mapping will be encoded as a name.
- NAMED_EXT = 32
- # a union value whose schema identity is not embedded.
- UNION = 33
- # a union value with embedded numeric union type ID.
- TYPED_UNION = 34
- # a union value with embedded union type name/TypeDef.
- NAMED_UNION = 35
- # represents an empty/unit value with no data (e.g., for empty union
alternatives).
- NONE = 36
- # an absolute length of time, independent of any calendar/timezone, as a
count of nanoseconds.
- DURATION = 37
- # a point in time, independent of any calendar/timezone, as a count of
nanoseconds. The count is relative
- # to an epoch at UTC midnight on January 1, 1970.
- TIMESTAMP = 38
- # a naive date without timezone. The count is days relative to an epoch at
UTC midnight on Jan 1, 1970.
- DATE = 39
- # exact decimal value represented as an integer value in two's complement.
- DECIMAL = 40
- # a variable-length array of bytes.
- BINARY = 41
- # generic dense array descriptor, reserved for future shaped-array
metadata.
- ARRAY = 42
- # one dimensional bool array.
- BOOL_ARRAY = 43
- # one dimensional Int8 array.
- INT8_ARRAY = 44
- # one dimensional Int16 array.
- INT16_ARRAY = 45
- # one dimensional Int32 array.
- INT32_ARRAY = 46
- # one dimensional Int64 array.
- INT64_ARRAY = 47
- # one dimensional UInt8 array.
- UINT8_ARRAY = 48
- # one dimensional UInt16 array.
- UINT16_ARRAY = 49
- # one dimensional UInt32 array.
- UINT32_ARRAY = 50
- # one dimensional UInt64 array.
- UINT64_ARRAY = 51
- # one dimensional float8 array.
- FLOAT8_ARRAY = 52
- # one dimensional Float16 array.
- FLOAT16_ARRAY = 53
- # one dimensional BFloat16 array.
- BFLOAT16_ARRAY = 54
- # one dimensional Float32 array.
- FLOAT32_ARRAY = 55
- # one dimensional Float64 array.
- FLOAT64_ARRAY = 56
-
- # Bound value for range checks (types with id >= BOUND are not internal
types).
- BOUND = 64
-
- @staticmethod
- def is_namespaced_type(type_id: int) -> bool:
- return type_id in __NAMESPACED_TYPES__
-
- @staticmethod
- def is_type_share_meta(type_id: int) -> bool:
- return type_id in __TYPE_SHARE_META__
-
-
-__NAMESPACED_TYPES__ = {
- TypeId.NAMED_EXT,
- TypeId.NAMED_ENUM,
- TypeId.NAMED_STRUCT,
- TypeId.NAMED_COMPATIBLE_STRUCT,
- TypeId.NAMED_UNION,
-}
-
-__TYPE_SHARE_META__ = {
- TypeId.NAMED_ENUM,
- TypeId.NAMED_STRUCT,
- TypeId.NAMED_EXT,
- TypeId.COMPATIBLE_STRUCT,
- TypeId.NAMED_COMPATIBLE_STRUCT,
- TypeId.NAMED_UNION,
-}
-
-
_primitive_types = {
int,
float,
@@ -252,9 +99,6 @@ _primitive_types_ids = {
}
-# `Union[type, TypeVar]` is not supported in py3.6, so skip adding type hints
for `type_` # noqa: E501
-# See more at https://github.com/python/typing/issues/492 and
-#
https://stackoverflow.com/questions/69427175/how-to-pass-forwardref-as-args-to-typevar-in-python-3-6
# noqa: E501
def is_primitive_type(type_) -> bool:
if type(type_) is int:
return type_ in _primitive_types_ids
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]