wbo4958 opened a new pull request, #49503: URL: https://github.com/apache/spark/pull/49503
### What changes were proposed in this pull request? This PR introduces connect ML with a plugin that could replace operators of spark.ml with third-party implementations. ### Why are the changes needed? The connect ML is using spark.ml as the backend implementations, which is not GPU-accelerated. This PR could allow the third-party library as the connect ML backend to accelerate the ML cases with GPUs without changing any users code. ### Does this PR introduce _any_ user-facing change? Yes, it's a new feature. This PR defines a spark configuration `spark.connect.ml.backend.classes` to specify the plugin which must implement `org.apache.spark.sql.connect.plugin.MLBackendPlugin`. ### How was this patch tested? Make sure the CI (the newly added tests which have covered the plugin cases) pass. And manually run below code without any exception. 1. Build your own plugin package - Customize the estimator ``` scala package com.example.ml import org.apache.commons.logging.LogFactory import org.apache.spark.ml.classification.{LogisticRegression => SparkLR, LogisticRegressionModel => SparkLRModel} import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.Dataset class LogisticRegression(override val uid: String) extends SparkLR { private val logger = LogFactory.getLog("com.example.ml.LogisticRegression") def this() = this(Identifiable.randomUID("logreg")) override def train(dataset: Dataset[_]): SparkLRModel = { logger.info("Bobby train in SparkRapidsML library.") super.train(dataset) } } ``` and create `org.apache.spark.ml.Estimator` file under the `META-INF/services` and register the customized LogisticRegression ``` com.example.ml.LogisticRegression ``` - Implement the MLBackendPlugin ``` scala package com.example.ml import org.apache.spark.sql.connect.plugin.MLBackendPlugin import java.util.Optional class Plugin extends MLBackendPlugin { override def transform(mlName: String): Optional[String] = { mlName match { case "org.apache.spark.ml.classification.LogisticRegression" => Optional.of("com.example.ml.LogisticRegression") case _ => Optional.empty() } } } ``` - Compile After compilation, we can get a similiar com.example.ml-1.0-SNAPSHOT.jar 2. Run the python tests ``` python from pyspark.ml.classification import (LogisticRegression, LogisticRegressionModel) from pyspark.ml.linalg import Vectors from pyspark.sql import SparkSession os.environ["PYSPARK_PYTHON"] = "/home/xxx/anaconda3/envs/pyspark/bin/python" os.environ["PYSPARK_DRIVER_PYTHON"] = "/home/xxx/anaconda3/envs/pyspark/bin/python" os.environ["SPARK_CONNECT_MODE_ENABLED"] = "" spark = (SparkSession.builder.remote("sc://localhost") .config("spark.connect.ml.backend.classes", "com.example.ml.Plugin") .getOrCreate()) spark.addArtifact("target/com.example.ml-1.0-SNAPSHOT.jar") def run_test(): df = spark.createDataFrame([ (Vectors.dense([1.0, 2.0]), 1), (Vectors.dense([2.0, -1.0]), 1), (Vectors.dense([-3.0, -2.0]), 0), (Vectors.dense([-1.0, -2.0]), 0), ], schema=['features', 'label']) lr = LogisticRegression() lr.setMaxIter(30) lr.setThreshold(0.8) model: LogisticRegressionModel = lr.fit(df) assert model.getMaxIter() == 30 assert model.getThreshold() == 0.8 print(f"model: {model}") x = model.predictRaw(Vectors.dense([1.0, 2.0])) print(f"predictRaw {x}") print(f"model.coefficients: {model.coefficients}") print(f"model.intercept: {model.intercept}") print("-------------------- done ---------------- ") # Train with com.example.ml.LogisticRegression run_test() ``` Then we can check if `SPARK_HOME/logs/**org.apache.spark.sql.connect.service.SparkConnectServer-1**` contains `25/01/15 15:49:27 INFO LogisticRegression: Bobby train in SparkRapidsML library.`. ### Was this patch authored or co-authored using generative AI tooling? No -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org