swuferhong commented on code in PR #22978:
URL: https://github.com/apache/flink/pull/22978#discussion_r1266275902


##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/WrapJsonAggFunctionArgumentsRule.java:
##########
@@ -80,81 +81,90 @@ public WrapJsonAggFunctionArgumentsRule(Config config) {
     @Override
     public void onMatch(RelOptRuleCall call) {
         final LogicalAggregate aggregate = call.rel(0);
-        final AggregateCall aggCall = aggregate.getAggCallList().get(0);
-
         final RelNode aggInput = aggregate.getInput();
         final RelBuilder relBuilder = call.builder().push(aggInput);
 
-        final List<Integer> affectedArgs = getAffectedArgs(aggCall);
-        addProjections(aggregate.getCluster(), relBuilder, affectedArgs);
-
-        final TargetMapping argsMapping =
-                getAggArgsMapping(aggInput.getRowType().getFieldCount(), 
affectedArgs);
-
-        final AggregateCall newAggregateCall = aggCall.transform(argsMapping);
-        final LogicalAggregate newAggregate =
-                aggregate.copy(
-                        aggregate.getTraitSet(),
-                        relBuilder.build(),
-                        aggregate.getGroupSet(),
-                        aggregate.getGroupSets(),
-                        Collections.singletonList(newAggregateCall));
-        
call.transformTo(newAggregate.withHints(Collections.singletonList(MARKER_HINT)));
+        final LogicalAggregate wrappedAggregate = wrapJsonAggregate(aggregate, 
relBuilder);
+        
call.transformTo(wrappedAggregate.withHints(Collections.singletonList(MARKER_HINT)));
     }
 
-    /**
-     * Returns the aggregation's arguments which need to be wrapped.
-     *
-     * <p>This list is a subset of {@link AggregateCall#getArgList()} as not 
every argument may need
-     * to be wrapped into a {@link BuiltInFunctionDefinitions#JSON_STRING} 
call.
-     *
-     * <p>Duplicates (e.g. for {@code JSON_OBJECTAGG(f0 VALUE f0)}) are 
removed as we only need to
-     * wrap them once.
-     */
-    private List<Integer> getAffectedArgs(AggregateCall aggCall) {
-        if (aggCall.getAggregation() instanceof SqlJsonObjectAggAggFunction) {
-            // For JSON_OBJECTAGG we only need to wrap its second (= value) 
argument
-            final int valueIndex = aggCall.getArgList().get(1);
-            return Collections.singletonList(valueIndex);
+    private LogicalAggregate wrapJsonAggregate(LogicalAggregate aggregate, 
RelBuilder relBuilder) {
+        final int inputCount = 
aggregate.getInput().getRowType().getFieldCount();
+        List<AggregateCall> aggCallList = new 
ArrayList<>(aggregate.getAggCallList());
+        // This map is a mapping relationship between jsonObjectAggCall and 
the argument index
+        // need to be wrapped into a BuiltInFunctionDefinitions#JSON_STRING. 
This map will be used
+        // to create newWrappedArgCallList after creating a new Project.
+        Map<Integer, Integer> wrapIndicesMap = new HashMap<>();
+        for (int i = 0; i < aggCallList.size(); i++) {
+            AggregateCall currentCall = aggCallList.get(i);
+            if (currentCall.getAggregation() instanceof 
SqlJsonObjectAggAggFunction) {
+                // For JSON_OBJECTAGG we only need to wrap its second (= 
value) argument
+                final int valueIndex = currentCall.getArgList().get(1);
+                wrapIndicesMap.put(i, valueIndex);
+            } else if (currentCall.getAggregation() instanceof 
SqlJsonArrayAggAggFunction) {
+                final int valueIndex = currentCall.getArgList().get(0);
+                wrapIndicesMap.put(i, valueIndex);
+            }
+        }
+
+        // Create a new Project.
+        Map<Integer, Integer> valueIndicesAfterProjection = new HashMap<>();
+        addProjections(
+                aggregate.getCluster(),
+                relBuilder,
+                new ArrayList<>(wrapIndicesMap.values()),
+                inputCount,
+                valueIndicesAfterProjection);
+
+        List<AggregateCall> newWrappedArgCallList = new 
ArrayList<>(aggCallList);
+        final int newInputCount = inputCount + 
valueIndicesAfterProjection.size();
+        for (Integer jsonAggCallIndex : wrapIndicesMap.keySet()) {
+            final TargetMapping argsMapping =
+                    Mappings.create(MappingType.BIJECTION, newInputCount, 
newInputCount);
+            Integer valueIndex = wrapIndicesMap.get(jsonAggCallIndex);
+            argsMapping.set(valueIndex, 
valueIndicesAfterProjection.get(valueIndex));
+            final AggregateCall newAggregateCall =
+                    
newWrappedArgCallList.get(jsonAggCallIndex).transform(argsMapping);
+            newWrappedArgCallList.set(jsonAggCallIndex, newAggregateCall);
         }
 
-        return 
aggCall.getArgList().stream().distinct().collect(Collectors.toList());
+        return aggregate.copy(
+                aggregate.getTraitSet(),
+                relBuilder.build(),
+                aggregate.getGroupSet(),
+                aggregate.getGroupSets(),
+                newWrappedArgCallList);
     }
 
     /**
-     * Adds (wrapped) projections for affected arguments of the aggregation.
+     * Adds (wrapped) projections for affected arguments of the aggregation. 
For duplicate
+     * projection fields, we only wrap them once and record the conversion 
relationship in the map
+     * valueIndicesAfterProjection.
      *
      * <p>Note that we cannot override any of the projections as a field may 
be used multiple times,
      * and in particular outside of the aggregation call. Therefore, we 
explicitly add the wrapped
      * projection as an additional one.
      */
     private void addProjections(
-            RelOptCluster cluster, RelBuilder relBuilder, List<Integer> 
affectedArgs) {
+            RelOptCluster cluster,
+            RelBuilder relBuilder,
+            List<Integer> affectedArgs,
+            int inputCount,
+            Map<Integer, Integer> valueIndicesAfterProjection) {
         final BridgingSqlFunction operandToStringOperator =
                 BridgingSqlFunction.of(cluster, JSON_STRING);
 
         final List<RexNode> projects = new ArrayList<>();
-        affectedArgs.stream()
-                .map(argIdx -> relBuilder.call(operandToStringOperator, 
relBuilder.field(argIdx)))
-                .forEach(projects::add);
-
-        relBuilder.projectPlus(projects);
-    }
-
-    /**
-     * Returns a {@link TargetMapping} that defines how the arguments of the 
aggregation must be
-     * mapped such that the wrapped arguments are used instead.
-     */
-    private TargetMapping getAggArgsMapping(int inputCount, List<Integer> 
affectedArgs) {
-        final int newCount = inputCount + affectedArgs.size();
-
-        final TargetMapping argsMapping =
-                Mappings.create(MappingType.BIJECTION, newCount, newCount);
-        for (int i = 0; i < affectedArgs.size(); i++) {
-            argsMapping.set(affectedArgs.get(i), inputCount + i);
+        int newProjectCount = 0;
+        for (Integer argIdx : affectedArgs) {

Review Comment:
   > we can change the param 'affectedArgs' to a `Set` and simplify the code 
further, e.g.,
   > 
   > ```java
   >         for (Integer argIdx : affectedArgs) {
   >             valueIndicesAfterProjection.put(argIdx, inputCount + 
projects.size());
   >             projects.add(relBuilder.call(operandToStringOperator, 
relBuilder.field(argIdx)));
   >         }
   > ```
   
   Done!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to