wbo4958 opened a new pull request, #49503:
URL: https://github.com/apache/spark/pull/49503

   ### What changes were proposed in this pull request?
   
   This PR introduces connect ML with a plugin that could replace operators of 
spark.ml with third-party implementations.
   
   ### Why are the changes needed?
   
   The connect ML is using spark.ml as the backend implementations, which is 
not GPU-accelerated. This PR could allow the
   third-party library as the connect ML backend to accelerate the ML cases 
with GPUs without changing any users code. 
   
   
   ### Does this PR introduce _any_ user-facing change?
   
   Yes, it's a new feature. This PR defines a spark configuration 
`spark.connect.ml.backend.classes` to specify the plugin which
   must implement `org.apache.spark.sql.connect.plugin.MLBackendPlugin`.
   
   ### How was this patch tested?
   
   Make sure the CI (the newly added tests which have covered the plugin cases) 
pass. And manually run below code without any exception.
   
   1. Build your own plugin package
   
   
   - Customize the estimator
   
   ``` scala
   package com.example.ml
   
   import org.apache.commons.logging.LogFactory
   import org.apache.spark.ml.classification.{LogisticRegression => SparkLR, 
LogisticRegressionModel => SparkLRModel}
   import org.apache.spark.ml.util.Identifiable
   import org.apache.spark.sql.Dataset
   
   class LogisticRegression(override val uid: String) extends SparkLR {
   
     private val logger = LogFactory.getLog("com.example.ml.LogisticRegression")
   
     def this() = this(Identifiable.randomUID("logreg"))
   
     override def train(dataset: Dataset[_]): SparkLRModel = {
       logger.info("Bobby train in SparkRapidsML library.")
       super.train(dataset)
     }
   
   }
   ```
   
   and create `org.apache.spark.ml.Estimator` file under the 
`META-INF/services` and register the customized LogisticRegression
   
   ```
   com.example.ml.LogisticRegression
   ```
   
   - Implement the MLBackendPlugin
   
   ``` scala
   package com.example.ml
   
   import org.apache.spark.sql.connect.plugin.MLBackendPlugin
   
   import java.util.Optional
   
   class Plugin extends MLBackendPlugin {
   
     override def transform(mlName: String): Optional[String] = {
       mlName match {
         case "org.apache.spark.ml.classification.LogisticRegression" =>
           Optional.of("com.example.ml.LogisticRegression")
         case _ => Optional.empty()
       }
     }
   }
   ```
   
   - Compile
   
   After compilation, we can get a similiar com.example.ml-1.0-SNAPSHOT.jar
   
   2. Run the python tests
   
   ``` python
   from pyspark.ml.classification import (LogisticRegression,
                                          LogisticRegressionModel)
   from pyspark.ml.linalg import Vectors
   from pyspark.sql import SparkSession
   
   os.environ["PYSPARK_PYTHON"] = "/home/xxx/anaconda3/envs/pyspark/bin/python"
   os.environ["PYSPARK_DRIVER_PYTHON"] = 
"/home/xxx/anaconda3/envs/pyspark/bin/python"
   
   
   os.environ["SPARK_CONNECT_MODE_ENABLED"] = ""
   spark = (SparkSession.builder.remote("sc://localhost")
            .config("spark.connect.ml.backend.classes", "com.example.ml.Plugin")
            .getOrCreate())
   spark.addArtifact("target/com.example.ml-1.0-SNAPSHOT.jar")
   
   def run_test():
       df = spark.createDataFrame([
           (Vectors.dense([1.0, 2.0]), 1),
           (Vectors.dense([2.0, -1.0]), 1),
           (Vectors.dense([-3.0, -2.0]), 0),
           (Vectors.dense([-1.0, -2.0]), 0),
       ], schema=['features', 'label'])
       lr = LogisticRegression()
       lr.setMaxIter(30)
       lr.setThreshold(0.8)
       model: LogisticRegressionModel = lr.fit(df)
       assert model.getMaxIter() == 30
       assert model.getThreshold() == 0.8
       print(f"model: {model}")
       x = model.predictRaw(Vectors.dense([1.0, 2.0]))
       print(f"predictRaw {x}")
       print(f"model.coefficients: {model.coefficients}")
       print(f"model.intercept: {model.intercept}")
       print("-------------------- done ---------------- ")
   
   # Train with com.example.ml.LogisticRegression
   
   run_test()
   ```
   
   Then we can check if 
`SPARK_HOME/logs/**org.apache.spark.sql.connect.service.SparkConnectServer-1**` 
contains `25/01/15 15:49:27 INFO LogisticRegression: Bobby train in 
SparkRapidsML library.`.
   
   
   ### Was this patch authored or co-authored using generative AI tooling?
   No


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to