cloud-fan commented on code in PR #48748:
URL: https://github.com/apache/spark/pull/48748#discussion_r1843593193


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala:
##########
@@ -265,3 +271,257 @@ private[aggregate] object CollectTopK {
     case _ => throw QueryCompilationErrors.invalidNumParameter(e)
   }
 }
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = """
+    _FUNC_(expr[, delimiter])[ WITHIN GROUP (ORDER BY key [ASC | DESC] 
[,...])] - Returns
+    the concatenation of non-null input values, separated by the delimiter 
ordered by key.
+    If all values are null, null is returned.
+    """,
+  arguments = """
+    Arguments:
+      * expr - a string or binary expression to be concatenated.
+      * delimiter - an optional string or binary foldable expression used to 
separate the input values.
+        If null, the concatenation will be performed without a delimiter. 
Default is null.
+      * key - an optional expression for ordering the input values. Multiple 
keys can be specified.
+        If none are specified, the order of the rows in the result is 
non-deterministic.
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_(col) FROM VALUES ('a'), ('b'), ('c') AS tab(col);
+       abc
+      > SELECT _FUNC_(col) WITHIN GROUP (ORDER BY col DESC) FROM VALUES ('a'), 
('b'), ('c') AS tab(col);
+       cba
+      > SELECT _FUNC_(col) FROM VALUES ('a'), (NULL), ('b') AS tab(col);
+       ab
+      > SELECT _FUNC_(col) FROM VALUES ('a'), ('a') AS tab(col);
+       aa
+      > SELECT _FUNC_(DISTINCT col) FROM VALUES ('a'), ('a'), ('b') AS 
tab(col);
+       ab
+      > SELECT _FUNC_(col, ', ') FROM VALUES ('a'), ('b'), ('c') AS tab(col);
+       a, b, c
+      > SELECT _FUNC_(col) FROM VALUES (NULL), (NULL) AS tab(col);
+       NULL
+  """,
+  note = """
+    * If the order is not specified, the function is non-deterministic because
+    the order of the rows may be non-deterministic after a shuffle.
+    * If DISTINCT is specified, then expr and key must be the same expression.
+  """,
+  group = "agg_funcs",
+  since = "4.0.0"
+)
+// scalastyle:on line.size.limit
+case class ListAgg(
+    child: Expression,
+    delimiter: Expression = Literal(null),
+    orderExpressions: Seq[SortOrder] = Nil,
+    mutableAggBufferOffset: Int = 0,
+    inputAggBufferOffset: Int = 0)
+  extends Collect[mutable.ArrayBuffer[Any]]
+  with SupportsOrderingWithinGroup
+  with ImplicitCastInputTypes {
+
+  override def isOrderingMandatory: Boolean = false
+  override def isDistinctSupported: Boolean = true
+  override protected lazy val bufferElementType: DataType = {
+    if (noNeedSaveOrderValue) {
+      child.dataType
+    } else {
+      StructType(
+        StructField("value", child.dataType)
+        +: orderValuesField
+      )
+    }
+  }
+  /** Indicates that the result of [[child]] is enough for evaluation  */
+  private lazy val noNeedSaveOrderValue: Boolean = 
isOrderCompatible(orderExpressions)
+
+  def this(child: Expression) =
+    this(child, Literal(null), Nil, 0, 0)
+
+  def this(child: Expression, delimiter: Expression) =
+    this(child, delimiter, Nil, 0, 0)
+
+  override def nullable: Boolean = true
+
+  override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = 
mutable.ArrayBuffer.empty
+
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ImperativeAggregate =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+  override def defaultResult: Option[Literal] = Option(Literal.create(null, 
dataType))
+
+  override def sql(isDistinct: Boolean): String = {
+    val distinct = if (isDistinct) "DISTINCT " else ""
+    val withinGroup = if (orderingFilled) {
+      s" WITHIN GROUP (ORDER BY ${orderExpressions.map(_.sql).mkString(", ")})"
+    } else {
+      ""
+    }
+    s"$prettyName($distinct${child.sql}, ${delimiter.sql})$withinGroup"
+  }
+
+  override def inputTypes: Seq[AbstractDataType] =
+    TypeCollection(
+      StringTypeWithCollation(supportsTrimCollation = true),
+      BinaryType
+    ) +:
+    TypeCollection(
+      StringTypeWithCollation(supportsTrimCollation = true),
+      BinaryType,
+      NullType
+    ) +:
+    orderExpressions.map(_ => AnyDataType)
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val matchInputTypes = super.checkInputDataTypes()
+    if (matchInputTypes.isFailure) {
+      return matchInputTypes
+    }
+    if (!delimiter.foldable) {
+      return DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("delimiter"),
+          "inputType" -> toSQLType(delimiter.dataType),
+          "inputExpr" -> toSQLExpr(delimiter)
+        )
+      )
+    }
+    if (delimiter.dataType == NullType) {
+      // null is the default empty delimiter so type is not important
+      TypeCheckSuccess
+    } else {
+      TypeUtils.checkForSameTypeInputExpr(child.dataType :: delimiter.dataType 
:: Nil, prettyName)
+    }
+  }
+
+  override def eval(buffer: mutable.ArrayBuffer[Any]): Any = {
+    if (buffer.nonEmpty) {
+      val sortedBufferWithoutNulls = sortBuffer(buffer)
+      concatSkippingNulls(sortedBufferWithoutNulls)
+    } else {
+      null
+    }
+  }
+
+  private[this] def sortBuffer(buffer: mutable.ArrayBuffer[Any]): 
mutable.ArrayBuffer[Any] = {
+    if (!orderingFilled) {
+      return buffer
+    }
+    if (noNeedSaveOrderValue) {
+      val ascendingOrdering = 
PhysicalDataType.ordering(orderExpressions.head.dataType)
+      val ordering =
+        if (orderExpressions.head.direction == Ascending) ascendingOrdering
+        else ascendingOrdering.reverse
+      buffer.sorted(ordering)
+    } else {
+      buffer
+        .asInstanceOf[mutable.ArrayBuffer[InternalRow]]
+        .sorted(bufferOrdering)
+        // drop order values after sort
+        .map(_.get(0, child.dataType))
+    }
+  }
+
+  private[this] def bufferOrdering: Ordering[InternalRow] = {
+    val bufferSortOrder = orderExpressions.zipWithIndex.map {
+      case (originalOrder, i) =>
+        originalOrder.copy(
+          // first value is the evaluated child so add +1 for order's values
+          child = BoundReference(i + 1, originalOrder.dataType, 
originalOrder.child.nullable)
+        )
+    }
+    new InterpretedOrdering(bufferSortOrder)
+  }
+
+  override def orderingFilled: Boolean = orderExpressions.nonEmpty

Review Comment:
   nit: let's put it together with other methods that are from 
`SupportsOrderingWithinGroup`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to