This is an automated email from the ASF dual-hosted git repository. skrawcz pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/hamilton.git
commit 5d57508c96fcfe21ec087376ede8f26b4360987e Author: Dev-iL <[email protected]> AuthorDate: Sat Feb 14 22:01:37 2026 +0200 Remove outdated numpy<2 version pins spark - supports numpy 2 as of version 4.0.0 --- .github/workflows/hamilton-main.yml | 12 ++++++------ hamilton/plugins/numpy_extensions.py | 16 ++++++++-------- pyproject.toml | 6 +++--- tests/test_node.py | 12 ++++++++++-- 4 files changed, 27 insertions(+), 19 deletions(-) diff --git a/.github/workflows/hamilton-main.yml b/.github/workflows/hamilton-main.yml index 49cad6c9..f13bb657 100644 --- a/.github/workflows/hamilton-main.yml +++ b/.github/workflows/hamilton-main.yml @@ -122,7 +122,7 @@ jobs: run: | sudo apt-get install --no-install-recommends --yes default-jre uv sync --group test --extra pyspark - uv pip install 'numpy<2' 'pyspark[connect]' 'grpcio' + uv pip install 'pyspark[connect]' 'grpcio' uv pip install --no-cache --reinstall --strict 'grpcio-status >= 1.48.1' uv run pytest plugin_tests/h_spark @@ -132,23 +132,23 @@ jobs: PYSPARK_SUBMIT_ARGS: "--conf spark.sql.ansi.enabled=false pyspark-shell" run: | uv sync --group test --extra pyspark - uv pip install 'numpy<2' 'pyspark[connect]' 'grpcio' + uv pip install 'pyspark[connect]' 'grpcio' uv pip install --no-cache --reinstall --strict 'grpcio-status >= 1.48.1' uv run pytest plugin_tests/h_spark + # Vaex 4.19 supports py<=3.12 and numpy>2 (https://github.com/vaexio/vaex/pull/2449) but limited by dask<2024.9 + # For now the test matrix is py3.10 and numpy<2 - name: Test vaex - # Vaex supports <= py3.10 and numpy<2 if: ${{ runner.os == 'Linux' && matrix.python-version == '3.10' }} run: | sudo apt-get install --no-install-recommends --yes libpcre3-dev cargo uv sync --group test --extra vaex uv pip install "numpy<2" - uv run pytest plugin_tests/h_vaex + uv run --no-sync pytest plugin_tests/h_vaex - name: Test vaex - # Vaex supports <= py3.10 and numpy<2 if: ${{ runner.os != 'Linux' && matrix.python-version == '3.10' }} run: | uv sync --group test --extra vaex uv pip install "numpy<2" - uv run pytest plugin_tests/h_vaex + uv run --no-sync pytest plugin_tests/h_vaex diff --git a/hamilton/plugins/numpy_extensions.py b/hamilton/plugins/numpy_extensions.py index 18b058c6..1527ebb5 100644 --- a/hamilton/plugins/numpy_extensions.py +++ b/hamilton/plugins/numpy_extensions.py @@ -43,12 +43,10 @@ class NumpyNpyWriter(DataSaver): fix_imports: bool | None = None def save_data(self, data: np.ndarray) -> dict[str, Any]: - np.save( - file=self.path, - arr=data, - allow_pickle=self.allow_pickle, - fix_imports=self.fix_imports, - ) + kwargs = dict(file=self.path, arr=data, allow_pickle=self.allow_pickle) + if np.__version__ < "2.4" and self.fix_imports is not None: + kwargs["fix_imports"] = self.fix_imports + np.save(**kwargs) return utils.get_file_metadata(self.path) @classmethod @@ -77,13 +75,15 @@ class NumpyNpyReader(DataLoader): return [np.ndarray] def load_data(self, type_: type) -> tuple[np.ndarray, dict[str, Any]]: - array = np.load( + kwargs = dict( file=self.path, mmap_mode=self.mmap_mode, allow_pickle=self.allow_pickle, - fix_imports=self.fix_imports, encoding=self.encoding, ) + if np.__version__ < "2.4" and self.fix_imports is not None: + kwargs["fix_imports"] = self.fix_imports + array = np.load(**kwargs) metadata = utils.get_file_metadata(self.path) return array, metadata diff --git a/pyproject.toml b/pyproject.toml index 39e757bd..fe00b6b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,7 @@ pandera = ["pandera"] pydantic = ["pydantic>=2.0"] pyspark = [ # we have to run these dependencies because Spark does not check to ensure the right target was called - "pyspark[pandas_on_spark,sql]", + "pyspark[pandas_on_spark,sql] >= 4.0.0", ] ray = ["ray>=2.0.0; python_version < '3.14'", "pyarrow"] rich = ["rich"] @@ -149,13 +149,13 @@ docs = [ "mock==1.0.1", # read the docs pins "myst-nb", "narwhals", - "numpy < 2.0.0", + "numpy", "packaging", "pandera", "pillow", "polars", "pyarrow >= 1.0.0", - "pydantic >=2.0", + "pydantic >= 2.0", "pyspark", "openlineage-python", "PyYAML", diff --git a/tests/test_node.py b/tests/test_node.py index 45bd62c1..a0879572 100644 --- a/tests/test_node.py +++ b/tests/test_node.py @@ -16,7 +16,6 @@ # under the License. import inspect -import sys from typing import Any, Literal, TypeVar import numpy as np @@ -76,7 +75,7 @@ def test_node_handles_annotated(): node = Node.from_fn(annotated_func) assert node.name == "annotated_func" - if major == 2 and minor > 1 and sys.version_info > (3, 9): # greater that 2.1 + if major == 2 and 2 <= minor <= 3: # numpy 2.2-2.3 expected = { "first": ( Annotated[np.ndarray[tuple[int, ...], np.dtype[np.float64]], Literal["N"]], @@ -85,6 +84,15 @@ def test_node_handles_annotated(): "other": (float, DependencyType.OPTIONAL), } expected_type = Annotated[np.ndarray[tuple[int, ...], np.dtype[np.float64]], Literal["N"]] + elif (major, minor) >= (2, 4): # numpy 2.4+ + expected = { + "first": ( + Annotated[np.ndarray[tuple[Any, ...], np.dtype[np.float64]], Literal["N"]], + DependencyType.REQUIRED, + ), + "other": (float, DependencyType.OPTIONAL), + } + expected_type = Annotated[np.ndarray[tuple[Any, ...], np.dtype[np.float64]], Literal["N"]] else: expected = { "first": (
