zhengruifeng commented on code in PR #49547:
URL: https://github.com/apache/spark/pull/49547#discussion_r1920896222


##########
python/pyspark/ml/tests/test_evaluation.py:
##########
@@ -14,18 +14,368 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
+import tempfile
 import unittest
 
 import numpy as np
 
-from pyspark.ml.evaluation import ClusteringEvaluator, RegressionEvaluator
+from pyspark.ml.evaluation import (
+    ClusteringEvaluator,
+    RegressionEvaluator,
+    BinaryClassificationEvaluator,
+    MulticlassClassificationEvaluator,
+    MultilabelClassificationEvaluator,
+    RankingEvaluator,
+)
 from pyspark.ml.linalg import Vectors
-from pyspark.sql import Row
-from pyspark.testing.mlutils import SparkSessionTestCase
+from pyspark.sql import Row, SparkSession
+
+
+class EvaluatorTestsMixin:
+    def test_ranking_evaluator(self):
+        scoreAndLabels = [
+            ([1.0, 6.0, 2.0, 7.0, 8.0, 3.0, 9.0, 10.0, 4.0, 5.0], [1.0, 2.0, 
3.0, 4.0, 5.0]),
+            ([4.0, 1.0, 5.0, 6.0, 2.0, 7.0, 3.0, 8.0, 9.0, 10.0], [1.0, 2.0, 
3.0]),
+            ([1.0, 2.0, 3.0, 4.0, 5.0], []),
+        ]
+        dataset = self.spark.createDataFrame(scoreAndLabels, ["prediction", 
"label"])
+
+        # Initialize RankingEvaluator
+        evaluator = RankingEvaluator().setPredictionCol("prediction")
+
+        # Evaluate the dataset using the default metric (mean average 
precision)
+        mean_average_precision = evaluator.evaluate(dataset)
+        self.assertTrue(np.allclose(mean_average_precision, 0.3550, atol=1e-4))
+
+        # Evaluate the dataset using precisionAtK for k=2
+        precision_at_k = evaluator.evaluate(
+            dataset, {evaluator.metricName: "precisionAtK", evaluator.k: 2}
+        )
+        self.assertTrue(np.allclose(precision_at_k, 0.3333, atol=1e-4))
+
+        # read/write
+        with tempfile.TemporaryDirectory(prefix="save") as tmp_dir:
+            # Save the evaluator
+            ranke_path = tmp_dir + "/ranke"
+            evaluator.write().overwrite().save(ranke_path)
+            # Load the saved evaluator
+            evaluator2 = RankingEvaluator.load(ranke_path)
+            self.assertEqual(evaluator2.getPredictionCol(), "prediction")
+
+    def test_multilabel_classification_evaluator(self):

Review Comment:
   A pure python implementation works for me. Let's open a ticket to track it



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to