[ 
https://issues.apache.org/jira/browse/BEAM-11808?focusedWorklogId=701609&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-701609
 ]

ASF GitHub Bot logged work on BEAM-11808:
-----------------------------------------

                Author: ASF GitHub Bot
            Created on: 28/Dec/21 17:18
            Start Date: 28/Dec/21 17:18
    Worklog Time Spent: 10m 
      Work Description: ibzib commented on a change in pull request #16200:
URL: https://github.com/apache/beam/pull/16200#discussion_r771014170



##########
File path: 
sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java
##########
@@ -148,24 +149,27 @@ private LogicalProject 
convertAggregateScanInputScanToLogicalProject(
       // aggregation?
       ResolvedAggregateFunctionCall aggregateFunctionCall =
           ((ResolvedAggregateFunctionCall) resolvedComputedColumn.getExpr());
-      if (aggregateFunctionCall.getArgumentList() != null
-          && aggregateFunctionCall.getArgumentList().size() == 1) {
-        ResolvedExpr resolvedExpr = 
aggregateFunctionCall.getArgumentList().get(0);
-
-        // TODO: assume aggregate function's input is either a ColumnRef or a 
cast(ColumnRef).
-        // TODO: user might use multiple CAST so we need to handle this rare 
case.
-        projects.add(
-            getExpressionConverter()
-                .convertRexNodeFromResolvedExpr(
-                    resolvedExpr,
-                    node.getInputScan().getColumnList(),
-                    input.getRowType().getFieldList(),
-                    ImmutableMap.of()));
-        
fieldNames.add(getTrait().resolveAlias(resolvedComputedColumn.getColumn()));
-      } else if (aggregateFunctionCall.getArgumentList() != null
-          && aggregateFunctionCall.getArgumentList().size() > 1) {
-        throw new IllegalArgumentException(
-            aggregateFunctionCall.getFunction().getName() + " has more than 
one argument.");
+      ImmutableList<ResolvedExpr> argumentList =

Review comment:
       Why do we need to copy to an ImmutableList?

##########
File path: 
sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperatorMappingTable.java
##########
@@ -105,4 +113,16 @@
           .put("nullif", new SqlNullIfOperatorRewriter())
           .put("$in", new SqlInOperatorRewriter())
           .build();
+
+  public static @Nullable SqlOperator create(

Review comment:
       Nit: this can/should be made package-private
   ```suggestion
     static @Nullable SqlOperator create(
   ```

##########
File path: 
sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java
##########
@@ -248,6 +249,8 @@ private AggregateCall convertAggCall(
           || expr.nodeKind() == RESOLVED_COLUMN_REF
           || expr.nodeKind() == RESOLVED_GET_STRUCT_FIELD) {
         argList.add(columnRefOff);
+      } else if (expr.nodeKind() == RESOLVED_LITERAL) {

Review comment:
       We should have separate cases here:
   
   if i == 0, must be one of (RESOLVED_CAST, RESOLVED_COLUMN_REF, 
RESOLVED_GET_STRUCT_FIELD)
   else must be RESOLVED_LITERAL

##########
File path: 
sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperatorMappingTable.java
##########
@@ -17,85 +17,93 @@
  */
 package org.apache.beam.sdk.extensions.sql.zetasql.translation;
 
+import com.google.zetasql.resolvedast.ResolvedNodes;
 import java.util.Map;
+import java.util.function.Function;
 import 
org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.SqlOperator;
 import 
org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.fun.SqlStdOperatorTable;
 import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import org.checkerframework.checker.nullness.qual.Nullable;
 
 /** SqlOperatorMappingTable. */
 class SqlOperatorMappingTable {
 
   // todo: Some of operators defined here are later overridden in 
ZetaSQLPlannerImpl.
   // We should remove them from this table and add generic way to provide 
custom
   // implementation. (Ex.: timestamp_add)
-  static final Map<String, SqlOperator> 
ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR =
-      ImmutableMap.<String, SqlOperator>builder()
-          // grouped window function
-          .put("TUMBLE", SqlStdOperatorTable.TUMBLE_OLD)
-          .put("HOP", SqlStdOperatorTable.HOP_OLD)
-          .put("SESSION", SqlStdOperatorTable.SESSION_OLD)
+  static final Map<String, Function<ResolvedNodes.ResolvedFunctionCallBase, 
SqlOperator>>
+      ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR =
+          ImmutableMap
+              .<String, Function<ResolvedNodes.ResolvedFunctionCallBase, 
SqlOperator>>builder()
+              // grouped window function
+              .put("TUMBLE", resolvedFunction -> 
SqlStdOperatorTable.TUMBLE_OLD)
+              .put("HOP", resolvedFunction -> SqlStdOperatorTable.HOP_OLD)
+              .put("SESSION", resolvedFunction -> 
SqlStdOperatorTable.SESSION_OLD)
 
-          // ZetaSQL functions
-          .put("$and", SqlStdOperatorTable.AND)
-          .put("$or", SqlStdOperatorTable.OR)
-          .put("$not", SqlStdOperatorTable.NOT)
-          .put("$equal", SqlStdOperatorTable.EQUALS)
-          .put("$not_equal", SqlStdOperatorTable.NOT_EQUALS)
-          .put("$greater", SqlStdOperatorTable.GREATER_THAN)
-          .put("$greater_or_equal", SqlStdOperatorTable.GREATER_THAN_OR_EQUAL)
-          .put("$less", SqlStdOperatorTable.LESS_THAN)
-          .put("$less_or_equal", SqlStdOperatorTable.LESS_THAN_OR_EQUAL)
-          .put("$like", SqlOperators.LIKE)
-          .put("$is_null", SqlStdOperatorTable.IS_NULL)
-          .put("$is_true", SqlStdOperatorTable.IS_TRUE)
-          .put("$is_false", SqlStdOperatorTable.IS_FALSE)
-          .put("$add", SqlStdOperatorTable.PLUS)
-          .put("$subtract", SqlStdOperatorTable.MINUS)
-          .put("$multiply", SqlStdOperatorTable.MULTIPLY)
-          .put("$unary_minus", SqlStdOperatorTable.UNARY_MINUS)
-          .put("$divide", SqlStdOperatorTable.DIVIDE)
-          .put("concat", SqlOperators.CONCAT)
-          .put("substr", SqlOperators.SUBSTR)
-          .put("substring", SqlOperators.SUBSTR)
-          .put("trim", SqlOperators.TRIM)
-          .put("replace", SqlOperators.REPLACE)
-          .put("char_length", SqlOperators.CHAR_LENGTH)
-          .put("starts_with", SqlOperators.START_WITHS)
-          .put("ends_with", SqlOperators.ENDS_WITH)
-          .put("ltrim", SqlOperators.LTRIM)
-          .put("rtrim", SqlOperators.RTRIM)
-          .put("reverse", SqlOperators.REVERSE)
-          .put("$count_star", SqlStdOperatorTable.COUNT)
-          .put("max", SqlStdOperatorTable.MAX)
-          .put("min", SqlStdOperatorTable.MIN)
-          .put("avg", SqlStdOperatorTable.AVG)
-          .put("sum", SqlStdOperatorTable.SUM)
-          .put("any_value", SqlStdOperatorTable.ANY_VALUE)
-          .put("count", SqlStdOperatorTable.COUNT)
-          .put("bit_and", SqlStdOperatorTable.BIT_AND)
-          .put("string_agg", SqlOperators.STRING_AGG_STRING_FN) // NULL values 
not supported
-          .put("array_agg", SqlOperators.ARRAY_AGG_FN)
-          .put("bit_or", SqlStdOperatorTable.BIT_OR)
-          .put("bit_xor", SqlOperators.BIT_XOR)
-          .put("ceil", SqlStdOperatorTable.CEIL)
-          .put("floor", SqlStdOperatorTable.FLOOR)
-          .put("mod", SqlStdOperatorTable.MOD)
-          .put("timestamp", SqlOperators.TIMESTAMP_OP)
-          .put("$case_no_value", SqlStdOperatorTable.CASE)
+              // ZetaSQL functions
+              .put("$and", resolvedFunction -> SqlStdOperatorTable.AND)
+              .put("$or", resolvedFunction -> SqlStdOperatorTable.OR)
+              .put("$not", resolvedFunction -> SqlStdOperatorTable.NOT)
+              .put("$equal", resolvedFunction -> SqlStdOperatorTable.EQUALS)
+              .put("$not_equal", resolvedFunction -> 
SqlStdOperatorTable.NOT_EQUALS)
+              .put("$greater", resolvedFunction -> 
SqlStdOperatorTable.GREATER_THAN)
+              .put(
+                  "$greater_or_equal",
+                  resolvedFunction -> 
SqlStdOperatorTable.GREATER_THAN_OR_EQUAL)
+              .put("$less", resolvedFunction -> SqlStdOperatorTable.LESS_THAN)
+              .put("$less_or_equal", resolvedFunction -> 
SqlStdOperatorTable.LESS_THAN_OR_EQUAL)
+              .put("$like", resolvedFunction -> SqlOperators.LIKE)
+              .put("$is_null", resolvedFunction -> SqlStdOperatorTable.IS_NULL)
+              .put("$is_true", resolvedFunction -> SqlStdOperatorTable.IS_TRUE)
+              .put("$is_false", resolvedFunction -> 
SqlStdOperatorTable.IS_FALSE)
+              .put("$add", resolvedFunction -> SqlStdOperatorTable.PLUS)
+              .put("$subtract", resolvedFunction -> SqlStdOperatorTable.MINUS)
+              .put("$multiply", resolvedFunction -> 
SqlStdOperatorTable.MULTIPLY)
+              .put("$unary_minus", resolvedFunction -> 
SqlStdOperatorTable.UNARY_MINUS)
+              .put("$divide", resolvedFunction -> SqlStdOperatorTable.DIVIDE)
+              .put("concat", resolvedFunction -> SqlOperators.CONCAT)
+              .put("substr", resolvedFunction -> SqlOperators.SUBSTR)
+              .put("substring", resolvedFunction -> SqlOperators.SUBSTR)
+              .put("trim", resolvedFunction -> SqlOperators.TRIM)
+              .put("replace", resolvedFunction -> SqlOperators.REPLACE)
+              .put("char_length", resolvedFunction -> SqlOperators.CHAR_LENGTH)
+              .put("starts_with", resolvedFunction -> SqlOperators.START_WITHS)
+              .put("ends_with", resolvedFunction -> SqlOperators.ENDS_WITH)
+              .put("ltrim", resolvedFunction -> SqlOperators.LTRIM)
+              .put("rtrim", resolvedFunction -> SqlOperators.RTRIM)
+              .put("reverse", resolvedFunction -> SqlOperators.REVERSE)
+              .put("$count_star", resolvedFunction -> 
SqlStdOperatorTable.COUNT)
+              .put("max", resolvedFunction -> SqlStdOperatorTable.MAX)
+              .put("min", resolvedFunction -> SqlStdOperatorTable.MIN)
+              .put("avg", resolvedFunction -> SqlStdOperatorTable.AVG)
+              .put("sum", resolvedFunction -> SqlStdOperatorTable.SUM)
+              .put("any_value", resolvedFunction -> 
SqlStdOperatorTable.ANY_VALUE)
+              .put("count", resolvedFunction -> SqlStdOperatorTable.COUNT)
+              .put("bit_and", resolvedFunction -> SqlStdOperatorTable.BIT_AND)
+              .put("string_agg", SqlOperators::createStringAggOperator) // 
NULL values not supported

Review comment:
       If I understand correctly, you parameterized the string_agg operator on 
resolvedFunction, since unlike other functions that support multiple arguments, 
there's no way to make string_agg generic (because String and byte[] don't 
share a base class like Number). I think this is fine, and probably an 
inevitable change.

##########
File path: 
sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java
##########
@@ -148,24 +149,27 @@ private LogicalProject 
convertAggregateScanInputScanToLogicalProject(
       // aggregation?
       ResolvedAggregateFunctionCall aggregateFunctionCall =
           ((ResolvedAggregateFunctionCall) resolvedComputedColumn.getExpr());
-      if (aggregateFunctionCall.getArgumentList() != null
-          && aggregateFunctionCall.getArgumentList().size() == 1) {
-        ResolvedExpr resolvedExpr = 
aggregateFunctionCall.getArgumentList().get(0);
-
-        // TODO: assume aggregate function's input is either a ColumnRef or a 
cast(ColumnRef).
-        // TODO: user might use multiple CAST so we need to handle this rare 
case.
-        projects.add(
-            getExpressionConverter()
-                .convertRexNodeFromResolvedExpr(
-                    resolvedExpr,
-                    node.getInputScan().getColumnList(),
-                    input.getRowType().getFieldList(),
-                    ImmutableMap.of()));
-        
fieldNames.add(getTrait().resolveAlias(resolvedComputedColumn.getColumn()));
-      } else if (aggregateFunctionCall.getArgumentList() != null
-          && aggregateFunctionCall.getArgumentList().size() > 1) {
-        throw new IllegalArgumentException(
-            aggregateFunctionCall.getFunction().getName() + " has more than 
one argument.");
+      ImmutableList<ResolvedExpr> argumentList =
+          ImmutableList.copyOf(aggregateFunctionCall.getArgumentList());
+      if (argumentList != null && argumentList.size() >= 1) {
+        ResolvedExpr resolvedExpr = argumentList.get(0);
+        for (int i = 0; i < argumentList.size(); i++) {
+          if (i == 0) {
+            // TODO: assume aggregate function's input is either a ColumnRef 
or a cast(ColumnRef).
+            // TODO: user might use multiple CAST so we need to handle this 
rare case.
+            projects.add(
+                getExpressionConverter()
+                    .convertRexNodeFromResolvedExpr(
+                        resolvedExpr,
+                        node.getInputScan().getColumnList(),
+                        input.getRowType().getFieldList(),
+                        ImmutableMap.of()));
+          } else {
+            projects.add(
+                
getExpressionConverter().convertRexNodeFromResolvedExpr(argumentList.get(i)));
+          }
+          
fieldNames.add(getTrait().resolveAlias(resolvedComputedColumn.getColumn()));

Review comment:
       This doesn't look like it belongs inside the for loop.

##########
File path: 
sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java
##########
@@ -180,6 +176,43 @@
           null,
           new CastFunctionImpl());
 
+  public static SqlOperator createStringAggOperator(
+      ResolvedNodes.ResolvedFunctionCallBase aggregateFunctionCall) {
+    List<ResolvedNodes.ResolvedExpr> args = 
aggregateFunctionCall.getArgumentList();
+    String inputType = args.get(0).getType().typeName();
+    Value delimiter = null;
+    if (args.size() == 2) {
+      delimiter = ((ResolvedNodes.ResolvedLiteral) args.get(1)).getValue();
+    }
+    switch (inputType) {
+      case "BYTES":
+        if (delimiter != null) {
+          return SqlOperators.createUdafOperator(
+              "string_agg",
+              x -> 
SqlOperators.createTypeFactory().createSqlType(SqlTypeName.VARBINARY),
+              new UdafImpl<>(new 
StringAgg.StringAggByte(delimiter.getBytesValue().toByteArray())));
+        }
+        return SqlOperators.createUdafOperator(
+            "string_agg",
+            x -> 
SqlOperators.createTypeFactory().createSqlType(SqlTypeName.VARBINARY),
+            new UdafImpl<>(new StringAgg.StringAggByte()));
+      case "STRING":
+        if (delimiter != null) {
+          return SqlOperators.createUdafOperator(
+              "string_agg",
+              x -> 
SqlOperators.createTypeFactory().createSqlType(SqlTypeName.VARCHAR),
+              new UdafImpl<>(new 
StringAgg.StringAggString(delimiter.getStringValue())));
+        }
+        return SqlOperators.createUdafOperator(
+            "string_agg",
+            x -> 
SqlOperators.createTypeFactory().createSqlType(SqlTypeName.VARCHAR),
+            new UdafImpl<>(new StringAgg.StringAggString()));

Review comment:
       Nit: I prefer to keep all the logic here. That way we can make the 
default delimiter explicit rather than implicit, and drop the extra constructor 
StringAggString(). And same with above for bytes.
   
   ```java
   return SqlOperators.createUdafOperator(
       "string_agg",
       x -> SqlOperators.createTypeFactory().createSqlType(SqlTypeName.VARCHAR),
       new UdafImpl<>(new StringAgg.StringAggString(delimiter == null ? "," : 
delimiter)));
   ```

##########
File path: 
sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java
##########
@@ -148,24 +149,27 @@ private LogicalProject 
convertAggregateScanInputScanToLogicalProject(
       // aggregation?
       ResolvedAggregateFunctionCall aggregateFunctionCall =
           ((ResolvedAggregateFunctionCall) resolvedComputedColumn.getExpr());
-      if (aggregateFunctionCall.getArgumentList() != null
-          && aggregateFunctionCall.getArgumentList().size() == 1) {
-        ResolvedExpr resolvedExpr = 
aggregateFunctionCall.getArgumentList().get(0);
-
-        // TODO: assume aggregate function's input is either a ColumnRef or a 
cast(ColumnRef).
-        // TODO: user might use multiple CAST so we need to handle this rare 
case.
-        projects.add(
-            getExpressionConverter()
-                .convertRexNodeFromResolvedExpr(
-                    resolvedExpr,
-                    node.getInputScan().getColumnList(),
-                    input.getRowType().getFieldList(),
-                    ImmutableMap.of()));
-        
fieldNames.add(getTrait().resolveAlias(resolvedComputedColumn.getColumn()));
-      } else if (aggregateFunctionCall.getArgumentList() != null
-          && aggregateFunctionCall.getArgumentList().size() > 1) {
-        throw new IllegalArgumentException(
-            aggregateFunctionCall.getFunction().getName() + " has more than 
one argument.");
+      ImmutableList<ResolvedExpr> argumentList =
+          ImmutableList.copyOf(aggregateFunctionCall.getArgumentList());
+      if (argumentList != null && argumentList.size() >= 1) {

Review comment:
       It seems like there is an assumption here that argumentList will never 
be null and will never have size 0. If that's true, can we check and throw an 
error for those cases?

##########
File path: 
sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/udaf/StringAgg.java
##########
@@ -73,4 +79,53 @@ public String extractOutput(String output) {
       return output;
     }
   }
+
+  /** A {@link CombineFn} that aggregates bytes with a byte as delimiter. */

Review comment:
       ```suggestion
     /** A {@link CombineFn} that aggregates bytes with a byte array as 
delimiter. */
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@beam.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Issue Time Tracking
-------------------

    Worklog Id:     (was: 701609)
    Time Spent: 1h 20m  (was: 1h 10m)

>  ZetaSQL layer doesn't support aggregate functions with two arguments
> ---------------------------------------------------------------------
>
>                 Key: BEAM-11808
>                 URL: https://issues.apache.org/jira/browse/BEAM-11808
>             Project: Beam
>          Issue Type: Improvement
>          Components: dsl-sql-zetasql
>            Reporter: Sonam Ramchand
>            Assignee: Benjamin Gonzalez
>            Priority: P3
>          Time Spent: 1h 20m
>  Remaining Estimate: 0h
>
> Blocked by:  parsed ZetaSQL to a Calcite logical expression doesn't currently 
> support aggregate functions with multiple columns. 
> [https://github.com/apache/beam/blob/3bb232fb098700de408f574585dfe74bbaff7230/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java#L147]



--
This message was sent by Atlassian Jira
(v8.20.1#820001)

Reply via email to