This is an automated email from the ASF dual-hosted git repository.

skrawcz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/hamilton.git

commit 092a1d4e612ec8e7608a9be8fd49771fdafb025e
Author: Charles Swartz <[email protected]>
AuthorDate: Sun Apr 6 08:31:58 2025 -0400

    Initial tests for `unpack_fields`
---
 tests/function_modifiers/test_expanders.py | 134 ++++++++++++++++++++++++++++-
 1 file changed, 133 insertions(+), 1 deletion(-)

diff --git a/tests/function_modifiers/test_expanders.py 
b/tests/function_modifiers/test_expanders.py
index 68741b8e..9baa96fd 100644
--- a/tests/function_modifiers/test_expanders.py
+++ b/tests/function_modifiers/test_expanders.py
@@ -1,5 +1,5 @@
 import sys
-from typing import Any, Dict, List, Optional, Type, TypedDict
+from typing import Any, Dict, List, Optional, Tuple, Type, TypedDict
 
 import numpy as np
 import pandas as pd
@@ -19,6 +19,9 @@ from hamilton.function_modifiers.dependencies import (
 from hamilton.htypes import Collect, Parallelizable
 from hamilton.node import DependencyType
 
+if sys.version_info >= (3, 9):
+    skip_if_before_39 = pytest.mark.skipif(sys.version_info < (3, 9), 
reason="Requires Python 3.9+")
+
 
 def test_parametrized_invalid_params():
     annotation = function_modifiers.parameterize_values(
@@ -468,6 +471,135 @@ def test_extract_fields_no_fill_with():
         nodes[1].callable(dummy_dict=dummy_dict())
 
 
+def test_unpack_fields_valid_explicit_tuple():
+    def dummy() -> Tuple[int, str, int]:
+        """dummy doc"""
+        return 1, "2", 3
+
+    annotation = function_modifiers.unpack_fields("A", "B", "C")
+    annotation.validate(dummy)
+    nodes = list(annotation.transform_node(node.Node.from_fn(dummy), {}, 
dummy))
+    assert len(nodes) == 4
+
+    assert nodes[0] == node.Node(
+        name=dummy.__name__,
+        typ=Tuple[int, str, int],
+        doc_string=getattr(dummy, "__doc__", ""),
+        callabl=dummy,
+        tags={"module": "tests.function_modifiers.test_expanders"},
+    )
+    assert nodes[1].name == "A"
+    assert nodes[1].type == int
+    assert nodes[1].documentation == "dummy doc"
+    assert nodes[1].input_types == {dummy.__name__: (Tuple[int, str, int], 
DependencyType.REQUIRED)}
+    assert nodes[2].name == "B"
+    assert nodes[2].type == str
+    assert nodes[2].documentation == "dummy doc"
+    assert nodes[2].input_types == {dummy.__name__: (Tuple[int, str, int], 
DependencyType.REQUIRED)}
+    assert nodes[3].name == "C"
+    assert nodes[3].type == int
+    assert nodes[3].documentation == "dummy doc"
+    assert nodes[3].input_types == {dummy.__name__: (Tuple[int, str, int], 
DependencyType.REQUIRED)}
+
+
+def test_unpack_fields_valid_explicit_tuple_subset():
+    def dummy() -> Tuple[int, str, int]:
+        """dummy doc"""
+        return 1, "2", 3
+
+    annotation = function_modifiers.unpack_fields("A")
+    annotation.validate(dummy)
+    nodes = list(annotation.transform_node(node.Node.from_fn(dummy), {}, 
dummy))
+    assert len(nodes) == 2
+
+    assert nodes[0] == node.Node(
+        name=dummy.__name__,
+        typ=Tuple[int, str, int],
+        doc_string=getattr(dummy, "__doc__", ""),
+        callabl=dummy,
+        tags={"module": "tests.function_modifiers.test_expanders"},
+    )
+    assert nodes[1].name == "A"
+    assert nodes[1].type == int
+    assert nodes[1].documentation == "dummy doc"
+    assert nodes[1].input_types == {dummy.__name__: (Tuple[int, str, int], 
DependencyType.REQUIRED)}
+
+
+def test_upack_fields_valid_indeterminate_tuple():
+    def dummy() -> Tuple[int, ...]:
+        """dummy doc"""
+        return 1, 2, 3
+
+    annotation = function_modifiers.unpack_fields("A", "B", "C")
+    annotation.validate(dummy)
+    nodes = list(annotation.transform_node(node.Node.from_fn(dummy), {}, 
dummy))
+    assert len(nodes) == 4
+
+    assert nodes[0] == node.Node(
+        name=dummy.__name__,
+        typ=Tuple[int, ...],
+        doc_string=getattr(dummy, "__doc__", ""),
+        callabl=dummy,
+        tags={"module": "tests.function_modifiers.test_expanders"},
+    )
+    assert nodes[1].name == "A"
+    assert nodes[1].type == int
+    assert nodes[1].documentation == "dummy doc"
+    assert nodes[1].input_types == {dummy.__name__: (Tuple[int, ...], 
DependencyType.REQUIRED)}
+    assert nodes[2].name == "B"
+    assert nodes[2].type == int
+    assert nodes[2].documentation == "dummy doc"
+    assert nodes[2].input_types == {dummy.__name__: (Tuple[int, ...], 
DependencyType.REQUIRED)}
+    assert nodes[3].name == "C"
+    assert nodes[3].type == int
+    assert nodes[3].documentation == "dummy doc"
+    assert nodes[3].input_types == {dummy.__name__: (Tuple[int, ...], 
DependencyType.REQUIRED)}
+
+
[email protected](
+    "return_type,fields",
+    [
+        (Tuple[int, int], ("A", "B")),
+        (Tuple[int, int, str], ("A", "B", "C")),
+        (Tuple[int, int, str], ("A", "B", "C")),
+        pytest.param(tuple[int, int], ("A", "B"), marks=skip_if_before_39),
+        pytest.param(tuple[int, int, str], ("A", "B", "C"), 
marks=skip_if_before_39),
+        pytest.param(tuple[int, int, str], ("A", "B", "C"), 
marks=skip_if_before_39),
+    ],
+)
+def test_unpack_fields_valid_type_annotations(return_type, fields):
+    def function() -> return_type:
+        return 1, 2, "3"  # Only testing validation, so return value doesn't 
matter
+
+    annotation = function_modifiers.unpack_fields(*fields)
+    annotation.validate(function)
+
+
[email protected](
+    "return_type,fields",
+    [
+        (int, ("A",)),
+        (list, ("A",)),
+        (dict, ("A",)),
+        (Tuple, ("A",)),
+        (Tuple[...], ("A", "B")),
+        (Tuple[int, int, ...], ("A", "B")),
+        (Tuple[int, int], ("A", "B", "C", "D")),
+        pytest.param(tuple, ("A",), marks=skip_if_before_39),
+        pytest.param(tuple[...], ("A", "B"), marks=skip_if_before_39),
+        pytest.param(tuple[int, int, ...], ("A", "B"), 
marks=skip_if_before_39),
+        pytest.param(tuple[int, int], ("A", "B", "C", "D"), 
marks=skip_if_before_39),
+    ],
+)
+def test_unpack_fields_invalid_type_annotations(return_type, fields):
+    def function() -> return_type:
+        return 1, 2, 3  # Only testing validation, so return value doesn't 
matter
+
+    annotation = function_modifiers.unpack_fields(*fields)
+    with 
pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException):
+        annotation.validate(function)
+
+
 def concat(upstream_parameter: str, literal_parameter: str) -> Any:
     """Concatenates {upstream_parameter} with literal_parameter"""
     return f"{upstream_parameter}{literal_parameter}"

Reply via email to