924060929 commented on code in PR #12583:
URL: https://github.com/apache/doris/pull/12583#discussion_r1020151860


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java:
##########
@@ -257,6 +283,210 @@ public PlanFragment visitPhysicalAggregate(
         return currentFragment;
     }
 
+    /**
+     * Translate physicalRepeat to RepeatNode.
+     *
+     * eg: select k1 from t1 group by grouping sets ((k1), (k2, k3), (k3));
+     * groupingInfo:
+     *      Contains:
+     *      virtualTupleDesc: the tupleDescriptor generated for the virtual 
column
+     *      preRepeatExprs: Expr used in repeatNode. eg: [k1, k2, k3]
+     *
+     * repeatSlotIdList: According to the bitmap corresponding to the original 
group by
+     *                   and the slotId information contained in the 
slotDescriptor,
+     *                   the bitmap corresponding to the slotId is obtained.
+     * eg: groupingSetsBitSet: [{0}, {1, 2}, {2}]. slotIdList:[5, 6, 7]
+     *     ==> [{5}, {6, 7}, {7}].
+     *
+     * groupingList: The value of the virtual column generated by each group 
by.
+     * eg: [3, 4, 6]
+     *
+     * allSlotId: slotId of all used columns. eg: [5, 6, 7]
+     */
+    @Override
+    public PlanFragment visitPhysicalRepeat(PhysicalRepeat repeat, 
PlanTranslatorContext context) {
+        PlanFragment inputPlanFragment = repeat.child(0).accept(this, context);
+        List<Expression> groupByExpressions = repeat.getGroupByExpressions();
+        List<NamedExpression> outputExpressionList = 
repeat.getOutputExpressions();
+        List<Expression> virtualGroupByExpressions = 
repeat.getVirtualGroupByExpressions();
+        List<Expression> nonVirtualGroupByExpressions = 
repeat.getNonVirtualGroupByExpressions();
+
+        // 1.create GroupingInfo
+        /*
+         * finalSlots: grouping sets the final output column.
+         * preExpressions: Fields required by groupInfo, all fields that need 
to be used.
+         *
+         * eg: select sum(k2), grouping(k1) from t1 group by grouping 
sets((k1));
+         * finalSlots: k1, k2, GROUPING_ID(), GROUPING_PREFIX_k1
+         * prePressions: k1, k2
+         */
+        // create virtual tupleDesc
+        List<VirtualSlotReference> virtualSlotList = 
virtualGroupByExpressions.stream()
+                .filter(VirtualSlotReference.class::isInstance)
+                .map(VirtualSlotReference.class::cast)
+                .collect(Collectors.toList());
+        List<Slot> virtualSlots = Lists.newArrayList(virtualSlotList);
+        TupleDescriptor virtualTupleDesc = generateTupleDesc(virtualSlots, 
null, context);
+
+        // create repeat slots and PreExpressions
+        List<SlotReference> groupBySlot = 
genGroupBySlotList(groupByExpressions, outputExpressionList);
+        List<Slot> finalSlots = genOutputSlots(groupBySlot, 
outputExpressionList, virtualSlots);
+
+        List<Expression> preExpressions = 
genPreExpressions(nonVirtualGroupByExpressions, outputExpressionList);
+        ArrayList<Expr> preRepeatExprs = preExpressions.stream()
+                .map(e -> ExpressionTranslator.translate(e, 
context)).collect(Collectors.toCollection(ArrayList::new));
+
+        // create output TupleDesc
+        TupleResult tupleResult = genTupleDescAndDescList(finalSlots, null, 
context);
+        TupleDescriptor outputTupleDesc = tupleResult.getTupleDescriptor();
+        setSlotNullable(outputTupleDesc, finalSlots);
+
+        GroupingInfo groupingInfo = new 
GroupingInfo(ExpressionTranslator.translateGroupingType(repeat),
+                virtualTupleDesc, outputTupleDesc, preRepeatExprs);
+
+        // 2.create repeat node
+        // Replace the bitset of groupingsets with the bitset corresponding to 
slotId
+        List<Set<Integer>> repeatSlotIdList = genGroupingIdList(
+                repeat.useBitsetsToRepresentGroupingSets(), 
tupleResult.getSlotDescriptors());
+        Set<Integer> allSlotId = genAllSlotId(repeatSlotIdList);
+
+        RepeatNode repeatNode = new RepeatNode(context.nextPlanNodeId(),
+                inputPlanFragment.getPlanRoot(), groupingInfo, 
repeatSlotIdList,
+                allSlotId, repeat.getVirtualSlotValues());
+        
repeatNode.setNumInstances(inputPlanFragment.getPlanRoot().getNumInstances());
+        inputPlanFragment.addPlanRoot(repeatNode);
+        inputPlanFragment.updateDataPartition(DataPartition.RANDOM);
+        return inputPlanFragment;
+    }
+
+    /**
+     * Generate outputSlot based on groupBy and output information.
+     * The generated outputSlot needs to guarantee the order,
+     * first slotReference and then virtualSlotReference.
+     *
+     * First collect the slot and then collect the virtualSlot.
+     * 1. Put the columns in groupBy into the slot
+     * 2. Put the columns in the output into the slot
+     * 3. Put slot and virtualSlot into finalSLots in turn.
+     *
+     * eg: select sum(k2) grouping(k1) from t1 group by grouping sets((k1));
+     * 1. slots:[k1]
+     * 2. slots:[k1, k2]
+     * 3. finalSlots: [k1, k2, GROUPING_PREFIX_k1]
+     */
+    private List<Slot> genOutputSlots(
+            List<SlotReference> groupBySlot,
+            List<NamedExpression> outputExpressionList,
+            List<Slot> virtualSlots) {
+
+        Map<Boolean, List<Slot>> allSlots = groupBySlot.stream()
+                
.collect(Collectors.groupingBy(VirtualSlotReference.class::isInstance,
+                        LinkedHashMap::new, Collectors.toList()));
+        List<Slot> nonVirtualSlots = allSlots.get(false);
+        Map<Boolean, List<NamedExpression>> allOutput = 
outputExpressionList.stream()
+                
.collect(Collectors.groupingBy(VirtualSlotReference.class::isInstance,
+                        LinkedHashMap::new, Collectors.toList()));
+        List<NamedExpression> nonVirtualOutput = allOutput.get(false);
+
+        nonVirtualSlots.addAll(nonVirtualOutput.stream()
+                .filter(e -> !nonVirtualSlots.contains(e))
+                .filter(SlotReference.class::isInstance)
+                .map(NamedExpression::toSlot)
+                .collect(Collectors.toSet()));
+        List<Slot> finalSlots = Lists.newArrayList();
+        finalSlots.addAll(nonVirtualSlots);
+        finalSlots.addAll(virtualSlots);
+        return finalSlots;
+    }
+
+    /**
+     * Get all the used columns that appear in repeatNode according to the 
column of groupBy
+     * and the column of aggFunc in output.
+     *
+     * eg: select sum(k2) grouping(k1) from t1 group by grouping sets((k1));
+     * preExpressions: [k1, k2].
+     */
+    private List<Expression> genPreExpressions(
+            List<Expression> nonVirtualGroupByExpressions, 
List<NamedExpression> outputExpressionList) {
+        List<Expression> preExpressions = Lists.newArrayList();
+        preExpressions.addAll(nonVirtualGroupByExpressions);
+        Map<Boolean, List<NamedExpression>> allOutput = 
outputExpressionList.stream()
+                
.collect(Collectors.groupingBy(VirtualSlotReference.class::isInstance,
+                        LinkedHashMap::new, Collectors.toList()));
+        List<NamedExpression> nonVirtualOutput = allOutput.get(false);
+        nonVirtualOutput.stream()
+                .filter(e -> !preExpressions.contains(e))
+                .forEach(preExpressions::add);
+        return preExpressions;
+    }
+
+    /**
+     * Generate bitSets corresponding to SlotId according to the original 
groupBy bitSets and the actual SlotId.
+     * eg:
+     * groupingSetsList: [(k1), (k2, k3), (k3)]
+     * originalGroupingIdList: [{0}, {1,2}, {2}]
+     * SlotIds in groupingSlotDescs: [3, 4, 5]
+     *
+     * return: [{3}, {4, 5}, {5}]
+     */
+    private List<Set<Integer>> genGroupingIdList(
+            List<BitSet> originalGroupingIdList, List<SlotDescriptor> 
groupingSlotDescs) {
+        List<Set<Integer>> groupingIdList = Lists.newArrayList();
+        for (BitSet bitSet : originalGroupingIdList) {
+            Set<Integer> slotIdSet = new HashSet<>();
+            for (int i = 0; i < groupingSlotDescs.size(); i++) {
+                if (bitSet.get(i)) {
+                    slotIdSet.add(groupingSlotDescs.get(i).getId().asInt());
+                }
+            }
+            groupingIdList.add(slotIdSet);
+        }
+        return groupingIdList;
+    }
+
+    private static Set<Integer> genAllSlotId(List<Set<Integer>> 
repeatSlotIdList) {
+        Set<Integer> allSlotId = new LinkedHashSet<>();
+        for (Set<Integer> s : repeatSlotIdList) {
+            allSlotId.addAll(s);
+        }
+        return allSlotId;
+    }
+
+    private void setSlotNullable(TupleDescriptor tupleDescriptor, List<Slot> 
slots) {
+        for (int i = 0; i < slots.size(); ++i) {
+            if (!(slots.get(i) instanceof VirtualSlotReference)) {
+                SlotDescriptor slotDescriptor = 
tupleDescriptor.getSlots().get(i);
+                slotDescriptor.setIsNullable(true);
+            }
+        }
+    }
+
+    private void setSlotNullable(TupleDescriptor tupleDescriptor, List<Slot> 
slots, List<NamedExpression> projects) {
+        Preconditions.checkState(projects.size() == slots.size());
+        Set<Integer> notSetSlotNullable = new HashSet<>();
+        for (int i = 0; i < projects.size(); i++) {
+            if (projects.get(i) instanceof Alias
+                    && ((Alias) projects.get(i)).child() instanceof 
GroupingScalarFunction) {
+                notSetSlotNullable.add(i);
+            }
+        }
+        for (int i = 0; i < slots.size(); ++i) {
+            if (!notSetSlotNullable.contains(i)) {
+                SlotDescriptor slotDescriptor = 
tupleDescriptor.getSlots().get(i);
+                slotDescriptor.setIsNullable(true);
+            }
+        }
+    }
+
+    private boolean needSetSlotToNullable(List<NamedExpression> slots) {
+        boolean hasVirtualSlotReference = 
slots.stream().anyMatch(VirtualSlotReference.class::isInstance);
+        boolean hasGroupingFunc = slots.stream()
+                .filter(Alias.class::isInstance)
+                .map(Alias.class::cast)
+                .anyMatch(s -> 
s.child().anyMatch(GroupingScalarFunction.class::isInstance));
+        return hasGroupingFunc || hasVirtualSlotReference;
+    }

Review Comment:
   So why not separate this method name to `containsGroupingScalarFunction` and 
`containsVirtualSlot`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to