ueshin commented on code in PR #46129:
URL: https://github.com/apache/spark/pull/46129#discussion_r1572787315
##########
python/pyspark/sql/dataframe.py:
##########
@@ -139,51 +123,29 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
created via using the constructor.
"""
- def __init__(
- self,
+ # HACK ALERT!! this is to reduce the backward compatibility concern, and
returns
+ # Spark Classic DataFrame by default. This is NOT an API, and NOT supposed
to
+ # be directly invoked. DO NOT use this constructor.
+ _sql_ctx: Optional["SQLContext"]
+ _session: "SparkSession"
+ _sc: "SparkContext"
+ _jdf: "JavaObject"
+ is_cached: bool
+ _schema: Optional[StructType]
+ _lazy_rdd: Optional["RDD[Row]"]
+ _support_repr_html: bool
+
+ def __new__(
+ cls,
jdf: "JavaObject",
sql_ctx: Union["SQLContext", "SparkSession"],
- ):
- from pyspark.sql.context import SQLContext
-
- self._sql_ctx: Optional["SQLContext"] = None
-
- if isinstance(sql_ctx, SQLContext):
- assert not os.environ.get("SPARK_TESTING") # Sanity check for our
internal usage.
- assert isinstance(sql_ctx, SQLContext)
- # We should remove this if-else branch in the future release, and
rename
- # sql_ctx to session in the constructor. This is an internal code
path but
- # was kept with a warning because it's used intensively by
third-party libraries.
- warnings.warn("DataFrame constructor is internal. Do not directly
use it.")
- self._sql_ctx = sql_ctx
- session = sql_ctx.sparkSession
- else:
- session = sql_ctx
- self._session: "SparkSession" = session
-
- self._sc: "SparkContext" = sql_ctx._sc
- self._jdf: "JavaObject" = jdf
- self.is_cached = False
- # initialized lazily
- self._schema: Optional[StructType] = None
- self._lazy_rdd: Optional["RDD[Row]"] = None
- # Check whether _repr_html is supported or not, we use it to avoid
calling _jdf twice
- # by __repr__ and _repr_html_ while eager evaluation opens.
- self._support_repr_html = False
-
- @property
- def sql_ctx(self) -> "SQLContext":
- from pyspark.sql.context import SQLContext
+ ) -> "DataFrame":
+ from pyspark.sql.classic.dataframe import DataFrame
- warnings.warn(
- "DataFrame.sql_ctx is an internal property, and will be removed "
- "in future releases. Use DataFrame.sparkSession instead."
- )
- if self._sql_ctx is None:
- self._sql_ctx = SQLContext._get_or_create(self._sc)
- return self._sql_ctx
+ return DataFrame.__new__(DataFrame, jdf, sql_ctx)
@property
+ @dispatch_df_method
Review Comment:
The dispatch for `property` seems not working.
```py
>>> class A:
... @property
... @dispatch_df_method
... def a(self):
... return 1
>>>
>>> a = A()
>>> A.a(a)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: 'property' object is not callable
```
I don't think we need this for `property` as this usage of `property` won't
work anyway?
##########
python/pyspark/sql/classic/dataframe.py:
##########
@@ -0,0 +1,1974 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import json
+import sys
+import random
+import warnings
+from collections.abc import Iterable
+from functools import reduce
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterator,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
+ cast,
+ overload,
+ TYPE_CHECKING,
+)
+
+from pyspark import _NoValue
+from pyspark.resource import ResourceProfile
+from pyspark._globals import _NoValueType
+from pyspark.errors import (
+ PySparkTypeError,
+ PySparkValueError,
+ PySparkIndexError,
+ PySparkAttributeError,
+)
+from pyspark.util import (
+ is_remote_only,
+ _load_from_socket,
+ _local_iterator_from_socket,
+)
+from pyspark.serializers import BatchedSerializer, CPickleSerializer,
UTF8Deserializer
+from pyspark.storagelevel import StorageLevel
+from pyspark.traceback_utils import SCCallSiteSync
+from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column
+from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2
+from pyspark.sql.streaming import DataStreamWriter
+from pyspark.sql.types import (
+ StructType,
+ Row,
+ _parse_datatype_json_string,
+)
+from pyspark.sql.dataframe import (
+ DataFrame as ParentDataFrame,
+ DataFrameNaFunctions as ParentDataFrameNaFunctions,
+ DataFrameStatFunctions as ParentDataFrameStatFunctions,
+)
+from pyspark.sql.utils import get_active_spark_context, toJArray
+from pyspark.sql.pandas.conversion import PandasConversionMixin
+from pyspark.sql.pandas.map_ops import PandasMapOpsMixin
+
+if TYPE_CHECKING:
+ from py4j.java_gateway import JavaObject
+ from pyspark.core.rdd import RDD
+ from pyspark.core.context import SparkContext
+ from pyspark._typing import PrimitiveType
+ from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
+ from pyspark.sql._typing import (
+ ColumnOrName,
+ ColumnOrNameOrOrdinal,
+ LiteralType,
+ OptionalPrimitiveType,
+ )
+ from pyspark.sql.context import SQLContext
+ from pyspark.sql.session import SparkSession
+ from pyspark.sql.group import GroupedData
+ from pyspark.sql.observation import Observation
+
+
+class DataFrame(ParentDataFrame, PandasMapOpsMixin, PandasConversionMixin):
+ def __new__(
+ cls,
+ jdf: "JavaObject",
+ sql_ctx: Union["SQLContext", "SparkSession"],
+ ) -> "DataFrame":
+ self = object.__new__(cls)
+ self.__init__(jdf, sql_ctx) # type: ignore[misc]
+ return self
+
+ def __init__(
+ self,
+ jdf: "JavaObject",
+ sql_ctx: Union["SQLContext", "SparkSession"],
+ ):
+ from pyspark.sql.context import SQLContext
+
+ self._sql_ctx: Optional["SQLContext"] = None
+
+ if isinstance(sql_ctx, SQLContext):
+ assert not os.environ.get("SPARK_TESTING") # Sanity check for our
internal usage.
+ assert isinstance(sql_ctx, SQLContext)
+ # We should remove this if-else branch in the future release, and
rename
+ # sql_ctx to session in the constructor. This is an internal code
path but
+ # was kept with a warning because it's used intensively by
third-party libraries.
+ warnings.warn("DataFrame constructor is internal. Do not directly
use it.")
+ self._sql_ctx = sql_ctx
+ session = sql_ctx.sparkSession
+ else:
+ session = sql_ctx
+ self._session: "SparkSession" = session
+
+ self._sc: "SparkContext" = sql_ctx._sc
+ self._jdf: "JavaObject" = jdf
+ self.is_cached = False
+ # initialized lazily
+ self._schema: Optional[StructType] = None
+ self._lazy_rdd: Optional["RDD[Row]"] = None
+ # Check whether _repr_html is supported or not, we use it to avoid
calling _jdf twice
+ # by __repr__ and _repr_html_ while eager evaluation opens.
+ self._support_repr_html = False
+
+ @property
+ def sql_ctx(self) -> "SQLContext":
+ from pyspark.sql.context import SQLContext
+
+ warnings.warn(
+ "DataFrame.sql_ctx is an internal property, and will be removed "
+ "in future releases. Use DataFrame.sparkSession instead."
+ )
+ if self._sql_ctx is None:
+ self._sql_ctx = SQLContext._get_or_create(self._sc)
+ return self._sql_ctx
+
+ @property
+ def sparkSession(self) -> "SparkSession":
+ return self._session
+
+ if not is_remote_only():
+
+ @property
+ def rdd(self) -> "RDD[Row]":
+ from pyspark.core.rdd import RDD
+
+ if self._lazy_rdd is None:
+ jrdd = self._jdf.javaToPython()
+ self._lazy_rdd = RDD(
+ jrdd, self.sparkSession._sc,
BatchedSerializer(CPickleSerializer())
+ )
+ return self._lazy_rdd
+
+ @property
+ def na(self) -> "DataFrameNaFunctions":
+ return DataFrameNaFunctions(self)
+
+ @property
+ def stat(self) -> "DataFrameStatFunctions":
+ return DataFrameStatFunctions(self)
+
+ if not is_remote_only():
Review Comment:
ditto?
##########
python/pyspark/sql/classic/dataframe.py:
##########
@@ -0,0 +1,1974 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import json
+import sys
+import random
+import warnings
+from collections.abc import Iterable
+from functools import reduce
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterator,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
+ cast,
+ overload,
+ TYPE_CHECKING,
+)
+
+from pyspark import _NoValue
+from pyspark.resource import ResourceProfile
+from pyspark._globals import _NoValueType
+from pyspark.errors import (
+ PySparkTypeError,
+ PySparkValueError,
+ PySparkIndexError,
+ PySparkAttributeError,
+)
+from pyspark.util import (
+ is_remote_only,
+ _load_from_socket,
+ _local_iterator_from_socket,
+)
+from pyspark.serializers import BatchedSerializer, CPickleSerializer,
UTF8Deserializer
+from pyspark.storagelevel import StorageLevel
+from pyspark.traceback_utils import SCCallSiteSync
+from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column
+from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2
+from pyspark.sql.streaming import DataStreamWriter
+from pyspark.sql.types import (
+ StructType,
+ Row,
+ _parse_datatype_json_string,
+)
+from pyspark.sql.dataframe import (
+ DataFrame as ParentDataFrame,
+ DataFrameNaFunctions as ParentDataFrameNaFunctions,
+ DataFrameStatFunctions as ParentDataFrameStatFunctions,
+)
+from pyspark.sql.utils import get_active_spark_context, toJArray
+from pyspark.sql.pandas.conversion import PandasConversionMixin
+from pyspark.sql.pandas.map_ops import PandasMapOpsMixin
+
+if TYPE_CHECKING:
+ from py4j.java_gateway import JavaObject
+ from pyspark.core.rdd import RDD
+ from pyspark.core.context import SparkContext
+ from pyspark._typing import PrimitiveType
+ from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
+ from pyspark.sql._typing import (
+ ColumnOrName,
+ ColumnOrNameOrOrdinal,
+ LiteralType,
+ OptionalPrimitiveType,
+ )
+ from pyspark.sql.context import SQLContext
+ from pyspark.sql.session import SparkSession
+ from pyspark.sql.group import GroupedData
+ from pyspark.sql.observation import Observation
+
+
+class DataFrame(ParentDataFrame, PandasMapOpsMixin, PandasConversionMixin):
+ def __new__(
+ cls,
+ jdf: "JavaObject",
+ sql_ctx: Union["SQLContext", "SparkSession"],
+ ) -> "DataFrame":
+ self = object.__new__(cls)
+ self.__init__(jdf, sql_ctx) # type: ignore[misc]
+ return self
+
+ def __init__(
+ self,
+ jdf: "JavaObject",
+ sql_ctx: Union["SQLContext", "SparkSession"],
+ ):
+ from pyspark.sql.context import SQLContext
+
+ self._sql_ctx: Optional["SQLContext"] = None
+
+ if isinstance(sql_ctx, SQLContext):
+ assert not os.environ.get("SPARK_TESTING") # Sanity check for our
internal usage.
+ assert isinstance(sql_ctx, SQLContext)
+ # We should remove this if-else branch in the future release, and
rename
+ # sql_ctx to session in the constructor. This is an internal code
path but
+ # was kept with a warning because it's used intensively by
third-party libraries.
+ warnings.warn("DataFrame constructor is internal. Do not directly
use it.")
+ self._sql_ctx = sql_ctx
+ session = sql_ctx.sparkSession
+ else:
+ session = sql_ctx
+ self._session: "SparkSession" = session
+
+ self._sc: "SparkContext" = sql_ctx._sc
+ self._jdf: "JavaObject" = jdf
+ self.is_cached = False
+ # initialized lazily
+ self._schema: Optional[StructType] = None
+ self._lazy_rdd: Optional["RDD[Row]"] = None
+ # Check whether _repr_html is supported or not, we use it to avoid
calling _jdf twice
+ # by __repr__ and _repr_html_ while eager evaluation opens.
+ self._support_repr_html = False
+
+ @property
+ def sql_ctx(self) -> "SQLContext":
+ from pyspark.sql.context import SQLContext
+
+ warnings.warn(
+ "DataFrame.sql_ctx is an internal property, and will be removed "
+ "in future releases. Use DataFrame.sparkSession instead."
+ )
+ if self._sql_ctx is None:
+ self._sql_ctx = SQLContext._get_or_create(self._sc)
+ return self._sql_ctx
+
+ @property
+ def sparkSession(self) -> "SparkSession":
+ return self._session
+
+ if not is_remote_only():
+
+ @property
+ def rdd(self) -> "RDD[Row]":
+ from pyspark.core.rdd import RDD
+
+ if self._lazy_rdd is None:
+ jrdd = self._jdf.javaToPython()
+ self._lazy_rdd = RDD(
+ jrdd, self.sparkSession._sc,
BatchedSerializer(CPickleSerializer())
+ )
+ return self._lazy_rdd
+
+ @property
+ def na(self) -> "DataFrameNaFunctions":
+ return DataFrameNaFunctions(self)
+
+ @property
+ def stat(self) -> "DataFrameStatFunctions":
Review Comment:
nit: we don't need `"` for `ParentXxx` type hints as they are pre-defined.
##########
python/pyspark/sql/utils.py:
##########
@@ -302,6 +302,33 @@ def wrapped(*args: Any, **kwargs: Any) -> Any:
return cast(FuncT, wrapped)
+def dispatch_df_method(f: FuncT) -> FuncT:
+ """
+ For the usecases of direct DataFrame.union(df, ...), it checks if self
+ is a Connect DataFrame or Classic DataFrame, and dispatches.
+ """
+
+ @functools.wraps(f)
+ def wrapped(*args: Any, **kwargs: Any) -> Any:
+ if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
+ from pyspark.sql.connect.dataframe import DataFrame as
ConnectDataFrame
+
+ if isinstance(args[0], ConnectDataFrame):
+ return getattr(ConnectDataFrame, f.__name__)(*args, **kwargs)
+ else:
+ from pyspark.sql.classic.dataframe import DataFrame as
ClassicDataFrame
Review Comment:
We may want to try-catch here for `pyspark-connect` package users?
##########
python/pyspark/sql/classic/dataframe.py:
##########
@@ -0,0 +1,1974 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import json
+import sys
+import random
+import warnings
+from collections.abc import Iterable
+from functools import reduce
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterator,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
+ cast,
+ overload,
+ TYPE_CHECKING,
+)
+
+from pyspark import _NoValue
+from pyspark.resource import ResourceProfile
+from pyspark._globals import _NoValueType
+from pyspark.errors import (
+ PySparkTypeError,
+ PySparkValueError,
+ PySparkIndexError,
+ PySparkAttributeError,
+)
+from pyspark.util import (
+ is_remote_only,
+ _load_from_socket,
+ _local_iterator_from_socket,
+)
+from pyspark.serializers import BatchedSerializer, CPickleSerializer,
UTF8Deserializer
+from pyspark.storagelevel import StorageLevel
+from pyspark.traceback_utils import SCCallSiteSync
+from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column
+from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2
+from pyspark.sql.streaming import DataStreamWriter
+from pyspark.sql.types import (
+ StructType,
+ Row,
+ _parse_datatype_json_string,
+)
+from pyspark.sql.dataframe import (
+ DataFrame as ParentDataFrame,
+ DataFrameNaFunctions as ParentDataFrameNaFunctions,
+ DataFrameStatFunctions as ParentDataFrameStatFunctions,
+)
+from pyspark.sql.utils import get_active_spark_context, toJArray
+from pyspark.sql.pandas.conversion import PandasConversionMixin
+from pyspark.sql.pandas.map_ops import PandasMapOpsMixin
+
+if TYPE_CHECKING:
+ from py4j.java_gateway import JavaObject
+ from pyspark.core.rdd import RDD
+ from pyspark.core.context import SparkContext
+ from pyspark._typing import PrimitiveType
+ from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
+ from pyspark.sql._typing import (
+ ColumnOrName,
+ ColumnOrNameOrOrdinal,
+ LiteralType,
+ OptionalPrimitiveType,
+ )
+ from pyspark.sql.context import SQLContext
+ from pyspark.sql.session import SparkSession
+ from pyspark.sql.group import GroupedData
+ from pyspark.sql.observation import Observation
+
+
+class DataFrame(ParentDataFrame, PandasMapOpsMixin, PandasConversionMixin):
+ def __new__(
+ cls,
+ jdf: "JavaObject",
+ sql_ctx: Union["SQLContext", "SparkSession"],
+ ) -> "DataFrame":
+ self = object.__new__(cls)
+ self.__init__(jdf, sql_ctx) # type: ignore[misc]
+ return self
+
+ def __init__(
+ self,
+ jdf: "JavaObject",
+ sql_ctx: Union["SQLContext", "SparkSession"],
+ ):
+ from pyspark.sql.context import SQLContext
+
+ self._sql_ctx: Optional["SQLContext"] = None
+
+ if isinstance(sql_ctx, SQLContext):
+ assert not os.environ.get("SPARK_TESTING") # Sanity check for our
internal usage.
+ assert isinstance(sql_ctx, SQLContext)
+ # We should remove this if-else branch in the future release, and
rename
+ # sql_ctx to session in the constructor. This is an internal code
path but
+ # was kept with a warning because it's used intensively by
third-party libraries.
+ warnings.warn("DataFrame constructor is internal. Do not directly
use it.")
+ self._sql_ctx = sql_ctx
+ session = sql_ctx.sparkSession
+ else:
+ session = sql_ctx
+ self._session: "SparkSession" = session
+
+ self._sc: "SparkContext" = sql_ctx._sc
+ self._jdf: "JavaObject" = jdf
+ self.is_cached = False
+ # initialized lazily
+ self._schema: Optional[StructType] = None
+ self._lazy_rdd: Optional["RDD[Row]"] = None
+ # Check whether _repr_html is supported or not, we use it to avoid
calling _jdf twice
+ # by __repr__ and _repr_html_ while eager evaluation opens.
+ self._support_repr_html = False
+
+ @property
+ def sql_ctx(self) -> "SQLContext":
+ from pyspark.sql.context import SQLContext
+
+ warnings.warn(
+ "DataFrame.sql_ctx is an internal property, and will be removed "
+ "in future releases. Use DataFrame.sparkSession instead."
+ )
+ if self._sql_ctx is None:
+ self._sql_ctx = SQLContext._get_or_create(self._sc)
+ return self._sql_ctx
+
+ @property
+ def sparkSession(self) -> "SparkSession":
+ return self._session
+
+ if not is_remote_only():
+
+ @property
+ def rdd(self) -> "RDD[Row]":
+ from pyspark.core.rdd import RDD
+
+ if self._lazy_rdd is None:
+ jrdd = self._jdf.javaToPython()
+ self._lazy_rdd = RDD(
+ jrdd, self.sparkSession._sc,
BatchedSerializer(CPickleSerializer())
+ )
+ return self._lazy_rdd
+
+ @property
+ def na(self) -> "DataFrameNaFunctions":
+ return DataFrameNaFunctions(self)
+
+ @property
+ def stat(self) -> "DataFrameStatFunctions":
+ return DataFrameStatFunctions(self)
+
+ if not is_remote_only():
+
+ def toJSON(self, use_unicode: bool = True) -> "RDD[str]":
+ from pyspark.core.rdd import RDD
+
+ rdd = self._jdf.toJSON()
+ return RDD(rdd.toJavaRDD(), self._sc,
UTF8Deserializer(use_unicode))
+
+ def registerTempTable(self, name: str) -> None:
+ warnings.warn("Deprecated in 2.0, use createOrReplaceTempView
instead.", FutureWarning)
+ self._jdf.createOrReplaceTempView(name)
+
+ def createTempView(self, name: str) -> None:
+ self._jdf.createTempView(name)
+
+ def createOrReplaceTempView(self, name: str) -> None:
+ self._jdf.createOrReplaceTempView(name)
+
+ def createGlobalTempView(self, name: str) -> None:
+ self._jdf.createGlobalTempView(name)
+
+ def createOrReplaceGlobalTempView(self, name: str) -> None:
+ self._jdf.createOrReplaceGlobalTempView(name)
+
+ @property
+ def write(self) -> DataFrameWriter:
+ return DataFrameWriter(self)
+
+ @property
+ def writeStream(self) -> DataStreamWriter:
+ return DataStreamWriter(self)
+
+ @property
+ def schema(self) -> StructType:
+ if self._schema is None:
+ try:
+ self._schema = cast(
+ StructType,
_parse_datatype_json_string(self._jdf.schema().json())
+ )
+ except Exception as e:
+ raise PySparkValueError(
+ error_class="CANNOT_PARSE_DATATYPE",
+ message_parameters={"error": str(e)},
+ )
+ return self._schema
+
+ def printSchema(self, level: Optional[int] = None) -> None:
+ if level:
+ print(self._jdf.schema().treeString(level))
+ else:
+ print(self._jdf.schema().treeString())
+
+ def explain(
+ self, extended: Optional[Union[bool, str]] = None, mode: Optional[str]
= None
+ ) -> None:
+ if extended is not None and mode is not None:
+ raise PySparkValueError(
+ error_class="CANNOT_SET_TOGETHER",
+ message_parameters={"arg_list": "extended and mode"},
+ )
+
+ # For the no argument case: df.explain()
+ is_no_argument = extended is None and mode is None
+
+ # For the cases below:
+ # explain(True)
+ # explain(extended=False)
+ is_extended_case = isinstance(extended, bool) and mode is None
+
+ # For the case when extended is mode:
+ # df.explain("formatted")
+ is_extended_as_mode = isinstance(extended, str) and mode is None
+
+ # For the mode specified:
+ # df.explain(mode="formatted")
+ is_mode_case = extended is None and isinstance(mode, str)
+
+ if not (is_no_argument or is_extended_case or is_extended_as_mode or
is_mode_case):
+ if (extended is not None) and (not isinstance(extended, (bool,
str))):
+ raise PySparkTypeError(
+ error_class="NOT_BOOL_OR_STR",
+ message_parameters={
+ "arg_name": "extended",
+ "arg_type": type(extended).__name__,
+ },
+ )
+ if (mode is not None) and (not isinstance(mode, str)):
+ raise PySparkTypeError(
+ error_class="NOT_STR",
+ message_parameters={"arg_name": "mode", "arg_type":
type(mode).__name__},
+ )
+
+ # Sets an explain mode depending on a given argument
+ if is_no_argument:
+ explain_mode = "simple"
+ elif is_extended_case:
+ explain_mode = "extended" if extended else "simple"
+ elif is_mode_case:
+ explain_mode = cast(str, mode)
+ elif is_extended_as_mode:
+ explain_mode = cast(str, extended)
+ assert self._sc._jvm is not None
+
print(self._sc._jvm.PythonSQLUtils.explainString(self._jdf.queryExecution(),
explain_mode))
+
+ def exceptAll(self, other: "ParentDataFrame") -> "ParentDataFrame":
+ return DataFrame(self._jdf.exceptAll(other._jdf), self.sparkSession)
+
+ def isLocal(self) -> bool:
+ return self._jdf.isLocal()
+
+ @property
+ def isStreaming(self) -> bool:
+ return self._jdf.isStreaming()
+
+ def isEmpty(self) -> bool:
+ return self._jdf.isEmpty()
+
+ def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical:
bool = False) -> None:
+ print(self._show_string(n, truncate, vertical))
+
+ def _show_string(
+ self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool =
False
+ ) -> str:
+ if not isinstance(n, int) or isinstance(n, bool):
+ raise PySparkTypeError(
+ error_class="NOT_INT",
+ message_parameters={"arg_name": "n", "arg_type":
type(n).__name__},
+ )
+
+ if not isinstance(vertical, bool):
+ raise PySparkTypeError(
+ error_class="NOT_BOOL",
+ message_parameters={"arg_name": "vertical", "arg_type":
type(vertical).__name__},
+ )
+
+ if isinstance(truncate, bool) and truncate:
+ return self._jdf.showString(n, 20, vertical)
+ else:
+ try:
+ int_truncate = int(truncate)
+ except ValueError:
+ raise PySparkTypeError(
+ error_class="NOT_BOOL",
+ message_parameters={
+ "arg_name": "truncate",
+ "arg_type": type(truncate).__name__,
+ },
+ )
+
+ return self._jdf.showString(n, int_truncate, vertical)
+
+ def __repr__(self) -> str:
+ if not self._support_repr_html and
self.sparkSession._jconf.isReplEagerEvalEnabled():
+ vertical = False
+ return self._jdf.showString(
+ self.sparkSession._jconf.replEagerEvalMaxNumRows(),
+ self.sparkSession._jconf.replEagerEvalTruncate(),
+ vertical,
+ )
+ else:
+ return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in
self.dtypes))
+
+ def _repr_html_(self) -> Optional[str]:
+ """Returns a :class:`DataFrame` with html code when you enabled eager
evaluation
+ by 'spark.sql.repl.eagerEval.enabled', this only called by REPL you are
+ using support eager evaluation with HTML.
+ """
+ if not self._support_repr_html:
+ self._support_repr_html = True
+ if self.sparkSession._jconf.isReplEagerEvalEnabled():
+ return self._jdf.htmlString(
+ self.sparkSession._jconf.replEagerEvalMaxNumRows(),
+ self.sparkSession._jconf.replEagerEvalTruncate(),
+ )
+ else:
+ return None
+
+ def checkpoint(self, eager: bool = True) -> "ParentDataFrame":
+ jdf = self._jdf.checkpoint(eager)
+ return DataFrame(jdf, self.sparkSession)
+
+ def localCheckpoint(self, eager: bool = True) -> "ParentDataFrame":
+ jdf = self._jdf.localCheckpoint(eager)
+ return DataFrame(jdf, self.sparkSession)
+
+ def withWatermark(self, eventTime: str, delayThreshold: str) ->
"ParentDataFrame":
+ if not eventTime or type(eventTime) is not str:
+ raise PySparkTypeError(
+ error_class="NOT_STR",
+ message_parameters={"arg_name": "eventTime", "arg_type":
type(eventTime).__name__},
+ )
+ if not delayThreshold or type(delayThreshold) is not str:
+ raise PySparkTypeError(
+ error_class="NOT_STR",
+ message_parameters={
+ "arg_name": "delayThreshold",
+ "arg_type": type(delayThreshold).__name__,
+ },
+ )
+ jdf = self._jdf.withWatermark(eventTime, delayThreshold)
+ return DataFrame(jdf, self.sparkSession)
+
+ def hint(
+ self, name: str, *parameters: Union["PrimitiveType", "Column",
List["PrimitiveType"]]
+ ) -> "ParentDataFrame":
+ if len(parameters) == 1 and isinstance(parameters[0], list):
+ parameters = parameters[0] # type: ignore[assignment]
+
+ if not isinstance(name, str):
+ raise PySparkTypeError(
+ error_class="NOT_STR",
+ message_parameters={"arg_name": "name", "arg_type":
type(name).__name__},
+ )
+
+ allowed_types = (str, float, int, Column, list)
+ allowed_primitive_types = (str, float, int)
+ allowed_types_repr = ", ".join(
+ [t.__name__ for t in allowed_types[:-1]]
+ + ["list[" + t.__name__ + "]" for t in allowed_primitive_types]
+ )
+ for p in parameters:
+ if not isinstance(p, allowed_types):
+ raise PySparkTypeError(
+ error_class="DISALLOWED_TYPE_FOR_CONTAINER",
+ message_parameters={
+ "arg_name": "parameters",
+ "arg_type": type(parameters).__name__,
+ "allowed_types": allowed_types_repr,
+ "item_type": type(p).__name__,
+ },
+ )
+ if isinstance(p, list):
+ if not all(isinstance(e, allowed_primitive_types) for e in p):
+ raise PySparkTypeError(
+ error_class="DISALLOWED_TYPE_FOR_CONTAINER",
+ message_parameters={
+ "arg_name": "parameters",
+ "arg_type": type(parameters).__name__,
+ "allowed_types": allowed_types_repr,
+ "item_type": type(p).__name__ + "[" +
type(p[0]).__name__ + "]",
+ },
+ )
+
+ def _converter(parameter: Union[str, list, float, int, Column]) -> Any:
+ if isinstance(parameter, Column):
+ return _to_java_column(parameter)
+ elif isinstance(parameter, list):
+ # for list input, we are assuming only one element type exist
in the list.
+ # for empty list, we are converting it into an empty long[] in
the JVM side.
+ gateway = self._sc._gateway
+ assert gateway is not None
+ jclass = gateway.jvm.long
+ if len(parameter) >= 1:
+ mapping = {
+ str: gateway.jvm.java.lang.String,
+ float: gateway.jvm.double,
+ int: gateway.jvm.long,
+ }
+ jclass = mapping[type(parameter[0])]
+ return toJArray(gateway, jclass, parameter)
+ else:
+ return parameter
+
+ jdf = self._jdf.hint(name, self._jseq(parameters, _converter))
+ return DataFrame(jdf, self.sparkSession)
+
+ def count(self) -> int:
+ return int(self._jdf.count())
+
+ def collect(self) -> List[Row]:
+ with SCCallSiteSync(self._sc):
+ sock_info = self._jdf.collectToPython()
+ return list(_load_from_socket(sock_info,
BatchedSerializer(CPickleSerializer())))
+
+ def toLocalIterator(self, prefetchPartitions: bool = False) ->
Iterator[Row]:
+ with SCCallSiteSync(self._sc):
+ sock_info = self._jdf.toPythonIterator(prefetchPartitions)
+ return _local_iterator_from_socket(sock_info,
BatchedSerializer(CPickleSerializer()))
+
+ def limit(self, num: int) -> "ParentDataFrame":
+ jdf = self._jdf.limit(num)
+ return DataFrame(jdf, self.sparkSession)
+
+ def offset(self, num: int) -> "ParentDataFrame":
+ jdf = self._jdf.offset(num)
+ return DataFrame(jdf, self.sparkSession)
+
+ def take(self, num: int) -> List[Row]:
+ return self.limit(num).collect()
+
+ def tail(self, num: int) -> List[Row]:
+ with SCCallSiteSync(self._sc):
+ sock_info = self._jdf.tailToPython(num)
+ return list(_load_from_socket(sock_info,
BatchedSerializer(CPickleSerializer())))
+
+ def foreach(self, f: Callable[[Row], None]) -> None:
+ self.rdd.foreach(f)
+
+ def foreachPartition(self, f: Callable[[Iterator[Row]], None]) -> None:
+ self.rdd.foreachPartition(f) # type: ignore[arg-type]
+
+ def cache(self) -> "ParentDataFrame":
+ self.is_cached = True
+ self._jdf.cache()
+ return self
+
+ def persist(
+ self,
+ storageLevel: StorageLevel = (StorageLevel.MEMORY_AND_DISK_DESER),
+ ) -> "ParentDataFrame":
+ self.is_cached = True
+ javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel)
+ self._jdf.persist(javaStorageLevel)
+ return self
+
+ @property
+ def storageLevel(self) -> StorageLevel:
+ java_storage_level = self._jdf.storageLevel()
+ storage_level = StorageLevel(
+ java_storage_level.useDisk(),
+ java_storage_level.useMemory(),
+ java_storage_level.useOffHeap(),
+ java_storage_level.deserialized(),
+ java_storage_level.replication(),
+ )
+ return storage_level
+
+ def unpersist(self, blocking: bool = False) -> "ParentDataFrame":
+ self.is_cached = False
+ self._jdf.unpersist(blocking)
+ return self
+
+ def coalesce(self, numPartitions: int) -> "ParentDataFrame":
+ return DataFrame(self._jdf.coalesce(numPartitions), self.sparkSession)
+
+ @overload
+ def repartition(self, numPartitions: int, *cols: "ColumnOrName") ->
"ParentDataFrame":
+ ...
Review Comment:
I'm wondering if we need `@overload` definitions in the subclasses?
##########
python/pyspark/sql/classic/dataframe.py:
##########
@@ -0,0 +1,1974 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import json
+import sys
+import random
+import warnings
+from collections.abc import Iterable
+from functools import reduce
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterator,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
+ cast,
+ overload,
+ TYPE_CHECKING,
+)
+
+from pyspark import _NoValue
+from pyspark.resource import ResourceProfile
+from pyspark._globals import _NoValueType
+from pyspark.errors import (
+ PySparkTypeError,
+ PySparkValueError,
+ PySparkIndexError,
+ PySparkAttributeError,
+)
+from pyspark.util import (
+ is_remote_only,
+ _load_from_socket,
+ _local_iterator_from_socket,
+)
+from pyspark.serializers import BatchedSerializer, CPickleSerializer,
UTF8Deserializer
+from pyspark.storagelevel import StorageLevel
+from pyspark.traceback_utils import SCCallSiteSync
+from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column
+from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2
+from pyspark.sql.streaming import DataStreamWriter
+from pyspark.sql.types import (
+ StructType,
+ Row,
+ _parse_datatype_json_string,
+)
+from pyspark.sql.dataframe import (
+ DataFrame as ParentDataFrame,
+ DataFrameNaFunctions as ParentDataFrameNaFunctions,
+ DataFrameStatFunctions as ParentDataFrameStatFunctions,
+)
+from pyspark.sql.utils import get_active_spark_context, toJArray
+from pyspark.sql.pandas.conversion import PandasConversionMixin
+from pyspark.sql.pandas.map_ops import PandasMapOpsMixin
+
+if TYPE_CHECKING:
+ from py4j.java_gateway import JavaObject
+ from pyspark.core.rdd import RDD
+ from pyspark.core.context import SparkContext
+ from pyspark._typing import PrimitiveType
+ from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
+ from pyspark.sql._typing import (
+ ColumnOrName,
+ ColumnOrNameOrOrdinal,
+ LiteralType,
+ OptionalPrimitiveType,
+ )
+ from pyspark.sql.context import SQLContext
+ from pyspark.sql.session import SparkSession
+ from pyspark.sql.group import GroupedData
+ from pyspark.sql.observation import Observation
+
+
+class DataFrame(ParentDataFrame, PandasMapOpsMixin, PandasConversionMixin):
+ def __new__(
+ cls,
+ jdf: "JavaObject",
+ sql_ctx: Union["SQLContext", "SparkSession"],
+ ) -> "DataFrame":
+ self = object.__new__(cls)
+ self.__init__(jdf, sql_ctx) # type: ignore[misc]
+ return self
+
+ def __init__(
+ self,
+ jdf: "JavaObject",
+ sql_ctx: Union["SQLContext", "SparkSession"],
+ ):
+ from pyspark.sql.context import SQLContext
+
+ self._sql_ctx: Optional["SQLContext"] = None
+
+ if isinstance(sql_ctx, SQLContext):
+ assert not os.environ.get("SPARK_TESTING") # Sanity check for our
internal usage.
+ assert isinstance(sql_ctx, SQLContext)
+ # We should remove this if-else branch in the future release, and
rename
+ # sql_ctx to session in the constructor. This is an internal code
path but
+ # was kept with a warning because it's used intensively by
third-party libraries.
+ warnings.warn("DataFrame constructor is internal. Do not directly
use it.")
+ self._sql_ctx = sql_ctx
+ session = sql_ctx.sparkSession
+ else:
+ session = sql_ctx
+ self._session: "SparkSession" = session
+
+ self._sc: "SparkContext" = sql_ctx._sc
+ self._jdf: "JavaObject" = jdf
+ self.is_cached = False
+ # initialized lazily
+ self._schema: Optional[StructType] = None
+ self._lazy_rdd: Optional["RDD[Row]"] = None
+ # Check whether _repr_html is supported or not, we use it to avoid
calling _jdf twice
+ # by __repr__ and _repr_html_ while eager evaluation opens.
+ self._support_repr_html = False
+
+ @property
+ def sql_ctx(self) -> "SQLContext":
+ from pyspark.sql.context import SQLContext
+
+ warnings.warn(
+ "DataFrame.sql_ctx is an internal property, and will be removed "
+ "in future releases. Use DataFrame.sparkSession instead."
+ )
+ if self._sql_ctx is None:
+ self._sql_ctx = SQLContext._get_or_create(self._sc)
+ return self._sql_ctx
+
+ @property
+ def sparkSession(self) -> "SparkSession":
+ return self._session
+
+ if not is_remote_only():
Review Comment:
We can remove this `if` now?
##########
python/pyspark/sql/connect/session.py:
##########
@@ -325,7 +325,7 @@ def active(cls) -> "SparkSession":
active.__doc__ = PySparkSession.active.__doc__
- def table(self, tableName: str) -> DataFrame:
+ def table(self, tableName: str) -> ParentDataFrame:
Review Comment:
I guess we can leave it as-is? And the following changes?
##########
python/pyspark/sql/connect/dataframe.py:
##########
@@ -2306,7 +2183,7 @@ def _test() -> None:
)
(failure_count, test_count) = doctest.testmod(
- pyspark.sql.connect.dataframe,
+ pyspark.sql.dataframe,
Review Comment:
Shall we put the comment here?
##########
python/pyspark/sql/classic/dataframe.py:
##########
@@ -0,0 +1,1974 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import json
+import sys
+import random
+import warnings
+from collections.abc import Iterable
+from functools import reduce
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterator,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
+ cast,
+ overload,
+ TYPE_CHECKING,
+)
+
+from pyspark import _NoValue
+from pyspark.resource import ResourceProfile
+from pyspark._globals import _NoValueType
+from pyspark.errors import (
+ PySparkTypeError,
+ PySparkValueError,
+ PySparkIndexError,
+ PySparkAttributeError,
+)
+from pyspark.util import (
+ is_remote_only,
+ _load_from_socket,
+ _local_iterator_from_socket,
+)
+from pyspark.serializers import BatchedSerializer, CPickleSerializer,
UTF8Deserializer
+from pyspark.storagelevel import StorageLevel
+from pyspark.traceback_utils import SCCallSiteSync
+from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column
+from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2
+from pyspark.sql.streaming import DataStreamWriter
+from pyspark.sql.types import (
+ StructType,
+ Row,
+ _parse_datatype_json_string,
+)
+from pyspark.sql.dataframe import (
+ DataFrame as ParentDataFrame,
+ DataFrameNaFunctions as ParentDataFrameNaFunctions,
+ DataFrameStatFunctions as ParentDataFrameStatFunctions,
+)
+from pyspark.sql.utils import get_active_spark_context, toJArray
+from pyspark.sql.pandas.conversion import PandasConversionMixin
+from pyspark.sql.pandas.map_ops import PandasMapOpsMixin
+
+if TYPE_CHECKING:
+ from py4j.java_gateway import JavaObject
+ from pyspark.core.rdd import RDD
+ from pyspark.core.context import SparkContext
+ from pyspark._typing import PrimitiveType
+ from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
+ from pyspark.sql._typing import (
+ ColumnOrName,
+ ColumnOrNameOrOrdinal,
+ LiteralType,
+ OptionalPrimitiveType,
+ )
+ from pyspark.sql.context import SQLContext
+ from pyspark.sql.session import SparkSession
+ from pyspark.sql.group import GroupedData
+ from pyspark.sql.observation import Observation
+
+
+class DataFrame(ParentDataFrame, PandasMapOpsMixin, PandasConversionMixin):
+ def __new__(
+ cls,
+ jdf: "JavaObject",
+ sql_ctx: Union["SQLContext", "SparkSession"],
+ ) -> "DataFrame":
+ self = object.__new__(cls)
+ self.__init__(jdf, sql_ctx) # type: ignore[misc]
+ return self
+
+ def __init__(
+ self,
+ jdf: "JavaObject",
+ sql_ctx: Union["SQLContext", "SparkSession"],
+ ):
+ from pyspark.sql.context import SQLContext
+
+ self._sql_ctx: Optional["SQLContext"] = None
+
+ if isinstance(sql_ctx, SQLContext):
+ assert not os.environ.get("SPARK_TESTING") # Sanity check for our
internal usage.
+ assert isinstance(sql_ctx, SQLContext)
+ # We should remove this if-else branch in the future release, and
rename
+ # sql_ctx to session in the constructor. This is an internal code
path but
+ # was kept with a warning because it's used intensively by
third-party libraries.
+ warnings.warn("DataFrame constructor is internal. Do not directly
use it.")
+ self._sql_ctx = sql_ctx
+ session = sql_ctx.sparkSession
+ else:
+ session = sql_ctx
+ self._session: "SparkSession" = session
+
+ self._sc: "SparkContext" = sql_ctx._sc
+ self._jdf: "JavaObject" = jdf
+ self.is_cached = False
+ # initialized lazily
+ self._schema: Optional[StructType] = None
+ self._lazy_rdd: Optional["RDD[Row]"] = None
+ # Check whether _repr_html is supported or not, we use it to avoid
calling _jdf twice
+ # by __repr__ and _repr_html_ while eager evaluation opens.
+ self._support_repr_html = False
+
+ @property
+ def sql_ctx(self) -> "SQLContext":
+ from pyspark.sql.context import SQLContext
+
+ warnings.warn(
+ "DataFrame.sql_ctx is an internal property, and will be removed "
+ "in future releases. Use DataFrame.sparkSession instead."
+ )
+ if self._sql_ctx is None:
+ self._sql_ctx = SQLContext._get_or_create(self._sc)
+ return self._sql_ctx
+
+ @property
+ def sparkSession(self) -> "SparkSession":
+ return self._session
+
+ if not is_remote_only():
+
+ @property
+ def rdd(self) -> "RDD[Row]":
+ from pyspark.core.rdd import RDD
+
+ if self._lazy_rdd is None:
+ jrdd = self._jdf.javaToPython()
+ self._lazy_rdd = RDD(
+ jrdd, self.sparkSession._sc,
BatchedSerializer(CPickleSerializer())
+ )
+ return self._lazy_rdd
+
+ @property
+ def na(self) -> "DataFrameNaFunctions":
Review Comment:
`-> ParentDataFrameNaFunctions`?
##########
python/pyspark/sql/classic/dataframe.py:
##########
@@ -0,0 +1,1974 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import json
+import sys
+import random
+import warnings
+from collections.abc import Iterable
+from functools import reduce
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterator,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
+ cast,
+ overload,
+ TYPE_CHECKING,
+)
+
+from pyspark import _NoValue
+from pyspark.resource import ResourceProfile
+from pyspark._globals import _NoValueType
+from pyspark.errors import (
+ PySparkTypeError,
+ PySparkValueError,
+ PySparkIndexError,
+ PySparkAttributeError,
+)
+from pyspark.util import (
+ is_remote_only,
+ _load_from_socket,
+ _local_iterator_from_socket,
+)
+from pyspark.serializers import BatchedSerializer, CPickleSerializer,
UTF8Deserializer
+from pyspark.storagelevel import StorageLevel
+from pyspark.traceback_utils import SCCallSiteSync
+from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column
+from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2
+from pyspark.sql.streaming import DataStreamWriter
+from pyspark.sql.types import (
+ StructType,
+ Row,
+ _parse_datatype_json_string,
+)
+from pyspark.sql.dataframe import (
+ DataFrame as ParentDataFrame,
+ DataFrameNaFunctions as ParentDataFrameNaFunctions,
+ DataFrameStatFunctions as ParentDataFrameStatFunctions,
+)
+from pyspark.sql.utils import get_active_spark_context, toJArray
+from pyspark.sql.pandas.conversion import PandasConversionMixin
+from pyspark.sql.pandas.map_ops import PandasMapOpsMixin
+
+if TYPE_CHECKING:
+ from py4j.java_gateway import JavaObject
+ from pyspark.core.rdd import RDD
+ from pyspark.core.context import SparkContext
+ from pyspark._typing import PrimitiveType
+ from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
+ from pyspark.sql._typing import (
+ ColumnOrName,
+ ColumnOrNameOrOrdinal,
+ LiteralType,
+ OptionalPrimitiveType,
+ )
+ from pyspark.sql.context import SQLContext
+ from pyspark.sql.session import SparkSession
+ from pyspark.sql.group import GroupedData
+ from pyspark.sql.observation import Observation
+
+
+class DataFrame(ParentDataFrame, PandasMapOpsMixin, PandasConversionMixin):
+ def __new__(
+ cls,
+ jdf: "JavaObject",
+ sql_ctx: Union["SQLContext", "SparkSession"],
+ ) -> "DataFrame":
+ self = object.__new__(cls)
+ self.__init__(jdf, sql_ctx) # type: ignore[misc]
+ return self
+
+ def __init__(
+ self,
+ jdf: "JavaObject",
+ sql_ctx: Union["SQLContext", "SparkSession"],
+ ):
+ from pyspark.sql.context import SQLContext
+
+ self._sql_ctx: Optional["SQLContext"] = None
+
+ if isinstance(sql_ctx, SQLContext):
+ assert not os.environ.get("SPARK_TESTING") # Sanity check for our
internal usage.
+ assert isinstance(sql_ctx, SQLContext)
+ # We should remove this if-else branch in the future release, and
rename
+ # sql_ctx to session in the constructor. This is an internal code
path but
+ # was kept with a warning because it's used intensively by
third-party libraries.
+ warnings.warn("DataFrame constructor is internal. Do not directly
use it.")
+ self._sql_ctx = sql_ctx
+ session = sql_ctx.sparkSession
+ else:
+ session = sql_ctx
+ self._session: "SparkSession" = session
+
+ self._sc: "SparkContext" = sql_ctx._sc
+ self._jdf: "JavaObject" = jdf
+ self.is_cached = False
+ # initialized lazily
+ self._schema: Optional[StructType] = None
+ self._lazy_rdd: Optional["RDD[Row]"] = None
+ # Check whether _repr_html is supported or not, we use it to avoid
calling _jdf twice
+ # by __repr__ and _repr_html_ while eager evaluation opens.
+ self._support_repr_html = False
+
+ @property
+ def sql_ctx(self) -> "SQLContext":
+ from pyspark.sql.context import SQLContext
+
+ warnings.warn(
+ "DataFrame.sql_ctx is an internal property, and will be removed "
+ "in future releases. Use DataFrame.sparkSession instead."
+ )
+ if self._sql_ctx is None:
+ self._sql_ctx = SQLContext._get_or_create(self._sc)
+ return self._sql_ctx
+
+ @property
+ def sparkSession(self) -> "SparkSession":
+ return self._session
+
+ if not is_remote_only():
+
+ @property
+ def rdd(self) -> "RDD[Row]":
+ from pyspark.core.rdd import RDD
+
+ if self._lazy_rdd is None:
+ jrdd = self._jdf.javaToPython()
+ self._lazy_rdd = RDD(
+ jrdd, self.sparkSession._sc,
BatchedSerializer(CPickleSerializer())
+ )
+ return self._lazy_rdd
+
+ @property
+ def na(self) -> "DataFrameNaFunctions":
+ return DataFrameNaFunctions(self)
+
+ @property
+ def stat(self) -> "DataFrameStatFunctions":
Review Comment:
`-> ParentDataFrameStatFunctions`?
##########
python/pyspark/sql/classic/dataframe.py:
##########
@@ -0,0 +1,1974 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import json
+import sys
+import random
+import warnings
+from collections.abc import Iterable
+from functools import reduce
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterator,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
+ cast,
+ overload,
+ TYPE_CHECKING,
+)
+
+from pyspark import _NoValue
+from pyspark.resource import ResourceProfile
+from pyspark._globals import _NoValueType
+from pyspark.errors import (
+ PySparkTypeError,
+ PySparkValueError,
+ PySparkIndexError,
+ PySparkAttributeError,
+)
+from pyspark.util import (
+ is_remote_only,
+ _load_from_socket,
+ _local_iterator_from_socket,
+)
+from pyspark.serializers import BatchedSerializer, CPickleSerializer,
UTF8Deserializer
+from pyspark.storagelevel import StorageLevel
+from pyspark.traceback_utils import SCCallSiteSync
+from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column
+from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2
+from pyspark.sql.streaming import DataStreamWriter
+from pyspark.sql.types import (
+ StructType,
+ Row,
+ _parse_datatype_json_string,
+)
+from pyspark.sql.dataframe import (
+ DataFrame as ParentDataFrame,
+ DataFrameNaFunctions as ParentDataFrameNaFunctions,
+ DataFrameStatFunctions as ParentDataFrameStatFunctions,
+)
+from pyspark.sql.utils import get_active_spark_context, toJArray
+from pyspark.sql.pandas.conversion import PandasConversionMixin
+from pyspark.sql.pandas.map_ops import PandasMapOpsMixin
+
+if TYPE_CHECKING:
+ from py4j.java_gateway import JavaObject
+ from pyspark.core.rdd import RDD
+ from pyspark.core.context import SparkContext
+ from pyspark._typing import PrimitiveType
+ from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
+ from pyspark.sql._typing import (
+ ColumnOrName,
+ ColumnOrNameOrOrdinal,
+ LiteralType,
+ OptionalPrimitiveType,
+ )
+ from pyspark.sql.context import SQLContext
+ from pyspark.sql.session import SparkSession
+ from pyspark.sql.group import GroupedData
+ from pyspark.sql.observation import Observation
+
+
+class DataFrame(ParentDataFrame, PandasMapOpsMixin, PandasConversionMixin):
+ def __new__(
+ cls,
+ jdf: "JavaObject",
+ sql_ctx: Union["SQLContext", "SparkSession"],
+ ) -> "DataFrame":
+ self = object.__new__(cls)
+ self.__init__(jdf, sql_ctx) # type: ignore[misc]
+ return self
+
+ def __init__(
+ self,
+ jdf: "JavaObject",
+ sql_ctx: Union["SQLContext", "SparkSession"],
+ ):
+ from pyspark.sql.context import SQLContext
+
+ self._sql_ctx: Optional["SQLContext"] = None
+
+ if isinstance(sql_ctx, SQLContext):
+ assert not os.environ.get("SPARK_TESTING") # Sanity check for our
internal usage.
+ assert isinstance(sql_ctx, SQLContext)
+ # We should remove this if-else branch in the future release, and
rename
+ # sql_ctx to session in the constructor. This is an internal code
path but
+ # was kept with a warning because it's used intensively by
third-party libraries.
+ warnings.warn("DataFrame constructor is internal. Do not directly
use it.")
+ self._sql_ctx = sql_ctx
+ session = sql_ctx.sparkSession
+ else:
+ session = sql_ctx
+ self._session: "SparkSession" = session
+
+ self._sc: "SparkContext" = sql_ctx._sc
+ self._jdf: "JavaObject" = jdf
+ self.is_cached = False
+ # initialized lazily
+ self._schema: Optional[StructType] = None
+ self._lazy_rdd: Optional["RDD[Row]"] = None
+ # Check whether _repr_html is supported or not, we use it to avoid
calling _jdf twice
+ # by __repr__ and _repr_html_ while eager evaluation opens.
+ self._support_repr_html = False
+
+ @property
+ def sql_ctx(self) -> "SQLContext":
+ from pyspark.sql.context import SQLContext
+
+ warnings.warn(
+ "DataFrame.sql_ctx is an internal property, and will be removed "
+ "in future releases. Use DataFrame.sparkSession instead."
+ )
+ if self._sql_ctx is None:
+ self._sql_ctx = SQLContext._get_or_create(self._sc)
+ return self._sql_ctx
+
+ @property
+ def sparkSession(self) -> "SparkSession":
+ return self._session
+
+ if not is_remote_only():
+
+ @property
+ def rdd(self) -> "RDD[Row]":
+ from pyspark.core.rdd import RDD
+
+ if self._lazy_rdd is None:
+ jrdd = self._jdf.javaToPython()
+ self._lazy_rdd = RDD(
+ jrdd, self.sparkSession._sc,
BatchedSerializer(CPickleSerializer())
+ )
+ return self._lazy_rdd
+
+ @property
+ def na(self) -> "DataFrameNaFunctions":
+ return DataFrameNaFunctions(self)
+
+ @property
+ def stat(self) -> "DataFrameStatFunctions":
Review Comment:
Note: for return type hints, we can use the subclass.
If we use `-> "DataFrameStatFunctions"` here, we may also want to use `->
"DataFrame"` instead of `-> ParentDataFrame` for the following functions.
##########
python/pyspark/sql/classic/dataframe.py:
##########
@@ -0,0 +1,1974 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import json
+import sys
+import random
+import warnings
+from collections.abc import Iterable
+from functools import reduce
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterator,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
+ cast,
+ overload,
+ TYPE_CHECKING,
+)
+
+from pyspark import _NoValue
+from pyspark.resource import ResourceProfile
+from pyspark._globals import _NoValueType
+from pyspark.errors import (
+ PySparkTypeError,
+ PySparkValueError,
+ PySparkIndexError,
+ PySparkAttributeError,
+)
+from pyspark.util import (
+ is_remote_only,
+ _load_from_socket,
+ _local_iterator_from_socket,
+)
+from pyspark.serializers import BatchedSerializer, CPickleSerializer,
UTF8Deserializer
+from pyspark.storagelevel import StorageLevel
+from pyspark.traceback_utils import SCCallSiteSync
+from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column
+from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2
+from pyspark.sql.streaming import DataStreamWriter
+from pyspark.sql.types import (
+ StructType,
+ Row,
+ _parse_datatype_json_string,
+)
+from pyspark.sql.dataframe import (
+ DataFrame as ParentDataFrame,
+ DataFrameNaFunctions as ParentDataFrameNaFunctions,
+ DataFrameStatFunctions as ParentDataFrameStatFunctions,
+)
+from pyspark.sql.utils import get_active_spark_context, toJArray
+from pyspark.sql.pandas.conversion import PandasConversionMixin
+from pyspark.sql.pandas.map_ops import PandasMapOpsMixin
+
+if TYPE_CHECKING:
+ from py4j.java_gateway import JavaObject
+ from pyspark.core.rdd import RDD
+ from pyspark.core.context import SparkContext
+ from pyspark._typing import PrimitiveType
+ from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
+ from pyspark.sql._typing import (
+ ColumnOrName,
+ ColumnOrNameOrOrdinal,
+ LiteralType,
+ OptionalPrimitiveType,
+ )
+ from pyspark.sql.context import SQLContext
+ from pyspark.sql.session import SparkSession
+ from pyspark.sql.group import GroupedData
+ from pyspark.sql.observation import Observation
+
+
+class DataFrame(ParentDataFrame, PandasMapOpsMixin, PandasConversionMixin):
+ def __new__(
+ cls,
+ jdf: "JavaObject",
+ sql_ctx: Union["SQLContext", "SparkSession"],
+ ) -> "DataFrame":
+ self = object.__new__(cls)
+ self.__init__(jdf, sql_ctx) # type: ignore[misc]
+ return self
+
+ def __init__(
+ self,
+ jdf: "JavaObject",
+ sql_ctx: Union["SQLContext", "SparkSession"],
+ ):
+ from pyspark.sql.context import SQLContext
+
+ self._sql_ctx: Optional["SQLContext"] = None
+
+ if isinstance(sql_ctx, SQLContext):
+ assert not os.environ.get("SPARK_TESTING") # Sanity check for our
internal usage.
+ assert isinstance(sql_ctx, SQLContext)
+ # We should remove this if-else branch in the future release, and
rename
+ # sql_ctx to session in the constructor. This is an internal code
path but
+ # was kept with a warning because it's used intensively by
third-party libraries.
+ warnings.warn("DataFrame constructor is internal. Do not directly
use it.")
+ self._sql_ctx = sql_ctx
+ session = sql_ctx.sparkSession
+ else:
+ session = sql_ctx
+ self._session: "SparkSession" = session
+
+ self._sc: "SparkContext" = sql_ctx._sc
+ self._jdf: "JavaObject" = jdf
+ self.is_cached = False
+ # initialized lazily
+ self._schema: Optional[StructType] = None
+ self._lazy_rdd: Optional["RDD[Row]"] = None
+ # Check whether _repr_html is supported or not, we use it to avoid
calling _jdf twice
+ # by __repr__ and _repr_html_ while eager evaluation opens.
+ self._support_repr_html = False
+
+ @property
+ def sql_ctx(self) -> "SQLContext":
+ from pyspark.sql.context import SQLContext
+
+ warnings.warn(
+ "DataFrame.sql_ctx is an internal property, and will be removed "
+ "in future releases. Use DataFrame.sparkSession instead."
+ )
+ if self._sql_ctx is None:
+ self._sql_ctx = SQLContext._get_or_create(self._sc)
+ return self._sql_ctx
+
+ @property
+ def sparkSession(self) -> "SparkSession":
+ return self._session
+
+ if not is_remote_only():
+
+ @property
+ def rdd(self) -> "RDD[Row]":
+ from pyspark.core.rdd import RDD
+
+ if self._lazy_rdd is None:
+ jrdd = self._jdf.javaToPython()
+ self._lazy_rdd = RDD(
+ jrdd, self.sparkSession._sc,
BatchedSerializer(CPickleSerializer())
+ )
+ return self._lazy_rdd
+
+ @property
+ def na(self) -> "DataFrameNaFunctions":
+ return DataFrameNaFunctions(self)
+
+ @property
+ def stat(self) -> "DataFrameStatFunctions":
+ return DataFrameStatFunctions(self)
+
+ if not is_remote_only():
+
+ def toJSON(self, use_unicode: bool = True) -> "RDD[str]":
+ from pyspark.core.rdd import RDD
+
+ rdd = self._jdf.toJSON()
+ return RDD(rdd.toJavaRDD(), self._sc,
UTF8Deserializer(use_unicode))
+
+ def registerTempTable(self, name: str) -> None:
+ warnings.warn("Deprecated in 2.0, use createOrReplaceTempView
instead.", FutureWarning)
+ self._jdf.createOrReplaceTempView(name)
+
+ def createTempView(self, name: str) -> None:
+ self._jdf.createTempView(name)
+
+ def createOrReplaceTempView(self, name: str) -> None:
+ self._jdf.createOrReplaceTempView(name)
+
+ def createGlobalTempView(self, name: str) -> None:
+ self._jdf.createGlobalTempView(name)
+
+ def createOrReplaceGlobalTempView(self, name: str) -> None:
+ self._jdf.createOrReplaceGlobalTempView(name)
+
+ @property
+ def write(self) -> DataFrameWriter:
+ return DataFrameWriter(self)
+
+ @property
+ def writeStream(self) -> DataStreamWriter:
+ return DataStreamWriter(self)
+
+ @property
+ def schema(self) -> StructType:
+ if self._schema is None:
+ try:
+ self._schema = cast(
+ StructType,
_parse_datatype_json_string(self._jdf.schema().json())
+ )
+ except Exception as e:
+ raise PySparkValueError(
+ error_class="CANNOT_PARSE_DATATYPE",
+ message_parameters={"error": str(e)},
+ )
+ return self._schema
+
+ def printSchema(self, level: Optional[int] = None) -> None:
+ if level:
+ print(self._jdf.schema().treeString(level))
+ else:
+ print(self._jdf.schema().treeString())
+
+ def explain(
+ self, extended: Optional[Union[bool, str]] = None, mode: Optional[str]
= None
+ ) -> None:
+ if extended is not None and mode is not None:
+ raise PySparkValueError(
+ error_class="CANNOT_SET_TOGETHER",
+ message_parameters={"arg_list": "extended and mode"},
+ )
+
+ # For the no argument case: df.explain()
+ is_no_argument = extended is None and mode is None
+
+ # For the cases below:
+ # explain(True)
+ # explain(extended=False)
+ is_extended_case = isinstance(extended, bool) and mode is None
+
+ # For the case when extended is mode:
+ # df.explain("formatted")
+ is_extended_as_mode = isinstance(extended, str) and mode is None
+
+ # For the mode specified:
+ # df.explain(mode="formatted")
+ is_mode_case = extended is None and isinstance(mode, str)
+
+ if not (is_no_argument or is_extended_case or is_extended_as_mode or
is_mode_case):
+ if (extended is not None) and (not isinstance(extended, (bool,
str))):
+ raise PySparkTypeError(
+ error_class="NOT_BOOL_OR_STR",
+ message_parameters={
+ "arg_name": "extended",
+ "arg_type": type(extended).__name__,
+ },
+ )
+ if (mode is not None) and (not isinstance(mode, str)):
+ raise PySparkTypeError(
+ error_class="NOT_STR",
+ message_parameters={"arg_name": "mode", "arg_type":
type(mode).__name__},
+ )
+
+ # Sets an explain mode depending on a given argument
+ if is_no_argument:
+ explain_mode = "simple"
+ elif is_extended_case:
+ explain_mode = "extended" if extended else "simple"
+ elif is_mode_case:
+ explain_mode = cast(str, mode)
+ elif is_extended_as_mode:
+ explain_mode = cast(str, extended)
+ assert self._sc._jvm is not None
+
print(self._sc._jvm.PythonSQLUtils.explainString(self._jdf.queryExecution(),
explain_mode))
+
+ def exceptAll(self, other: "ParentDataFrame") -> "ParentDataFrame":
+ return DataFrame(self._jdf.exceptAll(other._jdf), self.sparkSession)
+
+ def isLocal(self) -> bool:
+ return self._jdf.isLocal()
+
+ @property
+ def isStreaming(self) -> bool:
+ return self._jdf.isStreaming()
+
+ def isEmpty(self) -> bool:
+ return self._jdf.isEmpty()
+
+ def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical:
bool = False) -> None:
+ print(self._show_string(n, truncate, vertical))
+
+ def _show_string(
+ self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool =
False
+ ) -> str:
+ if not isinstance(n, int) or isinstance(n, bool):
+ raise PySparkTypeError(
+ error_class="NOT_INT",
+ message_parameters={"arg_name": "n", "arg_type":
type(n).__name__},
+ )
+
+ if not isinstance(vertical, bool):
+ raise PySparkTypeError(
+ error_class="NOT_BOOL",
+ message_parameters={"arg_name": "vertical", "arg_type":
type(vertical).__name__},
+ )
+
+ if isinstance(truncate, bool) and truncate:
+ return self._jdf.showString(n, 20, vertical)
+ else:
+ try:
+ int_truncate = int(truncate)
+ except ValueError:
+ raise PySparkTypeError(
+ error_class="NOT_BOOL",
+ message_parameters={
+ "arg_name": "truncate",
+ "arg_type": type(truncate).__name__,
+ },
+ )
+
+ return self._jdf.showString(n, int_truncate, vertical)
+
+ def __repr__(self) -> str:
+ if not self._support_repr_html and
self.sparkSession._jconf.isReplEagerEvalEnabled():
+ vertical = False
+ return self._jdf.showString(
+ self.sparkSession._jconf.replEagerEvalMaxNumRows(),
+ self.sparkSession._jconf.replEagerEvalTruncate(),
+ vertical,
+ )
+ else:
+ return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in
self.dtypes))
+
+ def _repr_html_(self) -> Optional[str]:
+ """Returns a :class:`DataFrame` with html code when you enabled eager
evaluation
+ by 'spark.sql.repl.eagerEval.enabled', this only called by REPL you are
+ using support eager evaluation with HTML.
+ """
+ if not self._support_repr_html:
+ self._support_repr_html = True
+ if self.sparkSession._jconf.isReplEagerEvalEnabled():
+ return self._jdf.htmlString(
+ self.sparkSession._jconf.replEagerEvalMaxNumRows(),
+ self.sparkSession._jconf.replEagerEvalTruncate(),
+ )
+ else:
+ return None
+
+ def checkpoint(self, eager: bool = True) -> "ParentDataFrame":
+ jdf = self._jdf.checkpoint(eager)
+ return DataFrame(jdf, self.sparkSession)
+
+ def localCheckpoint(self, eager: bool = True) -> "ParentDataFrame":
+ jdf = self._jdf.localCheckpoint(eager)
+ return DataFrame(jdf, self.sparkSession)
+
+ def withWatermark(self, eventTime: str, delayThreshold: str) ->
"ParentDataFrame":
+ if not eventTime or type(eventTime) is not str:
+ raise PySparkTypeError(
+ error_class="NOT_STR",
+ message_parameters={"arg_name": "eventTime", "arg_type":
type(eventTime).__name__},
+ )
+ if not delayThreshold or type(delayThreshold) is not str:
+ raise PySparkTypeError(
+ error_class="NOT_STR",
+ message_parameters={
+ "arg_name": "delayThreshold",
+ "arg_type": type(delayThreshold).__name__,
+ },
+ )
+ jdf = self._jdf.withWatermark(eventTime, delayThreshold)
+ return DataFrame(jdf, self.sparkSession)
+
+ def hint(
+ self, name: str, *parameters: Union["PrimitiveType", "Column",
List["PrimitiveType"]]
+ ) -> "ParentDataFrame":
+ if len(parameters) == 1 and isinstance(parameters[0], list):
+ parameters = parameters[0] # type: ignore[assignment]
+
+ if not isinstance(name, str):
+ raise PySparkTypeError(
+ error_class="NOT_STR",
+ message_parameters={"arg_name": "name", "arg_type":
type(name).__name__},
+ )
+
+ allowed_types = (str, float, int, Column, list)
+ allowed_primitive_types = (str, float, int)
+ allowed_types_repr = ", ".join(
+ [t.__name__ for t in allowed_types[:-1]]
+ + ["list[" + t.__name__ + "]" for t in allowed_primitive_types]
+ )
+ for p in parameters:
+ if not isinstance(p, allowed_types):
+ raise PySparkTypeError(
+ error_class="DISALLOWED_TYPE_FOR_CONTAINER",
+ message_parameters={
+ "arg_name": "parameters",
+ "arg_type": type(parameters).__name__,
+ "allowed_types": allowed_types_repr,
+ "item_type": type(p).__name__,
+ },
+ )
+ if isinstance(p, list):
+ if not all(isinstance(e, allowed_primitive_types) for e in p):
+ raise PySparkTypeError(
+ error_class="DISALLOWED_TYPE_FOR_CONTAINER",
+ message_parameters={
+ "arg_name": "parameters",
+ "arg_type": type(parameters).__name__,
+ "allowed_types": allowed_types_repr,
+ "item_type": type(p).__name__ + "[" +
type(p[0]).__name__ + "]",
+ },
+ )
+
+ def _converter(parameter: Union[str, list, float, int, Column]) -> Any:
+ if isinstance(parameter, Column):
+ return _to_java_column(parameter)
+ elif isinstance(parameter, list):
+ # for list input, we are assuming only one element type exist
in the list.
+ # for empty list, we are converting it into an empty long[] in
the JVM side.
+ gateway = self._sc._gateway
+ assert gateway is not None
+ jclass = gateway.jvm.long
+ if len(parameter) >= 1:
+ mapping = {
+ str: gateway.jvm.java.lang.String,
+ float: gateway.jvm.double,
+ int: gateway.jvm.long,
+ }
+ jclass = mapping[type(parameter[0])]
+ return toJArray(gateway, jclass, parameter)
+ else:
+ return parameter
+
+ jdf = self._jdf.hint(name, self._jseq(parameters, _converter))
+ return DataFrame(jdf, self.sparkSession)
+
+ def count(self) -> int:
+ return int(self._jdf.count())
+
+ def collect(self) -> List[Row]:
+ with SCCallSiteSync(self._sc):
+ sock_info = self._jdf.collectToPython()
+ return list(_load_from_socket(sock_info,
BatchedSerializer(CPickleSerializer())))
+
+ def toLocalIterator(self, prefetchPartitions: bool = False) ->
Iterator[Row]:
+ with SCCallSiteSync(self._sc):
+ sock_info = self._jdf.toPythonIterator(prefetchPartitions)
+ return _local_iterator_from_socket(sock_info,
BatchedSerializer(CPickleSerializer()))
+
+ def limit(self, num: int) -> "ParentDataFrame":
+ jdf = self._jdf.limit(num)
+ return DataFrame(jdf, self.sparkSession)
+
+ def offset(self, num: int) -> "ParentDataFrame":
+ jdf = self._jdf.offset(num)
+ return DataFrame(jdf, self.sparkSession)
+
+ def take(self, num: int) -> List[Row]:
+ return self.limit(num).collect()
+
+ def tail(self, num: int) -> List[Row]:
+ with SCCallSiteSync(self._sc):
+ sock_info = self._jdf.tailToPython(num)
+ return list(_load_from_socket(sock_info,
BatchedSerializer(CPickleSerializer())))
+
+ def foreach(self, f: Callable[[Row], None]) -> None:
+ self.rdd.foreach(f)
+
+ def foreachPartition(self, f: Callable[[Iterator[Row]], None]) -> None:
+ self.rdd.foreachPartition(f) # type: ignore[arg-type]
+
+ def cache(self) -> "ParentDataFrame":
+ self.is_cached = True
+ self._jdf.cache()
+ return self
+
+ def persist(
+ self,
+ storageLevel: StorageLevel = (StorageLevel.MEMORY_AND_DISK_DESER),
+ ) -> "ParentDataFrame":
+ self.is_cached = True
+ javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel)
+ self._jdf.persist(javaStorageLevel)
+ return self
+
+ @property
+ def storageLevel(self) -> StorageLevel:
+ java_storage_level = self._jdf.storageLevel()
+ storage_level = StorageLevel(
+ java_storage_level.useDisk(),
+ java_storage_level.useMemory(),
+ java_storage_level.useOffHeap(),
+ java_storage_level.deserialized(),
+ java_storage_level.replication(),
+ )
+ return storage_level
+
+ def unpersist(self, blocking: bool = False) -> "ParentDataFrame":
+ self.is_cached = False
+ self._jdf.unpersist(blocking)
+ return self
+
+ def coalesce(self, numPartitions: int) -> "ParentDataFrame":
+ return DataFrame(self._jdf.coalesce(numPartitions), self.sparkSession)
+
+ @overload
+ def repartition(self, numPartitions: int, *cols: "ColumnOrName") ->
"ParentDataFrame":
+ ...
+
+ @overload
+ def repartition(self, *cols: "ColumnOrName") -> "ParentDataFrame":
+ ...
+
+ def repartition( # type: ignore[misc]
+ self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName"
+ ) -> "ParentDataFrame":
+ if isinstance(numPartitions, int):
+ if len(cols) == 0:
+ return DataFrame(self._jdf.repartition(numPartitions),
self.sparkSession)
+ else:
+ return DataFrame(
+ self._jdf.repartition(numPartitions, self._jcols(*cols)),
+ self.sparkSession,
+ )
+ elif isinstance(numPartitions, (str, Column)):
+ cols = (numPartitions,) + cols
+ return DataFrame(self._jdf.repartition(self._jcols(*cols)),
self.sparkSession)
+ else:
+ raise PySparkTypeError(
+ error_class="NOT_COLUMN_OR_STR",
+ message_parameters={
+ "arg_name": "numPartitions",
+ "arg_type": type(numPartitions).__name__,
+ },
+ )
+
+ @overload
+ def repartitionByRange(self, numPartitions: int, *cols: "ColumnOrName") ->
"ParentDataFrame":
+ ...
+
+ @overload
+ def repartitionByRange(self, *cols: "ColumnOrName") -> "ParentDataFrame":
+ ...
+
+ def repartitionByRange( # type: ignore[misc]
+ self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName"
+ ) -> "ParentDataFrame":
+ if isinstance(numPartitions, int):
+ if len(cols) == 0:
+ raise PySparkValueError(
+ error_class="CANNOT_BE_EMPTY",
+ message_parameters={"item": "partition-by expression"},
+ )
+ else:
+ return DataFrame(
+ self._jdf.repartitionByRange(numPartitions,
self._jcols(*cols)),
+ self.sparkSession,
+ )
+ elif isinstance(numPartitions, (str, Column)):
+ cols = (numPartitions,) + cols
+ return DataFrame(self._jdf.repartitionByRange(self._jcols(*cols)),
self.sparkSession)
+ else:
+ raise PySparkTypeError(
+ error_class="NOT_COLUMN_OR_INT_OR_STR",
+ message_parameters={
+ "arg_name": "numPartitions",
+ "arg_type": type(numPartitions).__name__,
+ },
+ )
+
+ def distinct(self) -> "ParentDataFrame":
+ return DataFrame(self._jdf.distinct(), self.sparkSession)
+
+ @overload
+ def sample(self, fraction: float, seed: Optional[int] = ...) ->
"ParentDataFrame":
+ ...
+
+ @overload
+ def sample(
+ self,
+ withReplacement: Optional[bool],
+ fraction: float,
+ seed: Optional[int] = ...,
+ ) -> "ParentDataFrame":
+ ...
+
+ def sample( # type: ignore[misc]
+ self,
+ withReplacement: Optional[Union[float, bool]] = None,
+ fraction: Optional[Union[int, float]] = None,
+ seed: Optional[int] = None,
+ ) -> "ParentDataFrame":
+ # For the cases below:
+ # sample(True, 0.5 [, seed])
+ # sample(True, fraction=0.5 [, seed])
+ # sample(withReplacement=False, fraction=0.5 [, seed])
+ is_withReplacement_set = type(withReplacement) == bool and
isinstance(fraction, float)
+
+ # For the case below:
+ # sample(faction=0.5 [, seed])
+ is_withReplacement_omitted_kwargs = withReplacement is None and
isinstance(fraction, float)
+
+ # For the case below:
+ # sample(0.5 [, seed])
+ is_withReplacement_omitted_args = isinstance(withReplacement, float)
+
+ if not (
+ is_withReplacement_set
+ or is_withReplacement_omitted_kwargs
+ or is_withReplacement_omitted_args
+ ):
+ argtypes = [type(arg).__name__ for arg in [withReplacement,
fraction, seed]]
+ raise PySparkTypeError(
+ error_class="NOT_BOOL_OR_FLOAT_OR_INT",
+ message_parameters={
+ "arg_name": "withReplacement (optional), "
+ + "fraction (required) and seed (optional)",
+ "arg_type": ", ".join(argtypes),
+ },
+ )
+
+ if is_withReplacement_omitted_args:
+ if fraction is not None:
+ seed = cast(int, fraction)
+ fraction = withReplacement
+ withReplacement = None
+
+ seed = int(seed) if seed is not None else None
+ args = [arg for arg in [withReplacement, fraction, seed] if arg is not
None]
+ jdf = self._jdf.sample(*args)
+ return DataFrame(jdf, self.sparkSession)
+
+ def sampleBy(
+ self, col: "ColumnOrName", fractions: Dict[Any, float], seed:
Optional[int] = None
+ ) -> "ParentDataFrame":
+ if isinstance(col, str):
+ col = Column(col)
+ elif not isinstance(col, Column):
+ raise PySparkTypeError(
+ error_class="NOT_COLUMN_OR_STR",
+ message_parameters={"arg_name": "col", "arg_type":
type(col).__name__},
+ )
+ if not isinstance(fractions, dict):
+ raise PySparkTypeError(
+ error_class="NOT_DICT",
+ message_parameters={"arg_name": "fractions", "arg_type":
type(fractions).__name__},
+ )
+ for k, v in fractions.items():
+ if not isinstance(k, (float, int, str)):
+ raise PySparkTypeError(
+ error_class="DISALLOWED_TYPE_FOR_CONTAINER",
+ message_parameters={
+ "arg_name": "fractions",
+ "arg_type": type(fractions).__name__,
+ "allowed_types": "float, int, str",
+ "item_type": type(k).__name__,
+ },
+ )
+ fractions[k] = float(v)
+ col = col._jc
+ seed = seed if seed is not None else random.randint(0, sys.maxsize)
+ return DataFrame(
+ self._jdf.stat().sampleBy(col, self._jmap(fractions), seed),
self.sparkSession
+ )
+
+ def randomSplit(
+ self, weights: List[float], seed: Optional[int] = None
+ ) -> List["ParentDataFrame"]:
+ for w in weights:
+ if w < 0.0:
+ raise PySparkValueError(
+ error_class="VALUE_NOT_POSITIVE",
+ message_parameters={"arg_name": "weights", "arg_value":
str(w)},
+ )
+ seed = seed if seed is not None else random.randint(0, sys.maxsize)
+ df_array = self._jdf.randomSplit(
+ _to_list(self.sparkSession._sc, cast(List["ColumnOrName"],
weights)), int(seed)
+ )
+ return [DataFrame(df, self.sparkSession) for df in df_array]
+
+ @property
+ def dtypes(self) -> List[Tuple[str, str]]:
+ return [(str(f.name), f.dataType.simpleString()) for f in
self.schema.fields]
+
+ @property
+ def columns(self) -> List[str]:
+ return [f.name for f in self.schema.fields]
+
+ def colRegex(self, colName: str) -> Column:
+ if not isinstance(colName, str):
+ raise PySparkTypeError(
+ error_class="NOT_STR",
+ message_parameters={"arg_name": "colName", "arg_type":
type(colName).__name__},
+ )
+ jc = self._jdf.colRegex(colName)
+ return Column(jc)
+
+ def to(self, schema: StructType) -> "ParentDataFrame":
+ assert schema is not None
+ jschema = self._jdf.sparkSession().parseDataType(schema.json())
+ return DataFrame(self._jdf.to(jschema), self.sparkSession)
+
+ def alias(self, alias: str) -> "ParentDataFrame":
+ assert isinstance(alias, str), "alias should be a string"
+ return DataFrame(getattr(self._jdf, "as")(alias), self.sparkSession)
+
+ def crossJoin(self, other: "ParentDataFrame") -> "ParentDataFrame":
+ jdf = self._jdf.crossJoin(other._jdf)
+ return DataFrame(jdf, self.sparkSession)
+
+ def join(
+ self,
+ other: "ParentDataFrame",
+ on: Optional[Union[str, List[str], Column, List[Column]]] = None,
+ how: Optional[str] = None,
+ ) -> "ParentDataFrame":
+ if on is not None and not isinstance(on, list):
+ on = [on] # type: ignore[assignment]
+
+ if on is not None:
+ if isinstance(on[0], str):
+ on = self._jseq(cast(List[str], on))
+ else:
+ assert isinstance(on[0], Column), "on should be Column or list
of Column"
+ on = reduce(lambda x, y: x.__and__(y), cast(List[Column], on))
+ on = on._jc
+
+ if on is None and how is None:
+ jdf = self._jdf.join(other._jdf)
+ else:
+ if how is None:
+ how = "inner"
+ if on is None:
+ on = self._jseq([])
+ assert isinstance(how, str), "how should be a string"
+ jdf = self._jdf.join(other._jdf, on, how)
+ return DataFrame(jdf, self.sparkSession)
+
+ # TODO(SPARK-22947): Fix the DataFrame API.
+ def _joinAsOf(
+ self,
+ other: "ParentDataFrame",
+ leftAsOfColumn: Union[str, Column],
+ rightAsOfColumn: Union[str, Column],
+ on: Optional[Union[str, List[str], Column, List[Column]]] = None,
+ how: Optional[str] = None,
+ *,
+ tolerance: Optional[Column] = None,
+ allowExactMatches: bool = True,
+ direction: str = "backward",
+ ) -> "ParentDataFrame":
+ """
+ Perform an as-of join.
+
+ This is similar to a left-join except that we match on the nearest
+ key rather than equal keys.
+
+ .. versionchanged:: 4.0.0
+ Supports Spark Connect.
+
+ Parameters
+ ----------
+ other : :class:`DataFrame`
+ Right side of the join
+ leftAsOfColumn : str or :class:`Column`
+ a string for the as-of join column name, or a Column
+ rightAsOfColumn : str or :class:`Column`
+ a string for the as-of join column name, or a Column
+ on : str, list or :class:`Column`, optional
+ a string for the join column name, a list of column names,
+ a join expression (Column), or a list of Columns.
+ If `on` is a string or a list of strings indicating the name of
the join column(s),
+ the column(s) must exist on both sides, and this performs an
equi-join.
+ how : str, optional
+ default ``inner``. Must be one of: ``inner`` and ``left``.
+ tolerance : :class:`Column`, optional
+ an asof tolerance within this range; must be compatible
+ with the merge index.
+ allowExactMatches : bool, optional
+ default ``True``.
+ direction : str, optional
+ default ``backward``. Must be one of: ``backward``, ``forward``,
and ``nearest``.
+
+ Examples
+ --------
+ The following performs an as-of join between ``left`` and ``right``.
+
+ >>> left = spark.createDataFrame([(1, "a"), (5, "b"), (10, "c")],
["a", "left_val"])
+ >>> right = spark.createDataFrame([(1, 1), (2, 2), (3, 3), (6, 6), (7,
7)],
+ ... ["a", "right_val"])
+ >>> left._joinAsOf(
+ ... right, leftAsOfColumn="a", rightAsOfColumn="a"
+ ... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
+ [Row(a=1, left_val='a', right_val=1),
+ Row(a=5, left_val='b', right_val=3),
+ Row(a=10, left_val='c', right_val=7)]
+
+ >>> from pyspark.sql import functions as sf
+ >>> left._joinAsOf(
+ ... right, leftAsOfColumn="a", rightAsOfColumn="a",
tolerance=sf.lit(1)
+ ... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
+ [Row(a=1, left_val='a', right_val=1)]
+
+ >>> left._joinAsOf(
+ ... right, leftAsOfColumn="a", rightAsOfColumn="a", how="left",
tolerance=sf.lit(1)
+ ... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
+ [Row(a=1, left_val='a', right_val=1),
+ Row(a=5, left_val='b', right_val=None),
+ Row(a=10, left_val='c', right_val=None)]
+
+ >>> left._joinAsOf(
+ ... right, leftAsOfColumn="a", rightAsOfColumn="a",
allowExactMatches=False
+ ... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
+ [Row(a=5, left_val='b', right_val=3),
+ Row(a=10, left_val='c', right_val=7)]
+
+ >>> left._joinAsOf(
+ ... right, leftAsOfColumn="a", rightAsOfColumn="a",
direction="forward"
+ ... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
+ [Row(a=1, left_val='a', right_val=1),
+ Row(a=5, left_val='b', right_val=6)]
+ """
+ if isinstance(leftAsOfColumn, str):
+ leftAsOfColumn = self[leftAsOfColumn]
+ left_as_of_jcol = leftAsOfColumn._jc
+ if isinstance(rightAsOfColumn, str):
+ rightAsOfColumn = other[rightAsOfColumn]
+ right_as_of_jcol = rightAsOfColumn._jc
+
+ if on is not None and not isinstance(on, list):
+ on = [on] # type: ignore[assignment]
+
+ if on is not None:
+ if isinstance(on[0], str):
+ on = self._jseq(cast(List[str], on))
+ else:
+ assert isinstance(on[0], Column), "on should be Column or list
of Column"
+ on = reduce(lambda x, y: x.__and__(y), cast(List[Column], on))
+ on = on._jc
+
+ if how is None:
+ how = "inner"
+ assert isinstance(how, str), "how should be a string"
+
+ if tolerance is not None:
+ assert isinstance(tolerance, Column), "tolerance should be Column"
+ tolerance = tolerance._jc
+
+ jdf = self._jdf.joinAsOf(
+ other._jdf,
+ left_as_of_jcol,
+ right_as_of_jcol,
+ on,
+ how,
+ tolerance,
+ allowExactMatches,
+ direction,
+ )
+ return DataFrame(jdf, self.sparkSession)
+
+ def sortWithinPartitions(
+ self,
+ *cols: Union[int, str, Column, List[Union[int, str, Column]]],
+ **kwargs: Any,
+ ) -> "ParentDataFrame":
+ jdf = self._jdf.sortWithinPartitions(self._sort_cols(cols, kwargs))
+ return DataFrame(jdf, self.sparkSession)
+
+ def sort(
+ self,
+ *cols: Union[int, str, Column, List[Union[int, str, Column]]],
+ **kwargs: Any,
+ ) -> "ParentDataFrame":
+ jdf = self._jdf.sort(self._sort_cols(cols, kwargs))
+ return DataFrame(jdf, self.sparkSession)
+
+ orderBy = sort
+
+ def _jseq(
+ self,
+ cols: Sequence,
+ converter: Optional[Callable[..., Union["PrimitiveType",
"JavaObject"]]] = None,
+ ) -> "JavaObject":
+ """Return a JVM Seq of Columns from a list of Column or names"""
+ return _to_seq(self.sparkSession._sc, cols, converter)
+
+ def _jmap(self, jm: Dict) -> "JavaObject":
+ """Return a JVM Scala Map from a dict"""
+ return _to_scala_map(self.sparkSession._sc, jm)
+
+ def _jcols(self, *cols: "ColumnOrName") -> "JavaObject":
+ """Return a JVM Seq of Columns from a list of Column or column names
+
+ If `cols` has only one list in it, cols[0] will be used as the list.
+ """
+ if len(cols) == 1 and isinstance(cols[0], list):
+ cols = cols[0]
+ return self._jseq(cols, _to_java_column)
+
+ def _jcols_ordinal(self, *cols: "ColumnOrNameOrOrdinal") -> "JavaObject":
+ """Return a JVM Seq of Columns from a list of Column or column names
or column ordinals.
+
+ If `cols` has only one list in it, cols[0] will be used as the list.
+ """
+ if len(cols) == 1 and isinstance(cols[0], list):
+ cols = cols[0]
+
+ _cols = []
+ for c in cols:
+ if isinstance(c, int) and not isinstance(c, bool):
+ if c < 1:
+ raise PySparkIndexError(
+ error_class="INDEX_NOT_POSITIVE",
message_parameters={"index": str(c)}
+ )
+ # ordinal is 1-based
+ _cols.append(self[c - 1])
+ else:
+ _cols.append(c) # type: ignore[arg-type]
+ return self._jseq(_cols, _to_java_column)
+
+ def _sort_cols(
+ self,
+ cols: Sequence[Union[int, str, Column, List[Union[int, str, Column]]]],
+ kwargs: Dict[str, Any],
+ ) -> "JavaObject":
+ """Return a JVM Seq of Columns that describes the sort order"""
+ if not cols:
+ raise PySparkValueError(
+ error_class="CANNOT_BE_EMPTY",
+ message_parameters={"item": "column"},
+ )
+ if len(cols) == 1 and isinstance(cols[0], list):
+ cols = cols[0]
+
+ jcols = []
+ for c in cols:
+ if isinstance(c, int) and not isinstance(c, bool):
+ # ordinal is 1-based
+ if c > 0:
+ _c = self[c - 1]
+ # negative ordinal means sort by desc
+ elif c < 0:
+ _c = self[-c - 1].desc()
+ else:
+ raise PySparkIndexError(
+ error_class="ZERO_INDEX",
+ message_parameters={},
+ )
+ else:
+ _c = c # type: ignore[assignment]
+ jcols.append(_to_java_column(cast("ColumnOrName", _c)))
+
+ ascending = kwargs.get("ascending", True)
+ if isinstance(ascending, (bool, int)):
+ if not ascending:
+ jcols = [jc.desc() for jc in jcols]
+ elif isinstance(ascending, list):
+ jcols = [jc if asc else jc.desc() for asc, jc in zip(ascending,
jcols)]
+ else:
+ raise PySparkTypeError(
+ error_class="NOT_BOOL_OR_LIST",
+ message_parameters={"arg_name": "ascending", "arg_type":
type(ascending).__name__},
+ )
+ return self._jseq(jcols)
+
+ def describe(self, *cols: Union[str, List[str]]) -> "ParentDataFrame":
+ if len(cols) == 1 and isinstance(cols[0], list):
+ cols = cols[0] # type: ignore[assignment]
+ jdf = self._jdf.describe(self._jseq(cols))
+ return DataFrame(jdf, self.sparkSession)
+
+ def summary(self, *statistics: str) -> "ParentDataFrame":
+ if len(statistics) == 1 and isinstance(statistics[0], list):
+ statistics = statistics[0]
+ jdf = self._jdf.summary(self._jseq(statistics))
+ return DataFrame(jdf, self.sparkSession)
+
+ @overload
+ def head(self) -> Optional[Row]:
+ ...
+
+ @overload
+ def head(self, n: int) -> List[Row]:
+ ...
+
+ def head(self, n: Optional[int] = None) -> Union[Optional[Row], List[Row]]:
+ if n is None:
+ rs = self.head(1)
+ return rs[0] if rs else None
+ return self.take(n)
+
+ def first(self) -> Optional[Row]:
+ return self.head()
+
+ @overload
+ def __getitem__(self, item: Union[int, str]) -> Column:
+ ...
+
+ @overload
+ def __getitem__(self, item: Union[Column, List, Tuple]) ->
"ParentDataFrame":
+ ...
+
+ def __getitem__(
+ self, item: Union[int, str, Column, List, Tuple]
+ ) -> Union[Column, "ParentDataFrame"]:
+ if isinstance(item, str):
+ jc = self._jdf.apply(item)
+ return Column(jc)
+ elif isinstance(item, Column):
+ return self.filter(item)
+ elif isinstance(item, (list, tuple)):
+ return self.select(*item)
+ elif isinstance(item, int):
+ jc = self._jdf.apply(self.columns[item])
+ return Column(jc)
+ else:
+ raise PySparkTypeError(
+ error_class="NOT_COLUMN_OR_FLOAT_OR_INT_OR_LIST_OR_STR",
+ message_parameters={"arg_name": "item", "arg_type":
type(item).__name__},
+ )
+
+ def __getattr__(self, name: str) -> Column:
+ if name not in self.columns:
+ raise PySparkAttributeError(
+ error_class="ATTRIBUTE_NOT_SUPPORTED",
message_parameters={"attr_name": name}
+ )
+ jc = self._jdf.apply(name)
+ return Column(jc)
+
+ def __dir__(self) -> List[str]:
+ attrs = set(dir(DataFrame))
+ attrs.update(filter(lambda s: s.isidentifier(), self.columns))
+ return sorted(attrs)
+
+ @overload
+ def select(self, *cols: "ColumnOrName") -> "ParentDataFrame":
+ ...
+
+ @overload
+ def select(self, __cols: Union[List[Column], List[str]]) ->
"ParentDataFrame":
+ ...
+
+ def select(self, *cols: "ColumnOrName") -> "ParentDataFrame": # type:
ignore[misc]
+ jdf = self._jdf.select(self._jcols(*cols))
+ return DataFrame(jdf, self.sparkSession)
+
+ @overload
+ def selectExpr(self, *expr: str) -> "ParentDataFrame":
+ ...
+
+ @overload
+ def selectExpr(self, *expr: List[str]) -> "ParentDataFrame":
+ ...
+
+ def selectExpr(self, *expr: Union[str, List[str]]) -> "ParentDataFrame":
+ if len(expr) == 1 and isinstance(expr[0], list):
+ expr = expr[0] # type: ignore[assignment]
+ jdf = self._jdf.selectExpr(self._jseq(expr))
+ return DataFrame(jdf, self.sparkSession)
+
+ def filter(self, condition: "ColumnOrName") -> "ParentDataFrame":
+ if isinstance(condition, str):
+ jdf = self._jdf.filter(condition)
+ elif isinstance(condition, Column):
+ jdf = self._jdf.filter(condition._jc)
+ else:
+ raise PySparkTypeError(
+ error_class="NOT_COLUMN_OR_STR",
+ message_parameters={"arg_name": "condition", "arg_type":
type(condition).__name__},
+ )
+ return DataFrame(jdf, self.sparkSession)
+
+ @overload
+ def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
+ ...
+
+ @overload
+ def groupBy(self, __cols: Union[List[Column], List[str], List[int]]) ->
"GroupedData":
+ ...
+
+ def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": #
type: ignore[misc]
+ jgd = self._jdf.groupBy(self._jcols_ordinal(*cols))
+ from pyspark.sql.group import GroupedData
+
+ return GroupedData(jgd, self)
+
+ @overload
+ def rollup(self, *cols: "ColumnOrName") -> "GroupedData":
+ ...
+
+ @overload
+ def rollup(self, __cols: Union[List[Column], List[str]]) -> "GroupedData":
+ ...
+
+ def rollup(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": #
type: ignore[misc]
+ jgd = self._jdf.rollup(self._jcols_ordinal(*cols))
+ from pyspark.sql.group import GroupedData
+
+ return GroupedData(jgd, self)
+
+ @overload
+ def cube(self, *cols: "ColumnOrName") -> "GroupedData":
+ ...
+
+ @overload
+ def cube(self, __cols: Union[List[Column], List[str]]) -> "GroupedData":
+ ...
+
+ def cube(self, *cols: "ColumnOrName") -> "GroupedData": # type:
ignore[misc]
+ jgd = self._jdf.cube(self._jcols_ordinal(*cols))
+ from pyspark.sql.group import GroupedData
+
+ return GroupedData(jgd, self)
+
+ def groupingSets(
+ self, groupingSets: Sequence[Sequence["ColumnOrName"]], *cols:
"ColumnOrName"
+ ) -> "GroupedData":
+ from pyspark.sql.group import GroupedData
+
+ jgrouping_sets = _to_seq(self._sc, [self._jcols(*inner) for inner in
groupingSets])
+
+ jgd = self._jdf.groupingSets(jgrouping_sets, self._jcols(*cols))
+ return GroupedData(jgd, self)
+
+ def unpivot(
+ self,
+ ids: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName",
...]],
+ values: Optional[Union["ColumnOrName", List["ColumnOrName"],
Tuple["ColumnOrName", ...]]],
+ variableColumnName: str,
+ valueColumnName: str,
+ ) -> "ParentDataFrame":
+ assert ids is not None, "ids must not be None"
+
+ def to_jcols(
+ cols: Union["ColumnOrName", List["ColumnOrName"],
Tuple["ColumnOrName", ...]]
+ ) -> "JavaObject":
+ if isinstance(cols, list):
+ return self._jcols(*cols)
+ if isinstance(cols, tuple):
+ return self._jcols(*list(cols))
+ return self._jcols(cols)
+
+ jids = to_jcols(ids)
+ if values is None:
+ jdf = self._jdf.unpivotWithSeq(jids, variableColumnName,
valueColumnName)
+ else:
+ jvals = to_jcols(values)
+ jdf = self._jdf.unpivotWithSeq(jids, jvals, variableColumnName,
valueColumnName)
+
+ return DataFrame(jdf, self.sparkSession)
+
+ def melt(
+ self,
+ ids: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName",
...]],
+ values: Optional[Union["ColumnOrName", List["ColumnOrName"],
Tuple["ColumnOrName", ...]]],
+ variableColumnName: str,
+ valueColumnName: str,
+ ) -> "ParentDataFrame":
+ return self.unpivot(ids, values, variableColumnName, valueColumnName)
+
+ def agg(self, *exprs: Union[Column, Dict[str, str]]) -> "ParentDataFrame":
+ return self.groupBy().agg(*exprs) # type: ignore[arg-type]
+
+ def observe(
+ self,
+ observation: Union["Observation", str],
+ *exprs: Column,
+ ) -> "ParentDataFrame":
+ from pyspark.sql import Observation
+
+ if len(exprs) == 0:
+ raise PySparkValueError(
+ error_class="CANNOT_BE_EMPTY",
+ message_parameters={"item": "exprs"},
+ )
+ if not all(isinstance(c, Column) for c in exprs):
+ raise PySparkTypeError(
+ error_class="NOT_LIST_OF_COLUMN",
+ message_parameters={"arg_name": "exprs"},
+ )
+
+ if isinstance(observation, Observation):
+ return observation._on(self, *exprs)
+ elif isinstance(observation, str):
+ return DataFrame(
+ self._jdf.observe(
+ observation, exprs[0]._jc, _to_seq(self._sc, [c._jc for c
in exprs[1:]])
+ ),
+ self.sparkSession,
+ )
+ else:
+ raise PySparkTypeError(
+ error_class="NOT_LIST_OF_COLUMN",
+ message_parameters={
+ "arg_name": "observation",
+ "arg_type": type(observation).__name__,
+ },
+ )
+
+ def union(self, other: "ParentDataFrame") -> "ParentDataFrame":
+ return DataFrame(self._jdf.union(other._jdf), self.sparkSession)
+
+ def unionAll(self, other: "ParentDataFrame") -> "ParentDataFrame":
+ return self.union(other)
+
+ def unionByName(
+ self, other: "ParentDataFrame", allowMissingColumns: bool = False
+ ) -> "ParentDataFrame":
+ return DataFrame(self._jdf.unionByName(other._jdf,
allowMissingColumns), self.sparkSession)
+
+ def intersect(self, other: "ParentDataFrame") -> "ParentDataFrame":
+ return DataFrame(self._jdf.intersect(other._jdf), self.sparkSession)
+
+ def intersectAll(self, other: "ParentDataFrame") -> "ParentDataFrame":
+ return DataFrame(self._jdf.intersectAll(other._jdf), self.sparkSession)
+
+ def subtract(self, other: "ParentDataFrame") -> "ParentDataFrame":
+ return DataFrame(getattr(self._jdf, "except")(other._jdf),
self.sparkSession)
+
+ def dropDuplicates(self, subset: Optional[List[str]] = None) ->
"ParentDataFrame":
+ if subset is not None and (not isinstance(subset, Iterable) or
isinstance(subset, str)):
+ raise PySparkTypeError(
+ error_class="NOT_LIST_OR_TUPLE",
+ message_parameters={"arg_name": "subset", "arg_type":
type(subset).__name__},
+ )
+
+ if subset is None:
+ jdf = self._jdf.dropDuplicates()
+ else:
+ jdf = self._jdf.dropDuplicates(self._jseq(subset))
+ return DataFrame(jdf, self.sparkSession)
+
+ def dropDuplicatesWithinWatermark(
+ self, subset: Optional[List[str]] = None
+ ) -> "ParentDataFrame":
+ if subset is not None and (not isinstance(subset, Iterable) or
isinstance(subset, str)):
+ raise PySparkTypeError(
+ error_class="NOT_LIST_OR_TUPLE",
+ message_parameters={"arg_name": "subset", "arg_type":
type(subset).__name__},
+ )
+
+ if subset is None:
+ jdf = self._jdf.dropDuplicatesWithinWatermark()
+ else:
+ jdf = self._jdf.dropDuplicatesWithinWatermark(self._jseq(subset))
+ return DataFrame(jdf, self.sparkSession)
+
+ def dropna(
+ self,
+ how: str = "any",
+ thresh: Optional[int] = None,
+ subset: Optional[Union[str, Tuple[str, ...], List[str]]] = None,
+ ) -> "ParentDataFrame":
+ if how is not None and how not in ["any", "all"]:
+ raise PySparkValueError(
+ error_class="VALUE_NOT_ANY_OR_ALL",
+ message_parameters={"arg_name": "how", "arg_type": how},
+ )
+
+ if subset is None:
+ subset = self.columns
+ elif isinstance(subset, str):
+ subset = [subset]
+ elif not isinstance(subset, (list, tuple)):
+ raise PySparkTypeError(
+ error_class="NOT_LIST_OR_STR_OR_TUPLE",
+ message_parameters={"arg_name": "subset", "arg_type":
type(subset).__name__},
+ )
+
+ if thresh is None:
+ thresh = len(subset) if how == "any" else 1
+
+ return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)),
self.sparkSession)
+
+ @overload
+ def fillna(
+ self,
+ value: "LiteralType",
+ subset: Optional[Union[str, Tuple[str, ...], List[str]]] = ...,
+ ) -> "ParentDataFrame":
+ ...
+
+ @overload
+ def fillna(self, value: Dict[str, "LiteralType"]) -> "ParentDataFrame":
+ ...
+
+ def fillna(
+ self,
+ value: Union["LiteralType", Dict[str, "LiteralType"]],
+ subset: Optional[Union[str, Tuple[str, ...], List[str]]] = None,
+ ) -> "ParentDataFrame":
+ if not isinstance(value, (float, int, str, bool, dict)):
+ raise PySparkTypeError(
+ error_class="NOT_BOOL_OR_DICT_OR_FLOAT_OR_INT_OR_STR",
+ message_parameters={"arg_name": "value", "arg_type":
type(value).__name__},
+ )
+
+ # Note that bool validates isinstance(int), but we don't want to
+ # convert bools to floats
+
+ if not isinstance(value, bool) and isinstance(value, int):
+ value = float(value)
+
+ if isinstance(value, dict):
+ return DataFrame(self._jdf.na().fill(value), self.sparkSession)
+ elif subset is None:
+ return DataFrame(self._jdf.na().fill(value), self.sparkSession)
+ else:
+ if isinstance(subset, str):
+ subset = [subset]
+ elif not isinstance(subset, (list, tuple)):
+ raise PySparkTypeError(
+ error_class="NOT_LIST_OR_TUPLE",
+ message_parameters={"arg_name": "subset", "arg_type":
type(subset).__name__},
+ )
+
+ return DataFrame(self._jdf.na().fill(value, self._jseq(subset)),
self.sparkSession)
+
+ @overload
+ def replace(
+ self,
+ to_replace: "LiteralType",
+ value: "OptionalPrimitiveType",
+ subset: Optional[List[str]] = ...,
+ ) -> "ParentDataFrame":
+ ...
+
+ @overload
+ def replace(
+ self,
+ to_replace: List["LiteralType"],
+ value: List["OptionalPrimitiveType"],
+ subset: Optional[List[str]] = ...,
+ ) -> "ParentDataFrame":
+ ...
+
+ @overload
+ def replace(
+ self,
+ to_replace: Dict["LiteralType", "OptionalPrimitiveType"],
+ subset: Optional[List[str]] = ...,
+ ) -> "ParentDataFrame":
+ ...
+
+ @overload
+ def replace(
+ self,
+ to_replace: List["LiteralType"],
+ value: "OptionalPrimitiveType",
+ subset: Optional[List[str]] = ...,
+ ) -> "ParentDataFrame":
+ ...
+
+ def replace( # type: ignore[misc]
+ self,
+ to_replace: Union[
+ "LiteralType", List["LiteralType"], Dict["LiteralType",
"OptionalPrimitiveType"]
+ ],
+ value: Optional[
+ Union["OptionalPrimitiveType", List["OptionalPrimitiveType"],
_NoValueType]
+ ] = _NoValue,
+ subset: Optional[List[str]] = None,
+ ) -> "ParentDataFrame":
+ if value is _NoValue:
+ if isinstance(to_replace, dict):
+ value = None
+ else:
+ raise PySparkTypeError(
+ error_class="ARGUMENT_REQUIRED",
+ message_parameters={"arg_name": "value", "condition":
"`to_replace` is dict"},
+ )
+
+ # Helper functions
+ def all_of(types: Union[Type, Tuple[Type, ...]]) ->
Callable[[Iterable], bool]:
+ """Given a type or tuple of types and a sequence of xs
+ check if each x is instance of type(s)
+
+ >>> all_of(bool)([True, False])
+ True
+ >>> all_of(str)(["a", 1])
+ False
+ """
+
+ def all_of_(xs: Iterable) -> bool:
+ return all(isinstance(x, types) for x in xs)
+
+ return all_of_
+
+ all_of_bool = all_of(bool)
+ all_of_str = all_of(str)
+ all_of_numeric = all_of((float, int))
+
+ # Validate input types
+ valid_types = (bool, float, int, str, list, tuple)
+ if not isinstance(to_replace, valid_types + (dict,)):
+ raise PySparkTypeError(
+
error_class="NOT_BOOL_OR_DICT_OR_FLOAT_OR_INT_OR_LIST_OR_STR_OR_TUPLE",
+ message_parameters={
+ "arg_name": "to_replace",
+ "arg_type": type(to_replace).__name__,
+ },
+ )
+
+ if (
+ not isinstance(value, valid_types)
+ and value is not None
+ and not isinstance(to_replace, dict)
+ ):
+ raise PySparkTypeError(
+
error_class="NOT_BOOL_OR_FLOAT_OR_INT_OR_LIST_OR_NONE_OR_STR_OR_TUPLE",
+ message_parameters={
+ "arg_name": "value",
+ "arg_type": type(value).__name__,
+ },
+ )
+
+ if isinstance(to_replace, (list, tuple)) and isinstance(value, (list,
tuple)):
+ if len(to_replace) != len(value):
+ raise PySparkValueError(
+ error_class="LENGTH_SHOULD_BE_THE_SAME",
+ message_parameters={
+ "arg1": "to_replace",
+ "arg2": "value",
+ "arg1_length": str(len(to_replace)),
+ "arg2_length": str(len(value)),
+ },
+ )
+
+ if not (subset is None or isinstance(subset, (list, tuple, str))):
+ raise PySparkTypeError(
+ error_class="NOT_LIST_OR_STR_OR_TUPLE",
+ message_parameters={"arg_name": "subset", "arg_type":
type(subset).__name__},
+ )
+
+ # Reshape input arguments if necessary
+ if isinstance(to_replace, (float, int, str)):
+ to_replace = [to_replace]
+
+ if isinstance(to_replace, dict):
+ rep_dict = to_replace
+ if value is not None:
+ warnings.warn("to_replace is a dict and value is not None.
value will be ignored.")
+ else:
+ if isinstance(value, (float, int, str)) or value is None:
+ value = [value for _ in range(len(to_replace))]
+ rep_dict = dict(zip(to_replace,
cast("Iterable[Optional[Union[float, str]]]", value)))
+
+ if isinstance(subset, str):
+ subset = [subset]
+
+ # Verify we were not passed in mixed type generics.
+ if not any(
+ all_of_type(rep_dict.keys())
+ and all_of_type(x for x in rep_dict.values() if x is not None)
+ for all_of_type in [all_of_bool, all_of_str, all_of_numeric]
+ ):
+ raise PySparkValueError(
+ error_class="MIXED_TYPE_REPLACEMENT",
+ message_parameters={},
+ )
+
+ if subset is None:
+ return DataFrame(self._jdf.na().replace("*", rep_dict),
self.sparkSession)
+ else:
+ return DataFrame(
+ self._jdf.na().replace(self._jseq(subset),
self._jmap(rep_dict)),
+ self.sparkSession,
+ )
+
+ @overload
+ def approxQuantile(
+ self,
+ col: str,
+ probabilities: Union[List[float], Tuple[float]],
+ relativeError: float,
+ ) -> List[float]:
+ ...
+
+ @overload
+ def approxQuantile(
+ self,
+ col: Union[List[str], Tuple[str]],
+ probabilities: Union[List[float], Tuple[float]],
+ relativeError: float,
+ ) -> List[List[float]]:
+ ...
+
+ def approxQuantile(
+ self,
+ col: Union[str, List[str], Tuple[str]],
+ probabilities: Union[List[float], Tuple[float]],
+ relativeError: float,
+ ) -> Union[List[float], List[List[float]]]:
+ if not isinstance(col, (str, list, tuple)):
+ raise PySparkTypeError(
+ error_class="NOT_LIST_OR_STR_OR_TUPLE",
+ message_parameters={"arg_name": "col", "arg_type":
type(col).__name__},
+ )
+
+ isStr = isinstance(col, str)
+
+ if isinstance(col, tuple):
+ col = list(col)
+ elif isStr:
+ col = [cast(str, col)]
+
+ for c in col:
+ if not isinstance(c, str):
+ raise PySparkTypeError(
+ error_class="DISALLOWED_TYPE_FOR_CONTAINER",
+ message_parameters={
+ "arg_name": "col",
+ "arg_type": type(col).__name__,
+ "allowed_types": "str",
+ "item_type": type(c).__name__,
+ },
+ )
+ col = _to_list(self._sc, cast(List["ColumnOrName"], col))
+
+ if not isinstance(probabilities, (list, tuple)):
+ raise PySparkTypeError(
+ error_class="NOT_LIST_OR_TUPLE",
+ message_parameters={
+ "arg_name": "probabilities",
+ "arg_type": type(probabilities).__name__,
+ },
+ )
+ if isinstance(probabilities, tuple):
+ probabilities = list(probabilities)
+ for p in probabilities:
+ if not isinstance(p, (float, int)) or p < 0 or p > 1:
+ raise PySparkTypeError(
+ error_class="NOT_LIST_OF_FLOAT_OR_INT",
+ message_parameters={
+ "arg_name": "probabilities",
+ "arg_type": type(p).__name__,
+ },
+ )
+ probabilities = _to_list(self._sc, cast(List["ColumnOrName"],
probabilities))
+
+ if not isinstance(relativeError, (float, int)):
+ raise PySparkTypeError(
+ error_class="NOT_FLOAT_OR_INT",
+ message_parameters={
+ "arg_name": "relativeError",
+ "arg_type": type(relativeError).__name__,
+ },
+ )
+ if relativeError < 0:
+ raise PySparkValueError(
+ error_class="NEGATIVE_VALUE",
+ message_parameters={
+ "arg_name": "relativeError",
+ "arg_value": str(relativeError),
+ },
+ )
+ relativeError = float(relativeError)
+
+ jaq = self._jdf.stat().approxQuantile(col, probabilities,
relativeError)
+ jaq_list = [list(j) for j in jaq]
+ return jaq_list[0] if isStr else jaq_list
+
+ def corr(self, col1: str, col2: str, method: Optional[str] = None) ->
float:
+ if not isinstance(col1, str):
+ raise PySparkTypeError(
+ error_class="NOT_STR",
+ message_parameters={"arg_name": "col1", "arg_type":
type(col1).__name__},
+ )
+ if not isinstance(col2, str):
+ raise PySparkTypeError(
+ error_class="NOT_STR",
+ message_parameters={"arg_name": "col2", "arg_type":
type(col2).__name__},
+ )
+ if not method:
+ method = "pearson"
+ if not method == "pearson":
+ raise PySparkValueError(
+ error_class="VALUE_NOT_PEARSON",
+ message_parameters={"arg_name": "method", "arg_value": method},
+ )
+ return self._jdf.stat().corr(col1, col2, method)
+
+ def cov(self, col1: str, col2: str) -> float:
+ if not isinstance(col1, str):
+ raise PySparkTypeError(
+ error_class="NOT_STR",
+ message_parameters={"arg_name": "col1", "arg_type":
type(col1).__name__},
+ )
+ if not isinstance(col2, str):
+ raise PySparkTypeError(
+ error_class="NOT_STR",
+ message_parameters={"arg_name": "col2", "arg_type":
type(col2).__name__},
+ )
+ return self._jdf.stat().cov(col1, col2)
+
+ def crosstab(self, col1: str, col2: str) -> "ParentDataFrame":
+ if not isinstance(col1, str):
+ raise PySparkTypeError(
+ error_class="NOT_STR",
+ message_parameters={"arg_name": "col1", "arg_type":
type(col1).__name__},
+ )
+ if not isinstance(col2, str):
+ raise PySparkTypeError(
+ error_class="NOT_STR",
+ message_parameters={"arg_name": "col2", "arg_type":
type(col2).__name__},
+ )
+ return DataFrame(self._jdf.stat().crosstab(col1, col2),
self.sparkSession)
+
+ def freqItems(
+ self, cols: Union[List[str], Tuple[str]], support: Optional[float] =
None
+ ) -> "ParentDataFrame":
+ if isinstance(cols, tuple):
+ cols = list(cols)
+ if not isinstance(cols, list):
+ raise PySparkTypeError(
+ error_class="NOT_LIST_OR_TUPLE",
+ message_parameters={"arg_name": "cols", "arg_type":
type(cols).__name__},
+ )
+ if not support:
+ support = 0.01
+ return DataFrame(
+ self._jdf.stat().freqItems(_to_seq(self._sc, cols), support),
self.sparkSession
+ )
+
+ def _ipython_key_completions_(self) -> List[str]:
+ """Returns the names of columns in this :class:`DataFrame`.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], ["age",
"name"])
+ >>> df._ipython_key_completions_()
+ ['age', 'name']
+
+ Would return illegal identifiers.
+ >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], ["age 1",
"name?1"])
+ >>> df._ipython_key_completions_()
+ ['age 1', 'name?1']
+ """
+ return self.columns
+
+ def withColumns(self, *colsMap: Dict[str, Column]) -> "ParentDataFrame":
+ # Below code is to help enable kwargs in future.
+ assert len(colsMap) == 1
+ colsMap = colsMap[0] # type: ignore[assignment]
+
+ if not isinstance(colsMap, dict):
+ raise PySparkTypeError(
+ error_class="NOT_DICT",
+ message_parameters={"arg_name": "colsMap", "arg_type":
type(colsMap).__name__},
+ )
+
+ col_names = list(colsMap.keys())
+ cols = list(colsMap.values())
+
+ return DataFrame(
+ self._jdf.withColumns(_to_seq(self._sc, col_names),
self._jcols(*cols)),
+ self.sparkSession,
+ )
+
+ def withColumn(self, colName: str, col: Column) -> "ParentDataFrame":
+ if not isinstance(col, Column):
+ raise PySparkTypeError(
+ error_class="NOT_COLUMN",
+ message_parameters={"arg_name": "col", "arg_type":
type(col).__name__},
+ )
+ return DataFrame(self._jdf.withColumn(colName, col._jc),
self.sparkSession)
+
+ def withColumnRenamed(self, existing: str, new: str) -> "ParentDataFrame":
+ return DataFrame(self._jdf.withColumnRenamed(existing, new),
self.sparkSession)
+
+ def withColumnsRenamed(self, colsMap: Dict[str, str]) -> "ParentDataFrame":
+ if not isinstance(colsMap, dict):
+ raise PySparkTypeError(
+ error_class="NOT_DICT",
+ message_parameters={"arg_name": "colsMap", "arg_type":
type(colsMap).__name__},
+ )
+
+ col_names: List[str] = []
+ new_col_names: List[str] = []
+ for k, v in colsMap.items():
+ col_names.append(k)
+ new_col_names.append(v)
+
+ return DataFrame(
+ self._jdf.withColumnsRenamed(
+ _to_seq(self._sc, col_names), _to_seq(self._sc, new_col_names)
+ ),
+ self.sparkSession,
+ )
+
+ def withMetadata(self, columnName: str, metadata: Dict[str, Any]) ->
"ParentDataFrame":
+ from py4j.java_gateway import JVMView
+
+ if not isinstance(metadata, dict):
+ raise PySparkTypeError(
+ error_class="NOT_DICT",
+ message_parameters={"arg_name": "metadata", "arg_type":
type(metadata).__name__},
+ )
+ sc = get_active_spark_context()
+ jmeta = cast(JVMView,
sc._jvm).org.apache.spark.sql.types.Metadata.fromJson(
+ json.dumps(metadata)
+ )
+ return DataFrame(self._jdf.withMetadata(columnName, jmeta),
self.sparkSession)
+
+ @overload
+ def drop(self, cols: "ColumnOrName") -> "ParentDataFrame":
+ ...
+
+ @overload
+ def drop(self, *cols: str) -> "ParentDataFrame":
+ ...
+
+ def drop(self, *cols: "ColumnOrName") -> "ParentDataFrame": # type:
ignore[misc]
+ column_names: List[str] = []
+ java_columns: List["JavaObject"] = []
+
+ for c in cols:
+ if isinstance(c, str):
+ column_names.append(c)
+ elif isinstance(c, Column):
+ java_columns.append(c._jc)
+ else:
+ raise PySparkTypeError(
+ error_class="NOT_COLUMN_OR_STR",
+ message_parameters={"arg_name": "col", "arg_type":
type(c).__name__},
+ )
+
+ jdf = self._jdf
+ if len(java_columns) > 0:
+ first_column, *remaining_columns = java_columns
+ jdf = jdf.drop(first_column, self._jseq(remaining_columns))
+ if len(column_names) > 0:
+ jdf = jdf.drop(self._jseq(column_names))
+
+ return DataFrame(jdf, self.sparkSession)
+
+ def toDF(self, *cols: str) -> "ParentDataFrame":
+ for col in cols:
+ if not isinstance(col, str):
+ raise PySparkTypeError(
+ error_class="NOT_LIST_OF_STR",
+ message_parameters={"arg_name": "cols", "arg_type":
type(col).__name__},
+ )
+ jdf = self._jdf.toDF(self._jseq(cols))
+ return DataFrame(jdf, self.sparkSession)
+
+ def transform(
+ self, func: Callable[..., "ParentDataFrame"], *args: Any, **kwargs: Any
+ ) -> "ParentDataFrame":
+ result = func(self, *args, **kwargs)
+ assert isinstance(
+ result, DataFrame
+ ), "Func returned an instance of type [%s], " "should have been
DataFrame." % type(result)
+ return result
+
+ def sameSemantics(self, other: "ParentDataFrame") -> bool:
+ if not isinstance(other, DataFrame):
+ raise PySparkTypeError(
+ error_class="NOT_DATAFRAME",
+ message_parameters={"arg_name": "other", "arg_type":
type(other).__name__},
+ )
+ return self._jdf.sameSemantics(other._jdf)
+
+ def semanticHash(self) -> int:
+ return self._jdf.semanticHash()
+
+ def inputFiles(self) -> List[str]:
+ return list(self._jdf.inputFiles())
+
+ def where(self, condition: "ColumnOrName") -> "ParentDataFrame":
+ return self.filter(condition)
+
+ # Two aliases below were added for pandas compatibility many years ago.
+ # There are too many differences compared to pandas and we cannot just
+ # make it "compatible" by adding aliases. Therefore, we stop adding such
+ # aliases as of Spark 3.0. Two methods below remain just
+ # for legacy users currently.
+ @overload
+ def groupby(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
+ ...
+
+ @overload
+ def groupby(self, __cols: Union[List[Column], List[str], List[int]]) ->
"GroupedData":
+ ...
+
+ def groupby(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": #
type: ignore[misc]
+ return self.groupBy(*cols)
+
+ def drop_duplicates(self, subset: Optional[List[str]] = None) ->
"ParentDataFrame":
+ return self.dropDuplicates(subset)
+
+ def writeTo(self, table: str) -> DataFrameWriterV2:
+ return DataFrameWriterV2(self, table)
+
+ def pandas_api(
+ self, index_col: Optional[Union[str, List[str]]] = None
+ ) -> "PandasOnSparkDataFrame":
+ from pyspark.pandas.namespace import _get_index_map
+ from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
+ from pyspark.pandas.internal import InternalFrame
+
+ index_spark_columns, index_names = _get_index_map(self, index_col)
+ internal = InternalFrame(
+ spark_frame=self,
+ index_spark_columns=index_spark_columns,
+ index_names=index_names, # type: ignore[arg-type]
+ )
+ return PandasOnSparkDataFrame(internal)
+
+ def mapInPandas(
+ self,
+ func: "PandasMapIterFunction",
+ schema: Union[StructType, str],
+ barrier: bool = False,
+ profile: Optional[ResourceProfile] = None,
+ ) -> "DataFrame":
+ return PandasMapOpsMixin.mapInPandas(self, func, schema, barrier,
profile)
+
+ def mapInArrow(
+ self,
+ func: "ArrowMapIterFunction",
+ schema: Union[StructType, str],
+ barrier: bool = False,
+ profile: Optional[ResourceProfile] = None,
+ ) -> "DataFrame":
+ return PandasMapOpsMixin.mapInArrow(self, func, schema, barrier,
profile)
+
+ def toPandas(self) -> "PandasDataFrameLike":
+ return PandasConversionMixin.toPandas(self)
+
+
+def _to_scala_map(sc: "SparkContext", jm: Dict) -> "JavaObject":
+ """
+ Convert a dict into a JVM Map.
+ """
+ assert sc._jvm is not None
+ return sc._jvm.PythonUtils.toScalaMap(jm)
+
+
+class DataFrameNaFunctions(ParentDataFrameNaFunctions):
+ def __init__(self, df: ParentDataFrame):
+ self.df = df
+
+ def drop(
+ self,
+ how: str = "any",
+ thresh: Optional[int] = None,
+ subset: Optional[Union[str, Tuple[str, ...], List[str]]] = None,
+ ) -> ParentDataFrame:
+ return self.df.dropna(how=how, thresh=thresh, subset=subset)
+
+ @overload
+ def fill(self, value: "LiteralType", subset: Optional[List[str]] = ...) ->
ParentDataFrame:
+ ...
+
+ @overload
+ def fill(self, value: Dict[str, "LiteralType"]) -> ParentDataFrame:
+ ...
+
+ def fill(
+ self,
+ value: Union["LiteralType", Dict[str, "LiteralType"]],
+ subset: Optional[List[str]] = None,
+ ) -> ParentDataFrame:
+ return self.df.fillna(value=value, subset=subset) # type:
ignore[arg-type]
+
+ @overload
+ def replace(
+ self,
+ to_replace: List["LiteralType"],
+ value: List["OptionalPrimitiveType"],
+ subset: Optional[List[str]] = ...,
+ ) -> ParentDataFrame:
+ ...
+
+ @overload
+ def replace(
+ self,
+ to_replace: Dict["LiteralType", "OptionalPrimitiveType"],
+ subset: Optional[List[str]] = ...,
+ ) -> ParentDataFrame:
+ ...
+
+ @overload
+ def replace(
+ self,
+ to_replace: List["LiteralType"],
+ value: "OptionalPrimitiveType",
+ subset: Optional[List[str]] = ...,
+ ) -> ParentDataFrame:
+ ...
+
+ def replace( # type: ignore[misc]
+ self,
+ to_replace: Union[List["LiteralType"], Dict["LiteralType",
"OptionalPrimitiveType"]],
+ value: Optional[
+ Union["OptionalPrimitiveType", List["OptionalPrimitiveType"],
_NoValueType]
+ ] = _NoValue,
+ subset: Optional[List[str]] = None,
+ ) -> ParentDataFrame:
+ return self.df.replace(to_replace, value, subset) # type:
ignore[arg-type]
+
+
+class DataFrameStatFunctions(ParentDataFrameStatFunctions):
+ def __init__(self, df: ParentDataFrame):
+ self.df = df
+
+ @overload
+ def approxQuantile(
+ self,
+ col: str,
+ probabilities: Union[List[float], Tuple[float]],
+ relativeError: float,
+ ) -> List[float]:
+ ...
+
+ @overload
+ def approxQuantile(
+ self,
+ col: Union[List[str], Tuple[str]],
+ probabilities: Union[List[float], Tuple[float]],
+ relativeError: float,
+ ) -> List[List[float]]:
+ ...
+
+ def approxQuantile(
+ self,
+ col: Union[str, List[str], Tuple[str]],
+ probabilities: Union[List[float], Tuple[float]],
+ relativeError: float,
+ ) -> Union[List[float], List[List[float]]]:
+ return self.df.approxQuantile(col, probabilities, relativeError)
+
+ def corr(self, col1: str, col2: str, method: Optional[str] = None) ->
float:
+ return self.df.corr(col1, col2, method)
+
+ def cov(self, col1: str, col2: str) -> float:
+ return self.df.cov(col1, col2)
+
+ def crosstab(self, col1: str, col2: str) -> ParentDataFrame:
+ return self.df.crosstab(col1, col2)
+
+ def freqItems(self, cols: List[str], support: Optional[float] = None) ->
ParentDataFrame:
+ return self.df.freqItems(cols, support)
+
+ def sampleBy(
+ self, col: str, fractions: Dict[Any, float], seed: Optional[int] = None
+ ) -> ParentDataFrame:
+ return self.df.sampleBy(col, fractions, seed)
+
+
+def _test() -> None:
+ import doctest
+ from pyspark.sql import SparkSession
+ import pyspark.sql.dataframe
+
+ globs = pyspark.sql.dataframe.__dict__.copy()
+ spark = (
+ SparkSession.builder.master("local[4]").appName("sql.classic.dataframe
tests").getOrCreate()
+ )
+ globs["spark"] = spark
+ (failure_count, test_count) = doctest.testmod(
+ pyspark.sql.dataframe,
Review Comment:
ditto.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]