wbo4958 commented on code in PR #49596: URL: https://github.com/apache/spark/pull/49596#discussion_r1924953969
########## python/pyspark/ml/connect/readwrite.py: ########## @@ -37,52 +38,99 @@ def sc(self) -> "SparkContext": raise RuntimeError("Accessing SparkContext is not supported on Connect") def save(self, path: str) -> None: - from pyspark.ml.wrapper import JavaModel, JavaEstimator, JavaTransformer - from pyspark.ml.evaluation import JavaEvaluator from pyspark.sql.connect.session import SparkSession session = SparkSession.getActiveSession() assert session is not None + RemoteMLWriter.saveInstance( + self._instance, + path, + session, + self.shouldOverwrite, + self.optionMap, + ) + + @staticmethod + def saveInstance( + instance: "JavaMLWritable", + path: str, + session: "SparkSession", + shouldOverwrite: bool = False, + optionMap: Dict[str, Any] = {}, + ) -> None: + from pyspark.ml.wrapper import JavaModel, JavaEstimator, JavaTransformer + from pyspark.ml.evaluation import JavaEvaluator + from pyspark.ml.pipeline import Pipeline, PipelineModel + # Spark Connect ML is built on scala Spark.ML, that means we're only # supporting JavaModel or JavaEstimator or JavaEvaluator - if isinstance(self._instance, JavaModel): - model = cast("JavaModel", self._instance) + if isinstance(instance, JavaModel): + model = cast("JavaModel", instance) params = serialize_ml_params(model, session.client) assert isinstance(model._java_obj, str) writer = pb2.MlCommand.Write( obj_ref=pb2.ObjectRef(id=model._java_obj), params=params, path=path, - should_overwrite=self.shouldOverwrite, - options=self.optionMap, + should_overwrite=shouldOverwrite, + options=optionMap, ) - else: + command = pb2.Command() + command.ml_command.write.CopyFrom(writer) + session.client.execute_command(command) + + elif isinstance(instance, (JavaEstimator, JavaTransformer, JavaEvaluator)): operator: Union[JavaEstimator, JavaTransformer, JavaEvaluator] - if isinstance(self._instance, JavaEstimator): + if isinstance(instance, JavaEstimator): ml_type = pb2.MlOperator.ESTIMATOR - operator = cast("JavaEstimator", self._instance) - elif isinstance(self._instance, JavaEvaluator): + operator = cast("JavaEstimator", instance) + elif isinstance(instance, JavaEvaluator): ml_type = pb2.MlOperator.EVALUATOR - operator = cast("JavaEvaluator", self._instance) - elif isinstance(self._instance, JavaTransformer): - ml_type = pb2.MlOperator.TRANSFORMER - operator = cast("JavaTransformer", self._instance) + operator = cast("JavaEvaluator", instance) else: - raise NotImplementedError(f"Unsupported writing for {self._instance}") + ml_type = pb2.MlOperator.TRANSFORMER + operator = cast("JavaTransformer", instance) params = serialize_ml_params(operator, session.client) assert isinstance(operator._java_obj, str) writer = pb2.MlCommand.Write( operator=pb2.MlOperator(name=operator._java_obj, uid=operator.uid, type=ml_type), params=params, path=path, - should_overwrite=self.shouldOverwrite, - options=self.optionMap, + should_overwrite=shouldOverwrite, + options=optionMap, ) - command = pb2.Command() - command.ml_command.write.CopyFrom(writer) - session.client.execute_command(command) + command = pb2.Command() + command.ml_command.write.CopyFrom(writer) + session.client.execute_command(command) + + elif isinstance(instance, (Pipeline, PipelineModel)): + from pyspark.ml.pipeline import PipelineSharedReadWrite + + if shouldOverwrite: Review Comment: I'm wondering if we can have a DeleteFile command to do that? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org