zhengruifeng commented on code in PR #46129: URL: https://github.com/apache/spark/pull/46129#discussion_r1571874148
########## python/pyspark/sql/classic/dataframe.py: ########## @@ -0,0 +1,1952 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import json +import sys +import random +import warnings +from collections.abc import Iterable +from functools import reduce +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, + overload, + TYPE_CHECKING, +) + +from pyspark import _NoValue +from pyspark._globals import _NoValueType +from pyspark.errors import ( + PySparkTypeError, + PySparkValueError, + PySparkIndexError, + PySparkAttributeError, +) +from pyspark.util import ( + is_remote_only, + _load_from_socket, + _local_iterator_from_socket, +) +from pyspark.serializers import BatchedSerializer, CPickleSerializer, UTF8Deserializer +from pyspark.storagelevel import StorageLevel +from pyspark.traceback_utils import SCCallSiteSync +from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column +from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2 +from pyspark.sql.streaming import DataStreamWriter +from pyspark.sql.types import ( + StructType, + Row, + _parse_datatype_json_string, +) +from pyspark.sql.dataframe import ( + DataFrame as ParentDataFrame, + DataFrameNaFunctions as ParentDataFrameNaFunctions, + DataFrameStatFunctions as ParentDataFrameStatFunctions, +) +from pyspark.sql.utils import get_active_spark_context, toJArray +from pyspark.sql.pandas.conversion import PandasConversionMixin +from pyspark.sql.pandas.map_ops import PandasMapOpsMixin + +if TYPE_CHECKING: + from py4j.java_gateway import JavaObject + from pyspark.core.rdd import RDD + from pyspark.core.context import SparkContext + from pyspark._typing import PrimitiveType + from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame + from pyspark.sql._typing import ( + ColumnOrName, + ColumnOrNameOrOrdinal, + LiteralType, + OptionalPrimitiveType, + ) + from pyspark.sql.context import SQLContext + from pyspark.sql.session import SparkSession + from pyspark.sql.group import GroupedData + from pyspark.sql.observation import Observation + + +class DataFrame(ParentDataFrame, PandasMapOpsMixin, PandasConversionMixin): + def __new__( + cls, + jdf: "JavaObject", + sql_ctx: Union["SQLContext", "SparkSession"], + ) -> "DataFrame": + self = object.__new__(cls) + self.__init__(jdf, sql_ctx) # type: ignore[misc] + return self + + def __init__( + self, + jdf: "JavaObject", + sql_ctx: Union["SQLContext", "SparkSession"], + ): + from pyspark.sql.context import SQLContext + + self._sql_ctx: Optional["SQLContext"] = None + + if isinstance(sql_ctx, SQLContext): + assert not os.environ.get("SPARK_TESTING") # Sanity check for our internal usage. + assert isinstance(sql_ctx, SQLContext) + # We should remove this if-else branch in the future release, and rename + # sql_ctx to session in the constructor. This is an internal code path but + # was kept with a warning because it's used intensively by third-party libraries. + warnings.warn("DataFrame constructor is internal. Do not directly use it.") + self._sql_ctx = sql_ctx + session = sql_ctx.sparkSession + else: + session = sql_ctx + self._session: "SparkSession" = session + + self._sc: "SparkContext" = sql_ctx._sc + self._jdf: "JavaObject" = jdf + self.is_cached = False + # initialized lazily + self._schema: Optional[StructType] = None + self._lazy_rdd: Optional["RDD[Row]"] = None + # Check whether _repr_html is supported or not, we use it to avoid calling _jdf twice + # by __repr__ and _repr_html_ while eager evaluation opens. + self._support_repr_html = False + + @property + def sql_ctx(self) -> "SQLContext": + from pyspark.sql.context import SQLContext + + warnings.warn( + "DataFrame.sql_ctx is an internal property, and will be removed " + "in future releases. Use DataFrame.sparkSession instead." + ) + if self._sql_ctx is None: + self._sql_ctx = SQLContext._get_or_create(self._sc) + return self._sql_ctx + + @property + def sparkSession(self) -> "SparkSession": + return self._session + + if not is_remote_only(): + + @property + def rdd(self) -> "RDD[Row]": + from pyspark.core.rdd import RDD + + if self._lazy_rdd is None: + jrdd = self._jdf.javaToPython() + self._lazy_rdd = RDD( + jrdd, self.sparkSession._sc, BatchedSerializer(CPickleSerializer()) + ) + return self._lazy_rdd + + @property + def na(self) -> "DataFrameNaFunctions": + return DataFrameNaFunctions(self) + + @property + def stat(self) -> "DataFrameStatFunctions": + return DataFrameStatFunctions(self) + + if not is_remote_only(): + + def toJSON(self, use_unicode: bool = True) -> "RDD[str]": + from pyspark.core.rdd import RDD + + rdd = self._jdf.toJSON() + return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) + + def registerTempTable(self, name: str) -> None: + warnings.warn("Deprecated in 2.0, use createOrReplaceTempView instead.", FutureWarning) + self._jdf.createOrReplaceTempView(name) + + def createTempView(self, name: str) -> None: + self._jdf.createTempView(name) + + def createOrReplaceTempView(self, name: str) -> None: + self._jdf.createOrReplaceTempView(name) + + def createGlobalTempView(self, name: str) -> None: + self._jdf.createGlobalTempView(name) + + def createOrReplaceGlobalTempView(self, name: str) -> None: + self._jdf.createOrReplaceGlobalTempView(name) + + @property + def write(self) -> DataFrameWriter: + return DataFrameWriter(self) + + @property + def writeStream(self) -> DataStreamWriter: + return DataStreamWriter(self) + + @property + def schema(self) -> StructType: + if self._schema is None: + try: + self._schema = cast( + StructType, _parse_datatype_json_string(self._jdf.schema().json()) + ) + except Exception as e: + raise PySparkValueError( + error_class="CANNOT_PARSE_DATATYPE", + message_parameters={"error": str(e)}, + ) + return self._schema + + def printSchema(self, level: Optional[int] = None) -> None: + if level: + print(self._jdf.schema().treeString(level)) + else: + print(self._jdf.schema().treeString()) + + def explain( + self, extended: Optional[Union[bool, str]] = None, mode: Optional[str] = None + ) -> None: + if extended is not None and mode is not None: + raise PySparkValueError( + error_class="CANNOT_SET_TOGETHER", + message_parameters={"arg_list": "extended and mode"}, + ) + + # For the no argument case: df.explain() + is_no_argument = extended is None and mode is None + + # For the cases below: + # explain(True) + # explain(extended=False) + is_extended_case = isinstance(extended, bool) and mode is None + + # For the case when extended is mode: + # df.explain("formatted") + is_extended_as_mode = isinstance(extended, str) and mode is None + + # For the mode specified: + # df.explain(mode="formatted") + is_mode_case = extended is None and isinstance(mode, str) + + if not (is_no_argument or is_extended_case or is_extended_as_mode or is_mode_case): + if (extended is not None) and (not isinstance(extended, (bool, str))): + raise PySparkTypeError( + error_class="NOT_BOOL_OR_STR", + message_parameters={ + "arg_name": "extended", + "arg_type": type(extended).__name__, + }, + ) + if (mode is not None) and (not isinstance(mode, str)): + raise PySparkTypeError( + error_class="NOT_STR", + message_parameters={"arg_name": "mode", "arg_type": type(mode).__name__}, + ) + + # Sets an explain mode depending on a given argument + if is_no_argument: + explain_mode = "simple" + elif is_extended_case: + explain_mode = "extended" if extended else "simple" + elif is_mode_case: + explain_mode = cast(str, mode) + elif is_extended_as_mode: + explain_mode = cast(str, extended) + assert self._sc._jvm is not None + print(self._sc._jvm.PythonSQLUtils.explainString(self._jdf.queryExecution(), explain_mode)) + + def exceptAll(self, other: "ParentDataFrame") -> "ParentDataFrame": + return DataFrame(self._jdf.exceptAll(other._jdf), self.sparkSession) + + def isLocal(self) -> bool: + return self._jdf.isLocal() + + @property + def isStreaming(self) -> bool: + return self._jdf.isStreaming() + + def isEmpty(self) -> bool: + return self._jdf.isEmpty() + + def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> None: Review Comment: will the docstring be inherited from the superclass? or still need to copy them one by one? (like `show.__doc__ = PySparkDataFrame.show.__doc__` in connect) ########## python/pyspark/sql/classic/dataframe.py: ########## @@ -0,0 +1,1952 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import json +import sys +import random +import warnings +from collections.abc import Iterable +from functools import reduce +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, + overload, + TYPE_CHECKING, +) + +from pyspark import _NoValue +from pyspark._globals import _NoValueType +from pyspark.errors import ( + PySparkTypeError, + PySparkValueError, + PySparkIndexError, + PySparkAttributeError, +) +from pyspark.util import ( + is_remote_only, + _load_from_socket, + _local_iterator_from_socket, +) +from pyspark.serializers import BatchedSerializer, CPickleSerializer, UTF8Deserializer +from pyspark.storagelevel import StorageLevel +from pyspark.traceback_utils import SCCallSiteSync +from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column +from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2 +from pyspark.sql.streaming import DataStreamWriter +from pyspark.sql.types import ( + StructType, + Row, + _parse_datatype_json_string, +) +from pyspark.sql.dataframe import ( + DataFrame as ParentDataFrame, + DataFrameNaFunctions as ParentDataFrameNaFunctions, + DataFrameStatFunctions as ParentDataFrameStatFunctions, +) +from pyspark.sql.utils import get_active_spark_context, toJArray +from pyspark.sql.pandas.conversion import PandasConversionMixin +from pyspark.sql.pandas.map_ops import PandasMapOpsMixin + +if TYPE_CHECKING: + from py4j.java_gateway import JavaObject + from pyspark.core.rdd import RDD + from pyspark.core.context import SparkContext + from pyspark._typing import PrimitiveType + from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame + from pyspark.sql._typing import ( + ColumnOrName, + ColumnOrNameOrOrdinal, + LiteralType, + OptionalPrimitiveType, + ) + from pyspark.sql.context import SQLContext + from pyspark.sql.session import SparkSession + from pyspark.sql.group import GroupedData + from pyspark.sql.observation import Observation + + +class DataFrame(ParentDataFrame, PandasMapOpsMixin, PandasConversionMixin): + def __new__( + cls, + jdf: "JavaObject", + sql_ctx: Union["SQLContext", "SparkSession"], + ) -> "DataFrame": + self = object.__new__(cls) + self.__init__(jdf, sql_ctx) # type: ignore[misc] + return self + + def __init__( + self, + jdf: "JavaObject", + sql_ctx: Union["SQLContext", "SparkSession"], + ): + from pyspark.sql.context import SQLContext + + self._sql_ctx: Optional["SQLContext"] = None + + if isinstance(sql_ctx, SQLContext): + assert not os.environ.get("SPARK_TESTING") # Sanity check for our internal usage. + assert isinstance(sql_ctx, SQLContext) + # We should remove this if-else branch in the future release, and rename + # sql_ctx to session in the constructor. This is an internal code path but + # was kept with a warning because it's used intensively by third-party libraries. + warnings.warn("DataFrame constructor is internal. Do not directly use it.") + self._sql_ctx = sql_ctx + session = sql_ctx.sparkSession + else: + session = sql_ctx + self._session: "SparkSession" = session + + self._sc: "SparkContext" = sql_ctx._sc + self._jdf: "JavaObject" = jdf + self.is_cached = False + # initialized lazily + self._schema: Optional[StructType] = None + self._lazy_rdd: Optional["RDD[Row]"] = None + # Check whether _repr_html is supported or not, we use it to avoid calling _jdf twice + # by __repr__ and _repr_html_ while eager evaluation opens. + self._support_repr_html = False + + @property + def sql_ctx(self) -> "SQLContext": + from pyspark.sql.context import SQLContext + + warnings.warn( + "DataFrame.sql_ctx is an internal property, and will be removed " + "in future releases. Use DataFrame.sparkSession instead." + ) + if self._sql_ctx is None: + self._sql_ctx = SQLContext._get_or_create(self._sc) + return self._sql_ctx + + @property + def sparkSession(self) -> "SparkSession": + return self._session + + if not is_remote_only(): + + @property + def rdd(self) -> "RDD[Row]": + from pyspark.core.rdd import RDD + + if self._lazy_rdd is None: + jrdd = self._jdf.javaToPython() + self._lazy_rdd = RDD( + jrdd, self.sparkSession._sc, BatchedSerializer(CPickleSerializer()) + ) + return self._lazy_rdd + + @property + def na(self) -> "DataFrameNaFunctions": + return DataFrameNaFunctions(self) + + @property + def stat(self) -> "DataFrameStatFunctions": + return DataFrameStatFunctions(self) + + if not is_remote_only(): + + def toJSON(self, use_unicode: bool = True) -> "RDD[str]": + from pyspark.core.rdd import RDD + + rdd = self._jdf.toJSON() + return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) + + def registerTempTable(self, name: str) -> None: + warnings.warn("Deprecated in 2.0, use createOrReplaceTempView instead.", FutureWarning) + self._jdf.createOrReplaceTempView(name) + + def createTempView(self, name: str) -> None: + self._jdf.createTempView(name) + + def createOrReplaceTempView(self, name: str) -> None: + self._jdf.createOrReplaceTempView(name) + + def createGlobalTempView(self, name: str) -> None: + self._jdf.createGlobalTempView(name) + + def createOrReplaceGlobalTempView(self, name: str) -> None: + self._jdf.createOrReplaceGlobalTempView(name) + + @property + def write(self) -> DataFrameWriter: + return DataFrameWriter(self) + + @property + def writeStream(self) -> DataStreamWriter: + return DataStreamWriter(self) + + @property + def schema(self) -> StructType: + if self._schema is None: + try: + self._schema = cast( + StructType, _parse_datatype_json_string(self._jdf.schema().json()) + ) + except Exception as e: + raise PySparkValueError( + error_class="CANNOT_PARSE_DATATYPE", + message_parameters={"error": str(e)}, + ) + return self._schema + + def printSchema(self, level: Optional[int] = None) -> None: + if level: + print(self._jdf.schema().treeString(level)) + else: + print(self._jdf.schema().treeString()) + + def explain( + self, extended: Optional[Union[bool, str]] = None, mode: Optional[str] = None + ) -> None: + if extended is not None and mode is not None: Review Comment: we can move such complex preprocessing to the superclasses later -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
