LuciferYang commented on code in PR #54134:
URL: https://github.com/apache/spark/pull/54134#discussion_r2853789405


##########
sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala:
##########
@@ -1156,6 +1156,307 @@ class DataFrameAggregateSuite extends QueryTest
     }
   }
 
+  test("max_by and min_by with k") {
+    // Basic: string values, integer ordering
+    checkAnswer(
+      sql(
+        """
+          |SELECT max_by(x, y, 2), min_by(x, y, 2)
+          |FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)
+        """.stripMargin),
+      Row(Seq("b", "c"), Seq("a", "c")) :: Nil
+    )
+
+    // DataFrame API
+    checkAnswer(
+      spark.sql("SELECT * FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS 
tab(x, y)")
+        .agg(max_by(col("x"), col("y"), 2), min_by(col("x"), col("y"), 2)),
+      Row(Seq("b", "c"), Seq("a", "c")) :: Nil
+    )
+
+    // k larger than available rows
+    checkAnswer(
+      sql(
+        """
+          |SELECT max_by(x, y, 5), min_by(x, y, 5)
+          |FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)
+        """.stripMargin),
+      Row(Seq("b", "c", "a"), Seq("a", "c", "b")) :: Nil
+    )
+
+    // k = 1
+    checkAnswer(
+      sql(
+        """
+          |SELECT max_by(x, y, 1), min_by(x, y, 1)
+          |FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)
+        """.stripMargin),
+      Row(Seq("b"), Seq("a")) :: Nil
+    )
+
+    // NULL orderings are skipped
+    checkAnswer(
+      sql(
+        """
+          |SELECT max_by(x, y, 2), min_by(x, y, 2)
+          |FROM VALUES (('a', 10)), (('b', null)), (('c', 20)) AS tab(x, y)
+        """.stripMargin),
+      Row(Seq("c", "a"), Seq("a", "c")) :: Nil
+    )
+
+    // All NULL orderings yields null
+    checkAnswer(
+      sql(
+        """
+          |SELECT max_by(x, y, 2), min_by(x, y, 2)
+          |FROM VALUES (('a', null)), (('b', null)) AS tab(x, y)
+        """.stripMargin),
+      Row(null, null) :: Nil
+    )
+
+    // Empty input yields null
+    checkAnswer(
+      sql(
+        """
+          |SELECT max_by(x, y, 2), min_by(x, y, 2)
+          |FROM VALUES (('a', 10)) AS tab(x, y) WHERE false
+        """.stripMargin),
+      Row(null, null) :: Nil
+    )
+
+    // Integer values, integer ordering
+    checkAnswer(
+      sql(
+        """
+          |SELECT max_by(x, y, 2), min_by(x, y, 2)
+          |FROM VALUES ((1, 100)), ((2, 200)), ((3, 150)) AS tab(x, y)
+        """.stripMargin),
+      Row(Seq(2, 3), Seq(1, 3)) :: Nil
+    )
+
+    // 10 elements, k=3 - forces heap replacements
+    checkAnswer(
+      sql(
+        """
+          |SELECT max_by(x, y, 3), min_by(x, y, 3)
+          |FROM VALUES ((1, 50)), ((2, 30)), ((3, 80)), ((4, 10)), ((5, 90)),
+          |            ((6, 20)), ((7, 70)), ((8, 40)), ((9, 60)), ((10, 100))
+          |AS tab(x, y)
+        """.stripMargin),
+      Row(Seq(10, 5, 3), Seq(4, 6, 2)) :: Nil
+    )
+
+    // descending input order (worst case for min-heap)
+    checkAnswer(
+      sql(
+        """
+          |SELECT max_by(x, y, 3)
+          |FROM VALUES ((1, 100)), ((2, 90)), ((3, 80)), ((4, 70)), ((5, 60)),
+          |            ((6, 50)), ((7, 40)), ((8, 30)), ((9, 20)), ((10, 10))
+          |AS tab(x, y)
+        """.stripMargin),
+      Row(Seq(1, 2, 3)) :: Nil
+    )
+
+    // ascending input order (worst case for max-heap in min_by)
+    checkAnswer(
+      sql(
+        """
+          |SELECT min_by(x, y, 3)
+          |FROM VALUES ((1, 10)), ((2, 20)), ((3, 30)), ((4, 40)), ((5, 50)),
+          |            ((6, 60)), ((7, 70)), ((8, 80)), ((9, 90)), ((10, 100))
+          |AS tab(x, y)
+        """.stripMargin),
+      Row(Seq(1, 2, 3)) :: Nil
+    )
+
+    // Large k with many elements
+    checkAnswer(
+      sql(
+        """
+          |SELECT max_by(x, y, 5), min_by(x, y, 5)
+          |FROM VALUES ((1, 15)), ((2, 25)), ((3, 35)), ((4, 45)), ((5, 55)),
+          |            ((6, 65)), ((7, 75)), ((8, 85))
+          |AS tab(x, y)
+        """.stripMargin),
+      Row(Seq(8, 7, 6, 5, 4), Seq(1, 2, 3, 4, 5)) :: Nil
+    )
+
+    // Duplicate ordering values (non-deterministic order within ties, but set 
should match)
+    val dupsResult = sql(
+      """
+        |SELECT max_by(x, y, 3)
+        |FROM VALUES ((1, 50)), ((2, 50)), ((3, 50)), ((4, 10)), ((5, 90))
+        |AS tab(x, y)
+      """.stripMargin).collect()(0).getSeq[Int](0).toSet
+    assert(dupsResult.contains(5))  // 90 is highest, must be included
+    assert(dupsResult.size == 3)
+
+    // Struct ordering
+    checkAnswer(
+      sql(
+        """
+          |SELECT max_by(x, y, 2), min_by(x, y, 2)
+          |FROM VALUES (('a', (10, 20))), (('b', (10, 50))), (('c', (10, 60))) 
AS tab(x, y)
+        """.stripMargin),
+      Row(Seq("c", "b"), Seq("a", "b")) :: Nil
+    )
+
+    // Array ordering
+    checkAnswer(
+      sql(
+        """
+          |SELECT max_by(x, y, 2), min_by(x, y, 2)
+          |FROM VALUES (('a', array(10, 20))), (('b', array(10, 50))), (('c', 
array(10, 60)))
+          |AS tab(x, y)
+        """.stripMargin),
+      Row(Seq("c", "b"), Seq("a", "b")) :: Nil
+    )
+
+    // Struct values
+    checkAnswer(
+      sql(
+        """
+          |SELECT max_by(x, y, 2), min_by(x, y, 2)
+          |FROM VALUES ((('a', 1), 10)), ((('b', 2), 50)), ((('c', 3), 20)) AS 
tab(x, y)
+        """.stripMargin),
+      Row(Seq(Row("b", 2), Row("c", 3)), Seq(Row("a", 1), Row("c", 3))) :: Nil
+    )
+
+    // Array values
+    checkAnswer(
+      sql(
+        """
+          |SELECT max_by(x, y, 2), min_by(x, y, 2)
+          |FROM VALUES ((array(1, 2), 10)), ((array(3, 4), 50)), ((array(5, 
6), 20)) AS tab(x, y)
+        """.stripMargin),
+      Row(Seq(Seq(3, 4), Seq(5, 6)), Seq(Seq(1, 2), Seq(5, 6))) :: Nil
+    )
+
+    // Map values
+    checkAnswer(
+      sql(
+        """
+          |SELECT max_by(x, y, 2), min_by(x, y, 2)
+          |FROM VALUES ((map('a', 1), 10)), ((map('b', 2), 50)), ((map('c', 
3), 20)) AS tab(x, y)
+        """.stripMargin),
+      Row(Seq(Map("b" -> 2), Map("c" -> 3)), Seq(Map("a" -> 1), Map("c" -> 
3))) :: Nil
+    )
+
+    // GROUP BY
+    checkAnswer(
+      sql(
+        """
+          |SELECT course, max_by(year, earnings, 2), min_by(year, earnings, 2)
+          |FROM VALUES
+          |  (('Java', 2012, 20000)), (('Java', 2013, 30000)),
+          |  (('dotNET', 2012, 15000)), (('dotNET', 2013, 48000))
+          |AS tab(course, year, earnings)
+          |GROUP BY course
+          |ORDER BY course
+        """.stripMargin),
+      Row("Java", Seq(2013, 2012), Seq(2012, 2013)) ::
+        Row("dotNET", Seq(2013, 2012), Seq(2012, 2013)) :: Nil
+    )
+
+    // Error: k must be a constant (not a column reference)
+    Seq("max_by", "min_by").foreach { fn =>
+      val error = intercept[AnalysisException] {
+        sql(s"SELECT $fn(x, y, z) FROM VALUES (('a', 10, 2)) AS tab(x, y, 
z)").collect()
+      }
+      assert(error.getMessage.contains("NON_FOLDABLE_INPUT") ||
+        error.getMessage.contains("constant integer"))
+    }
+
+    // Float k is implicitly cast to integer (truncated) - 2.5 becomes 2
+    checkAnswer(
+      sql(
+        """
+          |SELECT max_by(x, y, 2.5), min_by(x, y, 2.5)
+          |FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)
+        """.stripMargin),
+      Row(Seq("b", "c"), Seq("a", "c")) :: Nil
+    )
+
+    // Error: string k cannot be cast to integer
+    Seq("max_by", "min_by").foreach { fn =>
+      val error = intercept[Exception] {
+        sql(s"SELECT $fn(x, y, 'two') FROM VALUES (('a', 10)) AS tab(x, 
y)").collect()

Review Comment:
   In non-ANSI tests, the exception thrown by this test case will include 
`VALUE_OUT_OF_RANGE`, which leads to the failure of the daily test in non-ANSI 
mode:
   
   - https://github.com/apache/spark/actions/runs/22247813526/job/64365502163
   
   <img width="2954" height="1114" alt="Image" 
src="https://github.com/user-attachments/assets/c6acdae3-d60b-47d6-aa7f-cc35ddeafb9b";
 />
   
   Attempt to fix it at https://github.com/apache/spark/pull/54484/changes.
   
   @AlSchlo @viirya @gengliangwang 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to