swuferhong commented on code in PR #22978: URL: https://github.com/apache/flink/pull/22978#discussion_r1266275902
########## flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/WrapJsonAggFunctionArgumentsRule.java: ########## @@ -80,81 +81,90 @@ public WrapJsonAggFunctionArgumentsRule(Config config) { @Override public void onMatch(RelOptRuleCall call) { final LogicalAggregate aggregate = call.rel(0); - final AggregateCall aggCall = aggregate.getAggCallList().get(0); - final RelNode aggInput = aggregate.getInput(); final RelBuilder relBuilder = call.builder().push(aggInput); - final List<Integer> affectedArgs = getAffectedArgs(aggCall); - addProjections(aggregate.getCluster(), relBuilder, affectedArgs); - - final TargetMapping argsMapping = - getAggArgsMapping(aggInput.getRowType().getFieldCount(), affectedArgs); - - final AggregateCall newAggregateCall = aggCall.transform(argsMapping); - final LogicalAggregate newAggregate = - aggregate.copy( - aggregate.getTraitSet(), - relBuilder.build(), - aggregate.getGroupSet(), - aggregate.getGroupSets(), - Collections.singletonList(newAggregateCall)); - call.transformTo(newAggregate.withHints(Collections.singletonList(MARKER_HINT))); + final LogicalAggregate wrappedAggregate = wrapJsonAggregate(aggregate, relBuilder); + call.transformTo(wrappedAggregate.withHints(Collections.singletonList(MARKER_HINT))); } - /** - * Returns the aggregation's arguments which need to be wrapped. - * - * <p>This list is a subset of {@link AggregateCall#getArgList()} as not every argument may need - * to be wrapped into a {@link BuiltInFunctionDefinitions#JSON_STRING} call. - * - * <p>Duplicates (e.g. for {@code JSON_OBJECTAGG(f0 VALUE f0)}) are removed as we only need to - * wrap them once. - */ - private List<Integer> getAffectedArgs(AggregateCall aggCall) { - if (aggCall.getAggregation() instanceof SqlJsonObjectAggAggFunction) { - // For JSON_OBJECTAGG we only need to wrap its second (= value) argument - final int valueIndex = aggCall.getArgList().get(1); - return Collections.singletonList(valueIndex); + private LogicalAggregate wrapJsonAggregate(LogicalAggregate aggregate, RelBuilder relBuilder) { + final int inputCount = aggregate.getInput().getRowType().getFieldCount(); + List<AggregateCall> aggCallList = new ArrayList<>(aggregate.getAggCallList()); + // This map is a mapping relationship between jsonObjectAggCall and the argument index + // need to be wrapped into a BuiltInFunctionDefinitions#JSON_STRING. This map will be used + // to create newWrappedArgCallList after creating a new Project. + Map<Integer, Integer> wrapIndicesMap = new HashMap<>(); + for (int i = 0; i < aggCallList.size(); i++) { + AggregateCall currentCall = aggCallList.get(i); + if (currentCall.getAggregation() instanceof SqlJsonObjectAggAggFunction) { + // For JSON_OBJECTAGG we only need to wrap its second (= value) argument + final int valueIndex = currentCall.getArgList().get(1); + wrapIndicesMap.put(i, valueIndex); + } else if (currentCall.getAggregation() instanceof SqlJsonArrayAggAggFunction) { + final int valueIndex = currentCall.getArgList().get(0); + wrapIndicesMap.put(i, valueIndex); + } + } + + // Create a new Project. + Map<Integer, Integer> valueIndicesAfterProjection = new HashMap<>(); + addProjections( + aggregate.getCluster(), + relBuilder, + new ArrayList<>(wrapIndicesMap.values()), + inputCount, + valueIndicesAfterProjection); + + List<AggregateCall> newWrappedArgCallList = new ArrayList<>(aggCallList); + final int newInputCount = inputCount + valueIndicesAfterProjection.size(); + for (Integer jsonAggCallIndex : wrapIndicesMap.keySet()) { + final TargetMapping argsMapping = + Mappings.create(MappingType.BIJECTION, newInputCount, newInputCount); + Integer valueIndex = wrapIndicesMap.get(jsonAggCallIndex); + argsMapping.set(valueIndex, valueIndicesAfterProjection.get(valueIndex)); + final AggregateCall newAggregateCall = + newWrappedArgCallList.get(jsonAggCallIndex).transform(argsMapping); + newWrappedArgCallList.set(jsonAggCallIndex, newAggregateCall); } - return aggCall.getArgList().stream().distinct().collect(Collectors.toList()); + return aggregate.copy( + aggregate.getTraitSet(), + relBuilder.build(), + aggregate.getGroupSet(), + aggregate.getGroupSets(), + newWrappedArgCallList); } /** - * Adds (wrapped) projections for affected arguments of the aggregation. + * Adds (wrapped) projections for affected arguments of the aggregation. For duplicate + * projection fields, we only wrap them once and record the conversion relationship in the map + * valueIndicesAfterProjection. * * <p>Note that we cannot override any of the projections as a field may be used multiple times, * and in particular outside of the aggregation call. Therefore, we explicitly add the wrapped * projection as an additional one. */ private void addProjections( - RelOptCluster cluster, RelBuilder relBuilder, List<Integer> affectedArgs) { + RelOptCluster cluster, + RelBuilder relBuilder, + List<Integer> affectedArgs, + int inputCount, + Map<Integer, Integer> valueIndicesAfterProjection) { final BridgingSqlFunction operandToStringOperator = BridgingSqlFunction.of(cluster, JSON_STRING); final List<RexNode> projects = new ArrayList<>(); - affectedArgs.stream() - .map(argIdx -> relBuilder.call(operandToStringOperator, relBuilder.field(argIdx))) - .forEach(projects::add); - - relBuilder.projectPlus(projects); - } - - /** - * Returns a {@link TargetMapping} that defines how the arguments of the aggregation must be - * mapped such that the wrapped arguments are used instead. - */ - private TargetMapping getAggArgsMapping(int inputCount, List<Integer> affectedArgs) { - final int newCount = inputCount + affectedArgs.size(); - - final TargetMapping argsMapping = - Mappings.create(MappingType.BIJECTION, newCount, newCount); - for (int i = 0; i < affectedArgs.size(); i++) { - argsMapping.set(affectedArgs.get(i), inputCount + i); + int newProjectCount = 0; + for (Integer argIdx : affectedArgs) { Review Comment: > we can change the param 'affectedArgs' to a `Set` and simplify the code further, e.g., > > ```java > for (Integer argIdx : affectedArgs) { > valueIndicesAfterProjection.put(argIdx, inputCount + projects.size()); > projects.add(relBuilder.call(operandToStringOperator, relBuilder.field(argIdx))); > } > ``` Done! -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org