[ https://issues.apache.org/jira/browse/FLINK-5315?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16614083#comment-16614083 ]
ASF GitHub Bot commented on FLINK-5315: --------------------------------------- asfgit closed pull request #6521: [FLINK-5315][table] Adding support for distinct operation for table API on DataStream URL: https://github.com/apache/flink/pull/6521 This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/docs/dev/table/tableApi.md b/docs/dev/table/tableApi.md index f8bcd3da1af..a9b92fad995 100644 --- a/docs/dev/table/tableApi.md +++ b/docs/dev/table/tableApi.md @@ -370,6 +370,44 @@ Table result = orders <p><b>Note:</b> All aggregates must be defined over the same window, i.e., same partitioning, sorting, and range. Currently, only windows with PRECEDING (UNBOUNDED and bounded) to CURRENT ROW range are supported. Ranges with FOLLOWING are not supported yet. ORDER BY must be specified on a single <a href="streaming.html#time-attributes">time attribute</a>.</p> </td> </tr> + <tr> + <td> + <strong>Distinct Aggregation</strong><br> + <span class="label label-primary">Batch</span> <span class="label label-primary">Streaming</span> <br> + <span class="label label-info">Result Updating</span> + </td> + <td> + <p>Similar to a SQL DISTINCT aggregation clause such as COUNT(DISTINCT a). Distinct aggregation declares that an aggregation function (built-in or user-defined) is only applied on distinct input values. Distinct can be applied to <b>GroupBy Aggregation</b>, <b>GroupBy Window Aggregation</b> and <b>Over Window Aggregation</b>.</p> +{% highlight java %} +Table orders = tableEnv.scan("Orders"); +// Distinct aggregation on group by +Table groupByDistinctResult = orders + .groupBy("a") + .select("a, b.sum.distinct as d"); +// Distinct aggregation on time window group by +Table groupByWindowDistinctResult = orders + .window(Tumble.over("5.minutes").on("rowtime").as("w")).groupBy("a, w") + .select("a, b.sum.distinct as d"); +// Distinct aggregation on over window +Table result = orders + .window(Over + .partitionBy("a") + .orderBy("rowtime") + .preceding("UNBOUNDED_RANGE") + .as("w")) + .select("a, b.avg.distinct over w, b.max over w, b.min over w"); +{% endhighlight %} + <p>User-defined aggregation function can also be used with DISTINCT modifiers. To calculate the aggregate results only for distinct values, simply add the distinct modifier towards the aggregation function. </p> +{% highlight java %} +Table orders = tEnv.scan("Orders"); + +// Use distinct aggregation for user-defined aggregate functions +tEnv.registerFunction("myUdagg", new MyUdagg()); +orders.groupBy("users").select("users, myUdagg.distinct(points) as myDistinctResult"); +{% endhighlight %} + <p><b>Note:</b> For streaming queries the required state to compute the query result might grow infinitely depending on the number of distinct fields. Please provide a query configuration with valid retention interval to prevent excessive state size. See <a href="streaming.html">Streaming Concepts</a> for details.</p> + </td> + </tr> <tr> <td> <strong>Distinct</strong><br> @@ -453,6 +491,44 @@ val result: Table = orders <p><b>Note:</b> All aggregates must be defined over the same window, i.e., same partitioning, sorting, and range. Currently, only windows with PRECEDING (UNBOUNDED and bounded) to CURRENT ROW range are supported. Ranges with FOLLOWING are not supported yet. ORDER BY must be specified on a single <a href="streaming.html#time-attributes">time attribute</a>.</p> </td> </tr> + <tr> + <td> + <strong>Distinct Aggregation</strong><br> + <span class="label label-primary">Batch</span> <span class="label label-primary">Streaming</span> <br> + <span class="label label-info">Result Updating</span> + </td> + <td> + <p>Similar to a SQL DISTINCT AGGREGATION clause such as COUNT(DISTINCT a). Distinct aggregation declares that an aggregation function (built-in or user-defined) is only applied on distinct input values. Distinct can be applied to <b>GroupBy Aggregation</b>, <b>GroupBy Window Aggregation</b> and <b>Over Window Aggregation</b>.</p> +{% highlight scala %} +val orders: Table = tableEnv.scan("Orders"); +// Distinct aggregation on group by +val groupByDistinctResult = orders + .groupBy('a) + .select('a, 'b.sum.distinct as 'd) +// Distinct aggregation on time window group by +val groupByWindowDistinctResult = orders + .window(Tumble over 5.minutes on 'rowtime as 'w).groupBy('a, 'w) + .select('a, 'b.sum.distinct as 'd) +// Distinct aggregation on over window +val result = orders + .window(Over + partitionBy 'a + orderBy 'rowtime + preceding UNBOUNDED_RANGE + as 'w) + .select('a, 'b.avg.distinct over 'w, 'b.max over 'w, 'b.min over 'w) +{% endhighlight %} + <p>User-defined aggregation function can also be used with DISTINCT modifiers. To calculate the aggregate results only for distinct values, simply add the distinct modifier towards the aggregation function. </p> +{% highlight scala %} +val orders: Table = tEnv.scan("Orders"); + +// Use distinct aggregation for user-defined aggregate functions +val myUdagg = new MyUdagg(); +orders.groupBy('users).select('users, myUdagg.distinct('points) as 'myDistinctResult); +{% endhighlight %} + <p><b>Note:</b> For streaming queries the required state to compute the query result might grow infinitely depending on the number of distinct fields. Please provide a query configuration with valid retention interval to prevent excessive state size. See <a href="streaming.html">Streaming Concepts</a> for details.</p> + </td> + </tr> <tr> <td> <strong>Distinct</strong><br> diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala index dfe69cb0411..d8a68f30d73 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala @@ -22,12 +22,12 @@ import java.sql.{Date, Time, Timestamp} import org.apache.calcite.avatica.util.DateTimeUtils._ import org.apache.flink.api.common.typeinfo.{SqlTimeTypeInfo, TypeInformation} -import org.apache.flink.table.api.{TableException, CurrentRow, CurrentRange, UnboundedRow, UnboundedRange} +import org.apache.flink.table.api.{CurrentRange, CurrentRow, TableException, UnboundedRange, UnboundedRow} import org.apache.flink.table.expressions.ExpressionUtils.{convertArray, toMilliInterval, toMonthInterval, toRowInterval} import org.apache.flink.table.api.Table import org.apache.flink.table.expressions.TimeIntervalUnit.TimeIntervalUnit import org.apache.flink.table.expressions._ -import org.apache.flink.table.functions.AggregateFunction +import org.apache.flink.table.functions.{AggregateFunction, DistinctAggregateFunction} import scala.language.implicitConversions @@ -214,7 +214,7 @@ trait ImplicitExpressionOperations { def varSamp = VarSamp(expr) /** - * Returns multiset aggregate of a given expression. + * Returns multiset aggregate of a given expression. */ def collect = Collect(expr) @@ -972,6 +972,10 @@ trait ImplicitExpressionConversions { implicit def array2ArrayConstructor(array: Array[_]): Expression = convertArray(array) implicit def userDefinedAggFunctionConstructor[T: TypeInformation, ACC: TypeInformation] (udagg: AggregateFunction[T, ACC]): UDAGGExpression[T, ACC] = UDAGGExpression(udagg) + implicit def toDistinct(agg: Aggregation): DistinctAgg = DistinctAgg(agg) + implicit def toDistinct[T: TypeInformation, ACC: TypeInformation] + (agg: AggregateFunction[T, ACC]): DistinctAggregateFunction[T, ACC] = + DistinctAggregateFunction(agg) } // ------------------------------------------------------------------------------------------------ diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala index 4b2440cf673..d7972110191 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala @@ -81,6 +81,7 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { lazy val GET: Keyword = Keyword("get") lazy val FLATTEN: Keyword = Keyword("flatten") lazy val OVER: Keyword = Keyword("over") + lazy val DISTINCT: Keyword = Keyword("distinct") lazy val CURRENT_ROW: Keyword = Keyword("current_row") lazy val CURRENT_RANGE: Keyword = Keyword("current_range") lazy val UNBOUNDED_ROW: Keyword = Keyword("unbounded_row") @@ -311,6 +312,9 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { lazy val suffixFlattening: PackratParser[Expression] = composite <~ "." ~ FLATTEN ~ opt("()") ^^ { e => Flattening(e) } + lazy val suffixDistinct: PackratParser[Expression] = + composite <~ "." ~ DISTINCT ~ opt("()") ^^ { e => DistinctAgg(e) } + lazy val suffixAs: PackratParser[Expression] = composite ~ "." ~ AS ~ "(" ~ rep1sep(fieldReference, ",") ~ ")" ^^ { case e ~ _ ~ _ ~ _ ~ target ~ _ => Alias(e, target.head.name, target.tail.map(_.name)) @@ -330,6 +334,8 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { suffixGet | // expression with special identifier suffixIf | + // expression with distinct suffix modifier + suffixDistinct | // function call must always be at the end suffixFunctionCall | suffixFunctionCallOneArg @@ -397,6 +403,11 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { lazy val prefixToTime: PackratParser[Expression] = TO_TIME ~ "(" ~> expression <~ ")" ^^ { e => Cast(e, SqlTimeTypeInfo.TIME) } + lazy val prefixDistinct: PackratParser[Expression] = + functionIdent ~ "." ~ DISTINCT ~ "(" ~ repsep(expression, ",") ~ ")" ^^ { + case name ~ _ ~ _ ~ _ ~ args ~ _ => DistinctAgg(Call(name.toUpperCase, args)) + } + lazy val prefixAs: PackratParser[Expression] = AS ~ "(" ~ expression ~ "," ~ rep1sep(fieldReference, ",") ~ ")" ^^ { case _ ~ _ ~ e ~ _ ~ target ~ _ => Alias(e, target.head.name, target.tail.map(_.name)) @@ -413,6 +424,8 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { prefixGet | // expression with special identifier prefixIf | + // expression with prefix distinct + prefixDistinct | // function call must always be at the end prefixFunctionCall | prefixFunctionCallOneArg diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala index b39bd9821d3..e03c5bef168 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala @@ -19,7 +19,6 @@ package org.apache.flink.table.expressions import org.apache.calcite.rex.RexNode import org.apache.calcite.sql.SqlAggFunction -import org.apache.calcite.sql.SqlKind._ import org.apache.calcite.sql.fun._ import org.apache.calcite.tools.RelBuilder import org.apache.calcite.tools.RelBuilder.AggCall @@ -43,7 +42,10 @@ abstract sealed class Aggregation extends Expression { /** * Convert Aggregate to its counterpart in Calcite, i.e. AggCall */ - private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall + private[flink] def toAggCall( + name: String, + isDistinct: Boolean = false + )(implicit relBuilder: RelBuilder): AggCall /** * Returns the SqlAggFunction for this Aggregation. @@ -52,12 +54,48 @@ abstract sealed class Aggregation extends Expression { } +case class DistinctAgg(child: Expression) extends Aggregation { + + private[flink] def distinct: Expression = DistinctAgg(child) + + override private[flink] def resultType: TypeInformation[_] = child.resultType + + override private[flink] def validateInput(): ValidationResult = { + super.validateInput() + child match { + case agg: Aggregation => + child.validateInput() + case _ => + ValidationFailure(s"Distinct modifier cannot be applied to $child! " + + s"It can only be applied to an aggregation expression, for example, " + + s"'a.count.distinct which is equivalent with COUNT(DISTINCT a).") + } + } + + override private[flink] def toAggCall( + name: String, isDistinct: Boolean = true)(implicit relBuilder: RelBuilder) = { + child.asInstanceOf[Aggregation].toAggCall(name, isDistinct = true) + } + + override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = { + child.asInstanceOf[Aggregation].getSqlAggFunction() + } + + override private[flink] def children = Seq(child) +} + case class Sum(child: Expression) extends Aggregation { override private[flink] def children: Seq[Expression] = Seq(child) override def toString = s"sum($child)" - override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { - relBuilder.aggregateCall(SqlStdOperatorTable.SUM, false, false, null, name, child.toRexNode) + override private[flink] def toAggCall( + name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = { + relBuilder.aggregateCall(SqlStdOperatorTable.SUM, + isDistinct, + false, + null, + name, + child.toRexNode) } override private[flink] def resultType = child.resultType @@ -77,8 +115,14 @@ case class Sum0(child: Expression) extends Aggregation { override private[flink] def children: Seq[Expression] = Seq(child) override def toString = s"sum0($child)" - override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { - relBuilder.aggregateCall(SqlStdOperatorTable.SUM0, false, false, null, name, child.toRexNode) + override private[flink] def toAggCall( + name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = { + relBuilder.aggregateCall(SqlStdOperatorTable.SUM0, + isDistinct, + false, + null, + name, + child.toRexNode) } override private[flink] def resultType = child.resultType @@ -94,8 +138,14 @@ case class Min(child: Expression) extends Aggregation { override private[flink] def children: Seq[Expression] = Seq(child) override def toString = s"min($child)" - override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { - relBuilder.aggregateCall(SqlStdOperatorTable.MIN, false, false, null, name, child.toRexNode) + override private[flink] def toAggCall( + name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = { + relBuilder.aggregateCall(SqlStdOperatorTable.MIN, + isDistinct, + false, + null, + name, + child.toRexNode) } override private[flink] def resultType = child.resultType @@ -112,8 +162,14 @@ case class Max(child: Expression) extends Aggregation { override private[flink] def children: Seq[Expression] = Seq(child) override def toString = s"max($child)" - override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { - relBuilder.aggregateCall(SqlStdOperatorTable.MAX, false, false, null, name, child.toRexNode) + override private[flink] def toAggCall( + name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = { + relBuilder.aggregateCall(SqlStdOperatorTable.MAX, + isDistinct, + false, + null, + name, + child.toRexNode) } override private[flink] def resultType = child.resultType @@ -130,8 +186,14 @@ case class Count(child: Expression) extends Aggregation { override private[flink] def children: Seq[Expression] = Seq(child) override def toString = s"count($child)" - override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { - relBuilder.aggregateCall(SqlStdOperatorTable.COUNT, false, false, null, name, child.toRexNode) + override private[flink] def toAggCall( + name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = { + relBuilder.aggregateCall(SqlStdOperatorTable.COUNT, + isDistinct, + false, + null, + name, + child.toRexNode) } override private[flink] def resultType = BasicTypeInfo.LONG_TYPE_INFO @@ -145,8 +207,14 @@ case class Avg(child: Expression) extends Aggregation { override private[flink] def children: Seq[Expression] = Seq(child) override def toString = s"avg($child)" - override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { - relBuilder.aggregateCall(SqlStdOperatorTable.AVG, false, false, null, name, child.toRexNode) + override private[flink] def toAggCall( + name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = { + relBuilder.aggregateCall(SqlStdOperatorTable.AVG, + isDistinct, + false, + null, + name, + child.toRexNode) } override private[flink] def resultType = child.resultType @@ -171,8 +239,14 @@ case class Collect(child: Expression) extends Aggregation { override def toString: String = s"collect($child)" - override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { - relBuilder.aggregateCall(SqlStdOperatorTable.COLLECT, false, false, null, name, child.toRexNode) + override private[flink] def toAggCall( + name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = { + relBuilder.aggregateCall(SqlStdOperatorTable.COLLECT, + isDistinct, + false, + null, + name, + child.toRexNode) } override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = { @@ -184,9 +258,15 @@ case class StddevPop(child: Expression) extends Aggregation { override private[flink] def children: Seq[Expression] = Seq(child) override def toString = s"stddev_pop($child)" - override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { + override private[flink] def toAggCall( + name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = { relBuilder.aggregateCall( - SqlStdOperatorTable.STDDEV_POP, false, false, null, name, child.toRexNode) + SqlStdOperatorTable.STDDEV_POP, + isDistinct, + false, + null, + name, + child.toRexNode) } override private[flink] def resultType = child.resultType @@ -202,9 +282,15 @@ case class StddevSamp(child: Expression) extends Aggregation { override private[flink] def children: Seq[Expression] = Seq(child) override def toString = s"stddev_samp($child)" - override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { + override private[flink] def toAggCall( + name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = { relBuilder.aggregateCall( - SqlStdOperatorTable.STDDEV_SAMP, false, false, null, name, child.toRexNode) + SqlStdOperatorTable.STDDEV_SAMP, + isDistinct, + false, + null, + name, + child.toRexNode) } override private[flink] def resultType = child.resultType @@ -220,8 +306,15 @@ case class VarPop(child: Expression) extends Aggregation { override private[flink] def children: Seq[Expression] = Seq(child) override def toString = s"var_pop($child)" - override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { - relBuilder.aggregateCall(SqlStdOperatorTable.VAR_POP, false, false, null, name, child.toRexNode) + override private[flink] def toAggCall( + name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = { + relBuilder.aggregateCall( + SqlStdOperatorTable.VAR_POP, + isDistinct, + false, + null, + name, + child.toRexNode) } override private[flink] def resultType = child.resultType @@ -237,9 +330,15 @@ case class VarSamp(child: Expression) extends Aggregation { override private[flink] def children: Seq[Expression] = Seq(child) override def toString = s"var_samp($child)" - override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { + override private[flink] def toAggCall( + name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = { relBuilder.aggregateCall( - SqlStdOperatorTable.VAR_SAMP, false, false, null, name, child.toRexNode) + SqlStdOperatorTable.VAR_SAMP, + isDistinct, + false, + null, + name, + child.toRexNode) } override private[flink] def resultType = child.resultType @@ -281,9 +380,15 @@ case class AggFunctionCall( override def toString: String = s"${aggregateFunction.getClass.getSimpleName}($args)" - override def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { + override def toAggCall( + name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = { relBuilder.aggregateCall( - this.getSqlAggFunction(), false, false, null, name, args.map(_.toRexNode): _*) + this.getSqlAggFunction(), + isDistinct, + false, + null, + name, + args.map(_.toRexNode): _*) } override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/DistinctAggregateFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/DistinctAggregateFunction.scala new file mode 100644 index 00000000000..c75e6faf107 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/DistinctAggregateFunction.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.functions + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.table.expressions.{AggFunctionCall, DistinctAgg, Expression} +import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{getAccumulatorTypeOfAggregateFunction, getResultTypeOfAggregateFunction} + +/** + * Defines an implicit conversion method (distinct) that converts [[AggregateFunction]]s into + * [[DistinctAgg]] Expressions. + */ +private[flink] case class DistinctAggregateFunction[T: TypeInformation, ACC: TypeInformation] + (aggFunction: AggregateFunction[T, ACC]) { + + private[flink] def distinct(params: Expression*): Expression = { + val resultTypeInfo: TypeInformation[_] = getResultTypeOfAggregateFunction( + aggFunction, + implicitly[TypeInformation[T]]) + + val accTypeInfo: TypeInformation[_] = getAccumulatorTypeOfAggregateFunction( + aggFunction, + implicitly[TypeInformation[ACC]]) + + DistinctAgg( + AggFunctionCall(aggFunction, resultTypeInfo, accTypeInfo, params)) + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala index a2bd1e45124..7579621a1b2 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala @@ -235,6 +235,14 @@ case class Aggregate( groupingExprs.foreach(validateGroupingExpression) def validateAggregateExpression(expr: Expression): Unit = expr match { + case distinctExpr: DistinctAgg => + distinctExpr.child match { + case _: DistinctAgg => failValidation( + "Chained distinct operators are not supported!") + case aggExpr: Aggregation => validateAggregateExpression(aggExpr) + case _ => failValidation( + "Distinct operator can only be applied to aggregation expressions!") + } // check aggregate function case aggExpr: Aggregation if aggExpr.getSqlAggFunction.requiresOver => diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/stringexpr/AggregateStringExpressionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/stringexpr/AggregateStringExpressionTest.scala index 4bbb1012f2f..4e7270ffcd7 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/stringexpr/AggregateStringExpressionTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/stringexpr/AggregateStringExpressionTest.scala @@ -27,6 +27,19 @@ import org.junit._ class AggregateStringExpressionTest extends TableTestBase { + @Test + def testDistinctAggregationTypes(): Unit = { + val util = batchTestUtil() + val t = util.addTable[(Int, Long, String)]("Table3") + + val t1 = t.select('_1.sum.distinct, '_1.count.distinct, '_1.avg.distinct) + val t2 = t.select("_1.sum.distinct, _1.count.distinct, _1.avg.distinct") + val t3 = t.select("sum.distinct(_1), count.distinct(_1), avg.distinct(_1)") + + verifyTableEquals(t1, t2) + verifyTableEquals(t1, t3) + } + @Test def testAggregationTypes(): Unit = { val util = batchTestUtil() @@ -118,6 +131,19 @@ class AggregateStringExpressionTest extends TableTestBase { verifyTableEquals(distinct, distinct2) } + @Test + def testDistinctGroupedAggregate(): Unit = { + val util = batchTestUtil() + val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c) + + val t1 = t.groupBy('b).select('b, 'a.sum.distinct, 'a.sum) + val t2 = t.groupBy("b").select("b, a.sum.distinct, a.sum") + val t3 = t.groupBy("b").select("b, sum.distinct(a), sum(a)") + + verifyTableEquals(t1, t2) + verifyTableEquals(t1, t3) + } + @Test def testGroupedAggregate(): Unit = { val util = batchTestUtil() @@ -238,6 +264,22 @@ class AggregateStringExpressionTest extends TableTestBase { verifyTableEquals(resScala, resJava) } + @Test + def testDistinctAggregateWithUDAGG(): Unit = { + val util = batchTestUtil() + val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c) + + val myCnt = new CountAggFunction + util.tableEnv.registerFunction("myCnt", myCnt) + val myWeightedAvg = new WeightedAvgWithMergeAndReset + util.tableEnv.registerFunction("myWeightedAvg", myWeightedAvg) + + val t1 = t.select(myCnt.distinct('a) as 'aCnt, myWeightedAvg.distinct('b, 'a) as 'wAvg) + val t2 = t.select("myCnt.distinct(a) as aCnt, myWeightedAvg.distinct(b, a) as wAvg") + + verifyTableEquals(t1, t2) + } + @Test def testAggregateWithUDAGG(): Unit = { val util = batchTestUtil() @@ -254,6 +296,30 @@ class AggregateStringExpressionTest extends TableTestBase { verifyTableEquals(t1, t2) } + @Test + def testDistinctGroupedAggregateWithUDAGG(): Unit = { + val util = batchTestUtil() + val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c) + + + val myCnt = new CountAggFunction + util.tableEnv.registerFunction("myCnt", myCnt) + val myWeightedAvg = new WeightedAvgWithMergeAndReset + util.tableEnv.registerFunction("myWeightedAvg", myWeightedAvg) + + val t1 = t.groupBy('b) + .select('b, + myCnt.distinct('a) + 9 as 'aCnt, + myWeightedAvg.distinct('b, 'a) * 2 as 'wAvg, + myWeightedAvg.distinct('a, 'a) as 'distAgg, + myWeightedAvg('a, 'a) as 'agg) + val t2 = t.groupBy("b") + .select("b, myCnt.distinct(a) + 9 as aCnt, myWeightedAvg.distinct(b, a) * 2 as wAvg, " + + "myWeightedAvg.distinct(a, a) as distAgg, myWeightedAvg(a, a) as agg") + + verifyTableEquals(t1, t2) + } + @Test def testGroupedAggregateWithUDAGG(): Unit = { val util = batchTestUtil() diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/AggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/AggregateTest.scala index 533235ad454..671f8dd8d1d 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/AggregateTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/AggregateTest.scala @@ -21,12 +21,64 @@ package org.apache.flink.table.api.stream.table import org.apache.flink.api.common.typeinfo.BasicTypeInfo import org.apache.flink.api.scala._ import org.apache.flink.table.api.scala._ +import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.WeightedAvg import org.apache.flink.table.utils.TableTestUtil._ import org.apache.flink.table.utils.TableTestBase import org.junit.Test class AggregateTest extends TableTestBase { + @Test + def testGroupDistinctAggregate(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('a, 'b, 'c) + + val resultTable = table + .groupBy('b) + .select('a.sum.distinct, 'c.count.distinct) + + val expected = + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamGroupAggregate", + streamTableNode(0), + term("groupBy", "b"), + term("select", "b", "SUM(DISTINCT a) AS TMP_0", "COUNT(DISTINCT c) AS TMP_1") + ), + term("select", "TMP_0", "TMP_1") + ) + util.verifyTable(resultTable, expected) + } + + @Test + def testGroupDistinctAggregateWithUDAGG(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('a, 'b, 'c) + val weightedAvg = new WeightedAvg + + val resultTable = table + .groupBy('c) + .select(weightedAvg.distinct('a, 'b), weightedAvg('a, 'b)) + + val expected = + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamGroupAggregate", + streamTableNode(0), + term("groupBy", "c"), + term( + "select", + "c", + "WeightedAvg(DISTINCT a, b) AS TMP_0", + "WeightedAvg(a, b) AS TMP_1") + ), + term("select", "TMP_0", "TMP_1") + ) + util.verifyTable(resultTable, expected) + } + @Test def testGroupAggregate() = { val util = streamTestUtil() diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/AggregateStringExpressionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/AggregateStringExpressionTest.scala index 2bef95e5b40..ec57436b420 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/AggregateStringExpressionTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/AggregateStringExpressionTest.scala @@ -20,12 +20,80 @@ package org.apache.flink.table.api.stream.table.stringexpr import org.apache.flink.api.scala._ import org.apache.flink.table.api.scala._ -import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.WeightedAvg +import org.apache.flink.table.functions.aggfunctions.CountAggFunction +import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{WeightedAvg, WeightedAvgWithMergeAndReset} import org.apache.flink.table.utils.TableTestBase import org.junit.Test class AggregateStringExpressionTest extends TableTestBase { + + @Test + def testDistinctNonGroupedAggregate(): Unit = { + val util = streamTestUtil() + val t = util.addTable[(Int, Long, String)]("Table3") + + val t1 = t.select('_1.sum.distinct, '_1.count.distinct, '_1.avg.distinct) + val t2 = t.select("_1.sum.distinct, _1.count.distinct, _1.avg.distinct") + val t3 = t.select("sum.distinct(_1), count.distinct(_1), avg.distinct(_1)") + + verifyTableEquals(t1, t2) + verifyTableEquals(t1, t3) + } + + @Test + def testDistinctGroupedAggregate(): Unit = { + val util = streamTestUtil() + val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c) + + val t1 = t.groupBy('b).select('b, 'a.sum.distinct, 'a.sum) + val t2 = t.groupBy("b").select("b, a.sum.distinct, a.sum") + val t3 = t.groupBy("b").select("b, sum.distinct(a), sum(a)") + + verifyTableEquals(t1, t2) + verifyTableEquals(t1, t3) + } + + @Test + def testDistinctNonGroupAggregateWithUDAGG(): Unit = { + val util = streamTestUtil() + val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c) + + val myCnt = new CountAggFunction + util.tableEnv.registerFunction("myCnt", myCnt) + val myWeightedAvg = new WeightedAvgWithMergeAndReset + util.tableEnv.registerFunction("myWeightedAvg", myWeightedAvg) + + val t1 = t.select(myCnt.distinct('a) as 'aCnt, myWeightedAvg.distinct('b, 'a) as 'wAvg) + val t2 = t.select("myCnt.distinct(a) as aCnt, myWeightedAvg.distinct(b, a) as wAvg") + + verifyTableEquals(t1, t2) + } + + @Test + def testDistinctGroupedAggregateWithUDAGG(): Unit = { + val util = streamTestUtil() + val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c) + + + val myCnt = new CountAggFunction + util.tableEnv.registerFunction("myCnt", myCnt) + val myWeightedAvg = new WeightedAvgWithMergeAndReset + util.tableEnv.registerFunction("myWeightedAvg", myWeightedAvg) + + val t1 = t.groupBy('b) + .select('b, + myCnt.distinct('a) + 9 as 'aCnt, + myWeightedAvg.distinct('b, 'a) * 2 as 'wAvg, + myWeightedAvg.distinct('a, 'a) as 'distAgg, + myWeightedAvg('a, 'a) as 'agg) + val t2 = t.groupBy("b") + .select("b, myCnt.distinct(a) + 9 as aCnt, myWeightedAvg.distinct(b, a) * 2 as wAvg, " + + "myWeightedAvg.distinct(a, a) as distAgg, myWeightedAvg(a, a) as agg") + + verifyTableEquals(t1, t2) + } + @Test def testGroupedAggregate(): Unit = { val util = streamTestUtil() diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala index db6820ae90e..219b7653c57 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala @@ -25,7 +25,7 @@ import org.apache.flink.table.api.scala._ import org.apache.flink.table.runtime.utils.StreamITCase.RetractingSink import org.apache.flink.table.api.{StreamQueryConfig, TableEnvironment, Types} import org.apache.flink.table.expressions.Null -import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{CountDistinct, DataViewTestAgg} +import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{CountDistinct, DataViewTestAgg, WeightedAvg} import org.apache.flink.table.runtime.utils.{JavaUserDefinedAggFunctions, StreamITCase, StreamTestData, StreamingWithStateTestBase} import org.apache.flink.types.Row import org.junit.Assert.assertEquals @@ -40,6 +40,84 @@ class AggregateITCase extends StreamingWithStateTestBase { private val queryConfig = new StreamQueryConfig() queryConfig.withIdleStateRetentionTime(Time.hours(1), Time.hours(2)) + @Test + def testDistinctUDAGG(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStateBackend(getStateBackend) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val testAgg = new DataViewTestAgg + val t = StreamTestData.get5TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c, 'd, 'e) + .groupBy('e) + .select('e, testAgg.distinct('d, 'e)) + + val results = t.toRetractStream[Row](queryConfig) + results.addSink(new StreamITCase.RetractingSink).setParallelism(1) + env.execute() + + val expected = mutable.MutableList("1,10", "2,21", "3,12") + assertEquals(expected.sorted, StreamITCase.retractedResults.sorted) + } + + @Test + def testDistinctUDAGGMixedWithNonDistinctUsage(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStateBackend(getStateBackend) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val testAgg = new WeightedAvg + val t = StreamTestData.get5TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c, 'd, 'e) + .groupBy('e) + .select('e, testAgg.distinct('a, 'a), testAgg('a, 'a)) + + val results = t.toRetractStream[Row](queryConfig) + results.addSink(new StreamITCase.RetractingSink).setParallelism(1) + env.execute() + + val expected = mutable.MutableList("1,3,3", "2,3,4", "3,4,4") + assertEquals(expected.sorted, StreamITCase.retractedResults.sorted) + } + + @Test + def testDistinctAggregate(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStateBackend(getStateBackend) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val t = StreamTestData.get5TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c, 'd, 'e) + .groupBy('e) + .select('e, 'a.count.distinct) + + val results = t.toRetractStream[Row](queryConfig) + results.addSink(new StreamITCase.RetractingSink).setParallelism(1) + env.execute() + + val expected = mutable.MutableList("1,4", "2,4", "3,2") + assertEquals(expected.sorted, StreamITCase.retractedResults.sorted) + } + + @Test + def testDistinctAggregateMixedWithNonDistinct(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStateBackend(getStateBackend) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val t = StreamTestData.get5TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c, 'd, 'e) + .groupBy('e) + .select('e, 'a.count.distinct, 'b.count) + + val results = t.toRetractStream[Row](queryConfig) + results.addSink(new StreamITCase.RetractingSink).setParallelism(1) + env.execute() + + val expected = mutable.MutableList("1,4,5", "2,4,7", "3,2,3") + assertEquals(expected.sorted, StreamITCase.retractedResults.sorted) + } + @Test def testDistinct(): Unit = { val env = StreamExecutionEnvironment.getExecutionEnvironment ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org > Support distinct aggregations in table api > ------------------------------------------ > > Key: FLINK-5315 > URL: https://issues.apache.org/jira/browse/FLINK-5315 > Project: Flink > Issue Type: Sub-task > Components: Table API & SQL > Reporter: Kurt Young > Assignee: Rong Rong > Priority: Major > Labels: pull-request-available > Fix For: 1.7.0 > > > Support distinct aggregations in Table API in the following format: > For Expressions: > {code:scala} > 'a.count.distinct // Expressions distinct modifier > {code} > For User-defined Function: > {code:scala} > singleArgUdaggFunc.distinct('a) // FunctionCall distinct modifier > multiArgUdaggFunc.distinct('a, 'b) // FunctionCall distinct modifier > {code} -- This message was sent by Atlassian JIRA (v7.6.3#76005)