This is an automated email from the ASF dual-hosted git repository.
ebenizzy pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/hamilton.git
The following commit(s) were added to refs/heads/main by this push:
new d601a92f Update extract fields (#1305)
d601a92f is described below
commit d601a92fa026b3aed76984d66c8a5c6eb95632f6
Author: Charles Swartz <[email protected]>
AuthorDate: Tue Jul 15 00:26:24 2025 -0400
Update extract fields (#1305)
---
docs/concepts/function-modifiers.rst | 70 +++++-
hamilton/function_modifiers/expanders.py | 206 +++++++++++------
tests/function_modifiers/test_expanders.py | 340 ++++++++++++++++++++++-------
3 files changed, 468 insertions(+), 148 deletions(-)
diff --git a/docs/concepts/function-modifiers.rst
b/docs/concepts/function-modifiers.rst
index 63b206eb..6bbed9eb 100644
--- a/docs/concepts/function-modifiers.rst
+++ b/docs/concepts/function-modifiers.rst
@@ -140,7 +140,7 @@ The ``@check_output`` function modifiers are applied on the
**node output / func
.. note::
- In the future, validatation capabailities may be added to ``@schema``. For
now, it's only added metadata.
+ In the future, validation capabilities may be added to ``@schema``. For
now, it's only added metadata.
@check_output*
~~~~~~~~~~~~~~
@@ -201,7 +201,7 @@ A good example is splitting a dataset into training,
validation, and test splits
from typing import Tuple
from hamilton.function_modifiers import unpack_fields
- @unpack_fields("X_train" "X_validation", "X_test")
+ @unpack_fields("X_train", "X_validation", "X_test")
def dataset_splits(X: np.ndarray) -> Tuple[np.ndarray, np.ndarray,
np.ndarray]:
"""Randomly split data into train, validation, test"""
X_train, X_validation, X_test = random_split(X)
@@ -216,14 +216,14 @@ Now, ``X_train``, ``X_validation``, and ``X_test`` are
available to other nodes
@extract_fields
~~~~~~~~~~~~~~~
-Additionally, we can extract fields from an output dictionary using
``@extract_fields``. In this case, you must specify the dictionary keys and
their types. The function must return a dictionary that contains, at a minimum,
those keys specified in the decorator.
+Additionally, we can extract fields from an output dictionary using
``@extract_fields``. The function must return a dictionary that contains, at a
minimum, those keys specified in the decorator. In this case, you can specify a
dictionary of fields and their types:
.. code-block:: python
from typing import Dict
from hamilton.function_modifiers import extract_fields
- @extract_fields(dict( # don't forget the dictionary
+ @extract_fields(dict( # fields specified as a dictionary
X_train=np.ndarray,
X_validation=np.ndarray,
X_test=np.ndarray,
@@ -240,6 +240,68 @@ Additionally, we can extract fields from an output
dictionary using ``@extract_f
.. image:: ./_function-modifiers/extract_fields.png
:height: 250px
+Or if you are using a generic dictionary, you can specify solely the field
names.
+
+.. code-block:: python
+
+ from typing import Dict
+ from hamilton.function_modifiers import extract_fields
+
+ @extract_fields("X_train", "X_validation", "X_test") # field names only
+ def dataset_splits(X: np.ndarray) -> Dict[str, np.ndarray]: # generic dict
+ """Randomly split data into train, validation, test"""
+ X_train, X_validation, X_test = random_split(X)
+ return dict(
+ X_train=X_train,
+ X_validation=X_validation,
+ X_test=X_test,
+ )
+
+If you are using a `TypedDict`, you can specify the just field names.
+
+.. code-block:: python
+
+ from typing import TypedDict
+ from hamilton.function_modifiers import extract_fields
+
+ class DatasetSplits(TypedDict):
+ X_train: np.ndarray
+ X_validation: np.ndarray
+ X_test: np.ndarray
+
+ @extract_fields("X_train", "X_validation", "X_test")
+ def dataset_splits(X: np.ndarray) -> DatasetSplits:
+ """Randomly split data into train, validation, test"""
+ X_train, X_validation, X_test = random_split(X)
+ return dict(
+ X_train=X_train,
+ X_validation=X_validation,
+ X_test=X_test,
+ )
+
+
+Or you can leave the field names empty and extract all fields from the
`TypedDict`.
+
+.. code-block:: python
+
+ from typing import TypedDict
+ from hamilton.function_modifiers import extract_fields
+
+ class DatasetSplits(TypedDict):
+ X_train: np.ndarray
+ X_validation: np.ndarray
+ X_test: np.ndarray
+
+ @extract_fields(DatasetSplits) # field names only
+ def dataset_splits(X: np.ndarray) -> DatasetSplits:
+ """Randomly split data into train, validation, test"""
+ X_train, X_validation, X_test = random_split(X)
+ return dict(
+ X_train=X_train,
+ X_validation=X_validation,
+ X_test=X_test,
+ )
+
Again, ``X_train``, ``X_validation``, and ``X_test`` are now available to
other nodes, or you can query the ``dataset_splits`` node to retrieve all
splits in a dictionary.
diff --git a/hamilton/function_modifiers/expanders.py
b/hamilton/function_modifiers/expanders.py
index d04cf6ee..a74f2aa2 100644
--- a/hamilton/function_modifiers/expanders.py
+++ b/hamilton/function_modifiers/expanders.py
@@ -3,7 +3,7 @@ import dataclasses
import functools
import inspect
import typing
-from typing import Any, Callable, Collection, Dict, List, Tuple, Type, Union
+from typing import Any, Callable, Collection, Dict, List, Optional, Tuple,
Type, Union
import typing_extensions
import typing_inspect
@@ -699,6 +699,89 @@ class extract_columns(base.SingleNodeNodeTransformer):
return output_nodes
+def _determine_fields_to_extract(
+ fields: Optional[Union[Dict[str, Any], List[str]]], output_type: Any
+) -> Dict[str, Any]:
+ """Determines which fields to extract based on user requested fields and
the output type of
+ the return type of the function.
+
+ :param fields: Dict of fields to extract.
+ :param output_type: The output type of the node function.
+ :return: List of field types.
+ """
+
+ output_type_error = (
+ f"For extracting fields, the decorated function output type must be a
`dict` or a "
+ f"`typing.Dict` with or without type parameters (i.e. `dict[str, int]`
or "
+ f"`typing.Dict[str, int]`), not: {output_type}"
+ )
+
+ if output_type == dict or output_type == Dict:
+ # NOTE: typing_inspect.is_generic_type(typing.Dict) without type
parameters returns True,
+ # so we need to address the bare dictionaries first before
generics.
+ if fields is None or not isinstance(fields, dict):
+ raise base.InvalidDecoratorException(
+ "When extracting fields from a function that returns a bare
`dict` output without "
+ "type parameters, you must supply a `dict` mapping field names
to types."
+ )
+ elif typing_inspect.is_generic_type(output_type):
+ base_type = typing_inspect.get_origin(output_type)
+ if base_type != dict and base_type != Dict:
+ raise base.InvalidDecoratorException(output_type_error)
+ if fields is None:
+ raise base.InvalidDecoratorException(
+ "When extracting fields from a function that returns a generic
`dict`, you must "
+ "supply either a `dict` (`typing.Dict`) mapping field names to
types or "
+ "alternatively a `list` (`typing.List`) of field names."
+ )
+ output_args = typing_inspect.get_args(output_type)
+ if len(output_args) != 2:
+ raise base.InvalidDecoratorException(
+ f"When extracting fields from a function that returns a
generic `dict`, you "
+ f"must specify only two type parameters (key, value), not
{output_args}."
+ )
+ if isinstance(fields, list):
+ fields = {field: output_args[1] for field in fields} # Infer type
from annotation
+ elif typing_extensions.is_typeddict(output_type):
+ typed_dict_fields = typing.get_type_hints(output_type) # Dict of
field name -> type
+ errors = []
+ if fields is None:
+ fields = typed_dict_fields # Infer fields and types from
annotation
+ elif isinstance(fields, list):
+ reduced_fields = {}
+ for field in fields:
+ if field not in typed_dict_fields:
+ errors.append(f"{field} is not a field in the `TypedDict`
{output_type}.")
+ reduced_fields[field] = typed_dict_fields[field]
+ fields = reduced_fields
+ elif isinstance(fields, dict):
+ for field_name, field_type in fields.items():
+ expected_type = typed_dict_fields.get(field_name, None)
+ if expected_type is None:
+ errors.append(f"{field_name} is not a field in the
`TypedDict` {output_type}.")
+ continue
+ elif expected_type == field_type or
htypes.custom_subclass_check(
+ field_type, expected_type
+ ):
+ continue
+ errors.append(
+ f"Error {field_name} did not match the TypedDict
annotation's field "
+ f"{field_type}. Expected {expected_type}."
+ )
+ if errors:
+ raise base.InvalidDecoratorException(
+ f"Error {fields} did not match a subset of the TypedDict
annotation's fields "
+ f"{typed_dict_fields}. The following fields were not valid:
{errors}."
+ )
+ else:
+ raise base.InvalidDecoratorException(output_type_error)
+
+ assert isinstance(fields, dict), "Internal error: fields should be a dict
at this point."
+ _validate_extract_fields(fields)
+
+ return fields
+
+
def _validate_extract_fields(fields: dict):
"""Validates the fields dict for extract field.
Rules are:
@@ -739,61 +822,43 @@ def _validate_extract_fields(fields: dict):
class extract_fields(base.SingleNodeNodeTransformer):
"""Extracts fields from a dictionary of output."""
- def __init__(self, fields: dict = None, fill_with: Any = None):
+ output_type: Any
+ resolved_fields: Dict[str, Type]
+
+ def __init__(
+ self,
+ fields: Optional[Union[Dict[str, Any], List[str], Any]] = None,
+ *others,
+ fill_with: Any = None,
+ ):
"""Constructor for a modifier that expands a single function into the
following nodes:
- n functions, each of which take in the original dict and output a
specific field
- 1 function that outputs the original dict
- :param fields: Fields to extract. A dict of 'field_name' ->
'field_type'.
+ :param fields: Fields to extract. Can be a dict of field names to
types, a list of field names, or a single field name.
+ :param others: Additional fields names to extract - argument
unpacking. Ignored if `fields` is a dict.
:param fill_with: If you want to extract a field that doesn't exist,
do you want to fill it with a default \
value? Or do you want to error out? Leave empty/None to error out, set
fill_value to dynamically create a \
field value.
"""
super(extract_fields, self).__init__()
+ if isinstance(fields, list):
+ fields = fields + list(others)
+ elif fields and not isinstance(fields, dict):
+ fields = [fields] + list(others)
self.fields = fields
self.fill_with = fill_with
def validate(self, fn: Callable):
- """A function is invalid if it is not annotated with a dict or
typing.Dict return type.
+ """A function is invalid if it is not annotated with a dict or
typing.Dict return type or if the
+ fields to extract are not valid.
:param fn: Function to validate.
:raises: InvalidDecoratorException If the function is not annotated
with a dict or typing.Dict type as output.
"""
- output_type = typing.get_type_hints(fn).get("return")
- if typing_inspect.is_generic_type(output_type):
- base_type = typing_inspect.get_origin(output_type)
- if base_type == dict or base_type == Dict:
- _validate_extract_fields(self.fields)
- else:
- raise base.InvalidDecoratorException(
- f"For extracting fields, output type must be a dict or
typing.Dict, not: {output_type}"
- )
- elif output_type == dict:
- _validate_extract_fields(self.fields)
- elif typing_extensions.is_typeddict(output_type):
- if self.fields is None:
- self.fields = typing.get_type_hints(output_type)
- else:
- # check that fields is a subset of TypedDict that is defined
- typed_dict_fields = typing.get_type_hints(output_type)
- for field_name, field_type in self.fields.items():
- expected_type = typed_dict_fields.get(field_name, None)
- if expected_type == field_type:
- pass # we're definitely good
- elif expected_type is not None and
htypes.custom_subclass_check(
- field_type, expected_type
- ):
- pass
- else:
- raise base.InvalidDecoratorException(
- f"Error {self.fields} did not match a subset of
the TypedDict annotation's fields {typed_dict_fields}."
- )
- _validate_extract_fields(self.fields)
- else:
- raise base.InvalidDecoratorException(
- f"For extracting fields, output type must be a dict or
typing.Dict, not: {output_type}"
- )
+ self.output_type = typing.get_type_hints(fn).get("return")
+ self.resolved_fields = _determine_fields_to_extract(self.fields,
self.output_type)
def transform_node(
self, node_: node.Node, config: Dict[str, Any], fn: Callable
@@ -813,53 +878,38 @@ class extract_fields(base.SingleNodeNodeTransformer):
# if fn is async
if inspect.iscoroutinefunction(fn):
- async def dict_generator(*args, **kwargs):
+ async def dict_generator(*args, **kwargs): # type: ignore
dict_generated = await fn(*args, **kwargs)
if self.fill_with is not None:
- for field in self.fields:
+ for field in self.resolved_fields:
if field not in dict_generated:
dict_generated[field] = self.fill_with
return dict_generated
else:
- def dict_generator(*args, **kwargs):
+ def dict_generator(*args, **kwargs): # type: ignore
dict_generated = fn(*args, **kwargs)
if self.fill_with is not None:
- for field in self.fields:
+ for field in self.resolved_fields:
if field not in dict_generated:
dict_generated[field] = self.fill_with
return dict_generated
output_nodes = [node_.copy_with(callabl=dict_generator)]
- for field, field_type in self.fields.items():
+ for field, field_type in self.resolved_fields.items():
doc_string = base_doc # default doc string of base function.
- # if fn is async
- if inspect.iscoroutinefunction(fn):
-
- async def extractor_fn(field_to_extract: str = field,
**kwargs) -> field_type:
- dt = kwargs[node_.name]
- if field_to_extract not in dt:
- raise base.InvalidDecoratorException(
- f"No such field: {field_to_extract} produced by
{node_.name}. "
- f"It only produced {list(dt.keys())}"
- )
- return kwargs[node_.name][field_to_extract]
-
- else:
-
- def extractor_fn(
- field_to_extract: str = field, **kwargs
- ) -> field_type: # avoiding problems with closures
- dt = kwargs[node_.name]
- if field_to_extract not in dt:
- raise base.InvalidDecoratorException(
- f"No such field: {field_to_extract} produced by
{node_.name}. "
- f"It only produced {list(dt.keys())}"
- )
- return kwargs[node_.name][field_to_extract]
+ # This extractor is constructed to avoid closure issues.
+ def extractor_fn(field_to_extract: str = field, **kwargs) ->
field_type: # type: ignore
+ dt = kwargs[node_.name]
+ if field_to_extract not in dt:
+ raise base.InvalidDecoratorException(
+ f"No such field: {field_to_extract} produced by
{node_.name}. "
+ f"It only produced {list(dt.keys())}"
+ )
+ return kwargs[node_.name][field_to_extract]
output_nodes.append(
node.Node(
@@ -867,15 +917,16 @@ class extract_fields(base.SingleNodeNodeTransformer):
field_type,
doc_string,
extractor_fn,
- input_types={node_.name: dict},
+ input_types={node_.name: self.output_type},
tags=node_.tags.copy(),
)
)
return output_nodes
-def _process_unpack_fields(fields: List[str], output_type: Any) -> List[Type]:
- """Processes the fields and base output type to extract a list of field
types.
+def _determine_fields_to_unpack(fields: List[str], output_type: Any) ->
List[Type]:
+ """Determines which fields to unpack based on user requested fields and
the output type of
+ the return type of the function.
:param fields: List of fields to to unpack.
:param output_type: The output type of the node function.
@@ -957,8 +1008,13 @@ class unpack_fields(base.SingleNodeNodeTransformer):
@override
def validate(self, fn: Callable):
+ """Validates that the return type of the function is a tuple or
typing.Tuple with the
+
+ :param fn: Function to validate
+ :raises: InvalidDecoratorException If the function does not output a
tuple or typing.Tuple type.
+ """
output_type = typing.get_type_hints(fn).get("return")
- field_types = _process_unpack_fields(self.fields, output_type)
+ field_types = _determine_fields_to_unpack(self.fields, output_type)
self.field_types = field_types
self.output_type = output_type
@@ -966,6 +1022,14 @@ class unpack_fields(base.SingleNodeNodeTransformer):
def transform_node(
self, node_: node.Node, config: Dict[str, Any], fn: Callable
) -> Collection[node.Node]:
+ """Unpacks the specified fields form the tuple output into separate
nodes.
+
+ :param node_: Node to transform
+ :param config: Config to use
+ :param fn: Function to unpack fields from. Must output a tuple.
+ :return: A collection of nodes --
+ one for the original tuple generator, and another for each
field to unpack.
+ """
fn = node_.callable
base_doc = node_.documentation
base_tags = node_.tags.copy()
diff --git a/tests/function_modifiers/test_expanders.py
b/tests/function_modifiers/test_expanders.py
index f272cfd5..33292fe4 100644
--- a/tests/function_modifiers/test_expanders.py
+++ b/tests/function_modifiers/test_expanders.py
@@ -333,24 +333,6 @@ class MyDictBad(TypedDict):
test2: str
[email protected](
- "return_type",
- [
- dict,
- Dict,
- Dict[str, str],
- Dict[str, Any],
- MyDict,
- ],
-)
-def test_extract_fields_validate_happy(return_type):
- def return_dict() -> return_type:
- return {}
-
- annotation = function_modifiers.extract_fields({"test": int})
- annotation.validate(return_dict)
-
-
class SomeObject:
pass
@@ -369,95 +351,306 @@ class MyDictInheritanceBadCase(TypedDict):
test2: str
-def test_extract_fields_validate_happy_inheritance():
- def return_dict() -> MyDictInheritance:
- return {}
-
- annotation = function_modifiers.extract_fields({"test": InheritedObject})
- annotation.validate(return_dict)
-
-
-def test_extract_fields_validate_not_subclass():
- def return_dict() -> MyDictInheritanceBadCase:
- return {}
-
- annotation = function_modifiers.extract_fields({"test": SomeObject})
- with pytest.raises(base.InvalidDecoratorException):
- annotation.validate(return_dict)
-
-
@pytest.mark.parametrize(
- "return_type",
- [(int), (list), (np.ndarray), (pd.DataFrame), (MyDictBad)],
+ "return_type_str,fields",
+ [
+ ("Dict[str, int]", ("A", "B")),
+ ("Dict[str, int]", (["A", "B"])),
+ ("Dict", {"A": str, "B": int}),
+ ("MyDict", ()),
+ ("MyDict", {"test2": str}),
+ ("MyDictInheritance", {"test": InheritedObject}),
+ pytest.param("dict[str, int]", ("A", "B"),
marks=skipif(**prior_to_py39)),
+ pytest.param("dict[str, int]", (["A", "B"]),
marks=skipif(**prior_to_py39)),
+ pytest.param("dict", {"A": str, "B": int},
marks=skipif(**prior_to_py39)),
+ ],
)
-def test_extract_fields_validate_errors(return_type):
- def return_dict() -> return_type:
- return {}
-
- annotation = function_modifiers.extract_fields({"test": int})
- with
pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException):
- annotation.validate(return_dict)
+def test_extract_fields_valid_annotations_for_inferred_types(return_type_str,
fields):
+ return_type = eval(return_type_str)
+ def function() -> return_type: # type: ignore
+ return {} # Only testing validation, so return value doesn't matter
-def test_extract_fields_typeddict_empty_fields():
- def return_dict() -> MyDict:
- return {}
+ if isinstance(fields, tuple):
+ annotation = function_modifiers.extract_fields(*fields)
+ else:
+ annotation = function_modifiers.extract_fields(fields)
+ annotation.validate(function)
- # don't need fields for TypedDict
- annotation = function_modifiers.extract_fields()
- annotation.validate(return_dict)
[email protected](
+ "return_type_str,fields",
+ [
+ ("Dict", ("A", "B")),
+ ("Dict", (["A", "B"])),
+ ("Dict", (["A"])),
+ ("Dict", (["A", "B", "C"])),
+ ("int", {"A": int}),
+ ("list", {"A": int}),
+ ("np.ndarray", {"A": int}),
+ ("pd.DataFrame", {"A": int}),
+ ("MyDictBad", {"A": int}),
+ ("MyDictInheritanceBadCase", {"A": SomeObject}),
+ pytest.param("dict", ("A", "B"), marks=skipif(**prior_to_py39)),
+ pytest.param("dict", (["A", "B"]), marks=skipif(**prior_to_py39)),
+ pytest.param("dict", (["A"]), marks=skipif(**prior_to_py39)),
+ pytest.param("dict", (["A", "B", "C"]), marks=skipif(**prior_to_py39)),
+ ],
+)
+def
test_extract_fields_invalid_annotations_for_inferred_types(return_type_str,
fields):
+ return_type = eval(return_type_str)
-def test_extract_fields_typeddict_subset():
- def return_dict() -> MyDict:
- return {}
+ def function() -> return_type: # type: ignore
+ return {} # Only testing validation, so return value doesn't matter
- # test that a subset of fields is fine
- annotation = function_modifiers.extract_fields({"test2": str})
- annotation.validate(return_dict)
+ if isinstance(fields, tuple):
+ annotation = function_modifiers.extract_fields(*fields)
+ else:
+ annotation = function_modifiers.extract_fields(fields)
+ with
pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException):
+ annotation.validate(function)
-def test_valid_extract_fields():
- """Tests whole extract_fields decorator."""
+def test_extract_fields_transform_on_bare_dict_with_explicit_types():
+ """Tests whole extract_fields decorator using a bare, non-generic, dict
and explicit types."""
annotation = function_modifiers.extract_fields(
{"col_1": list, "col_2": int, "col_3": np.ndarray}
)
- def dummy_dict_generator() -> dict:
+ def dummy_dict() -> dict: # bare dict, not generic
"""dummy doc"""
return {"col_1": [1, 2, 3, 4], "col_2": 1, "col_3": np.ndarray([1, 2,
3, 4])}
- nodes = list(
- annotation.transform_node(node.Node.from_fn(dummy_dict_generator), {},
dummy_dict_generator)
- )
+ annotation.validate(dummy_dict)
+ nodes = list(annotation.transform_node(node.Node.from_fn(dummy_dict), {},
dummy_dict))
+
assert len(nodes) == 4
assert nodes[0] == node.Node(
- name=dummy_dict_generator.__name__,
+ name=dummy_dict.__name__,
typ=dict,
- doc_string=dummy_dict_generator.__doc__,
- callabl=dummy_dict_generator,
+ doc_string=getattr(dummy_dict, "__doc__", ""),
+ callabl=dummy_dict,
tags={"module": "tests.function_modifiers.test_expanders"},
)
assert nodes[1].name == "col_1"
assert nodes[1].type == list
assert nodes[1].documentation == "dummy doc" # we default to base
function doc.
- assert nodes[1].input_types == {dummy_dict_generator.__name__: (dict,
DependencyType.REQUIRED)}
+ assert nodes[1].input_types == {dummy_dict.__name__: (dict,
DependencyType.REQUIRED)}
assert nodes[2].name == "col_2"
assert nodes[2].type == int
assert nodes[2].documentation == "dummy doc"
- assert nodes[2].input_types == {dummy_dict_generator.__name__: (dict,
DependencyType.REQUIRED)}
+ assert nodes[2].input_types == {dummy_dict.__name__: (dict,
DependencyType.REQUIRED)}
assert nodes[3].name == "col_3"
assert nodes[3].type == np.ndarray
assert nodes[3].documentation == "dummy doc"
- assert nodes[3].input_types == {dummy_dict_generator.__name__: (dict,
DependencyType.REQUIRED)}
+ assert nodes[3].input_types == {dummy_dict.__name__: (dict,
DependencyType.REQUIRED)}
+
+
+def test_extract_fields_transform_on_generic_dict_with_explicit_types():
+ """Tests whole extract_fields decorator using a generic dict and explicit
types."""
+ annotation = function_modifiers.extract_fields({"col_1": int, "col_2":
int})
+
+ def dummy_dict() -> Dict[str, int]:
+ """dummy doc"""
+ return {"col_1": 1, "col_2": 2}
+
+ annotation.validate(dummy_dict)
+ nodes = list(annotation.transform_node(node.Node.from_fn(dummy_dict), {},
dummy_dict))
+
+ assert len(nodes) == 3
+ assert nodes[0] == node.Node(
+ name=dummy_dict.__name__,
+ typ=Dict[str, int],
+ doc_string=getattr(dummy_dict, "__doc__", ""),
+ callabl=dummy_dict,
+ tags={"module": "tests.function_modifiers.test_expanders"},
+ )
+
+ assert nodes[1].name == "col_1"
+ assert nodes[1].type == int
+ assert nodes[1].documentation == "dummy doc" # we default to base
function doc.
+ assert nodes[1].input_types == {dummy_dict.__name__: (Dict[str, int],
DependencyType.REQUIRED)}
+ assert nodes[2].name == "col_2"
+ assert nodes[2].type == int
+ assert nodes[2].documentation == "dummy doc"
+ assert nodes[2].input_types == {dummy_dict.__name__: (Dict[str, int],
DependencyType.REQUIRED)}
+
+
+def test_extract_fields_transform_on_generic_dict_with_field_list():
+ """Tests whole extract_fields decorator using a generic dict and a list of
field names."""
+ annotation = function_modifiers.extract_fields(["col_1", "col_2"])
+
+ def dummy_dict() -> Dict[str, int]:
+ """dummy doc"""
+ return {"col_1": 1, "col_2": 2}
+
+ annotation.validate(dummy_dict)
+ nodes = list(annotation.transform_node(node.Node.from_fn(dummy_dict), {},
dummy_dict))
+
+ assert len(nodes) == 3
+ assert nodes[0] == node.Node(
+ name=dummy_dict.__name__,
+ typ=Dict[str, int],
+ doc_string=getattr(dummy_dict, "__doc__", ""),
+ callabl=dummy_dict,
+ tags={"module": "tests.function_modifiers.test_expanders"},
+ )
+
+ assert nodes[1].name == "col_1"
+ assert nodes[1].type == int
+ assert nodes[1].documentation == "dummy doc" # we default to base
function doc.
+ assert nodes[1].input_types == {dummy_dict.__name__: (Dict[str, int],
DependencyType.REQUIRED)}
+ assert nodes[2].name == "col_2"
+ assert nodes[2].type == int
+ assert nodes[2].documentation == "dummy doc"
+ assert nodes[2].input_types == {dummy_dict.__name__: (Dict[str, int],
DependencyType.REQUIRED)}
+
+
+def test_extract_fields_transform_on_generic_dict_with_unpacked_fields():
+ """Tests whole extract_fields decorator using a generic dict and unpacked
field names."""
+ annotation = function_modifiers.extract_fields("col_1", "col_2")
+
+ def dummy_dict() -> Dict[str, int]:
+ """dummy doc"""
+ return {"col_1": 1, "col_2": 2}
+
+ annotation.validate(dummy_dict)
+ nodes = list(annotation.transform_node(node.Node.from_fn(dummy_dict), {},
dummy_dict))
+
+ assert len(nodes) == 3
+ assert nodes[0] == node.Node(
+ name=dummy_dict.__name__,
+ typ=Dict[str, int],
+ doc_string=getattr(dummy_dict, "__doc__", ""),
+ callabl=dummy_dict,
+ tags={"module": "tests.function_modifiers.test_expanders"},
+ )
+
+ assert nodes[1].name == "col_1"
+ assert nodes[1].type == int
+ assert nodes[1].documentation == "dummy doc" # we default to base
function doc.
+ assert nodes[1].input_types == {dummy_dict.__name__: (Dict[str, int],
DependencyType.REQUIRED)}
+ assert nodes[2].name == "col_2"
+ assert nodes[2].type == int
+ assert nodes[2].documentation == "dummy doc"
+ assert nodes[2].input_types == {dummy_dict.__name__: (Dict[str, int],
DependencyType.REQUIRED)}
+
+
+def test_extract_fields_transform_on_typed_dict_with_explicit_types():
+ """Tests whole extract_fields decorator using a TypedDict and explicit
types."""
+ annotation = function_modifiers.extract_fields({"test2": str})
+
+ def dummy_dict() -> MyDict:
+ """dummy doc"""
+ return {"test": 1, "test2": "2"}
+
+ annotation.validate(dummy_dict)
+ nodes = list(annotation.transform_node(node.Node.from_fn(dummy_dict), {},
dummy_dict))
+
+ assert len(nodes) == 2
+ assert nodes[0] == node.Node(
+ name=dummy_dict.__name__,
+ typ=MyDict,
+ doc_string=getattr(dummy_dict, "__doc__", ""),
+ callabl=dummy_dict,
+ tags={"module": "tests.function_modifiers.test_expanders"},
+ )
+
+ assert nodes[1].name == "test2"
+ assert nodes[1].type == str
+ assert nodes[1].documentation == "dummy doc"
+ assert nodes[1].input_types == {dummy_dict.__name__: (MyDict,
DependencyType.REQUIRED)}
+
+
+def test_extract_fields_transform_on_typed_dict_with_field_list():
+ """Tests whole extract_fields decorator using a TypedDict and a list of
field names."""
+ annotation = function_modifiers.extract_fields(["test2"])
+
+ def dummy_dict() -> MyDict:
+ """dummy doc"""
+ return {"test": 1, "test2": "2"}
+
+ annotation.validate(dummy_dict)
+ nodes = list(annotation.transform_node(node.Node.from_fn(dummy_dict), {},
dummy_dict))
+
+ assert len(nodes) == 2
+ assert nodes[0] == node.Node(
+ name=dummy_dict.__name__,
+ typ=MyDict,
+ doc_string=getattr(dummy_dict, "__doc__", ""),
+ callabl=dummy_dict,
+ tags={"module": "tests.function_modifiers.test_expanders"},
+ )
+
+ assert nodes[1].name == "test2"
+ assert nodes[1].type == str
+ assert nodes[1].documentation == "dummy doc"
+ assert nodes[1].input_types == {dummy_dict.__name__: (MyDict,
DependencyType.REQUIRED)}
+
+
+def test_extract_fields_transform_on_typed_dict_with_unpacked_fields():
+ """Tests whole extract_fields decorator using a TypedDict and explicit
types."""
+ annotation = function_modifiers.extract_fields("test2")
+
+ def dummy_dict() -> MyDict:
+ """dummy doc"""
+ return {"test": 1, "test2": "2"}
+
+ annotation.validate(dummy_dict)
+ nodes = list(annotation.transform_node(node.Node.from_fn(dummy_dict), {},
dummy_dict))
+
+ assert len(nodes) == 2
+ assert nodes[0] == node.Node(
+ name=dummy_dict.__name__,
+ typ=MyDict,
+ doc_string=getattr(dummy_dict, "__doc__", ""),
+ callabl=dummy_dict,
+ tags={"module": "tests.function_modifiers.test_expanders"},
+ )
+
+ assert nodes[1].name == "test2"
+ assert nodes[1].type == str
+ assert nodes[1].documentation == "dummy doc"
+ assert nodes[1].input_types == {dummy_dict.__name__: (MyDict,
DependencyType.REQUIRED)}
+
+
+def test_extract_fields_transform_on_typed_dict_with_inferred_types():
+ """Tests whole extract_fields decorator using a TypedDict and inferred
types."""
+ annotation = function_modifiers.extract_fields()
+
+ def dummy_dict() -> MyDict:
+ """dummy doc"""
+ return {"test": 1, "test2": "2"}
+
+ annotation.validate(dummy_dict)
+ nodes = list(annotation.transform_node(node.Node.from_fn(dummy_dict), {},
dummy_dict))
+
+ assert len(nodes) == 3
+ assert nodes[0] == node.Node(
+ name=dummy_dict.__name__,
+ typ=MyDict,
+ doc_string=getattr(dummy_dict, "__doc__", ""),
+ callabl=dummy_dict,
+ tags={"module": "tests.function_modifiers.test_expanders"},
+ )
+
+ assert nodes[1].name == "test"
+ assert nodes[1].type == int
+ assert nodes[1].documentation == "dummy doc" # we default to base
function doc.
+ assert nodes[1].input_types == {dummy_dict.__name__: (MyDict,
DependencyType.REQUIRED)}
+ assert nodes[2].name == "test2"
+ assert nodes[2].type == str
+ assert nodes[2].documentation == "dummy doc"
+ assert nodes[2].input_types == {dummy_dict.__name__: (MyDict,
DependencyType.REQUIRED)}
-def test_extract_fields_fill_with():
+def test_extract_fields_transform_using_fill_with():
def dummy_dict() -> dict:
"""dummy doc"""
return {"col_1": [1, 2, 3, 4], "col_2": 1, "col_3": np.ndarray([1, 2,
3, 4])}
annotation = function_modifiers.extract_fields({"col_2": int, "col_4":
float}, fill_with=1.0)
+ annotation.validate(dummy_dict)
original_node, extracted_field_node, missing_field_node =
annotation.transform_node(
node.Node.from_fn(dummy_dict), {}, dummy_dict
)
@@ -468,18 +661,19 @@ def test_extract_fields_fill_with():
assert missing_field == 1.0
-def test_extract_fields_no_fill_with():
+def test_extract_fields_transform_not_using_fill_with():
def dummy_dict() -> dict:
"""dummy doc"""
return {"col_1": [1, 2, 3, 4], "col_2": 1, "col_3": np.ndarray([1, 2,
3, 4])}
annotation = function_modifiers.extract_fields({"col_4": int})
+ annotation.validate(dummy_dict)
nodes = list(annotation.transform_node(node.Node.from_fn(dummy_dict), {},
dummy_dict))
with
pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException):
nodes[1].callable(dummy_dict=dummy_dict())
-def test_unpack_fields_valid_explicit_tuple():
+def test_unpack_fields_transform_on_explicit_tuple():
def dummy() -> Tuple[int, str, int]:
"""dummy doc"""
return 1, "2", 3
@@ -510,7 +704,7 @@ def test_unpack_fields_valid_explicit_tuple():
assert nodes[3].input_types == {dummy.__name__: (Tuple[int, str, int],
DependencyType.REQUIRED)}
-def test_unpack_fields_valid_explicit_tuple_subset():
+def test_unpack_fields_transform_on_explicit_tuple_subset():
def dummy() -> Tuple[int, str, int]:
"""dummy doc"""
return 1, "2", 3
@@ -533,7 +727,7 @@ def test_unpack_fields_valid_explicit_tuple_subset():
assert nodes[1].input_types == {dummy.__name__: (Tuple[int, str, int],
DependencyType.REQUIRED)}
-def test_unpack_fields_valid_indeterminate_tuple():
+def test_unpack_fields_transform_on_indeterminate_tuple():
def dummy() -> Tuple[int, ...]:
"""dummy doc"""
return 1, 2, 3