This is an automated email from the ASF dual-hosted git repository.
chaokunyang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/fory.git
The following commit(s) were added to refs/heads/main by this push:
new 97f582f9e fix(python): enforce more checks in read (#3632)
97f582f9e is described below
commit 97f582f9e7ad89483b77de9f3f88b7d66a832727
Author: Shawn Yang <[email protected]>
AuthorDate: Wed Apr 29 01:23:11 2026 +0800
fix(python): enforce more checks in read (#3632)
## Why?
## What does this PR do?
## Related issues
## AI Contribution Checklist
- [ ] Substantial AI assistance was used in this PR: `yes` / `no`
- [ ] If `yes`, I included a completed [AI Contribution
Checklist](https://github.com/apache/fory/blob/main/AI_POLICY.md#9-contributor-checklist-for-ai-assisted-prs)
in this PR description and the required `AI Usage Disclosure`.
- [ ] If `yes`, my PR description includes the required `ai_review`
summary and screenshot evidence of the final clean AI review results
from both fresh reviewers on the current PR diff or current HEAD after
the latest code changes.
## Does this PR introduce any user-facing change?
- [ ] Does this PR introduce any public API change?
- [ ] Does this PR introduce any binary protocol compatibility change?
## Benchmark
---
python/pyfory/buffer.pxi | 23 +-
python/pyfory/context.pxi | 8 +-
python/pyfory/context.py | 6 +
python/pyfory/format/row.pxi | 4 +-
python/pyfory/format/tests/test_encoder.py | 12 +
python/pyfory/meta/typedef_decoder.py | 7 +
python/pyfory/registry.py | 19 +-
python/pyfory/serialization.pyx | 2 +-
python/pyfory/serializer.py | 262 ++++++++++----
python/pyfory/struct.pxi | 37 +-
python/pyfory/struct.py | 55 ++-
python/pyfory/tests/test_buffer.py | 14 +
python/pyfory/tests/test_meta_share.py | 10 +
python/pyfory/tests/test_policy.py | 534 ++++++++++++++++++++++++++++
python/pyfory/tests/test_size_guardrails.py | 67 ++++
python/pyfory/type_util.py | 29 +-
16 files changed, 985 insertions(+), 104 deletions(-)
diff --git a/python/pyfory/buffer.pxi b/python/pyfory/buffer.pxi
index 400520ee2..b9dd1cb7a 100644
--- a/python/pyfory/buffer.pxi
+++ b/python/pyfory/buffer.pxi
@@ -593,7 +593,7 @@ cdef class Buffer:
cdef uint32_t target_index = start_index
cdef uint8_t sep = 10 # '\n'
cdef int32_t buffer_size = self.c_buffer.size()
- while arr[target_index] != sep and target_index < buffer_size:
+ while target_index < <uint32_t>buffer_size and arr[target_index] !=
sep:
target_index += <int32_t>1
cdef bytes data = arr[start_index:target_index]
self.c_buffer.reader_index(target_index)
@@ -707,7 +707,12 @@ cdef class Buffer:
cpdef inline str read_string(self):
cdef uint64_t header = self.read_var_uint64()
- cdef uint32_t size = header >> 2
+ cdef uint64_t size64 = header >> 2
+ if size64 > <uint64_t>self.max_binary_size:
+ raise ValueError(f"String size {size64} exceeds the configured
limit of {self.max_binary_size}")
+ if size64 > <uint64_t>2147483647:
+ raise ValueError(f"String size {size64} exceeds the maximum
supported size")
+ cdef uint32_t size = <uint32_t>size64
cdef uint32_t encoding = header & <uint32_t>0b11
if size == 0:
return ""
@@ -744,11 +749,15 @@ cdef class Buffer:
return self.c_buffer.size()
def to_bytes(self, int32_t offset=0, int32_t length=0) -> bytes:
- if length != 0:
- assert 0 < length <= self.c_buffer.size(),\
- f"length {length} size {self.c_buffer.size()}"
- else:
- length = self.c_buffer.size()
+ cdef int32_t size_ = self.c_buffer.size()
+ if offset < 0 or offset > size_:
+ raise ValueError(f"offset {offset} out of bound {0, size_}")
+ if length < 0:
+ raise ValueError(f"length {length} must be non-negative")
+ if length == 0:
+ length = size_ - offset
+ elif length > size_ - offset:
+ raise ValueError(f"Address range {(offset, offset + length)} out
of bound {(0, size_)}")
cdef:
uint8_t* data = self.c_buffer.data() + offset
return data[:length]
diff --git a/python/pyfory/context.pxi b/python/pyfory/context.pxi
index 75dc39a1b..7aec0d11b 100644
--- a/python/pyfory/context.pxi
+++ b/python/pyfory/context.pxi
@@ -329,6 +329,8 @@ cdef class MetaStringReader:
if header & 0b1:
if length <= 0:
raise ValueError("Invalid dynamic metastring id 0")
+ if length > <int32_t>
self._c_dynamic_id_to_encoded_meta_string_vec.size():
+ raise ValueError(f"Invalid dynamic metastring id {length}")
return <object>
self._c_dynamic_id_to_encoded_meta_string_vec[length - 1]
if length <= SMALL_STRING_THRESHOLD:
if length == 0:
@@ -868,11 +870,13 @@ cdef class ReadContext:
return obj
cpdef read_buffer_object(self):
- cdef int32_t size
+ cdef uint32_t size
cdef int32_t reader_index
cdef Buffer buf
if not self.peer_out_of_band_enabled:
size = self.read_var_uint32()
+ if size > <uint32_t>self.max_binary_size:
+ raise ValueError(f"Binary size {size} exceeds the configured
limit of {self.max_binary_size}")
if self.buffer.has_input_stream():
return self.buffer.read_bytes(size)
reader_index = self.buffer.get_reader_index()
@@ -883,6 +887,8 @@ cdef class ReadContext:
assert self.buffers is not None
return next(self.buffers)
size = self.read_var_uint32()
+ if size > <uint32_t>self.max_binary_size:
+ raise ValueError(f"Binary size {size} exceeds the configured limit
of {self.max_binary_size}")
if self.buffer.has_input_stream():
return self.buffer.read_bytes(size)
reader_index = self.buffer.get_reader_index()
diff --git a/python/pyfory/context.py b/python/pyfory/context.py
index 332378b64..9f5becfe6 100644
--- a/python/pyfory/context.py
+++ b/python/pyfory/context.py
@@ -141,6 +141,8 @@ class MetaStringReader:
if (header & 0b1) != 0:
if length <= 0:
raise ValueError("Invalid dynamic metastring id 0")
+ if length > len(self._dynamic_id_to_encoded_meta_strings):
+ raise ValueError(f"Invalid dynamic metastring id {length}")
return self._dynamic_id_to_encoded_meta_strings[length - 1]
if length <= SMALL_STRING_THRESHOLD:
encoded_meta_string = self._read_small_meta_string(buffer, length)
@@ -622,6 +624,8 @@ class ReadContext:
def read_buffer_object(self):
if not self.peer_out_of_band_enabled:
size = self.buffer.read_var_uint32()
+ if size > self.max_binary_size:
+ raise ValueError(f"Binary size {size} exceeds the configured
limit of {self.max_binary_size}")
if self.buffer.has_input_stream():
return self.buffer.read_bytes(size)
reader_index = self.buffer.get_reader_index()
@@ -633,6 +637,8 @@ class ReadContext:
assert self.buffers is not None
return next(self.buffers)
size = self.buffer.read_var_uint32()
+ if size > self.max_binary_size:
+ raise ValueError(f"Binary size {size} exceeds the configured limit
of {self.max_binary_size}")
if self.buffer.has_input_stream():
return self.buffer.read_bytes(size)
reader_index = self.buffer.get_reader_index()
diff --git a/python/pyfory/format/row.pxi b/python/pyfory/format/row.pxi
index a1d82cff5..f4409ab96 100644
--- a/python/pyfory/format/row.pxi
+++ b/python/pyfory/format/row.pxi
@@ -169,7 +169,7 @@ cdef class ArrayData(Getter):
return MapData.wrap(v, map_type)
def __getitem__(self, i):
- if i > self.num_elements or i < 0:
+ if i >= self.num_elements or i < 0:
raise IndexError("length is {}, but index is {}"
.format(self.num_elements, i))
return self.get(i)
@@ -350,7 +350,7 @@ cdef class RowData(Getter):
if not isinstance(i, int):
assert type(i) is str
i = self.schema_.names.index(i)
- if i > self.num_fields or i < 0:
+ if i >= self.num_fields or i < 0:
raise IndexError("num_fields is {}, but index is {}"
.format(self.num_fields, i))
return self.get(i)
diff --git a/python/pyfory/format/tests/test_encoder.py
b/python/pyfory/format/tests/test_encoder.py
index dfbf22e54..b94fc5f9d 100644
--- a/python/pyfory/format/tests/test_encoder.py
+++ b/python/pyfory/format/tests/test_encoder.py
@@ -18,6 +18,8 @@
import timeit
import pickle
+import pytest
+
import pyfory
from pyfory.format import (
schema,
@@ -102,6 +104,16 @@ def test_encode():
assert foo.f6 == new_foo.f6
+def test_row_and_array_reject_one_past_end_index():
+ encoder = pyfory.create_row_encoder(foo_schema())
+ row = encoder.to_row(create_foo())
+ with pytest.raises(IndexError):
+ row[row.num_fields]
+ array_data = row.get_array_data(2)
+ with pytest.raises(IndexError):
+ array_data[array_data.num_elements]
+
+
def test_encoder():
foo = create_foo()
encoder = pyfory.encoder(Foo)
diff --git a/python/pyfory/meta/typedef_decoder.py
b/python/pyfory/meta/typedef_decoder.py
index 59e850b5f..b763519ad 100644
--- a/python/pyfory/meta/typedef_decoder.py
+++ b/python/pyfory/meta/typedef_decoder.py
@@ -152,6 +152,8 @@ def decode_typedef(buffer: Buffer, resolver, header=None)
-> TypeDef:
if has_fields_meta:
field_infos = read_fields_info(meta_buffer, resolver, name, num_fields)
if type_cls is None:
+ if getattr(resolver, "strict", False) and not getattr(resolver,
"_allow_unregistered_typedef", False):
+ raise ValueError(f"TypeDef {name} is not registered in strict
mode")
# Check generated class count limit
if _generated_class_count >= MAX_GENERATED_CLASSES:
raise ValueError(
@@ -164,6 +166,11 @@ def decode_typedef(buffer: Buffer, resolver, header=None)
-> TypeDef:
# Use a valid Python identifier for class name
class_name = typename.replace(".", "_").replace("$", "_")
type_cls = make_dataclass(class_name, field_definitions)
+ policy = getattr(resolver, "policy", None)
+ if policy is not None:
+ result = policy.validate_class(type_cls, is_local=False)
+ if result is not None:
+ type_cls = result
# Create TypeDef object
type_def = TypeDef(
diff --git a/python/pyfory/registry.py b/python/pyfory/registry.py
index 6b3bee2cc..2711b748d 100644
--- a/python/pyfory/registry.py
+++ b/python/pyfory/registry.py
@@ -75,9 +75,11 @@ from pyfory.serializer import (
EnumSerializer,
SliceSerializer,
StatefulSerializer,
+ _DefaultPolicyStatefulSerializer,
ReduceSerializer,
FunctionSerializer,
ObjectSerializer,
+ _DefaultPolicyObjectSerializer,
TypeSerializer,
ModuleSerializer,
MappingProxySerializer,
@@ -87,6 +89,7 @@ from pyfory.serializer import (
PickleBufferSerializer,
UnionSerializer,
)
+from pyfory.policy import DEFAULT_POLICY
from pyfory.serialization import (
Serializer as CythonSerializer,
bfloat16,
@@ -331,6 +334,7 @@ class TypeResolver:
"meta_share",
"_internal_py_serializer_map",
"_actual_type_resolver",
+ "_allow_unregistered_typedef",
)
def __init__(self, config, *, shared_registry):
@@ -365,6 +369,7 @@ class TypeResolver:
self.meta_share = config.meta_share
self._internal_py_serializer_map = {}
self._actual_type_resolver = self
+ self._allow_unregistered_typedef = False
def _set_actual_resolver(self, type_resolver):
# Cython mode injects the compiled companion before initialize() so all
@@ -840,6 +845,7 @@ class TypeResolver:
def _create_serializer(self, cls):
serializer_type_resolver = self._actual_type_resolver
+ use_default_policy = serializer_type_resolver.policy is DEFAULT_POLICY
# Check if it's a Union type first
origin = typing.get_origin(cls) if hasattr(typing, "get_origin") else
getattr(cls, "__origin__", None)
if origin is typing.Union:
@@ -894,9 +900,11 @@ class TypeResolver:
serializer = ReduceSerializer(serializer_type_resolver, cls)
elif hasattr(cls, "__getstate__") and hasattr(cls, "__setstate__"):
# Use StatefulSerializer for objects that support __getstate__
and __setstate__
- serializer = StatefulSerializer(serializer_type_resolver, cls)
+ serializer_cls = _DefaultPolicyStatefulSerializer if
use_default_policy else StatefulSerializer
+ serializer = serializer_cls(serializer_type_resolver, cls)
elif hasattr(cls, "__dict__") or hasattr(cls, "__slots__"):
- serializer = ObjectSerializer(serializer_type_resolver, cls)
+ serializer_cls = _DefaultPolicyObjectSerializer if
use_default_policy else ObjectSerializer
+ serializer = serializer_cls(serializer_type_resolver, cls)
else:
# c-extension types will go to here
serializer = UnsupportedSerializer(serializer_type_resolver,
cls)
@@ -957,7 +965,7 @@ class TypeResolver:
if typeinfo is not None:
self._ns_type_to_type_info[(ns_metabytes, type_metabytes)] =
typeinfo
return typeinfo
- cls = load_class(ns + "#" + typename)
+ cls = load_class(ns + "#" + typename, policy=self.policy)
typeinfo = self.get_type_info(cls)
self._ns_type_to_type_info[(ns_metabytes, type_metabytes)] = typeinfo
return typeinfo
@@ -1012,7 +1020,7 @@ class TypeResolver:
return typeinfo
typename = split_typename
ns = split_ns
- if typename:
+ if typename and not self.strict:
matches = [info for (reg_ns, reg_typename), info in
self._named_type_to_type_info.items() if reg_typename == typename]
if len(matches) == 1:
typeinfo = matches[0]
@@ -1127,12 +1135,9 @@ class TypeResolver:
# Check if we already have this TypeDef cached
type_info = self._meta_shared_type_info.get(header)
if type_info is not None:
- # Skip the rest of the TypeDef binary for faster performance
skip_typedef(buffer, header)
else:
- # Read the TypeDef and create TypeInfo
type_def = decode_typedef(buffer, self, header=header)
type_info = self._build_type_info_from_typedef(type_def)
- # Cache the tuple for future use
self._meta_shared_type_info[header] = type_info
return type_info
diff --git a/python/pyfory/serialization.pyx b/python/pyfory/serialization.pyx
index 717104b36..c9e85de3c 100644
--- a/python/pyfory/serialization.pyx
+++ b/python/pyfory/serialization.pyx
@@ -40,7 +40,7 @@ from pyfory._fory import (
NO_USER_TYPE_ID,
NOT_NULL_INT64_FLAG,
)
-from pyfory.meta.typedef_decoder import decode_typedef, skip_typedef
+from pyfory.meta.typedef_decoder import decode_typedef
from pyfory.meta.metastring import MetaStringDecoder
from pyfory.policy import DEFAULT_POLICY
from pyfory.resolver import NULL_FLAG, NOT_NULL_VALUE_FLAG
diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py
index a50fa9a38..f086a5c08 100644
--- a/python/pyfory/serializer.py
+++ b/python/pyfory/serializer.py
@@ -29,6 +29,7 @@ from typing import Tuple
from pyfory.serialization import Buffer
from pyfory.resolver import NULL_FLAG, NOT_NULL_VALUE_FLAG
+from pyfory.policy import DEFAULT_POLICY
try:
import numpy as np
@@ -45,6 +46,87 @@ _WINDOWS = os.name == "nt"
from pyfory.serialization import ENABLE_FORY_CYTHON_SERIALIZATION
+
+def _import_validated_module(policy, module_name):
+ result = policy.validate_module(module_name)
+ if result is not None:
+ if isinstance(result, types.ModuleType):
+ return result
+ assert isinstance(result, str), f"validate_module must return module,
str, or None, got {type(result)}"
+ module_name = result
+ return importlib.import_module(module_name)
+
+
+def _resolve_validated_module_attr(policy, module_name, attr_name):
+ module = _import_validated_module(policy, module_name)
+ return getattr(module, attr_name)
+
+
+def _resolve_validated_module_qualname(policy, module_name, qualname):
+ obj = _import_validated_module(policy, module_name)
+ for name in qualname.split("."):
+ obj = getattr(obj, name)
+ return obj
+
+
+def _check_collection_size(read_context, size, kind):
+ if size < 0:
+ raise ValueError(f"{kind} size {size} must be non-negative")
+ if size > read_context.max_collection_size:
+ raise ValueError(f"{kind} size {size} exceeds the configured limit of
{read_context.max_collection_size}")
+
+
+def _validate_function_value(policy, func, is_local):
+ if isinstance(func, type):
+ result = policy.validate_class(func, is_local=is_local)
+ if result is not None:
+ func = result
+ if isinstance(func, type):
+ raise TypeError(f"Function serializer resolved class
{func.__module__}.{func.__qualname__}")
+ if not callable(func):
+ raise TypeError(f"Function serializer resolved non-callable object
{func!r}")
+ result = policy.validate_function(func, is_local=is_local)
+ if result is not None:
+ func = result
+ return func
+
+
+def _bind_static_method(obj, method_name):
+ cls = obj if isinstance(obj, type) else obj.__class__
+ try:
+ attr = inspect.getattr_static(obj, method_name)
+ except AttributeError as exc:
+ raise ValueError(f"Cannot resolve method {method_name!r} safely") from
exc
+
+ if isinstance(attr, staticmethod):
+ method = attr.__func__
+ elif isinstance(attr, classmethod):
+ method = types.MethodType(attr.__func__, cls)
+ elif isinstance(attr, types.FunctionType):
+ method = types.MethodType(attr, obj)
+ elif isinstance(attr, (types.MethodType, types.BuiltinMethodType)):
+ method = attr
+ elif isinstance(attr, (types.MethodDescriptorType,
types.WrapperDescriptorType)):
+ method = attr.__get__(obj, cls)
+ elif callable(attr) and not hasattr(type(attr), "__get__"):
+ method = attr
+ else:
+ raise ValueError(f"Cannot resolve method {method_name!r} safely")
+ if not callable(method):
+ raise ValueError(f"Resolved method {method_name!r} is not callable")
+ return method
+
+
+def _resolve_validated_bound_method(policy, obj, method_name, is_local):
+ if policy is DEFAULT_POLICY:
+ return getattr(obj, method_name)
+ method = _bind_static_method(obj, method_name)
+ result = policy.validate_method(method, is_local=is_local)
+ if result is not None:
+ method = result
+ return method
+
+
# Keep the mode switch here instead of inside `_serializer.py`.
# In Cython mode the active hot-path serializer classes, including primitive,
# enum/slice, and collection serializers, must come from
`pyfory.serialization`.
@@ -646,11 +728,15 @@ class PythonNDArraySerializer(NDArraySerializer):
read_context.set_reader_index(reader_index)
dtype = np.dtype(read_context.read_string())
ndim = read_context.read_var_uint32()
+ _check_collection_size(read_context, ndim, "ndarray dimension")
shape = tuple(read_context.read_var_uint32() for _ in range(ndim))
if dtype.kind == "O":
length = read_context.read_varint32()
+ _check_collection_size(read_context, length, "ndarray object")
items = [read_context.read_ref() for _ in range(length)]
return np.array(items, dtype=object)
+ for dim in shape:
+ _check_collection_size(read_context, dim, "ndarray dimension")
fory_buf = read_context.read_buffer_object()
if isinstance(fory_buf, memoryview):
return np.frombuffer(fory_buf, dtype=dtype).reshape(shape)
@@ -788,6 +874,7 @@ class StatefulSerializer(Serializer):
kwargs = read_context.read_ref()
state = read_context.read_ref()
+ read_context.policy.authorize_instantiation(self.cls)
if args or kwargs:
# Case 1: __getnewargs__ was used. Re-create by calling __init__.
obj = self.cls(*args, **kwargs)
@@ -795,12 +882,30 @@ class StatefulSerializer(Serializer):
# Case 2: Only __getstate__ was used. Create without calling
__init__.
obj = self.cls.__new__(self.cls)
- if state:
+ if state is not None:
read_context.policy.intercept_setstate(obj, state)
obj.__setstate__(state)
return obj
+class _DefaultPolicyStatefulSerializer(StatefulSerializer):
+ def read(self, read_context):
+ args = read_context.read_ref()
+ kwargs = read_context.read_ref()
+ state = read_context.read_ref()
+
+ if args or kwargs:
+ # Case 1: __getnewargs__ was used. Re-create by calling __init__.
+ obj = self.cls(*args, **kwargs)
+ else:
+ # Case 2: Only __getstate__ was used. Create without calling
__init__.
+ obj = self.cls.__new__(self.cls)
+
+ if state is not None:
+ obj.__setstate__(state)
+ return obj
+
+
class ReduceSerializer(Serializer):
"""
Serializer for objects that support __reduce__ or __reduce_ex__.
@@ -817,6 +922,36 @@ class ReduceSerializer(Serializer):
self._getnewargs_ex = getattr(cls, "__getnewargs_ex__", None)
self._getnewargs = getattr(cls, "__getnewargs__", None)
+ def _validate_global_object(self, policy, obj):
+ result = None
+ if isinstance(obj, type):
+ result = policy.validate_class(obj, is_local=False)
+ elif isinstance(
+ obj,
+ (
+ types.FunctionType,
+ types.BuiltinFunctionType,
+ types.MethodType,
+ types.BuiltinMethodType,
+ ),
+ ):
+ result = policy.validate_function(obj, is_local=False)
+ if result is not None:
+ obj = result
+ return obj
+
+ def _resolve_global_name(self, read_context, global_name):
+ policy = read_context.policy
+ if "." in global_name:
+ module_name, obj_name = global_name.rsplit(".", 1)
+ else:
+ module_name, obj_name = "builtins", global_name
+ try:
+ obj = _resolve_validated_module_attr(policy, module_name, obj_name)
+ except AttributeError:
+ raise ValueError(f"Cannot resolve global name: {global_name}")
+ return self._validate_global_object(policy, obj)
+
def write(self, write_context, value):
# Try __reduce_ex__ first (with protocol 5 for pickle5 out-of-band
buffer support), then __reduce__
# Check if the object has a custom __reduce_ex__ method (not just the
default from object)
@@ -869,28 +1004,15 @@ class ReduceSerializer(Serializer):
def read(self, read_context):
reduce_data_num_items = read_context.read_var_uint32()
- assert reduce_data_num_items <= 6, read_context
+ if reduce_data_num_items > 6:
+ raise ValueError(f"Invalid reduce data length:
{reduce_data_num_items}")
reduce_data = [None] * 6
for i in range(reduce_data_num_items):
reduce_data[i] = read_context.read_ref()
if reduce_data[0] == 0:
# Case 1: Global name
- global_name = reduce_data[1]
- # Import and return the global object
- if "." in global_name:
- module_name, obj_name = global_name.rsplit(".", 1)
- module = __import__(module_name, fromlist=[obj_name])
- return getattr(module, obj_name)
- else:
- # Handle case where global_name doesn't contain a dot
- # This might be a built-in type or a simple name
- try:
- import builtins
-
- return getattr(builtins, global_name)
- except AttributeError:
- raise ValueError(f"Cannot resolve global name:
{global_name}")
+ return self._resolve_global_name(read_context, reduce_data[1])
elif reduce_data[0] == 1:
# Case 2-5: Callable with args and optional state/items
callable_obj = reduce_data[1]
@@ -902,10 +1024,13 @@ class ReduceSerializer(Serializer):
obj = read_context.policy.intercept_reduce_call(callable_obj, args)
if obj is None:
# Create the object using the callable and args
+ if isinstance(callable_obj, type):
+ read_context.policy.authorize_instantiation(callable_obj)
obj = callable_obj(*args)
# Restore state if present
if state is not None:
+ read_context.policy.intercept_setstate(obj, state)
if hasattr(obj, "__setstate__"):
obj.__setstate__(state)
else:
@@ -959,9 +1084,7 @@ class TypeSerializer(Serializer):
return self._deserialize_local_class(read_context)
module_name = read_context.read_string()
qualname = read_context.read_string()
- cls = importlib.import_module(module_name)
- for name in qualname.split("."):
- cls = getattr(cls, name)
+ cls = _resolve_validated_module_qualname(read_context.policy,
module_name, qualname)
result = read_context.policy.validate_class(cls, is_local=False)
if result is not None:
cls = result
@@ -1009,11 +1132,18 @@ class TypeSerializer(Serializer):
ref_id = read_context.last_preserved_ref_id()
num_bases = read_context.read_var_uint32()
+ _check_collection_size(read_context, num_bases, "local class base")
bases = tuple(read_context.read_ref() for _ in range(num_bases))
+ read_context.policy.authorize_instantiation(type, module=module,
qualname=qualname, bases=bases)
cls = type(name, bases, {})
read_context.set_read_ref(ref_id, cls)
+ result = read_context.policy.validate_class(cls, is_local=True)
+ if result is not None:
+ cls = result
- for _ in range(read_context.read_var_uint32()):
+ num_class_methods = read_context.read_var_uint32()
+ _check_collection_size(read_context, num_class_methods, "local class
method")
+ for _ in range(num_class_methods):
attr_name = read_context.read_string()
func = read_context.read_ref()
method = types.MethodType(func, cls)
@@ -1042,13 +1172,7 @@ class ModuleSerializer(Serializer):
def read(self, read_context):
mod_name = read_context.read_string()
- result = read_context.policy.validate_module(mod_name)
- if result is not None:
- if isinstance(result, types.ModuleType):
- return result
- assert isinstance(result, str), f"validate_module must return
module, str, or None, got {type(result)}"
- mod_name = result
- return importlib.import_module(mod_name)
+ return _import_validated_module(read_context.policy, mod_name)
class MappingProxySerializer(Serializer):
@@ -1195,25 +1319,20 @@ class FunctionSerializer(Serializer):
if func_type_id == 0:
self_obj = read_context.read_ref()
method_name = read_context.read_string()
- func = getattr(self_obj, method_name)
- result = read_context.policy.validate_function(func,
is_local=False)
- if result is not None:
- func = result
- return func
+ policy = read_context.policy
+ if policy is DEFAULT_POLICY:
+ return getattr(self_obj, method_name)
+ return _resolve_validated_bound_method(policy, self_obj,
method_name, is_local=False)
if func_type_id == 1:
module = read_context.read_string()
qualname = read_context.read_string()
- mod = importlib.import_module(module)
- for name in qualname.split("."):
- mod = getattr(mod, name)
- result = read_context.policy.validate_function(mod, is_local=False)
- if result is not None:
- mod = result
- return mod
+ mod = _resolve_validated_module_qualname(read_context.policy,
module, qualname)
+ return _validate_function_value(read_context.policy, mod,
is_local=False)
module = read_context.read_string()
qualname = read_context.read_string()
+ mod = _import_validated_module(read_context.policy, module)
name = qualname.rsplit(".")[-1]
marshalled_code = read_context.read_bytes_and_size()
@@ -1223,6 +1342,7 @@ class FunctionSerializer(Serializer):
defaults = None
if has_defaults:
num_defaults = read_context.read_var_uint32()
+ _check_collection_size(read_context, num_defaults, "function
default")
default_values = []
for _ in range(num_defaults):
default_values.append(read_context.read_ref())
@@ -1230,6 +1350,7 @@ class FunctionSerializer(Serializer):
has_closure = read_context.read_bool()
num_freevars = read_context.read_var_uint32()
+ _check_collection_size(read_context, num_freevars, "function closure")
closure = None
closure_values = []
@@ -1240,6 +1361,7 @@ class FunctionSerializer(Serializer):
closure = tuple(types.CellType(value) for value in closure_values)
num_freevars = read_context.read_var_uint32()
+ _check_collection_size(read_context, num_freevars, "function free
variable")
freevars = []
for _ in range(num_freevars):
freevars.append(read_context.read_string())
@@ -1248,12 +1370,8 @@ class FunctionSerializer(Serializer):
# Create a globals dictionary with module's globals as the base
func_globals = {}
- try:
- mod = importlib.import_module(module)
- if mod:
- func_globals.update(mod.__dict__)
- except (KeyError, AttributeError):
- pass
+ if mod:
+ func_globals.update(mod.__dict__)
func_globals.update(globals_dict)
@@ -1270,10 +1388,7 @@ class FunctionSerializer(Serializer):
for attr_name, attr_value in attrs.items():
setattr(func, attr_name, attr_value)
- result = read_context.policy.validate_function(func, is_local=True)
- if result is not None:
- func = result
- return func
+ return _validate_function_value(read_context.policy, func,
is_local=True)
def write(self, write_context, value):
self._serialize_function(write_context, value)
@@ -1299,14 +1414,15 @@ class NativeFuncMethodSerializer(Serializer):
name = read_context.read_string()
if read_context.read_bool():
module = read_context.read_string()
- mod = importlib.import_module(module)
- func = getattr(mod, name)
+ func = _resolve_validated_module_attr(read_context.policy, module,
name)
+ func = _validate_function_value(read_context.policy, func,
is_local=False)
else:
obj = read_context.read_ref()
- func = getattr(obj, name)
- result = read_context.policy.validate_function(func, is_local=False)
- if result is not None:
- func = result
+ policy = read_context.policy
+ if policy is DEFAULT_POLICY:
+ func = getattr(obj, name)
+ else:
+ func = _resolve_validated_bound_method(policy, obj, name,
is_local=False)
return func
@@ -1316,6 +1432,7 @@ class MethodSerializer(Serializer):
def __init__(self, type_resolver, cls):
super().__init__(type_resolver, cls)
self.cls = cls
+ self._use_default_policy = type_resolver.policy is DEFAULT_POLICY
def write(self, write_context, value):
instance = value.__self__
@@ -1328,13 +1445,11 @@ class MethodSerializer(Serializer):
instance = read_context.read_ref()
method_name = read_context.read_string()
- method = getattr(instance, method_name)
- cls = method.__self__.__class__
+ if self._use_default_policy:
+ return getattr(instance, method_name)
+ cls = instance if isinstance(instance, type) else instance.__class__
is_local = cls.__module__ == "__main__" or "<locals>" in
cls.__qualname__
- result = read_context.policy.validate_method(method, is_local=is_local)
- if result is not None:
- method = result
- return method
+ return _resolve_validated_bound_method(read_context.policy, instance,
method_name, is_local=is_local)
class ObjectSerializer(Serializer):
@@ -1367,10 +1482,31 @@ class ObjectSerializer(Serializer):
write_context.write_ref(field_value)
def read(self, read_context):
- read_context.policy.authorize_instantiation(self.type_)
+ policy = read_context.policy
+ policy.authorize_instantiation(self.type_)
+ obj = self.type_.__new__(self.type_)
+ read_context.reference(obj)
+ num_fields = read_context.read_var_uint32()
+ if num_fields > read_context.max_collection_size:
+ raise ValueError(f"object field size {num_fields} exceeds the
configured limit of {read_context.max_collection_size}")
+ state = {}
+ for _ in range(num_fields):
+ field_name = read_context.read_string()
+ field_value = read_context.read_ref()
+ state[field_name] = field_value
+ policy.intercept_setstate(obj, state)
+ for field_name, field_value in state.items():
+ setattr(obj, field_name, field_value)
+ return obj
+
+
+class _DefaultPolicyObjectSerializer(ObjectSerializer):
+ def read(self, read_context):
obj = self.type_.__new__(self.type_)
read_context.reference(obj)
num_fields = read_context.read_var_uint32()
+ if num_fields > read_context.max_collection_size:
+ raise ValueError(f"object field size {num_fields} exceeds the
configured limit of {read_context.max_collection_size}")
for _ in range(num_fields):
field_name = read_context.read_string()
field_value = read_context.read_ref()
diff --git a/python/pyfory/struct.pxi b/python/pyfory/struct.pxi
index 371ac15bb..7b4ee4259 100644
--- a/python/pyfory/struct.pxi
+++ b/python/pyfory/struct.pxi
@@ -41,6 +41,7 @@ cdef class DataClassSerializer(Serializer):
cdef public dict _type_hints
cdef public bint _has_slots
cdef public bint _fields_from_typedef
+ cdef public bint _has_missing_fields
cdef public list _field_names
cdef public list _serializers
cdef public dict _nullable_fields
@@ -235,6 +236,7 @@ cdef class DataClassSerializer(Serializer):
cdef FieldRuntimeInfo runtime_info
self._field_runtime_infos.clear()
+ self._has_missing_fields = False
current_fields = set(self._get_field_names(self.type_))
self._field_runtime_infos.reserve(len(self._field_names))
@@ -249,6 +251,8 @@ cdef class DataClassSerializer(Serializer):
runtime_info.track_ref = 1 if is_tracking_ref else 0
runtime_info.is_dynamic = 1 if is_dynamic else 0
runtime_info.field_exists = 1 if field_name in current_fields else 0
+ if runtime_info.field_exists == 0:
+ self._has_missing_fields = True
runtime_info.field_name = <PyObject *>self._field_name_interned[i]
runtime_info.serializer = <PyObject *>serializer
self._field_runtime_infos.push_back(runtime_info)
@@ -349,7 +353,7 @@ cdef class DataClassSerializer(Serializer):
cdef object obj
cdef int32_t read_hash
- if not read_context.strict:
+ if read_context.policy is not DEFAULT_POLICY:
read_context.policy.authorize_instantiation(self.type_)
if not read_context.compatible:
@@ -384,11 +388,20 @@ cdef class DataClassSerializer(Serializer):
cdef object field_name
cdef FieldRuntimeInfo *field_info
+ if not self._has_missing_fields:
+ for i in range(field_count):
+ field_info = &self._field_runtime_infos[i]
+ field_value = self._read_field_value(read_context, field_info)
+ field_name = <object>field_info.field_name
+ obj_dict[field_name] = field_value
+ return
+
for i in range(field_count):
field_info = &self._field_runtime_infos[i]
- field_value = self._read_field_value(read_context, field_info)
if field_info.field_exists == 0:
+ self._read_missing_field_value(read_context, field_info)
continue
+ field_value = self._read_field_value(read_context, field_info)
field_name = <object>field_info.field_name
obj_dict[field_name] = field_value
@@ -399,14 +412,32 @@ cdef class DataClassSerializer(Serializer):
cdef object field_name
cdef FieldRuntimeInfo *field_info
+ if not self._has_missing_fields:
+ for i in range(field_count):
+ field_info = &self._field_runtime_infos[i]
+ field_value = self._read_field_value(read_context, field_info)
+ field_name = <object>field_info.field_name
+ PyObject_SetAttr(obj, field_name, field_value)
+ return
+
for i in range(field_count):
field_info = &self._field_runtime_infos[i]
- field_value = self._read_field_value(read_context, field_info)
if field_info.field_exists == 0:
+ self._read_missing_field_value(read_context, field_info)
continue
+ field_value = self._read_field_value(read_context, field_info)
field_name = <object>field_info.field_name
PyObject_SetAttr(obj, field_name, field_value)
+ cdef inline object _read_missing_field_value(self, ReadContext
read_context, FieldRuntimeInfo *field_info):
+ cdef object resolver = self.type_resolver.resolver
+ cdef object previous = resolver._allow_unregistered_typedef
+ resolver._allow_unregistered_typedef = True
+ try:
+ return self._read_field_value(read_context, field_info)
+ finally:
+ resolver._allow_unregistered_typedef = previous
+
cdef inline object _read_field_value(self, ReadContext read_context,
FieldRuntimeInfo *field_info):
cdef uint8_t type_id = field_info.basic_type_id
cdef bint is_nullable = field_info.is_nullable != 0
diff --git a/python/pyfory/struct.py b/python/pyfory/struct.py
index 1658c28fc..ad4492e82 100644
--- a/python/pyfory/struct.py
+++ b/python/pyfory/struct.py
@@ -28,6 +28,7 @@ import typing
from typing import List, Dict
from pyfory.lib.mmh3 import hash_buffer
+from pyfory.policy import DEFAULT_POLICY
from pyfory.types import (
TypeId,
int8,
@@ -422,6 +423,7 @@ class DataClassSerializer(Serializer):
self._field_name_interned = {name: sys.intern(name) for name in
self._field_names}
self._current_class_field_names =
set(self._get_field_names(self.type_))
+ self._has_missing_fields = any(field_name not in
self._current_class_field_names for field_name in self._field_names)
self._default_values_factory = (
build_default_values_factory(self.type_resolver, self._type_hints,
dataclasses.fields(self.type_))
if dataclasses.is_dataclass(self.type_)
@@ -544,7 +546,7 @@ class DataClassSerializer(Serializer):
write_context.try_flush()
def read(self, read_context):
- if not self.type_resolver.strict:
+ if read_context.policy is not DEFAULT_POLICY:
read_context.policy.authorize_instantiation(self.type_)
if not self.type_resolver.compatible:
hash_ = read_context.read_int32()
@@ -555,20 +557,35 @@ class DataClassSerializer(Serializer):
obj = self.type_.__new__(self.type_)
read_context.reference(obj)
obj_dict = obj.__dict__ if not self._has_slots else None
- for index, field_name in enumerate(self._field_names):
- serializer = self._serializers[index]
- is_nullable = self._nullable_fields.get(field_name, False)
- is_dynamic = self._dynamic_fields.get(field_name, False)
- is_tracking_ref = self._ref_fields.get(field_name, False)
- is_basic = self._basic_field_flags[index]
- field_value = self._read_field_value(read_context, serializer,
is_nullable, is_dynamic, is_basic, is_tracking_ref)
- if field_name not in self._current_class_field_names:
- continue
- interned_name = self._field_name_interned[field_name]
- if obj_dict is not None:
- obj_dict[interned_name] = field_value
- else:
- setattr(obj, interned_name, field_value)
+ if self._has_missing_fields:
+ for index, field_name in enumerate(self._field_names):
+ serializer = self._serializers[index]
+ is_nullable = self._nullable_fields.get(field_name, False)
+ is_dynamic = self._dynamic_fields.get(field_name, False)
+ is_tracking_ref = self._ref_fields.get(field_name, False)
+ is_basic = self._basic_field_flags[index]
+ if field_name not in self._current_class_field_names:
+ self._read_missing_field_value(read_context, serializer,
is_nullable, is_dynamic, is_basic, is_tracking_ref)
+ continue
+ field_value = self._read_field_value(read_context, serializer,
is_nullable, is_dynamic, is_basic, is_tracking_ref)
+ interned_name = self._field_name_interned[field_name]
+ if obj_dict is not None:
+ obj_dict[interned_name] = field_value
+ else:
+ setattr(obj, interned_name, field_value)
+ else:
+ for index, field_name in enumerate(self._field_names):
+ serializer = self._serializers[index]
+ is_nullable = self._nullable_fields.get(field_name, False)
+ is_dynamic = self._dynamic_fields.get(field_name, False)
+ is_tracking_ref = self._ref_fields.get(field_name, False)
+ is_basic = self._basic_field_flags[index]
+ field_value = self._read_field_value(read_context, serializer,
is_nullable, is_dynamic, is_basic, is_tracking_ref)
+ interned_name = self._field_name_interned[field_name]
+ if obj_dict is not None:
+ obj_dict[interned_name] = field_value
+ else:
+ setattr(obj, interned_name, field_value)
if self._missing_field_defaults:
for field_name, default_factory in self._missing_field_defaults:
@@ -580,6 +597,14 @@ class DataClassSerializer(Serializer):
read_context.shrink_input_buffer()
return obj
+ def _read_missing_field_value(self, read_context, serializer, is_nullable,
is_dynamic, is_basic, is_tracking_ref):
+ previous = self.type_resolver._allow_unregistered_typedef
+ self.type_resolver._allow_unregistered_typedef = True
+ try:
+ return self._read_field_value(read_context, serializer,
is_nullable, is_dynamic, is_basic, is_tracking_ref)
+ finally:
+ self.type_resolver._allow_unregistered_typedef = previous
+
class DataClassStubSerializer(DataClassSerializer):
def __init__(self, type_resolver, clz: type):
diff --git a/python/pyfory/tests/test_buffer.py
b/python/pyfory/tests/test_buffer.py
index 021412e57..e9a569b69 100644
--- a/python/pyfory/tests/test_buffer.py
+++ b/python/pyfory/tests/test_buffer.py
@@ -149,6 +149,20 @@ def test_empty_buffer():
assert writable_buffer.get_int32(0) == 10
+def test_to_bytes_rejects_out_of_bounds_range():
+ buffer = Buffer(b"abc")
+ assert buffer.to_bytes(1) == b"bc"
+ assert buffer.to_bytes(1, 2) == b"bc"
+ with pytest.raises(ValueError, match="offset 99 out of bound"):
+ buffer.to_bytes(99, 1)
+ with pytest.raises(ValueError, match="out of bound"):
+ buffer.to_bytes(2, 2)
+
+
+def test_readline_without_newline_does_not_read_out_of_bounds():
+ assert Buffer(b"abc").readline() == b"abc"
+
+
def test_write_varint32():
buf = Buffer.allocate(32)
for i in range(1):
diff --git a/python/pyfory/tests/test_meta_share.py
b/python/pyfory/tests/test_meta_share.py
index 59ba7508b..764a67915 100644
--- a/python/pyfory/tests/test_meta_share.py
+++ b/python/pyfory/tests/test_meta_share.py
@@ -18,6 +18,8 @@
import dataclasses
from typing import Dict, List
+import pytest
+
import pyfory
from pyfory import Fory
@@ -109,6 +111,14 @@ class TestMetaShareMode:
deserialized = fory.deserialize(fory.serialize(obj))
assert deserialized == obj
+ def test_strict_reader_rejects_unknown_typedef(self):
+ writer = Fory(xlang=True, compatible=True, strict=False)
+ writer.register_type(SimpleDataClass)
+ reader = Fory(xlang=True, compatible=True, strict=True)
+
+ with pytest.raises(ValueError, match="not registered in strict mode"):
+ reader.deserialize(writer.serialize(SimpleDataClass(name="test",
age=25, active=True)))
+
def test_multiple_objects_same_type(self):
fory = Fory(xlang=True, compatible=True)
fory.register_type(SimpleDataClass)
diff --git a/python/pyfory/tests/test_policy.py
b/python/pyfory/tests/test_policy.py
index ef790ef74..7da8dd099 100644
--- a/python/pyfory/tests/test_policy.py
+++ b/python/pyfory/tests/test_policy.py
@@ -15,8 +15,54 @@
# specific language governing permissions and limitations
# under the License.
+import types
+
import pytest
from pyfory import Fory, DeserializationPolicy
+from pyfory.serializer import FunctionSerializer, NativeFuncMethodSerializer
+
+
+def policy_global_function():
+ return "safe"
+
+
+class FakeReadContext:
+ def __init__(self, policy, values):
+ self.policy = policy
+ self._values = iter(values)
+
+ def read_int8(self):
+ return next(self._values)
+
+ def read_bool(self):
+ return next(self._values)
+
+ def read_string(self):
+ return next(self._values)
+
+
+class FalseyState:
+ bool_called = False
+
+ def __bool__(self):
+ type(self).bool_called = True
+ return False
+
+
+class FalseyStatePayload:
+ def __getstate__(self):
+ return FalseyState()
+
+ def __setstate__(self, state):
+ self.state = state
+
+
+class ObjectSetAttrPayload:
+ setattr_called = False
+
+ def __setattr__(self, name, value):
+ type(self).setattr_called = True
+ super().__setattr__(name, value)
class BlockClassPolicy(DeserializationPolicy):
@@ -146,6 +192,73 @@ def test_sanitize_state():
assert result.password == "***REDACTED***"
+def test_reduce_state_sanitizes_state():
+ """Test sanitizing object state restored from __reduce__."""
+
+ class CountingSanitizePolicy(DeserializationPolicy):
+ def __init__(self):
+ self.intercept_setstate_calls = 0
+
+ def intercept_setstate(self, obj, state, **kwargs):
+ self.intercept_setstate_calls += 1
+ if isinstance(state, dict) and "password" in state:
+ state["password"] = "***REDACTED***"
+ return None
+
+ class SecretReduceHolder:
+ def __reduce__(self):
+ return (SecretReduceHolder, (), {"username": "admin", "password":
"secret123"})
+
+ def __setstate__(self, state):
+ self.__dict__.update(state)
+
+ policy = CountingSanitizePolicy()
+ fory = Fory(ref=True, strict=False, policy=policy)
+ data = fory.serialize(SecretReduceHolder())
+
+ result = fory.deserialize(data)
+ assert policy.intercept_setstate_calls == 1
+ assert result.username == "admin"
+ assert result.password == "***REDACTED***"
+
+
+def test_stateful_intercepts_falsey_state_before_bool():
+ """Test stateful path calls intercept_setstate without evaluating state
truthiness."""
+
+ class BlockSetStatePolicy(DeserializationPolicy):
+ def intercept_setstate(self, obj, state, **kwargs):
+ raise ValueError("state blocked")
+
+ FalseyState.bool_called = False
+ fory = Fory(ref=True, strict=False, policy=BlockSetStatePolicy())
+ data = fory.serialize(FalseyStatePayload())
+
+ with pytest.raises(ValueError, match="state blocked"):
+ fory.deserialize(data)
+ assert not FalseyState.bool_called
+
+
+def test_object_serializer_intercepts_state_before_setattr():
+ """Test object serializer state hook runs before applying
attacker-controlled fields."""
+
+ class BlockSetStatePolicy(DeserializationPolicy):
+ def intercept_setstate(self, obj, state, **kwargs):
+ raise ValueError("object state blocked")
+
+ obj = ObjectSetAttrPayload()
+ obj.value = 1
+ ObjectSetAttrPayload.setattr_called = False
+
+ writer = Fory(ref=True, strict=False)
+ reader = Fory(ref=True, strict=False, policy=BlockSetStatePolicy())
+ writer.register(ObjectSetAttrPayload)
+ reader.register(ObjectSetAttrPayload)
+
+ with pytest.raises(ValueError, match="object state blocked"):
+ reader.deserialize(writer.serialize(obj))
+ assert not ObjectSetAttrPayload.setattr_called
+
+
def test_policy_with_local_class():
"""Test policy intercepts local class deserialization."""
@@ -261,6 +374,94 @@ def test_policy_with_nested_reduce():
fory.deserialize(data)
+def test_stateful_authorizes_instantiation():
+ """Test authorize_instantiation policy hook for stateful
deserialization."""
+
+ class StatefulPayload:
+ def __init__(self):
+ self.value = 1
+
+ def __getstate__(self):
+ return {"value": self.value}
+
+ def __setstate__(self, state):
+ self.__dict__.update(state)
+
+ class BlockInstantiationPolicy(DeserializationPolicy):
+ def __init__(self):
+ self.authorize_instantiation_calls = 0
+
+ def authorize_instantiation(self, cls, **kwargs):
+ self.authorize_instantiation_calls += 1
+ if cls is StatefulPayload:
+ raise ValueError("StatefulPayload blocked")
+ return None
+
+ policy = BlockInstantiationPolicy()
+ fory = Fory(ref=True, strict=False, policy=policy)
+ with pytest.raises(ValueError, match="StatefulPayload blocked"):
+ fory.deserialize(fory.serialize(StatefulPayload()))
+ assert policy.authorize_instantiation_calls == 1
+
+
+def test_reduce_class_callable_authorizes_instantiation():
+ """Test authorize_instantiation policy hook for reduce class callables."""
+
+ class ReduceTarget:
+ pass
+
+ class ReducePayload:
+ def __reduce__(self):
+ return (ReduceTarget, ())
+
+ class BlockInstantiationPolicy(DeserializationPolicy):
+ def __init__(self):
+ self.authorize_instantiation_calls = 0
+ self.reduce_target_calls = 0
+
+ def authorize_instantiation(self, cls, **kwargs):
+ self.authorize_instantiation_calls += 1
+ if cls.__name__ == "ReduceTarget":
+ self.reduce_target_calls += 1
+ raise ValueError("ReduceTarget blocked")
+ return None
+
+ policy = BlockInstantiationPolicy()
+ fory = Fory(ref=True, strict=False, policy=policy)
+ with pytest.raises(ValueError, match="ReduceTarget blocked"):
+ fory.deserialize(fory.serialize(ReducePayload()))
+ assert policy.reduce_target_calls == 1
+
+
+def test_registered_dataclass_authorizes_instantiation_in_strict_mode():
+ """Test registered dataclass reads still honor authorize_instantiation."""
+ from dataclasses import dataclass
+
+ @dataclass
+ class StrictDataClass:
+ value: int
+
+ class BlockInstantiationPolicy(DeserializationPolicy):
+ def __init__(self):
+ self.authorize_instantiation_calls = 0
+
+ def authorize_instantiation(self, cls, **kwargs):
+ self.authorize_instantiation_calls += 1
+ if cls is StrictDataClass:
+ raise ValueError("StrictDataClass blocked")
+ return None
+
+ policy = BlockInstantiationPolicy()
+ writer = Fory(ref=True, strict=True)
+ reader = Fory(ref=True, strict=True, policy=policy)
+ writer.register(StrictDataClass)
+ reader.register(StrictDataClass)
+
+ with pytest.raises(ValueError, match="StrictDataClass blocked"):
+ reader.deserialize(writer.serialize(StrictDataClass(1)))
+ assert policy.authorize_instantiation_calls == 1
+
+
def test_validate_module():
"""Test validate_module policy hook for module deserialization."""
import json
@@ -291,3 +492,336 @@ def test_validate_module():
fory3 = Fory(ref=True, strict=False, policy=BlockPolicy())
with pytest.raises(ValueError, match="blocked"):
fory3.deserialize(fory3.serialize(json))
+
+
+def test_type_deserialization_validates_module():
+ """Test validate_module policy hook for global class deserialization."""
+ import subprocess
+
+ class BlockModulePolicy(DeserializationPolicy):
+ def __init__(self):
+ self.validate_module_calls = 0
+
+ def validate_module(self, module_name, **kwargs):
+ self.validate_module_calls += 1
+ if module_name == "subprocess":
+ raise ValueError("subprocess blocked")
+ return None
+
+ policy = BlockModulePolicy()
+ fory = Fory(ref=True, strict=False, policy=policy)
+ with pytest.raises(ValueError, match="subprocess blocked"):
+ fory.deserialize(fory.serialize(subprocess.Popen))
+ assert policy.validate_module_calls == 1
+
+
+def test_native_bound_method_uses_validate_method():
+ """Test bound native methods are checked by method policy, not function
policy."""
+
+ class BlockMethodPolicy(DeserializationPolicy):
+ def __init__(self):
+ self.validate_method_calls = 0
+ self.validate_function_calls = 0
+
+ def validate_method(self, method, is_local, **kwargs):
+ self.validate_method_calls += 1
+ raise ValueError("method blocked")
+
+ def validate_function(self, func, is_local, **kwargs):
+ self.validate_function_calls += 1
+ return None
+
+ policy = BlockMethodPolicy()
+ fory = Fory(ref=True, strict=False, policy=policy)
+
+ with pytest.raises(ValueError, match="method blocked"):
+ fory.deserialize(fory.serialize([].append))
+ assert policy.validate_method_calls == 1
+ assert policy.validate_function_calls == 0
+
+
+def test_bound_method_policy_runs_before_getattribute_side_effect():
+ """Test bound method deserialization validates before dynamic attribute
lookup."""
+
+ class GuardedMethod:
+ getattribute_called = False
+
+ def __getattribute__(self, name):
+ if name == "run":
+ type(self).getattribute_called = True
+ return super().__getattribute__(name)
+
+ def run(self):
+ return "unsafe"
+
+ class BlockMethodPolicy(DeserializationPolicy):
+ def validate_method(self, method, is_local, **kwargs):
+ raise ValueError("method blocked")
+
+ obj = GuardedMethod()
+ method = types.MethodType(GuardedMethod.run, obj)
+ fory = Fory(ref=True, strict=False, policy=BlockMethodPolicy())
+ data = fory.serialize(method)
+
+ GuardedMethod.getattribute_called = False
+ with pytest.raises(ValueError, match="method blocked"):
+ fory.deserialize(data)
+ assert not GuardedMethod.getattribute_called
+
+
+def test_function_serializer_rejects_class_resolution():
+ """Test function deserialization cannot resolve classes through the
function policy."""
+
+ class BlockClassPolicy(DeserializationPolicy):
+ def __init__(self):
+ self.validate_class_calls = 0
+ self.validate_function_calls = 0
+
+ def validate_class(self, cls, is_local, **kwargs):
+ self.validate_class_calls += 1
+ raise ValueError("class blocked")
+
+ def validate_function(self, func, is_local, **kwargs):
+ self.validate_function_calls += 1
+ return None
+
+ policy = BlockClassPolicy()
+ fory = Fory(ref=True, strict=False, policy=policy)
+ serializer = FunctionSerializer(fory.type_resolver,
type(policy_global_function))
+ read_context = FakeReadContext(policy, [1, "subprocess", "Popen"])
+
+ with pytest.raises(ValueError, match="class blocked"):
+ serializer._deserialize_function(read_context)
+ assert policy.validate_class_calls == 1
+ assert policy.validate_function_calls == 0
+
+
+def test_native_function_serializer_rejects_class_resolution():
+ """Test native function deserialization cannot resolve classes through the
function policy."""
+
+ class BlockClassPolicy(DeserializationPolicy):
+ def __init__(self):
+ self.validate_class_calls = 0
+ self.validate_function_calls = 0
+
+ def validate_class(self, cls, is_local, **kwargs):
+ self.validate_class_calls += 1
+ raise ValueError("class blocked")
+
+ def validate_function(self, func, is_local, **kwargs):
+ self.validate_function_calls += 1
+ return None
+
+ policy = BlockClassPolicy()
+ fory = Fory(ref=True, strict=False, policy=policy)
+ serializer = NativeFuncMethodSerializer(fory.type_resolver,
type(policy_global_function))
+ read_context = FakeReadContext(policy, ["Popen", True, "subprocess"])
+
+ with pytest.raises(ValueError, match="class blocked"):
+ serializer.read(read_context)
+ assert policy.validate_class_calls == 1
+ assert policy.validate_function_calls == 0
+
+
+def test_global_function_deserialization_validates_module():
+ """Test validate_module policy hook for global function deserialization."""
+
+ class BlockModulePolicy(DeserializationPolicy):
+ def __init__(self):
+ self.validate_module_calls = 0
+
+ def validate_module(self, module_name, **kwargs):
+ self.validate_module_calls += 1
+ if module_name == policy_global_function.__module__:
+ raise ValueError("function module blocked")
+ return None
+
+ policy = BlockModulePolicy()
+ fory = Fory(ref=True, strict=False, policy=policy)
+ with pytest.raises(ValueError, match="function module blocked"):
+ fory.deserialize(fory.serialize(policy_global_function))
+ assert policy.validate_module_calls == 1
+
+
+def test_local_function_deserialization_validates_module():
+ """Test validate_module policy hook for local function deserialization."""
+
+ def local_function():
+ return "safe"
+
+ class BlockModulePolicy(DeserializationPolicy):
+ def __init__(self):
+ self.validate_module_calls = 0
+
+ def validate_module(self, module_name, **kwargs):
+ self.validate_module_calls += 1
+ if module_name == local_function.__module__:
+ raise ValueError("local function module blocked")
+ return None
+
+ policy = BlockModulePolicy()
+ fory = Fory(ref=True, strict=False, policy=policy)
+ with pytest.raises(ValueError, match="local function module blocked"):
+ fory.deserialize(fory.serialize(local_function))
+ assert policy.validate_module_calls == 1
+
+
+def test_native_function_deserialization_validates_module():
+ """Test validate_module policy hook for native function deserialization."""
+ import time
+
+ class BlockModulePolicy(DeserializationPolicy):
+ def __init__(self):
+ self.validate_module_calls = 0
+
+ def validate_module(self, module_name, **kwargs):
+ self.validate_module_calls += 1
+ if module_name == "time":
+ raise ValueError("time blocked")
+ return None
+
+ policy = BlockModulePolicy()
+ fory = Fory(ref=True, strict=False, policy=policy)
+ with pytest.raises(ValueError, match="time blocked"):
+ fory.deserialize(fory.serialize(time.time))
+ assert policy.validate_module_calls == 1
+
+
+def test_type_metadata_load_validates_module():
+ """Test validate_module policy hook for by-name type metadata loading."""
+
+ class BlockModulePolicy(DeserializationPolicy):
+ def __init__(self):
+ self.validate_module_calls = 0
+
+ def validate_module(self, module_name, **kwargs):
+ self.validate_module_calls += 1
+ if module_name == "subprocess":
+ raise ValueError("subprocess blocked")
+ return None
+
+ policy = BlockModulePolicy()
+ fory = Fory(ref=True, strict=False, policy=policy)
+ from pyfory.registry import SharedRegistry, TypeResolver
+
+ resolver = TypeResolver(fory.config, shared_registry=SharedRegistry())
+ namespace = resolver.namespace_encoder.encode("subprocess")
+ ns_metabytes = resolver.shared_registry.get_encoded_meta_string(namespace)
+ typename = resolver.typename_encoder.encode("Popen")
+ type_metabytes = resolver.shared_registry.get_encoded_meta_string(typename)
+
+ with pytest.raises(ValueError, match="subprocess blocked"):
+ resolver._load_metabytes_to_type_info(ns_metabytes, type_metabytes)
+ assert policy.validate_module_calls == 1
+
+
+def test_type_metadata_load_validates_class():
+ """Test validate_class policy hook for by-name type metadata loading."""
+
+ class BlockClassPolicy(DeserializationPolicy):
+ def __init__(self):
+ self.validate_class_calls = 0
+
+ def validate_class(self, cls, is_local, **kwargs):
+ self.validate_class_calls += 1
+ if cls.__module__ == "subprocess" and cls.__name__ == "Popen":
+ raise ValueError("Popen blocked")
+ return None
+
+ policy = BlockClassPolicy()
+ fory = Fory(ref=True, strict=False, policy=policy)
+ from pyfory.registry import SharedRegistry, TypeResolver
+
+ resolver = TypeResolver(fory.config, shared_registry=SharedRegistry())
+ namespace = resolver.namespace_encoder.encode("subprocess")
+ ns_metabytes = resolver.shared_registry.get_encoded_meta_string(namespace)
+ typename = resolver.typename_encoder.encode("Popen")
+ type_metabytes = resolver.shared_registry.get_encoded_meta_string(typename)
+
+ with pytest.raises(ValueError, match="Popen blocked"):
+ resolver._load_metabytes_to_type_info(ns_metabytes, type_metabytes)
+ assert policy.validate_class_calls == 1
+
+
+def test_reduce_global_name_validates_module():
+ """Test validate_module policy hook for reduce global-name
deserialization."""
+
+ class GlobalNamePayload:
+ def __reduce__(self):
+ return "subprocess.Popen"
+
+ class BlockModulePolicy(DeserializationPolicy):
+ def __init__(self):
+ self.validate_module_calls = 0
+
+ def validate_module(self, module_name, **kwargs):
+ self.validate_module_calls += 1
+ if module_name == "subprocess":
+ raise ValueError(f"Module {module_name} blocked")
+ return None
+
+ policy = BlockModulePolicy()
+ fory = Fory(ref=True, strict=False, policy=policy)
+ with pytest.raises(ValueError, match="subprocess blocked"):
+ fory.deserialize(fory.serialize(GlobalNamePayload()))
+ assert policy.validate_module_calls == 1
+
+
+def test_reduce_global_name_validates_class():
+ """Test validate_class policy hook for reduce global-name
deserialization."""
+
+ class GlobalNamePayload:
+ def __reduce__(self):
+ return "subprocess.Popen"
+
+ class BlockClassPolicy(DeserializationPolicy):
+ def __init__(self):
+ self.validate_module_calls = 0
+ self.validate_class_calls = 0
+
+ def validate_module(self, module_name, **kwargs):
+ self.validate_module_calls += 1
+ return None
+
+ def validate_class(self, cls, is_local, **kwargs):
+ self.validate_class_calls += 1
+ if cls.__module__ == "subprocess" and cls.__name__ == "Popen":
+ raise ValueError("subprocess.Popen blocked")
+ return None
+
+ policy = BlockClassPolicy()
+ fory = Fory(ref=True, strict=False, policy=policy)
+ with pytest.raises(ValueError, match="subprocess.Popen blocked"):
+ fory.deserialize(fory.serialize(GlobalNamePayload()))
+ assert policy.validate_module_calls == 1
+ assert policy.validate_class_calls == 1
+
+
+def test_reduce_global_name_validates_function():
+ """Test validate_function policy hook for reduce builtins-name
deserialization."""
+
+ class GlobalNamePayload:
+ def __reduce__(self):
+ return "eval"
+
+ class BlockFunctionPolicy(DeserializationPolicy):
+ def __init__(self):
+ self.validate_module_calls = 0
+ self.validate_function_calls = 0
+
+ def validate_module(self, module_name, **kwargs):
+ self.validate_module_calls += 1
+ return None
+
+ def validate_function(self, func, is_local, **kwargs):
+ self.validate_function_calls += 1
+ if func.__name__ == "eval":
+ raise ValueError("eval blocked")
+ return None
+
+ policy = BlockFunctionPolicy()
+ fory = Fory(ref=True, strict=False, policy=policy)
+ with pytest.raises(ValueError, match="eval blocked"):
+ fory.deserialize(fory.serialize(GlobalNamePayload()))
+ assert policy.validate_module_calls == 1
+ assert policy.validate_function_calls == 1
diff --git a/python/pyfory/tests/test_size_guardrails.py
b/python/pyfory/tests/test_size_guardrails.py
index 700b2baa2..b199c4413 100644
--- a/python/pyfory/tests/test_size_guardrails.py
+++ b/python/pyfory/tests/test_size_guardrails.py
@@ -31,6 +31,11 @@ import pytest
import pyfory
from pyfory import Fory
from pyfory.serialization import Buffer
+from pyfory.types import TypeId
+
+
+class ObjectPayload:
+ pass
def roundtrip(data, limit, xlang=False, ref=False):
@@ -111,6 +116,49 @@ class TestCollectionSizeLimit:
with pytest.raises(ValueError, match="exceeds the configured limit"):
reader.deserialize(writer.serialize(Container(items=list(range(10)))))
+ def test_object_field_count_exceeds_limit(self):
+ obj = ObjectPayload()
+ obj.value = 1
+ writer = Fory(ref=True, strict=False)
+ reader = Fory(ref=True, strict=False, max_collection_size=0)
+ writer.register(ObjectPayload)
+ reader.register(ObjectPayload)
+
+ with pytest.raises(ValueError, match="object field size 1 exceeds"):
+ reader.deserialize(writer.serialize(obj))
+
+ def test_local_class_base_count_exceeds_limit(self):
+ def make_local_class():
+ class LocalPayload:
+ pass
+
+ return LocalPayload
+
+ writer = Fory(ref=True, strict=False)
+ reader = Fory(ref=True, strict=False, max_collection_size=0)
+
+ with pytest.raises(ValueError, match="local class base size 1
exceeds"):
+ reader.deserialize(writer.serialize(make_local_class()))
+
+ def test_local_function_defaults_exceed_limit(self):
+ def local_function(value=1):
+ return value
+
+ writer = Fory(ref=True, strict=False)
+ reader = Fory(ref=True, strict=False, max_collection_size=0)
+
+ with pytest.raises(ValueError, match="function default size 1
exceeds"):
+ reader.deserialize(writer.serialize(local_function))
+
+ def test_object_ndarray_length_exceeds_limit(self):
+ np = pytest.importorskip("numpy")
+ arr = np.array([object(), object()], dtype=object)
+ writer = Fory(ref=True, strict=False)
+ reader = Fory(ref=True, strict=False, max_collection_size=1)
+
+ with pytest.raises(ValueError, match="ndarray object size 2 exceeds"):
+ reader.deserialize(writer.serialize(arr))
+
class TestBinarySizeLimit:
"""Binary reads are guarded by max_binary_size on the Buffer."""
@@ -127,6 +175,13 @@ class TestBinarySizeLimit:
with pytest.raises(ValueError, match="exceeds the configured limit"):
roundtrip_binary(b"x" * 200, max_binary_size=100, xlang=xlang)
+ @pytest.mark.parametrize("xlang", [False, True])
+ def test_string_exceeds_limit_fails(self, xlang):
+ writer = Fory(xlang=xlang)
+ reader = Fory(xlang=xlang, max_binary_size=1)
+ with pytest.raises(ValueError, match="String size 2 exceeds"):
+ reader.deserialize(writer.serialize("xx"))
+
def test_from_stream_respects_limit(self):
import io
@@ -134,3 +189,15 @@ class TestBinarySizeLimit:
buf = Buffer.from_stream(io.BytesIO(payload), max_binary_size=100)
with pytest.raises(ValueError, match="exceeds the configured limit"):
Fory(max_binary_size=100).deserialize(buf)
+
+ def test_in_band_buffer_object_respects_limit(self):
+ payload = b"x" * 200
+ data = Fory(ref=True).serialize(payload, buffer_callback=lambda
_buffer: True)
+
+ with pytest.raises(ValueError, match="exceeds the configured limit"):
+ Fory(ref=True, max_binary_size=100).deserialize(data, buffers=[])
+
+ def test_malformed_metastring_ref_raises_value_error(self):
+ payload = bytes([2, 255, TypeId.NAMED_STRUCT, 3])
+ with pytest.raises(ValueError, match="Invalid dynamic metastring id"):
+ Fory(xlang=True, strict=False).deserialize(payload)
diff --git a/python/pyfory/type_util.py b/python/pyfory/type_util.py
index 60dc72007..d57b95b27 100644
--- a/python/pyfory/type_util.py
+++ b/python/pyfory/type_util.py
@@ -18,6 +18,7 @@
import dataclasses
import importlib
import inspect
+import types
import typing
from typing import TypeVar
@@ -337,17 +338,35 @@ def qualified_class_name(cls):
return cls.__module__ + "#" + cls.__qualname__
-def load_class(classname: str):
+def load_class(classname: str, policy=None):
mod_name, cls_name = classname.rsplit("#", 1)
- try:
- mod = importlib.import_module(mod_name)
- except ImportError as ex:
- raise Exception(f"Can't import module {mod_name}") from ex
+ if policy is not None:
+ result = policy.validate_module(mod_name)
+ if result is not None:
+ if isinstance(result, str):
+ mod_name = result
+ mod = None
+ else:
+ assert isinstance(result, types.ModuleType), f"validate_module
must return module, str, or None, got {type(result)}"
+ mod = result
+ else:
+ mod = None
+ else:
+ mod = None
+ if mod is None:
+ try:
+ mod = importlib.import_module(mod_name)
+ except ImportError as ex:
+ raise Exception(f"Can't import module {mod_name}") from ex
try:
classes = cls_name.split(".")
cls = getattr(mod, classes.pop(0))
while classes:
cls = getattr(cls, classes.pop(0))
+ if policy is not None:
+ result = policy.validate_class(cls, is_local=False)
+ if result is not None:
+ cls = result
return cls
except AttributeError as ex:
raise Exception(f"Can't import class {cls_name} from module
{mod_name}") from ex
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]