This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 3b09d9a0e2 [SEDONA-706] Fix Python dataframe api for multi-threaded
environment (#1785)
3b09d9a0e2 is described below
commit 3b09d9a0e2c113fd364b6212be507ccff6bd9041
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Tue Feb 4 00:29:00 2025 +0800
[SEDONA-706] Fix Python dataframe api for multi-threaded environment (#1785)
---
.github/workflows/python.yml | 8 ++------
python/sedona/sql/dataframe_api.py | 14 +++++++-------
python/tests/sql/test_dataframe_api.py | 27 +++++++++++++++++++++++++++
3 files changed, 36 insertions(+), 13 deletions(-)
diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml
index aaca28df05..6aad4a97b7 100644
--- a/.github/workflows/python.yml
+++ b/.github/workflows/python.yml
@@ -163,13 +163,9 @@ jobs:
- name: Run Spark Connect tests
env:
PYTHON_VERSION: ${{ matrix.python }}
+ SPARK_VERSION: ${{ matrix.spark }}
+ if: ${{ matrix.spark >= '3.4.0' }}
run: |
- if [ ! -f
"${VENV_PATH}/lib/python${PYTHON_VERSION}/site-packages/pyspark/sbin/start-connect-server.sh"
]
- then
- echo "Skipping connect tests for Spark $SPARK_VERSION"
- exit
- fi
-
export
SPARK_HOME=${VENV_PATH}/lib/python${PYTHON_VERSION}/site-packages/pyspark
export SPARK_REMOTE=local
diff --git a/python/sedona/sql/dataframe_api.py
b/python/sedona/sql/dataframe_api.py
index 2f56dfffa5..b1639a97bf 100644
--- a/python/sedona/sql/dataframe_api.py
+++ b/python/sedona/sql/dataframe_api.py
@@ -21,6 +21,7 @@ import itertools
import typing
from typing import Any, Callable, Iterable, List, Mapping, Tuple, Type, Union
+from pyspark import SparkContext
from pyspark.sql import Column, SparkSession
from pyspark.sql import functions as f
@@ -57,12 +58,6 @@ def _convert_argument_to_java_column(arg: Any) -> Column:
def call_sedona_function(
object_name: str, function_name: str, args: Union[Any, Tuple[Any]]
) -> Column:
- spark = SparkSession.getActiveSession()
- if spark is None:
- raise ValueError(
- "No active spark session was detected. Unable to call sedona
function."
- )
-
# apparently a Column is an Iterable so we need to check for it explicitly
if (not isinstance(args, Iterable)) or isinstance(
args, (str, Column, ConnectColumn)
@@ -75,7 +70,12 @@ def call_sedona_function(
args = map(_convert_argument_to_java_column, args)
- jobject = getattr(spark._jvm, object_name)
+ jvm = SparkContext._jvm
+ if jvm is None:
+ raise ValueError(
+ "No active spark context was detected. Unable to call sedona
function."
+ )
+ jobject = getattr(jvm, object_name)
jfunc = getattr(jobject, function_name)
jc = jfunc(*args)
diff --git a/python/tests/sql/test_dataframe_api.py
b/python/tests/sql/test_dataframe_api.py
index 7f64750190..de65f6f0f4 100644
--- a/python/tests/sql/test_dataframe_api.py
+++ b/python/tests/sql/test_dataframe_api.py
@@ -15,6 +15,9 @@
# specific language governing permissions and limitations
# under the License.
from math import radians
+import os
+import threading
+import concurrent.futures
from typing import Callable, Tuple
import pytest
@@ -1732,6 +1735,26 @@ class TestDataFrameAPI(TestBase):
):
func(*args)
+ def test_multi_thread(self):
+ df = self.spark.range(0, 100)
+
+ def run_spatial_query():
+ result = df.select(
+ stf.ST_Buffer(stc.ST_Point("id", f.col("id") + 1),
1.0).alias("geom")
+ ).collect()
+ assert len(result) == 100
+
+ # Create and run 4 threads
+ with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
+ futures = [executor.submit(run_spatial_query) for _ in range(4)]
+ concurrent.futures.wait(futures)
+ for future in futures:
+ future.result()
+
+ @pytest.mark.skipif(
+ os.getenv("SPARK_REMOTE") is not None,
+ reason="Checkpoint dir is not available in Spark Connect",
+ )
def test_dbscan(self):
df = self.spark.createDataFrame([{"id": 1, "x": 2, "y":
3}]).withColumn(
"geometry", f.expr("ST_Point(x, y)")
@@ -1739,6 +1762,10 @@ class TestDataFrameAPI(TestBase):
df.withColumn("dbscan", ST_DBSCAN("geometry", 1.0, 2, False)).collect()
+ @pytest.mark.skipif(
+ os.getenv("SPARK_REMOTE") is not None,
+ reason="Checkpoint dir is not available in Spark Connect",
+ )
def test_lof(self):
df = self.spark.createDataFrame([{"id": 1, "x": 2, "y":
3}]).withColumn(
"geometry", f.expr("ST_Point(x, y)")