This is an automated email from the ASF dual-hosted git repository. skrawcz pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/hamilton.git
commit f6dc944ddcd8d35a98ad7e35ba9f122bf5b0c831 Author: Charles Swartz <[email protected]> AuthorDate: Sat Apr 5 14:38:35 2025 -0400 Initial implementation of `unpack_fields` --- hamilton/function_modifiers/__init__.py | 1 + hamilton/function_modifiers/expanders.py | 157 ++++++++++++++++++++++++++++++- 2 files changed, 157 insertions(+), 1 deletion(-) diff --git a/hamilton/function_modifiers/__init__.py b/hamilton/function_modifiers/__init__.py index 3113b13f..1e27fc90 100644 --- a/hamilton/function_modifiers/__init__.py +++ b/hamilton/function_modifiers/__init__.py @@ -59,6 +59,7 @@ parametrized_input = expanders.parametrized_input # Extract decorators extract_columns = expanders.extract_columns extract_fields = expanders.extract_fields +unpack_fields = expanders.unpack_fields # does decorator does = macros.does diff --git a/hamilton/function_modifiers/expanders.py b/hamilton/function_modifiers/expanders.py index 62610f89..a5d13d23 100644 --- a/hamilton/function_modifiers/expanders.py +++ b/hamilton/function_modifiers/expanders.py @@ -3,7 +3,7 @@ import dataclasses import functools import inspect import typing -from typing import Any, Callable, Collection, Dict, Tuple, Union +from typing import Any, Callable, Collection, Dict, List, Tuple, Type, Union import typing_extensions import typing_inspect @@ -18,6 +18,11 @@ from hamilton.function_modifiers.dependencies import ( value, ) +try: + from typing import override +except ImportError: + override = lambda x: x # noqa E731 + """Decorators that enables DRY code by expanding one node into many""" @@ -869,6 +874,156 @@ class extract_fields(base.SingleNodeNodeTransformer): return output_nodes +def _process_unpack_fields(fields: List[str], output_type: Any) -> List[Type] | None: + """Processes the fields and base output type args to extract a tuple of field types. + + :param fields: Tuple of fields to extract from the tuple. + :param args: Tuple of types to extract from the tuple. + :return: List of types. + """ + + base_type = typing_inspect.get_origin(output_type) + if base_type != tuple and base_type != Tuple: + return + + output_args = typing_inspect.get_args(output_type) + num_ellipsis = output_args.count(Ellipsis) + if num_ellipsis > 1: + raise base.InvalidDecoratorException( + f"Invalid tuple: Found more than one ellipsis ('...'): {output_type}" + ) + elif num_ellipsis == 1: + if len(output_args) != 2: + raise base.InvalidDecoratorException( + f"Invalid tuple: Ellipsis ('...') must be second element: {output_type}" + ) + # Valid Indeterminate length tuple, e.g. `tuple[int, ...]`, `typing.Tuple[int, ...]` + output_args = tuple(output_args[0] for _ in range(len(fields))) + + if len(output_args) < len(fields): + raise base.InvalidDecoratorException( + f"Number of unpacked fields ({len(fields)}) is greater than the number of fields in " + f"the output type ({len(output_args)}): {output_type}" + ) + + errors = [] + field_types = [] + for idx, arg in enumerate(output_args): + # Determine if the type is a valid type. Note that for Python <3.11, `Any` is not a type + if not ( + isinstance(arg, type) + or arg is Any + or typing_inspect.is_generic_type(arg) + or typing_inspect.is_union_type(arg) + ): + field_name = fields[idx] + errors.append(f"Field {field_name} (index {idx}) does not declare a valid type: {arg}") + field_types.append(arg) + + if errors: + raise base.InvalidDecoratorException(f"Found errors in the output type: {errors}") + + return field_types + + +class unpack_fields(base.SingleNodeNodeTransformer): + """Extracts fields from a tuple output. + + Expands a single function into the following nodes: + + - 1 function that outputs the original tuple + - n functions, each of which take in the original tuple and output a specific field + + The decorated function must have an return type of either `tuple` (python 3.9+) or + `typing.Tuple`, and must specify either: + - An explicit length tuple (e.g.`tuple[int, str]`, `typing.Tuple[int, str]`) + - An indeterminate length tuple (e.g. `tuple[int, ...]`, `typing.Tuple[int, ...]`) + + :param fields: Fields to extract from the return value of the decorated function. + """ + + output_type: Any + field_types: List[Type] + + def __init__(self, *fields: str): + super().__init__() + self.fields = list(fields) + + @override + def validate(self, fn: Callable): + output_type = typing.get_type_hints(fn).get("return") + field_types = _process_unpack_fields(self.fields, output_type) + if field_types: + self.field_types = field_types + self.output_type = output_type + else: + message = ( + f"For unpacking fields, the decorated function output type must be either an " + f"explicit length tuple (e.g.`tuple[int, str]`, `typing.Tuple[int, str]`) or an " + f"indeterminate length tuple (e.g. `tuple[int, ...]`, `typing.Tuple[int, ...]`), " + f"not: {output_type}" + ) + raise base.InvalidDecoratorException(message) + + @override + def transform_node( + self, node_: node.Node, config: Dict[str, Any], fn: Callable + ) -> Collection[node.Node]: + fn = node_.callable + base_doc = node_.documentation + base_tags = node_.tags.copy() + + if inspect.iscoroutinefunction(fn): + + async def tuple_generator(*args, **kwargs): # type: ignore + tuple_generated = await fn(*args, **kwargs) + return tuple_generated + + else: + + def tuple_generator(*args, **kwargs): + tuple_generated = fn(*args, **kwargs) + return tuple_generated + + output_nodes = [node_.copy_with(callabl=tuple_generator)] + + for idx, (field_name, field_type) in enumerate(zip(self.fields, self.field_types)): + # NOTE: The extractors, as defined below, are constructed to avoid closure issues. + if inspect.iscoroutinefunction(fn): + + async def extractor(field_index: int = idx, **kwargs) -> field_type: # type: ignore + dt = kwargs[node_.name] + if field_index < 0 or field_index >= len(dt): + raise base.InvalidDecoratorException( + f"Out of bounds field: {node_.name} contains only {len(dt)} fields, " + f"index requested: {field_index}. " + ) + return kwargs[node_.name][field_index] + + else: + + def extractor(field_index: int = idx, **kwargs) -> field_type: # type: ignore + dt = kwargs[node_.name] + if field_index < 0 or field_index >= len(dt): + raise base.InvalidDecoratorException( + f"Out of bounds field: {field_index} produced by {node_.name}. " + f"It only produced {list(dt)} fields." + ) + return kwargs[node_.name][field_index] + + output_nodes.append( + node.Node( + field_name, + field_type, + base_doc, + extractor, + input_types={node_.name: self.output_type}, + tags=base_tags, + ) + ) + return output_nodes + + @dataclasses.dataclass class ParameterizedExtract: """Dataclass to hold inputs for @parameterize and @parameterize_extract_columns.
