ueshin commented on code in PR #54319:
URL: https://github.com/apache/spark/pull/54319#discussion_r2874793488


##########
python/pyspark/pandas/frame.py:
##########
@@ -4955,29 +4970,52 @@ def nunique(
         from pyspark.pandas.series import first_series
 
         axis = validate_axis(axis)
-        if axis != 0:
-            raise NotImplementedError('axis should be either 0 or "index" 
currently.')
-        sdf = self._internal.spark_frame.select(
-            [F.lit(None).cast(StringType()).alias(SPARK_DEFAULT_INDEX_NAME)]
-            + [
-                self._psser_for(label)._nunique(dropna, approx, rsd)
-                for label in self._internal.column_labels
-            ]
-        )
+        if axis == 0:
+            sdf = self._internal.spark_frame.select(
+                
[F.lit(None).cast(StringType()).alias(SPARK_DEFAULT_INDEX_NAME)]
+                + [
+                    self._psser_for(label)._nunique(dropna, approx, rsd)
+                    for label in self._internal.column_labels
+                ]
+            )
 
-        # The data is expected to be small so it's fine to transpose/use the 
default index.
-        with ps.option_context("compute.max_rows", 1):
-            internal = self._internal.copy(
-                spark_frame=sdf,
-                index_spark_columns=[scol_for(sdf, SPARK_DEFAULT_INDEX_NAME)],
-                index_names=[None],
-                index_fields=[None],
-                data_spark_columns=[
-                    scol_for(sdf, col) for col in 
self._internal.data_spark_column_names
-                ],
-                data_fields=None,
+            # The data is expected to be small so it's fine to transpose/use 
the default index.
+            with ps.option_context("compute.max_rows", 1):
+                internal = self._internal.copy(
+                    spark_frame=sdf,
+                    index_spark_columns=[scol_for(sdf, 
SPARK_DEFAULT_INDEX_NAME)],
+                    index_names=[None],
+                    index_fields=[None],
+                    data_spark_columns=[
+                        scol_for(sdf, col) for col in 
self._internal.data_spark_column_names
+                    ],
+                    data_fields=None,
+                )
+                return first_series(DataFrame(internal).transpose())
+        elif axis == 1:
+            from pyspark.pandas.series import first_series
+
+            arr = F.array(
+                *[self._internal.spark_column_for(label) for label in 
self._internal.column_labels]
+            )
+            arr = F.filter(arr, lambda x: x.isNotNull()) if dropna else arr
+
+            sdf = self._internal.spark_frame.select(

Review Comment:
   Sorry for the late comment, but I guess we can do this without creating a 
new Spark DataFrame, similar to the recent PRs?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to