devin-petersohn commented on code in PR #54009:
URL: https://github.com/apache/spark/pull/54009#discussion_r2878825724


##########
python/pyspark/pandas/frame.py:
##########
@@ -11630,17 +11643,88 @@ def rank(
         1  2.5
         2  2.5
         3  4.0
+
+        Rank across columns with axis=1:
+
+        >>> df = ps.DataFrame({'A': [1, 2, 2, 3], 'B': [4, 3, 2, 1]}, 
columns=['A', 'B'])
+        >>> df.rank(axis=1).sort_index()
+             A    B
+        0  1.0  2.0
+        1  1.0  2.0
+        2  1.5  1.5
+        3  2.0  1.0
         """
+        axis = validate_axis(axis)
+
         if numeric_only:
             numeric_col_names = [
                 self._psser_for(label).name
                 for label in self._internal.column_labels
                 if isinstance(self._psser_for(label).spark.data_type, 
(NumericType, BooleanType))
             ]
-        psdf = self[numeric_col_names] if numeric_only else self
-        return psdf._apply_series_op(
-            lambda psser: psser._rank(method=method, ascending=ascending), 
should_resolve=True
-        )
+            psdf = self[numeric_col_names]
+        else:
+            psdf = self
+
+        if axis == 0:
+            return psdf._apply_series_op(
+                lambda psser: psser._rank(method=method, ascending=ascending), 
should_resolve=True
+            )
+        else:
+            # Fast path for small dataframes
+            limit = get_option("compute.shortcut_limit")
+            pdf = psdf.head(limit + 1)._to_internal_pandas()
+            if len(pdf) <= limit:
+                pdf_rank = pdf.rank(method=method, ascending=ascending, 
axis=1, numeric_only=False)
+                return DataFrame(InternalFrame.from_pandas(pdf_rank))
+
+            column_label_strings = [
+                name_like_string(label) for label in 
psdf._internal.column_labels
+            ]
+
+            @pandas_udf(  # type: ignore[call-overload]
+                returnType=StructType(
+                    [
+                        StructField(col_name, DoubleType(), nullable=True)
+                        for col_name in column_label_strings
+                    ]
+                )
+            )
+            def rank_axis_1(*cols: pd.Series) -> pd.DataFrame:
+                pdf_row = pd.concat(cols, axis=1, keys=column_label_strings)
+                return pdf_row.rank(method=method, ascending=ascending, 
axis=1).rename(
+                    columns=dict(zip(pdf_row.columns, column_label_strings))
+                )
+
+            ranked_struct_col = rank_axis_1(*psdf._internal.data_spark_columns)
+            new_data_columns = [
+                ranked_struct_col[col_name].alias(col_name) for col_name in 
column_label_strings
+            ]
+            sdf = psdf._internal.spark_frame.select(
+                psdf._internal.index_spark_columns + new_data_columns
+            )
+            data_fields = [
+                InternalField(
+                    dtype=np.dtype("float64"),
+                    struct_field=StructField(name_like_string(label), 
DoubleType(), nullable=True),
+                )
+                for label in psdf._internal.column_labels
+            ]
+            internal = InternalFrame(

Review Comment:
   Good catch, updated to use internal.with_new_columns instead. Thanks!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to