wbo4958 commented on code in PR #49596:
URL: https://github.com/apache/spark/pull/49596#discussion_r1924953969


##########
python/pyspark/ml/connect/readwrite.py:
##########
@@ -37,52 +38,99 @@ def sc(self) -> "SparkContext":
         raise RuntimeError("Accessing SparkContext is not supported on 
Connect")
 
     def save(self, path: str) -> None:
-        from pyspark.ml.wrapper import JavaModel, JavaEstimator, 
JavaTransformer
-        from pyspark.ml.evaluation import JavaEvaluator
         from pyspark.sql.connect.session import SparkSession
 
         session = SparkSession.getActiveSession()
         assert session is not None
 
+        RemoteMLWriter.saveInstance(
+            self._instance,
+            path,
+            session,
+            self.shouldOverwrite,
+            self.optionMap,
+        )
+
+    @staticmethod
+    def saveInstance(
+        instance: "JavaMLWritable",
+        path: str,
+        session: "SparkSession",
+        shouldOverwrite: bool = False,
+        optionMap: Dict[str, Any] = {},
+    ) -> None:
+        from pyspark.ml.wrapper import JavaModel, JavaEstimator, 
JavaTransformer
+        from pyspark.ml.evaluation import JavaEvaluator
+        from pyspark.ml.pipeline import Pipeline, PipelineModel
+
         # Spark Connect ML is built on scala Spark.ML, that means we're only
         # supporting JavaModel or JavaEstimator or JavaEvaluator
-        if isinstance(self._instance, JavaModel):
-            model = cast("JavaModel", self._instance)
+        if isinstance(instance, JavaModel):
+            model = cast("JavaModel", instance)
             params = serialize_ml_params(model, session.client)
             assert isinstance(model._java_obj, str)
             writer = pb2.MlCommand.Write(
                 obj_ref=pb2.ObjectRef(id=model._java_obj),
                 params=params,
                 path=path,
-                should_overwrite=self.shouldOverwrite,
-                options=self.optionMap,
+                should_overwrite=shouldOverwrite,
+                options=optionMap,
             )
-        else:
+            command = pb2.Command()
+            command.ml_command.write.CopyFrom(writer)
+            session.client.execute_command(command)
+
+        elif isinstance(instance, (JavaEstimator, JavaTransformer, 
JavaEvaluator)):
             operator: Union[JavaEstimator, JavaTransformer, JavaEvaluator]
-            if isinstance(self._instance, JavaEstimator):
+            if isinstance(instance, JavaEstimator):
                 ml_type = pb2.MlOperator.ESTIMATOR
-                operator = cast("JavaEstimator", self._instance)
-            elif isinstance(self._instance, JavaEvaluator):
+                operator = cast("JavaEstimator", instance)
+            elif isinstance(instance, JavaEvaluator):
                 ml_type = pb2.MlOperator.EVALUATOR
-                operator = cast("JavaEvaluator", self._instance)
-            elif isinstance(self._instance, JavaTransformer):
-                ml_type = pb2.MlOperator.TRANSFORMER
-                operator = cast("JavaTransformer", self._instance)
+                operator = cast("JavaEvaluator", instance)
             else:
-                raise NotImplementedError(f"Unsupported writing for 
{self._instance}")
+                ml_type = pb2.MlOperator.TRANSFORMER
+                operator = cast("JavaTransformer", instance)
 
             params = serialize_ml_params(operator, session.client)
             assert isinstance(operator._java_obj, str)
             writer = pb2.MlCommand.Write(
                 operator=pb2.MlOperator(name=operator._java_obj, 
uid=operator.uid, type=ml_type),
                 params=params,
                 path=path,
-                should_overwrite=self.shouldOverwrite,
-                options=self.optionMap,
+                should_overwrite=shouldOverwrite,
+                options=optionMap,
             )
-        command = pb2.Command()
-        command.ml_command.write.CopyFrom(writer)
-        session.client.execute_command(command)
+            command = pb2.Command()
+            command.ml_command.write.CopyFrom(writer)
+            session.client.execute_command(command)
+
+        elif isinstance(instance, (Pipeline, PipelineModel)):
+            from pyspark.ml.pipeline import PipelineSharedReadWrite
+
+            if shouldOverwrite:

Review Comment:
   I'm wondering if we can have a DeleteFile command to do that?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to