This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new b1ceb1e8c [SEDONA-661] add local outlier factor implementation. (#1623)
b1ceb1e8c is described below
commit b1ceb1e8c5e5fdc1c4e1bdd9ec929aa1393045de
Author: James Willis <[email protected]>
AuthorDate: Tue Oct 15 20:49:18 2024 -0700
[SEDONA-661] add local outlier factor implementation. (#1623)
* add local outlier factor implementation.
* LOF docs
* precommit changes
* precommit formatting changes
---------
Co-authored-by: jameswillis <[email protected]>
---
docs/api/stats/sql.md | 22 ++-
docs/tutorial/sql.md | 67 +++++++++-
python/sedona/stats/outlier_detection/__init__.py | 18 +++
.../outlier_detection/local_outlier_factor.py | 60 +++++++++
python/tests/stats/test_local_outlier_factor.py | 107 +++++++++++++++
.../outlierDetection/LocalOutlierFactor.scala | 148 +++++++++++++++++++++
.../org/apache/sedona/sql/TestBaseScala.scala | 2 +
.../outlierDetection/LocalOutlierFactorTest.scala | 59 ++++++++
8 files changed, 479 insertions(+), 4 deletions(-)
diff --git a/docs/api/stats/sql.md b/docs/api/stats/sql.md
index fe7c0e90e..290691710 100644
--- a/docs/api/stats/sql.md
+++ b/docs/api/stats/sql.md
@@ -9,7 +9,7 @@ complete set of geospatial analysis tools.
## Using DBSCAN
-The DBSCAN function is provided at `org.apache.sedona.stats.DBSCAN.dbscan` in
scala/java and `sedona.stats.dbscan.dbscan` in python.
+The DBSCAN function is provided at
`org.apache.sedona.stats.clustering.DBSCAN.dbscan` in scala/java and
`sedona.stats.clustering.dbscan.dbscan` in python.
The function annotates a dataframe with a cluster label for each data record
using the DBSCAN algorithm.
The dataframe should contain at least one `GeometryType` column. Rows must be
unique. If one
@@ -29,3 +29,23 @@ names in parentheses are python variable names
- useSpheroid (use_spheroid) - whether to use a cartesian or spheroidal
distance calculation. Default is false
The output is the input DataFrame with the cluster label added to each row.
Outlier will have a cluster value of -1 if included.
+
+## Using Local Outlier Factor (LOF)
+
+The LOF function is provided at
`org.apache.sedona.stats.outlierDetection.LocalOutlierFactor.localOutlierFactor`
in scala/java and
`sedona.stats.outlier_detection.local_outlier_factor.local_outlier_factor` in
python.
+
+The function annotates a dataframe with a column containing the local outlier
factor for each data record.
+The dataframe should contain at least one `GeometryType` column. Rows must be
unique. If one
+geometry column is present it will be used automatically. If two are present,
the one named
+'geometry' will be used. If more than one are present and neither is named
'geometry', the
+column name must be provided.
+
+### Parameters
+
+names in parentheses are python variable names
+
+- dataframe - dataframe containing the point geometries
+- k - number of nearest neighbors that will be considered for the LOF
calculation
+- geometry - name of the geometry column
+- handleTies (handle_ties) - whether to handle ties in the k-distance
calculation. Default is false
+- useSpheroid (use_spheroid) - whether to use a cartesian or spheroidal
distance calculation. Default is false
diff --git a/docs/tutorial/sql.md b/docs/tutorial/sql.md
index 75a013f85..7754338a8 100644
--- a/docs/tutorial/sql.md
+++ b/docs/tutorial/sql.md
@@ -842,7 +842,7 @@ The first parameter is the dataframe, the next two are the
epsilon and min_point
=== "Scala"
```scala
- import org.apache.sedona.stats.DBSCAN.dbscan
+ import org.apache.sedona.stats.clustering.DBSCAN.dbscan
dbscan(df, 0.1, 5).show()
```
@@ -850,7 +850,7 @@ The first parameter is the dataframe, the next two are the
epsilon and min_point
=== "Java"
```java
- import org.apache.sedona.stats.DBSCAN;
+ import org.apache.sedona.stats.clustering.DBSCAN;
DBSCAN.dbscan(df, 0.1, 5).show();
```
@@ -858,7 +858,7 @@ The first parameter is the dataframe, the next two are the
epsilon and min_point
=== "Python"
```python
- from sedona.stats.dbscan import dbscan
+ from sedona.stats.clustering.dbscan import dbscan
dbscan(df, 0.1, 5).show()
```
@@ -885,6 +885,67 @@ The output will look like this:
+----------------+---+------+-------+
```
+## Calculate the Local Outlier Factor (LOF)
+
+Sedona provides an implementation of the [Local Outlier
Factor](https://en.wikipedia.org/wiki/Local_outlier_factor) algorithm to
identify anomalous data.
+
+The algorithm is available as a Scala and Python function called on a spatial
dataframe. The returned dataframe has an additional column added containing the
local outlier factor.
+
+The first parameter is the dataframe, the next is the number of nearest
neighbors to consider use in calculating the score.
+
+=== "Scala"
+
+ ```scala
+ import
org.apache.sedona.stats.outlierDetection.LocalOutlierFactor.localOutlierFactor
+
+ localOutlierFactor(df, 20).show()
+ ```
+
+=== "Java"
+
+ ```java
+ import org.apache.sedona.stats.outlierDetection.LocalOutlierFactor;
+
+ LocalOutlierFactor.localOutlierFactor(df, 20).show();
+ ```
+
+=== "Python"
+
+ ```python
+ from sedona.stats.outlier_detection.local_outlier_factor import
local_outlier_factor
+
+ local_outlier_factor(df, 20).show()
+ ```
+
+The output will look like this:
+
+```
++--------------------+------------------+
+| geometry| lof|
++--------------------+------------------+
+|POINT (-2.0231305...| 0.952098153363662|
+|POINT (-2.0346944...|0.9975325496668104|
+|POINT (-2.2040074...|1.0825843906411081|
+|POINT (1.61573501...|1.7367129352162634|
+|POINT (-2.1176324...|1.5714144683150393|
+|POINT (-2.2349759...|0.9167275845938276|
+|POINT (1.65470192...| 1.046231536764447|
+|POINT (0.62624112...|1.1988700676990034|
+|POINT (2.01746261...|1.1060219481067417|
+|POINT (-2.0483857...|1.0775553430145446|
+|POINT (2.43969463...|1.1129132178576646|
+|POINT (-2.2425480...| 1.104108012697006|
+|POINT (-2.7859235...| 2.86371824574529|
+|POINT (-1.9738858...|1.0398822680356794|
+|POINT (2.00153403...| 0.927409656346015|
+|POINT (2.06422812...|0.9222203762264445|
+|POINT (-1.7533819...|1.0273650471626696|
+|POINT (-2.2030766...| 0.964744555830738|
+|POINT (-1.8509857...|1.0375927869698574|
+|POINT (2.10849080...|1.0753419197322656|
++--------------------+------------------+
+```
+
## Run spatial queries
After creating a Geometry type column, you are able to run spatial queries.
diff --git a/python/sedona/stats/outlier_detection/__init__.py
b/python/sedona/stats/outlier_detection/__init__.py
new file mode 100644
index 000000000..4dd25a3ff
--- /dev/null
+++ b/python/sedona/stats/outlier_detection/__init__.py
@@ -0,0 +1,18 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Algorithms for detecting outliers in spatial datasets."""
diff --git a/python/sedona/stats/outlier_detection/local_outlier_factor.py
b/python/sedona/stats/outlier_detection/local_outlier_factor.py
new file mode 100644
index 000000000..3050d216b
--- /dev/null
+++ b/python/sedona/stats/outlier_detection/local_outlier_factor.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Functions related to calculating the local outlier factor of a dataset."""
+from typing import Optional
+
+from pyspark.sql import DataFrame, SparkSession
+
+ID_COLUMN_NAME = "__id"
+CONTENTS_COLUMN_NAME = "__contents"
+
+
+def local_outlier_factor(
+ dataframe: DataFrame,
+ k: int = 20,
+ geometry: Optional[str] = None,
+ handle_ties: bool = False,
+ use_spheroid=False,
+):
+ """Annotates a dataframe with a column containing the local outlier factor
for each data record.
+
+ The dataframe should contain at least one GeometryType column. Rows must
be unique. If one geometry column is
+ present it will be used automatically. If two are present, the one named
'geometry' will be used. If more than one
+ are present and neither is named 'geometry', the column name must be
provided.
+
+ Args:
+ dataframe: apache sedona idDataframe containing the point geometries
+ k: number of nearest neighbors that will be considered for the LOF
calculation
+ geometry: name of the geometry column
+ handle_ties: whether to handle ties in the k-distance calculation.
Default is false
+ use_spheroid: whether to use a cartesian or spheroidal distance
calculation. Default is false
+
+ Returns:
+ A PySpark DataFrame containing the lof for each row
+ """
+ sedona = SparkSession.getActiveSession()
+
+ result_df =
sedona._jvm.org.apache.sedona.stats.outlierDetection.LocalOutlierFactor.localOutlierFactor(
+ dataframe._jdf,
+ k,
+ geometry,
+ handle_ties,
+ use_spheroid,
+ )
+
+ return DataFrame(result_df, sedona)
diff --git a/python/tests/stats/test_local_outlier_factor.py
b/python/tests/stats/test_local_outlier_factor.py
new file mode 100644
index 000000000..52ec860a0
--- /dev/null
+++ b/python/tests/stats/test_local_outlier_factor.py
@@ -0,0 +1,107 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import pyspark.sql.functions as f
+import pytest
+from pyspark.sql import DataFrame
+from pyspark.sql.types import DoubleType, IntegerType, StructField, StructType
+from sklearn.neighbors import LocalOutlierFactor
+from tests.test_base import TestBase
+
+from sedona.sql.st_constructors import ST_MakePoint
+from sedona.sql.st_functions import ST_X, ST_Y
+from sedona.stats.outlier_detection.local_outlier_factor import
local_outlier_factor
+
+
+class TestLOF(TestBase):
+ def get_small_data(self) -> DataFrame:
+ schema = StructType(
+ [
+ StructField("id", IntegerType(), True),
+ StructField("x", DoubleType(), True),
+ StructField("y", DoubleType(), True),
+ ]
+ )
+ return self.spark.createDataFrame(
+ [
+ (1, 1.0, 2.0),
+ (2, 2.0, 2.0),
+ (3, 3.0, 3.0),
+ ],
+ schema,
+ ).select("id", ST_MakePoint("x", "y").alias("geometry"))
+
+ def get_medium_data(self):
+ np.random.seed(42)
+
+ X_inliers = 0.3 * np.random.randn(100, 2)
+ X_inliers = np.r_[X_inliers + 2, X_inliers - 2]
+ X_outliers = np.random.uniform(low=-4, high=4, size=(20, 2))
+ return np.r_[X_inliers, X_outliers]
+
+ def get_medium_dataframe(self, data):
+ schema = StructType(
+ [StructField("x", DoubleType(), True), StructField("y",
DoubleType(), True)]
+ )
+
+ return (
+ self.spark.createDataFrame(data, schema)
+ .select(ST_MakePoint("x", "y").alias("geometry"))
+ .withColumn("anotherColumn", f.rand())
+ )
+
+ def compare_results(self, actual, expected, k):
+ assert len(actual) == len(expected)
+ missing = set(expected.keys()) - set(actual.keys())
+ assert len(missing) == 0
+ big_diff = {
+ k: (v, expected[k], abs(1 - v / expected[k]))
+ for k, v in actual.items()
+ if abs(1 - v / expected[k]) > 0.0000000001
+ }
+ assert len(big_diff) == 0
+
+ @pytest.mark.parametrize("k", [5, 21, 3])
+ def test_lof_matches_sklearn(self, k):
+ data = self.get_medium_data()
+ actual = {
+ tuple(x[0]): x[1]
+ for x in
local_outlier_factor(self.get_medium_dataframe(data.tolist()), k)
+ .select(f.array(ST_X("geometry"), ST_Y("geometry")), "lof")
+ .collect()
+ }
+ clf = LocalOutlierFactor(n_neighbors=k, contamination="auto")
+ clf.fit_predict(data)
+ expected = dict(
+ zip(
+ [tuple(x) for x in data],
+ [float(-x) for x in clf.negative_outlier_factor_],
+ )
+ )
+ self.compare_results(actual, expected, k)
+
+ # TODO uncomment when KNN join supports empty dfs
+ # def test_handle_empty_dataframe(self):
+ # empty_df = self.spark.createDataFrame([],
self.get_small_data().schema)
+ # result_df = local_outlier_factor(empty_df, 2)
+ #
+ # assert 0 == result_df.count()
+
+ def test_raise_error_for_invalid_k_value(self):
+ with pytest.raises(Exception):
+ local_outlier_factor(self.get_small_data(), -1)
diff --git
a/spark/common/src/main/scala/org/apache/sedona/stats/outlierDetection/LocalOutlierFactor.scala
b/spark/common/src/main/scala/org/apache/sedona/stats/outlierDetection/LocalOutlierFactor.scala
new file mode 100644
index 000000000..b98919de2
--- /dev/null
+++
b/spark/common/src/main/scala/org/apache/sedona/stats/outlierDetection/LocalOutlierFactor.scala
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.stats.outlierDetection
+
+import org.apache.sedona.stats.Util.getGeometryColumnName
+import org.apache.spark.sql.sedona_sql.expressions.st_functions.{ST_Distance,
ST_DistanceSpheroid}
+import org.apache.spark.sql.{Column, DataFrame, SparkSession, functions => f}
+
+object LocalOutlierFactor {
+
+ private val ID_COLUMN_NAME = "__id"
+ private val CONTENTS_COLUMN_NAME = "__contents"
+
+ /**
+ * Annotates a dataframe with a column containing the local outlier factor
for each data record.
+ * The dataframe should contain at least one GeometryType column. Rows must
be unique. If one
+ * geometry column is present it will be used automatically. If two are
present, the one named
+ * 'geometry' will be used. If more than one are present and neither is
named 'geometry', the
+ * column name must be provided.
+ *
+ * @param dataframe
+ * dataframe containing the point geometries
+ * @param k
+ * number of nearest neighbors that will be considered for the LOF
calculation
+ * @param geometry
+ * name of the geometry column
+ * @param handleTies
+ * whether to handle ties in the k-distance calculation. Default is false
+ * @param useSpheroid
+ * whether to use a cartesian or spheroidal distance calculation. Default
is false
+ *
+ * @return
+ * A DataFrame containing the lof for each row
+ */
+ def localOutlierFactor(
+ dataframe: DataFrame,
+ k: Int = 20,
+ geometry: String = null,
+ handleTies: Boolean = false,
+ useSpheroid: Boolean = false): DataFrame = {
+
+ if (k < 1)
+ throw new IllegalArgumentException("k must be a positive integer")
+
+ val prior: String = if (handleTies) {
+ val prior =
+ SparkSession.getActiveSession.get.conf
+ .get("spark.sedona.join.knn.includeTieBreakers", "false")
+
SparkSession.getActiveSession.get.conf.set("spark.sedona.join.knn.includeTieBreakers",
true)
+ prior
+ } else "false" // else case to make compiler happy
+
+ val distanceFunction: (Column, Column) => Column =
+ if (useSpheroid) ST_DistanceSpheroid else ST_Distance
+ val useSpheroidString = if (useSpheroid) "True" else "False" // for the
SQL expression
+
+ val geometryColumn = if (geometry == null)
getGeometryColumnName(dataframe) else geometry
+
+ val KNNFunction = "ST_KNN"
+
+ // Store original contents, prep necessary columns
+ val formattedDataframe = dataframe
+ .withColumn(CONTENTS_COLUMN_NAME, f.struct("*"))
+ .withColumn(ID_COLUMN_NAME,
f.sha2(f.to_json(f.col(CONTENTS_COLUMN_NAME)), 256))
+ .withColumnRenamed(geometryColumn, "geometry")
+
+ val kDistanceDf = formattedDataframe
+ .alias("l")
+ .join(
+ formattedDataframe.alias("r"),
+ // k + 1 because we are not counting the row matching to itself
+ f.expr(f"$KNNFunction(l.geometry, r.geometry, $k + 1,
$useSpheroidString)") && f.col(
+ f"l.$ID_COLUMN_NAME") =!= f.col(f"r.$ID_COLUMN_NAME"))
+ .groupBy(f"l.$ID_COLUMN_NAME")
+ .agg(
+ f.first("l.geometry").alias("geometry"),
+ f.first(f"l.$CONTENTS_COLUMN_NAME").alias(CONTENTS_COLUMN_NAME),
+ f.max(distanceFunction(f.col("l.geometry"),
f.col("r.geometry"))).alias("k_distance"),
+ f.collect_list(f"r.$ID_COLUMN_NAME").alias("neighbors"))
+ .checkpoint()
+
+ val lrdDf = kDistanceDf
+ .alias("A")
+ .select(
+ f.col(ID_COLUMN_NAME).alias("a_id"),
+ f.col(CONTENTS_COLUMN_NAME),
+ f.col("geometry").alias("a_geometry"),
+ f.explode(f.col("neighbors")).alias("n_id"))
+ .join(
+ kDistanceDf.select(
+ f.col(ID_COLUMN_NAME).alias("b_id"),
+ f.col("geometry").alias("b_geometry"),
+ f.col("k_distance").alias("b_k_distance")),
+ f.expr("n_id = b_id"))
+ .select(
+ f.col("a_id"),
+ f.col("b_id"),
+ f.col(CONTENTS_COLUMN_NAME),
+ f.array_max(
+ f.array(
+ f.col("b_k_distance"),
+ distanceFunction(f.col("a_geometry"), f.col("b_geometry"))))
+ .alias("rd"))
+ .groupBy("a_id")
+ .agg(
+ // + 1e-10 to avoid division by zero, matches sklearn impl
+ (f.lit(1.0) / (f.mean("rd") + 1e-10)).alias("lrd"),
+ f.collect_list(f.col("b_id")).alias("neighbors"),
+ f.first(CONTENTS_COLUMN_NAME).alias(CONTENTS_COLUMN_NAME))
+
+ val ret = lrdDf
+ .select(
+ f.col("a_id"),
+ f.col("lrd").alias("a_lrd"),
+ f.col(CONTENTS_COLUMN_NAME),
+ f.explode(f.col("neighbors")).alias("n_id"))
+ .join(
+ lrdDf.select(f.col("a_id").alias("b_id"), f.col("lrd").alias("b_lrd")),
+ f.expr("n_id = b_id"))
+ .groupBy("a_id")
+ .agg(
+ f.first(CONTENTS_COLUMN_NAME).alias(CONTENTS_COLUMN_NAME),
+ (f.sum("b_lrd") / (f.count("b_lrd") * f.first("a_lrd"))).alias("lof"))
+ .select(f.col(f"$CONTENTS_COLUMN_NAME.*"), f.col("lof"))
+
+ if (handleTies)
+ SparkSession.getActiveSession.get.conf
+ .set("spark.sedona.join.knn.includeTieBreakers", prior)
+ ret
+ }
+
+}
diff --git
a/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
b/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
index b6c073574..dc9c841eb 100644
--- a/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
+++ b/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
@@ -32,6 +32,7 @@ import org.locationtech.jts.geom._
import org.scalatest.{BeforeAndAfterAll, FunSpec}
import java.io.File
+import java.nio.file.Files
trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
Logger.getRootLogger.setLevel(Level.WARN)
@@ -97,6 +98,7 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
override def beforeAll(): Unit = {
super.beforeAll()
SedonaContext.create(sparkSession)
+ sc.setCheckpointDir(Files.createTempDirectory("checkpoints").toString)
}
override def afterAll(): Unit = {
diff --git
a/spark/common/src/test/scala/org/apache/sedona/stats/outlierDetection/LocalOutlierFactorTest.scala
b/spark/common/src/test/scala/org/apache/sedona/stats/outlierDetection/LocalOutlierFactorTest.scala
new file mode 100644
index 000000000..c401f599b
--- /dev/null
+++
b/spark/common/src/test/scala/org/apache/sedona/stats/outlierDetection/LocalOutlierFactorTest.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.stats.outlierDetection
+
+import org.apache.sedona.sql.TestBaseScala
+import org.apache.spark.sql.sedona_sql.expressions.st_constructors.ST_MakePoint
+import org.apache.spark.sql.{DataFrame, functions => f}
+
+class LocalOutlierFactorTest extends TestBaseScala {
+
+ case class Point(id: Int, x: Double, y: Double, expected_lof: Double)
+
+ def get_data(): DataFrame = {
+ // expected pulled from sklearn.neighbors.LocalOutlierFactor
+ sparkSession
+ .createDataFrame(
+ Seq(
+ Point(0, 2.0, 2.0, 0.8747092756023607),
+ Point(1, 2.0, 3.0, 0.9460678118688717),
+ Point(2, 3.0, 3.0, 1.0797104443580348),
+ Point(3, 3.0, 2.0, 1.0517766952923475),
+ Point(4, 3.0, 1.0, 1.0797104443580348),
+ Point(5, 2.0, 1.0, 0.9460678118688719),
+ Point(6, 1.0, 1.0, 1.0797104443580348),
+ Point(7, 1.0, 2.0, 1.0517766952923475),
+ Point(8, 1.0, 3.0, 1.0797104443580348),
+ Point(9, 0.0, 2.0, 1.0517766952923475),
+ Point(10, 4.0, 2.0, 1.0517766952923475)))
+ .withColumn("geometry", ST_MakePoint("x", "y"))
+ .drop("x", "y")
+ }
+
+ describe("LocalOutlierFactor") {
+ it("returns correct results") {
+ val resultDf = LocalOutlierFactor.localOutlierFactor(get_data(), 4)
+ assert(resultDf.count() == 11)
+ assert(
+ resultDf
+ .filter(f.abs(f.col("expected_lof") - f.col("lof")) < .00000001)
+ .count() == 11)
+ }
+ }
+}