wbo4958 commented on code in PR #49547: URL: https://github.com/apache/spark/pull/49547#discussion_r1920185334
########## python/pyspark/ml/tests/test_evaluation.py: ########## @@ -14,18 +14,368 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import tempfile import unittest import numpy as np -from pyspark.ml.evaluation import ClusteringEvaluator, RegressionEvaluator +from pyspark.ml.evaluation import ( + ClusteringEvaluator, + RegressionEvaluator, + BinaryClassificationEvaluator, + MulticlassClassificationEvaluator, + MultilabelClassificationEvaluator, + RankingEvaluator, +) from pyspark.ml.linalg import Vectors -from pyspark.sql import Row -from pyspark.testing.mlutils import SparkSessionTestCase +from pyspark.sql import Row, SparkSession + + +class EvaluatorTestsMixin: + def test_ranking_evaluator(self): + scoreAndLabels = [ + ([1.0, 6.0, 2.0, 7.0, 8.0, 3.0, 9.0, 10.0, 4.0, 5.0], [1.0, 2.0, 3.0, 4.0, 5.0]), + ([4.0, 1.0, 5.0, 6.0, 2.0, 7.0, 3.0, 8.0, 9.0, 10.0], [1.0, 2.0, 3.0]), + ([1.0, 2.0, 3.0, 4.0, 5.0], []), + ] + dataset = self.spark.createDataFrame(scoreAndLabels, ["prediction", "label"]) + + # Initialize RankingEvaluator + evaluator = RankingEvaluator().setPredictionCol("prediction") + + # Evaluate the dataset using the default metric (mean average precision) + mean_average_precision = evaluator.evaluate(dataset) + self.assertTrue(np.allclose(mean_average_precision, 0.3550, atol=1e-4)) + + # Evaluate the dataset using precisionAtK for k=2 + precision_at_k = evaluator.evaluate( + dataset, {evaluator.metricName: "precisionAtK", evaluator.k: 2} + ) + self.assertTrue(np.allclose(precision_at_k, 0.3333, atol=1e-4)) + + # read/write + with tempfile.TemporaryDirectory(prefix="save") as tmp_dir: + # Save the evaluator + ranke_path = tmp_dir + "/ranke" + evaluator.write().overwrite().save(ranke_path) + # Load the saved evaluator + evaluator2 = RankingEvaluator.load(ranke_path) + self.assertEqual(evaluator2.getPredictionCol(), "prediction") + + def test_multilabel_classification_evaluator(self): Review Comment: Wow, you got me, the isLargerBetter is not working in the current design. I wondering if we can refactor the isLargeBetter in pyspark side to keep the same with the implementation in scala? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org