This is an automated email from the ASF dual-hosted git repository.

imbruced pushed a commit to branch SEDONA-725-add-flink-register-functions
in repository https://gitbox.apache.org/repos/asf/sedona.git

commit 0c4f10cd69895bd9b54d6b7e69dbda1aca992dbc
Author: pawelkocinski <[email protected]>
AuthorDate: Sun Mar 23 19:06:02 2025 +0100

    SEDONA-725 Add pyflink to Sedona.
---
 .github/workflows/pyflink.yml                 | 53 +++++++++++++++++++++++++++
 python/sedona/flink/__init__.py               | 12 ++++++
 python/sedona/flink/context.py                | 23 ++++++++++++
 python/tests/flink/conftest.py                | 32 ++++++++++++++++
 python/tests/flink/test_flink_registration.py | 45 +++++++++++++++++++++++
 5 files changed, 165 insertions(+)

diff --git a/.github/workflows/pyflink.yml b/.github/workflows/pyflink.yml
new file mode 100644
index 0000000000..8b5a69b833
--- /dev/null
+++ b/.github/workflows/pyflink.yml
@@ -0,0 +1,53 @@
+name: Sedona Pyflink Test
+
+on:
+  push:
+    branches:
+      - master
+    paths:
+      - 'common/**'
+      - 'flink/**'
+      - 'flink-shaded/**'
+      - 'pom.xml'
+      - 'python/**'
+      - '.github/workflows/pyflink.yml'
+  pull_request:
+    branches:
+      - '*'
+    paths:
+      - 'common/**'
+      - 'flink/**'
+      - 'flink-shaded/**'
+      - 'pom.xml'
+      - 'python/**'
+      - '.github/workflows/pyflink.yml'
+
+jobs:
+  build:
+    runs-on: ubuntu-22.04
+    strategy:
+      matrix:
+        include:
+          - python: '3.10'
+    steps:
+      - uses: actions/checkout@v4
+      - uses: actions/setup-java@v4
+        with:
+          distribution: 'zulu'
+          java-version: '8'
+      - uses: actions/setup-python@v5
+        with:
+          python-version: ${{ matrix.python }}
+      - run: sudo apt-get -y install python3-pip python-dev-is-python3
+      - run: mvn package -pl "org.apache.sedona:sedona-flink-shaded_2.12" -am 
-DskipTests
+      - run: sudo pip3 install -U setuptools
+      - run: sudo pip3 install -U wheel
+      - run: sudo pip3 install -U virtualenvwrapper
+      - run: python3 -m pip install uv
+      - run: cd python
+      - run: uv init --no-workspace
+      - run: uv add apache-flink==1.20.1
+      - run: uv add pytest --dev
+      - run: |
+          SEDONA_PYFLINK_EXTRA_JARS=${PWD}/$(find flink-shaded/target -name 
sedona-flink*.jar)
+          uv run pytest -v -m flink ./tests
diff --git a/python/sedona/flink/__init__.py b/python/sedona/flink/__init__.py
new file mode 100644
index 0000000000..0d8af26e15
--- /dev/null
+++ b/python/sedona/flink/__init__.py
@@ -0,0 +1,12 @@
+import logging
+
+try:
+    from sedona.flink.context import SedonaContext
+
+    __all__ = ["SedonaContext"]
+except ImportError:
+    logging.log(
+        logging.WARN,
+        "SedonaContext could not be imported. This is likely due to a missing 
flink dependency.",
+    )
+    __all__ = []
diff --git a/python/sedona/flink/context.py b/python/sedona/flink/context.py
new file mode 100644
index 0000000000..c6d6c2b38e
--- /dev/null
+++ b/python/sedona/flink/context.py
@@ -0,0 +1,23 @@
+from pyflink.table import EnvironmentSettings, StreamTableEnvironment
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.java_gateway import get_gateway
+
+
+class SedonaContext:
+
+    @classmethod
+    def create(
+        cls, env: StreamExecutionEnvironment, settings: EnvironmentSettings
+    ) -> StreamTableEnvironment:
+        table_env = StreamTableEnvironment.create(env, settings)
+        gateway = get_gateway()
+
+        flink_sedona_context = 
gateway.jvm.org.apache.sedona.flink.SedonaContext
+
+        table_env_j = flink_sedona_context.create(
+            env._j_stream_execution_environment, table_env._j_tenv
+        )
+
+        table_env._j_tenv = table_env_j
+
+        return table_env
diff --git a/python/tests/flink/conftest.py b/python/tests/flink/conftest.py
new file mode 100644
index 0000000000..fa333998ab
--- /dev/null
+++ b/python/tests/flink/conftest.py
@@ -0,0 +1,32 @@
+import os
+
+import pytest
+
+from sedona.flink import SedonaContext
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.table import EnvironmentSettings, StreamTableEnvironment
+
+
+EXTRA_JARS = os.getenv("SEDONA_PYFLINK_EXTRA_JARS")
+
+
[email protected](scope="module")
+def flink_settings():
+    return EnvironmentSettings.in_streaming_mode()
+
+
[email protected](scope="module")
+def stream_env() -> StreamExecutionEnvironment:
+    env = StreamExecutionEnvironment.get_execution_environment()
+    jars = EXTRA_JARS.split(",") if EXTRA_JARS else []
+    for jar in jars:
+        env.add_jars(f"file://{jar}")
+
+    return env
+
+
[email protected](scope="module")
+def table_env(
+    stream_env: StreamExecutionEnvironment, flink_settings: EnvironmentSettings
+) -> StreamTableEnvironment:
+    return SedonaContext.create(stream_env, flink_settings)
diff --git a/python/tests/flink/test_flink_registration.py 
b/python/tests/flink/test_flink_registration.py
new file mode 100644
index 0000000000..e99b933e3f
--- /dev/null
+++ b/python/tests/flink/test_flink_registration.py
@@ -0,0 +1,45 @@
+from pyflink.table import StreamTableEnvironment
+from pyflink.table.udf import ScalarFunction, udf
+from shapely.wkb import loads
+import pytest
+
+
+class Buffer(ScalarFunction):
+    def eval(self, s):
+        geom = loads(s)
+        return geom.buffer(1).wkb
+
+
[email protected]
+def test_register(table_env: StreamTableEnvironment):
+    result = (
+        table_env.sql_query("SELECT ST_ASBinary(ST_Point(1.0, 2.0))")
+        .execute()
+        .collect()
+    )
+    assert 1 == len(([el for el in result]))
+
+
[email protected]
+def test_register_udf(table_env: StreamTableEnvironment):
+    table_env.create_temporary_function(
+        "ST_BufferPython", udf(Buffer(), result_type="Binary")
+    )
+
+    buffer_table = table_env.sql_query(
+        "SELECT ST_BufferPython(ST_ASBinary(ST_Point(1.0, 2.0))) AS buffer"
+    )
+
+    table_env.create_temporary_view("buffer_table", buffer_table)
+
+    result = (
+        table_env.sql_query("SELECT ST_Area(ST_GeomFromWKB(buffer)) FROM 
buffer_table")
+        .execute()
+        .collect()
+    )
+
+    items = [el for el in result]
+    area = items[0][0]
+
+    assert 3.12 < area < 3.14
+    assert 1 == len(items)

Reply via email to