This is an automated email from the ASF dual-hosted git repository.

skrawcz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/hamilton.git

commit f6dc944ddcd8d35a98ad7e35ba9f122bf5b0c831
Author: Charles Swartz <[email protected]>
AuthorDate: Sat Apr 5 14:38:35 2025 -0400

    Initial implementation of `unpack_fields`
---
 hamilton/function_modifiers/__init__.py  |   1 +
 hamilton/function_modifiers/expanders.py | 157 ++++++++++++++++++++++++++++++-
 2 files changed, 157 insertions(+), 1 deletion(-)

diff --git a/hamilton/function_modifiers/__init__.py 
b/hamilton/function_modifiers/__init__.py
index 3113b13f..1e27fc90 100644
--- a/hamilton/function_modifiers/__init__.py
+++ b/hamilton/function_modifiers/__init__.py
@@ -59,6 +59,7 @@ parametrized_input = expanders.parametrized_input
 # Extract decorators
 extract_columns = expanders.extract_columns
 extract_fields = expanders.extract_fields
+unpack_fields = expanders.unpack_fields
 
 # does decorator
 does = macros.does
diff --git a/hamilton/function_modifiers/expanders.py 
b/hamilton/function_modifiers/expanders.py
index 62610f89..a5d13d23 100644
--- a/hamilton/function_modifiers/expanders.py
+++ b/hamilton/function_modifiers/expanders.py
@@ -3,7 +3,7 @@ import dataclasses
 import functools
 import inspect
 import typing
-from typing import Any, Callable, Collection, Dict, Tuple, Union
+from typing import Any, Callable, Collection, Dict, List, Tuple, Type, Union
 
 import typing_extensions
 import typing_inspect
@@ -18,6 +18,11 @@ from hamilton.function_modifiers.dependencies import (
     value,
 )
 
+try:
+    from typing import override
+except ImportError:
+    override = lambda x: x  # noqa E731
+
 """Decorators that enables DRY code by expanding one node into many"""
 
 
@@ -869,6 +874,156 @@ class extract_fields(base.SingleNodeNodeTransformer):
         return output_nodes
 
 
+def _process_unpack_fields(fields: List[str], output_type: Any) -> List[Type] 
| None:
+    """Processes the fields and base output type args to extract a tuple of 
field types.
+
+    :param fields: Tuple of fields to extract from the tuple.
+    :param args: Tuple of types to extract from the tuple.
+    :return: List of types.
+    """
+
+    base_type = typing_inspect.get_origin(output_type)
+    if base_type != tuple and base_type != Tuple:
+        return
+
+    output_args = typing_inspect.get_args(output_type)
+    num_ellipsis = output_args.count(Ellipsis)
+    if num_ellipsis > 1:
+        raise base.InvalidDecoratorException(
+            f"Invalid tuple: Found more than one ellipsis ('...'): 
{output_type}"
+        )
+    elif num_ellipsis == 1:
+        if len(output_args) != 2:
+            raise base.InvalidDecoratorException(
+                f"Invalid tuple: Ellipsis ('...') must be second element: 
{output_type}"
+            )
+        # Valid Indeterminate length tuple, e.g. `tuple[int, ...]`, 
`typing.Tuple[int, ...]`
+        output_args = tuple(output_args[0] for _ in range(len(fields)))
+
+    if len(output_args) < len(fields):
+        raise base.InvalidDecoratorException(
+            f"Number of unpacked fields ({len(fields)}) is greater than the 
number of fields in "
+            f"the output type ({len(output_args)}): {output_type}"
+        )
+
+    errors = []
+    field_types = []
+    for idx, arg in enumerate(output_args):
+        # Determine if the type is a valid type. Note that for Python <3.11, 
`Any` is not a type
+        if not (
+            isinstance(arg, type)
+            or arg is Any
+            or typing_inspect.is_generic_type(arg)
+            or typing_inspect.is_union_type(arg)
+        ):
+            field_name = fields[idx]
+            errors.append(f"Field {field_name} (index {idx}) does not declare 
a valid type: {arg}")
+        field_types.append(arg)
+
+    if errors:
+        raise base.InvalidDecoratorException(f"Found errors in the output 
type: {errors}")
+
+    return field_types
+
+
+class unpack_fields(base.SingleNodeNodeTransformer):
+    """Extracts fields from a tuple output.
+
+    Expands a single function into the following nodes:
+
+    - 1 function that outputs the original tuple
+    - n functions, each of which take in the original tuple and output a 
specific field
+
+    The decorated function must have an return type of either `tuple` (python 
3.9+) or
+    `typing.Tuple`, and must specify either:
+    - An explicit length tuple (e.g.`tuple[int, str]`, `typing.Tuple[int, 
str]`)
+    - An indeterminate length tuple (e.g. `tuple[int, ...]`, 
`typing.Tuple[int, ...]`)
+
+    :param fields: Fields to extract from the return value of the decorated 
function.
+    """
+
+    output_type: Any
+    field_types: List[Type]
+
+    def __init__(self, *fields: str):
+        super().__init__()
+        self.fields = list(fields)
+
+    @override
+    def validate(self, fn: Callable):
+        output_type = typing.get_type_hints(fn).get("return")
+        field_types = _process_unpack_fields(self.fields, output_type)
+        if field_types:
+            self.field_types = field_types
+            self.output_type = output_type
+        else:
+            message = (
+                f"For unpacking fields, the decorated function output type 
must be either an "
+                f"explicit length tuple (e.g.`tuple[int, str]`, 
`typing.Tuple[int, str]`) or an "
+                f"indeterminate length tuple (e.g. `tuple[int, ...]`, 
`typing.Tuple[int, ...]`), "
+                f"not: {output_type}"
+            )
+            raise base.InvalidDecoratorException(message)
+
+    @override
+    def transform_node(
+        self, node_: node.Node, config: Dict[str, Any], fn: Callable
+    ) -> Collection[node.Node]:
+        fn = node_.callable
+        base_doc = node_.documentation
+        base_tags = node_.tags.copy()
+
+        if inspect.iscoroutinefunction(fn):
+
+            async def tuple_generator(*args, **kwargs):  # type: ignore
+                tuple_generated = await fn(*args, **kwargs)
+                return tuple_generated
+
+        else:
+
+            def tuple_generator(*args, **kwargs):
+                tuple_generated = fn(*args, **kwargs)
+                return tuple_generated
+
+        output_nodes = [node_.copy_with(callabl=tuple_generator)]
+
+        for idx, (field_name, field_type) in enumerate(zip(self.fields, 
self.field_types)):
+            # NOTE: The extractors, as defined below, are constructed to avoid 
closure issues.
+            if inspect.iscoroutinefunction(fn):
+
+                async def extractor(field_index: int = idx, **kwargs) -> 
field_type:  # type: ignore
+                    dt = kwargs[node_.name]
+                    if field_index < 0 or field_index >= len(dt):
+                        raise base.InvalidDecoratorException(
+                            f"Out of bounds field: {node_.name} contains only 
{len(dt)} fields, "
+                            f"index requested: {field_index}. "
+                        )
+                    return kwargs[node_.name][field_index]
+
+            else:
+
+                def extractor(field_index: int = idx, **kwargs) -> field_type: 
 # type: ignore
+                    dt = kwargs[node_.name]
+                    if field_index < 0 or field_index >= len(dt):
+                        raise base.InvalidDecoratorException(
+                            f"Out of bounds field: {field_index} produced by 
{node_.name}. "
+                            f"It only produced {list(dt)} fields."
+                        )
+                    return kwargs[node_.name][field_index]
+
+            output_nodes.append(
+                node.Node(
+                    field_name,
+                    field_type,
+                    base_doc,
+                    extractor,
+                    input_types={node_.name: self.output_type},
+                    tags=base_tags,
+                )
+            )
+        return output_nodes
+
+
 @dataclasses.dataclass
 class ParameterizedExtract:
     """Dataclass to hold inputs for @parameterize and 
@parameterize_extract_columns.

Reply via email to