This is an automated email from the ASF dual-hosted git repository.

xiangfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new d4e2ee1d92 [feature] Support funnelMaxStep function (#13176)
d4e2ee1d92 is described below

commit d4e2ee1d9245ced1327e3e040e81fb73abfc6580
Author: Xiang Fu <[email protected]>
AuthorDate: Thu May 23 06:14:27 2024 -0700

    [feature] Support funnelMaxStep function (#13176)
    
    * Support funnelMaxStep function
    
    * address comments
---
 .../apache/pinot/core/common/ObjectSerDeUtils.java | 125 +++++---
 .../function/AggregationFunctionFactory.java       |   3 +
 .../funnel/FunnelMaxStepAggregationFunction.java   | 347 +++++++++++++++++++++
 .../function/funnel/FunnelStepEvent.java           | 112 +++++++
 .../pinot/core/common/ObjectSerDeUtilsTest.java    |  19 ++
 .../integration/tests/custom/WindowFunnelTest.java | 287 +++++++++++++++++
 .../pinot/segment/spi/AggregationFunctionType.java |   2 +
 7 files changed, 855 insertions(+), 40 deletions(-)

diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java 
b/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java
index d4a0098f1b..d68ba2bb80 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java
@@ -63,6 +63,7 @@ import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.PriorityQueue;
 import java.util.Set;
 import org.apache.datasketches.common.ArrayOfStringsSerDe;
 import org.apache.datasketches.cpc.CpcSketch;
@@ -75,6 +76,7 @@ import org.apache.datasketches.tuple.aninteger.IntegerSummary;
 import org.apache.datasketches.tuple.aninteger.IntegerSummaryDeserializer;
 import org.apache.pinot.common.CustomObject;
 import org.apache.pinot.common.utils.HashUtil;
+import org.apache.pinot.core.query.aggregation.function.funnel.FunnelStepEvent;
 import 
org.apache.pinot.core.query.aggregation.utils.exprminmax.ExprMinMaxObject;
 import org.apache.pinot.core.query.distinct.DistinctTable;
 import org.apache.pinot.core.query.utils.idset.IdSet;
@@ -162,7 +164,8 @@ public class ObjectSerDeUtils {
     ThetaSketchAccumulator(47),
     TupleIntSketchAccumulator(48),
     CpcSketchAccumulator(49),
-    OrderedStringSet(50);
+    OrderedStringSet(50),
+    FunnelStepEventAccumulator(51);
 
     private final int _value;
 
@@ -294,6 +297,13 @@ public class ObjectSerDeUtils {
         return ObjectType.TupleIntSketchAccumulator;
       } else if (value instanceof CpcSketchAccumulator) {
         return ObjectType.CpcSketchAccumulator;
+      } else if (value instanceof PriorityQueue) {
+        PriorityQueue priorityQueue = (PriorityQueue) value;
+        if (priorityQueue.isEmpty() || priorityQueue.peek() instanceof 
FunnelStepEvent) {
+          return ObjectType.FunnelStepEventAccumulator;
+        }
+        throw new IllegalArgumentException(
+            "Unsupported type of value: " + 
priorityQueue.peek().getClass().getSimpleName());
       } else {
         throw new IllegalArgumentException("Unsupported type of value: " + 
value.getClass().getSimpleName());
       }
@@ -1690,47 +1700,81 @@ public class ObjectSerDeUtils {
   public static final ObjectSerDe<ObjectLinkedOpenHashSet<String>> 
ORDERED_STRING_SET_SER_DE =
       new ObjectSerDe<ObjectLinkedOpenHashSet<String>>() {
 
-    @Override
-    public byte[] serialize(ObjectLinkedOpenHashSet<String> stringSet) {
-      int size = stringSet.size();
-      // Besides the value bytes, we store: size, length for each value
-      long bufferSize = (1 + (long) size) * Integer.BYTES;
-      byte[][] valueBytesArray = new byte[size][];
-      int index = 0;
-      for (String value : stringSet) {
-        byte[] valueBytes = value.getBytes(UTF_8);
-        bufferSize += valueBytes.length;
-        valueBytesArray[index++] = valueBytes;
-      }
-      Preconditions.checkState(bufferSize <= Integer.MAX_VALUE, "Buffer size 
exceeds 2GB");
-      byte[] bytes = new byte[(int) bufferSize];
-      ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
-      byteBuffer.putInt(size);
-      for (byte[] valueBytes : valueBytesArray) {
-        byteBuffer.putInt(valueBytes.length);
-        byteBuffer.put(valueBytes);
-      }
-      return bytes;
-    }
+        @Override
+        public byte[] serialize(ObjectLinkedOpenHashSet<String> stringSet) {
+          int size = stringSet.size();
+          // Besides the value bytes, we store: size, length for each value
+          long bufferSize = (1 + (long) size) * Integer.BYTES;
+          byte[][] valueBytesArray = new byte[size][];
+          int index = 0;
+          for (String value : stringSet) {
+            byte[] valueBytes = value.getBytes(UTF_8);
+            bufferSize += valueBytes.length;
+            valueBytesArray[index++] = valueBytes;
+          }
+          Preconditions.checkState(bufferSize <= Integer.MAX_VALUE, "Buffer 
size exceeds 2GB");
+          byte[] bytes = new byte[(int) bufferSize];
+          ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
+          byteBuffer.putInt(size);
+          for (byte[] valueBytes : valueBytesArray) {
+            byteBuffer.putInt(valueBytes.length);
+            byteBuffer.put(valueBytes);
+          }
+          return bytes;
+        }
 
-    @Override
-    public ObjectLinkedOpenHashSet<String> deserialize(byte[] bytes) {
-      return deserialize(ByteBuffer.wrap(bytes));
-    }
+        @Override
+        public ObjectLinkedOpenHashSet<String> deserialize(byte[] bytes) {
+          return deserialize(ByteBuffer.wrap(bytes));
+        }
 
-    @Override
-    public ObjectLinkedOpenHashSet<String> deserialize(ByteBuffer byteBuffer) {
-      int size = byteBuffer.getInt();
-      ObjectLinkedOpenHashSet<String> stringSet = new 
ObjectLinkedOpenHashSet<>(size);
-      for (int i = 0; i < size; i++) {
-        int length = byteBuffer.getInt();
-        byte[] bytes = new byte[length];
-        byteBuffer.get(bytes);
-        stringSet.add(new String(bytes, UTF_8));
-      }
-      return stringSet;
-    }
-  };
+        @Override
+        public ObjectLinkedOpenHashSet<String> deserialize(ByteBuffer 
byteBuffer) {
+          int size = byteBuffer.getInt();
+          ObjectLinkedOpenHashSet<String> stringSet = new 
ObjectLinkedOpenHashSet<>(size);
+          for (int i = 0; i < size; i++) {
+            int length = byteBuffer.getInt();
+            byte[] bytes = new byte[length];
+            byteBuffer.get(bytes);
+            stringSet.add(new String(bytes, UTF_8));
+          }
+          return stringSet;
+        }
+      };
+
+  public static final ObjectSerDe<PriorityQueue<FunnelStepEvent>> 
FUNNEL_STEP_EVENT_ACCUMULATOR_SER_DE =
+      new ObjectSerDe<PriorityQueue<FunnelStepEvent>>() {
+
+        @Override
+        public byte[] serialize(PriorityQueue<FunnelStepEvent> 
funnelStepEvents) {
+          long bufferSize = Integer.BYTES + funnelStepEvents.size() * 
FunnelStepEvent.SIZE_IN_BYTES;
+          Preconditions.checkState(bufferSize <= Integer.MAX_VALUE, "Buffer 
size exceeds 2GB");
+          byte[] bytes = new byte[(int) bufferSize];
+          ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
+          byteBuffer.putInt(funnelStepEvents.size());
+          for (FunnelStepEvent funnelStepEvent : funnelStepEvents) {
+            byteBuffer.put(funnelStepEvent.getBytes());
+          }
+          return bytes;
+        }
+
+        @Override
+        public PriorityQueue<FunnelStepEvent> deserialize(byte[] bytes) {
+          return deserialize(ByteBuffer.wrap(bytes));
+        }
+
+        @Override
+        public PriorityQueue<FunnelStepEvent> deserialize(ByteBuffer 
byteBuffer) {
+          int size = byteBuffer.getInt();
+          PriorityQueue<FunnelStepEvent> funnelStepEvents = new 
PriorityQueue<>(size);
+          for (int i = 0; i < size; i++) {
+            byte[] bytes = new byte[FunnelStepEvent.SIZE_IN_BYTES];
+            byteBuffer.get(bytes);
+            funnelStepEvents.add(new FunnelStepEvent(bytes));
+          }
+          return funnelStepEvents;
+        }
+      };
 
   // NOTE: DO NOT change the order, it has to be the same order as the 
ObjectType
   //@formatter:off
@@ -1786,6 +1830,7 @@ public class ObjectSerDeUtils {
       DATA_SKETCH_INT_TUPLE_ACCUMULATOR_SER_DE,
       DATA_SKETCH_CPC_ACCUMULATOR_SER_DE,
       ORDERED_STRING_SET_SER_DE,
+      FUNNEL_STEP_EVENT_ACCUMULATOR_SER_DE,
   };
   //@formatter:on
 
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
index d56a9429d5..5ef12fc661 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
@@ -36,6 +36,7 @@ import 
org.apache.pinot.core.query.aggregation.function.array.ArrayAggStringFunc
 import 
org.apache.pinot.core.query.aggregation.function.array.ListAggDistinctFunction;
 import org.apache.pinot.core.query.aggregation.function.array.ListAggFunction;
 import 
org.apache.pinot.core.query.aggregation.function.funnel.FunnelCountAggregationFunctionFactory;
+import 
org.apache.pinot.core.query.aggregation.function.funnel.FunnelMaxStepAggregationFunction;
 import org.apache.pinot.segment.spi.AggregationFunctionType;
 import org.apache.pinot.spi.data.FieldSpec.DataType;
 import org.apache.pinot.spi.exception.BadQueryRequestException;
@@ -452,6 +453,8 @@ public class AggregationFunctionFactory {
                 "Aggregation function: " + functionType + " is only supported 
in selection without alias.");
           case FUNNELCOUNT:
             return new FunnelCountAggregationFunctionFactory(arguments).get();
+          case FUNNELMAXSTEP:
+            return new FunnelMaxStepAggregationFunction(arguments);
           case FREQUENTSTRINGSSKETCH:
             return new FrequentStringsSketchAggregationFunction(arguments);
           case FREQUENTLONGSSKETCH:
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelMaxStepAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelMaxStepAggregationFunction.java
new file mode 100644
index 0000000000..e8f316e187
--- /dev/null
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelMaxStepAggregationFunction.java
@@ -0,0 +1,347 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.aggregation.function.funnel;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import java.util.stream.Collectors;
+import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.common.BlockValSet;
+import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
+import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
+import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
+import 
org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
+import org.apache.pinot.segment.spi.AggregationFunctionType;
+
+
+public class FunnelMaxStepAggregationFunction
+    implements AggregationFunction<PriorityQueue<FunnelStepEvent>, Long> {
+  private final ExpressionContext _timestampExpression;
+  private final long _windowSize;
+  private final List<ExpressionContext> _stepExpressions;
+  private final FunnelModes _modes = new FunnelModes();
+  private final int _numSteps;
+
+  public FunnelMaxStepAggregationFunction(List<ExpressionContext> arguments) {
+    int numArguments = arguments.size();
+    Preconditions.checkArgument(numArguments > 2,
+        "FUNNELMAXSTEP expects >= 3 arguments, got: %s. The function can be 
used as "
+            + "funnelMaxStep(timestampExpression, windowSize, 
ARRAY[stepExpression, ..], [mode, [mode, ... ]])",
+        numArguments);
+    _timestampExpression = arguments.get(0);
+    _windowSize = arguments.get(1).getLiteral().getLongValue();
+    Preconditions.checkArgument(_windowSize > 0, "Window size must be > 0");
+    ExpressionContext stepExpressionContext = arguments.get(2);
+    if (stepExpressionContext.getFunction() != null) {
+      // LEAF stage init this function like 
funnelmaxstep($1,'1000',arrayValueConstructor($2,$3,$4,...))
+      _stepExpressions = stepExpressionContext.getFunction().getArguments();
+    } else {
+      // Intermediate Stage init this function like 
funnelmaxstep($1,'1000',__PLACEHOLDER__)
+      _stepExpressions = ImmutableList.of();
+    }
+    if (numArguments > 3) {
+      arguments.subList(3, numArguments)
+          .forEach(arg -> 
_modes.add(Mode.valueOf(arg.getLiteral().getStringValue().toUpperCase())));
+    }
+    _numSteps = _stepExpressions.size();
+  }
+
+  @Override
+  public AggregationFunctionType getType() {
+    return AggregationFunctionType.FUNNELMAXSTEP;
+  }
+
+  @Override
+  public String getResultColumnName() {
+    return getType().getName().toLowerCase() + "(" + _windowSize + ")  (" + 
_timestampExpression.toString() + ", "
+        + _stepExpressions.stream().map(ExpressionContext::toString)
+        .collect(Collectors.joining(",")) + ")";
+  }
+
+  @Override
+  public List<ExpressionContext> getInputExpressions() {
+    List<ExpressionContext> inputs = new ArrayList<>(1 + _numSteps);
+    inputs.add(_timestampExpression);
+    inputs.addAll(_stepExpressions);
+    return inputs;
+  }
+
+  @Override
+  public AggregationResultHolder createAggregationResultHolder() {
+    return new ObjectAggregationResultHolder();
+  }
+
+  @Override
+  public GroupByResultHolder createGroupByResultHolder(int initialCapacity, 
int maxCapacity) {
+    return new ObjectGroupByResultHolder(initialCapacity, maxCapacity);
+  }
+
+  @Override
+  public void aggregate(int length, AggregationResultHolder 
aggregationResultHolder,
+      Map<ExpressionContext, BlockValSet> blockValSetMap) {
+    long[] timestampBlock = 
blockValSetMap.get(_timestampExpression).getLongValuesSV();
+    List<int[]> stepBlocks = new ArrayList<>();
+    for (ExpressionContext stepExpression : _stepExpressions) {
+      stepBlocks.add(blockValSetMap.get(stepExpression).getIntValuesSV());
+    }
+    PriorityQueue<FunnelStepEvent> stepEvents = new PriorityQueue<>(length);
+    for (int i = 0; i < length; i++) {
+      for (int j = 0; j < _numSteps; j++) {
+        if (stepBlocks.get(j)[i] == 1) {
+          stepEvents.add(new FunnelStepEvent(timestampBlock[i], j));
+          break;
+        }
+      }
+    }
+    aggregationResultHolder.setValue(stepEvents);
+  }
+
+  @Override
+  public void aggregateGroupBySV(int length, int[] groupKeyArray, 
GroupByResultHolder groupByResultHolder,
+      Map<ExpressionContext, BlockValSet> blockValSetMap) {
+    long[] timestampBlock = 
blockValSetMap.get(_timestampExpression).getLongValuesSV();
+    List<int[]> stepBlocks = new ArrayList<>();
+    for (ExpressionContext stepExpression : _stepExpressions) {
+      stepBlocks.add(blockValSetMap.get(stepExpression).getIntValuesSV());
+    }
+    for (int i = 0; i < length; i++) {
+      int groupKey = groupKeyArray[i];
+      for (int j = 0; j < _numSteps; j++) {
+        if (stepBlocks.get(j)[i] == 1) {
+          PriorityQueue<FunnelStepEvent> stepEvents = 
groupByResultHolder.getResult(groupKey);
+          if (stepEvents == null) {
+            stepEvents = new PriorityQueue<>();
+          }
+          stepEvents.add(new FunnelStepEvent(timestampBlock[i], j));
+          groupByResultHolder.setValueForKey(groupKey, stepEvents);
+          break;
+        }
+      }
+    }
+  }
+
+  @Override
+  public void aggregateGroupByMV(int length, int[][] groupKeysArray, 
GroupByResultHolder groupByResultHolder,
+      Map<ExpressionContext, BlockValSet> blockValSetMap) {
+    long[] timestampBlock = 
blockValSetMap.get(_timestampExpression).getLongValuesSV();
+    List<int[]> stepBlocks = new ArrayList<>();
+    for (ExpressionContext stepExpression : _stepExpressions) {
+      stepBlocks.add(blockValSetMap.get(stepExpression).getIntValuesSV());
+    }
+    for (int i = 0; i < length; i++) {
+      int[] groupKeys = groupKeysArray[i];
+      for (int j = 0; j < _numSteps; j++) {
+        if (stepBlocks.get(j)[i] == 1) {
+          for (int groupKey : groupKeys) {
+            PriorityQueue<FunnelStepEvent> stepEvents = 
groupByResultHolder.getResult(groupKey);
+            if (stepEvents == null) {
+              stepEvents = new PriorityQueue<>();
+            }
+            stepEvents.add(new FunnelStepEvent(timestampBlock[i], j));
+            groupByResultHolder.setValueForKey(groupKey, stepEvents);
+          }
+          break;
+        }
+      }
+    }
+  }
+
+  @Override
+  public PriorityQueue<FunnelStepEvent> 
extractAggregationResult(AggregationResultHolder aggregationResultHolder) {
+    return aggregationResultHolder.getResult();
+  }
+
+  @Override
+  public PriorityQueue<FunnelStepEvent> 
extractGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey) {
+    return groupByResultHolder.getResult(groupKey);
+  }
+
+  @Override
+  public PriorityQueue<FunnelStepEvent> merge(PriorityQueue<FunnelStepEvent> 
intermediateResult1,
+      PriorityQueue<FunnelStepEvent> intermediateResult2) {
+    if (intermediateResult1 == null) {
+      return intermediateResult2;
+    }
+    if (intermediateResult2 == null) {
+      return intermediateResult1;
+    }
+    intermediateResult1.addAll(intermediateResult2);
+    return intermediateResult1;
+  }
+
+  @Override
+  public DataSchema.ColumnDataType getIntermediateResultColumnType() {
+    return DataSchema.ColumnDataType.OBJECT;
+  }
+
+  @Override
+  public DataSchema.ColumnDataType getFinalResultColumnType() {
+    return DataSchema.ColumnDataType.LONG;
+  }
+
+  @Override
+  public Long extractFinalResult(PriorityQueue<FunnelStepEvent> stepEvents) {
+    long finalMaxStep = 0;
+    if (stepEvents == null || stepEvents.isEmpty()) {
+      return finalMaxStep;
+    }
+    ArrayDeque<FunnelStepEvent> slidingWindow = new ArrayDeque<>();
+    while (!stepEvents.isEmpty()) {
+      fillWindow(stepEvents, slidingWindow);
+      if (slidingWindow.isEmpty()) {
+        break;
+      }
+      int maxSteps = processWindow(slidingWindow);
+      finalMaxStep = Math.max(finalMaxStep, maxSteps);
+      if (finalMaxStep == _numSteps) {
+        break;
+      }
+      if (!slidingWindow.isEmpty()) {
+        slidingWindow.pollFirst();
+      }
+    }
+    return finalMaxStep;
+  }
+
+  /**
+   * Fill the sliding window with the events that fall into the window.
+   * Note that the events from stepEvents are dequeued and added to the 
sliding window.
+   * This method ensure the first event from the sliding window is the first 
step event.
+   * @param stepEvents The priority queue of step events
+   * @param slidingWindow The sliding window with events that fall into the 
window
+   */
+  private void fillWindow(PriorityQueue<FunnelStepEvent> stepEvents, 
ArrayDeque<FunnelStepEvent> slidingWindow) {
+    // Ensure for the sliding window, the first event is the first step
+    while ((!slidingWindow.isEmpty()) && slidingWindow.peek().getStep() != 0) {
+      slidingWindow.pollFirst();
+    }
+    if (slidingWindow.isEmpty()) {
+      while (!stepEvents.isEmpty() && stepEvents.peek().getStep() != 0) {
+        stepEvents.poll();
+      }
+      if (stepEvents.isEmpty()) {
+        return;
+      }
+      slidingWindow.addLast(stepEvents.poll());
+    }
+    // SlidingWindow is not empty
+    long windowStart = slidingWindow.peek().getTimestamp();
+    long windowEnd = windowStart + _windowSize;
+    while (!stepEvents.isEmpty() && (stepEvents.peek().getTimestamp() < 
windowEnd)) {
+      slidingWindow.addLast(stepEvents.poll());
+    }
+  }
+
+  private int processWindow(ArrayDeque<FunnelStepEvent> slidingWindow) {
+    int maxStep = 0;
+    long previousTimestamp = -1;
+    for (FunnelStepEvent event : slidingWindow) {
+      int currentEventStep = event.getStep();
+      // If the same condition holds for the sequence of events, then such 
repeating event interrupts further
+      // processing.
+      if (_modes.hasStrictDeduplication()) {
+        if (currentEventStep == maxStep - 1) {
+          return maxStep;
+        }
+      }
+      // Don't allow interventions of other events. E.g. in the case of 
A->B->D->C, it stops finding A->B->C at the D
+      // and the max event level is 2.
+      if (_modes.hasStrictOrder()) {
+        if (currentEventStep != maxStep) {
+          return maxStep;
+        }
+      }
+      // Apply conditions only to events with strictly increasing timestamps.
+      if (_modes.hasStrictIncrease()) {
+        if (previousTimestamp == event.getTimestamp()) {
+          continue;
+        }
+      }
+      previousTimestamp = event.getTimestamp();
+      if (maxStep == currentEventStep) {
+        maxStep++;
+      }
+      if (maxStep == _numSteps) {
+        break;
+      }
+    }
+    return maxStep;
+  }
+
+  @Override
+  public String toExplainString() {
+    return "WindowFunnelAggregationFunction{"
+        + "_timestampExpression=" + _timestampExpression
+        + ", _windowSize=" + _windowSize
+        + ", _stepExpressions=" + _stepExpressions
+        + ", _numSteps=" + _numSteps
+        + '}';
+  }
+
+  enum Mode {
+    STRICT_DEDUPLICATION(1),
+    STRICT_ORDER(2),
+    STRICT_INCREASE(4);
+
+    private final int _value;
+
+    Mode(int value) {
+      _value = value;
+    }
+
+    public int getValue() {
+      return _value;
+    }
+  }
+
+  static class FunnelModes {
+    private int _bitmask = 0;
+
+    public void add(Mode mode) {
+      _bitmask |= mode.getValue();
+    }
+
+    public void remove(Mode mode) {
+      _bitmask &= ~mode.getValue();
+    }
+
+    public boolean contains(Mode mode) {
+      return (_bitmask & mode.getValue()) != 0;
+    }
+
+    public boolean hasStrictDeduplication() {
+      return contains(Mode.STRICT_DEDUPLICATION);
+    }
+
+    public boolean hasStrictOrder() {
+      return contains(Mode.STRICT_ORDER);
+    }
+
+    public boolean hasStrictIncrease() {
+      return contains(Mode.STRICT_INCREASE);
+    }
+  }
+}
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelStepEvent.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelStepEvent.java
new file mode 100644
index 0000000000..d0309b6e50
--- /dev/null
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelStepEvent.java
@@ -0,0 +1,112 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.aggregation.function.funnel;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+
+
+public class FunnelStepEvent implements Comparable<FunnelStepEvent> {
+  public final static int SIZE_IN_BYTES = Long.BYTES + Integer.BYTES;
+
+  private final long _timestamp;
+  private final int _step;
+
+  public FunnelStepEvent(long timestamp, int step) {
+    _timestamp = timestamp;
+    _step = step;
+  }
+
+  public FunnelStepEvent(byte[] bytes) {
+    try {
+      DataInputStream dataInputStream = new DataInputStream(new 
ByteArrayInputStream(bytes));
+      _timestamp = dataInputStream.readLong();
+      _step = dataInputStream.readInt();
+      dataInputStream.close();
+    } catch (Exception e) {
+      throw new RuntimeException("Caught exception while converting byte[] to 
FunnelStepEvent", e);
+    }
+  }
+
+  public long getTimestamp() {
+    return _timestamp;
+  }
+
+  public int getStep() {
+    return _step;
+  }
+
+  @Override
+  public String toString() {
+    return "StepEvent{"
+        + "timestamp=" + _timestamp
+        + ", step=" + _step
+        + '}';
+  }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) {
+      return true;
+    }
+    if (o == null || getClass() != o.getClass()) {
+      return false;
+    }
+
+    FunnelStepEvent stepEvent = (FunnelStepEvent) o;
+
+    if (_timestamp != stepEvent._timestamp) {
+      return false;
+    }
+    return _step == stepEvent._step;
+  }
+
+  @Override
+  public int hashCode() {
+    int result = Long.hashCode(_timestamp);
+    result = 31 * result + _step;
+    return result;
+  }
+
+  @Override
+  public int compareTo(FunnelStepEvent o) {
+    if (_timestamp < o._timestamp) {
+      return -1;
+    } else if (_timestamp > o._timestamp) {
+      return 1;
+    } else {
+      return Integer.compare(_step, o._step);
+    }
+  }
+
+  public byte[] getBytes() {
+    ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
+    DataOutputStream dataOutputStream = new 
DataOutputStream(byteArrayOutputStream);
+    try {
+      dataOutputStream.writeLong(_timestamp);
+      dataOutputStream.writeInt(_step);
+      dataOutputStream.close();
+    } catch (Exception e) {
+      throw new RuntimeException("Caught exception while converting 
FunnelStepEvent to byte[]", e);
+    }
+    return byteArrayOutputStream.toByteArray();
+  }
+}
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/core/common/ObjectSerDeUtilsTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/core/common/ObjectSerDeUtilsTest.java
index 4de5c68135..33d18e19a1 100644
--- 
a/pinot-core/src/test/java/org/apache/pinot/core/common/ObjectSerDeUtilsTest.java
+++ 
b/pinot-core/src/test/java/org/apache/pinot/core/common/ObjectSerDeUtilsTest.java
@@ -37,6 +37,7 @@ import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.PriorityQueue;
 import java.util.Random;
 import org.apache.commons.lang.RandomStringUtils;
 import org.apache.datasketches.cpc.CpcSketch;
@@ -49,6 +50,7 @@ import org.apache.datasketches.tuple.aninteger.IntegerSummary;
 import org.apache.datasketches.tuple.aninteger.IntegerSummarySetOperations;
 import 
org.apache.pinot.core.query.aggregation.function.PercentileEstAggregationFunction;
 import 
org.apache.pinot.core.query.aggregation.function.PercentileTDigestAggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.funnel.FunnelStepEvent;
 import org.apache.pinot.segment.local.customobject.AvgPair;
 import org.apache.pinot.segment.local.customobject.CpcSketchAccumulator;
 import org.apache.pinot.segment.local.customobject.DoubleLongPair;
@@ -594,4 +596,21 @@ public class ObjectSerDeUtilsTest {
       }
     }
   }
+
+  @Test
+  public void testFunnelStepEventAccumulator() {
+    for (int i = 0; i < NUM_ITERATIONS; i++) {
+      int size = RANDOM.nextInt(1000);
+      PriorityQueue<FunnelStepEvent> expected = new 
PriorityQueue<FunnelStepEvent>();
+      for (int j = 0; j < size; j++) {
+        expected.add(new FunnelStepEvent(RANDOM.nextLong(), RANDOM.nextInt()));
+      }
+      byte[] bytes = ObjectSerDeUtils.serialize(expected);
+      PriorityQueue<FunnelStepEvent> actual =
+          ObjectSerDeUtils.deserialize(bytes, 
ObjectSerDeUtils.ObjectType.FunnelStepEventAccumulator);
+      while (!actual.isEmpty()) {
+        assertEquals(actual.poll(), expected.poll(), ERROR_MESSAGE);
+      }
+    }
+  }
 }
diff --git 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/WindowFunnelTest.java
 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/WindowFunnelTest.java
new file mode 100644
index 0000000000..4b373d5fac
--- /dev/null
+++ 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/WindowFunnelTest.java
@@ -0,0 +1,287 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.integration.tests.custom;
+
+import com.fasterxml.jackson.databind.JsonNode;
+import com.google.common.collect.ImmutableList;
+import java.io.File;
+import org.apache.avro.file.DataFileWriter;
+import org.apache.avro.generic.GenericData;
+import org.apache.avro.generic.GenericDatumWriter;
+import org.apache.pinot.spi.data.FieldSpec;
+import org.apache.pinot.spi.data.Schema;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.assertEquals;
+
+
+@Test(suiteName = "CustomClusterIntegrationTest")
+public class WindowFunnelTest extends CustomDataQueryClusterIntegrationTest {
+
+  private static final String DEFAULT_TABLE_NAME = "WindowFunnelTest";
+  private static final String URL_COLUMN = "url";
+  private static final String TIMESTAMP_COLUMN = "timestampCol";
+  private static final String USER_ID_COLUMN = "userId";
+  private static long _countStarResult = 0;
+
+  @Override
+  protected long getCountStarResult() {
+    return _countStarResult;
+  }
+
+  @Test(dataProvider = "useBothQueryEngines")
+  public void testFunnelMaxStepQueries(boolean useMultiStageQueryEngine)
+      throws Exception {
+    setUseMultiStageQueryEngine(useMultiStageQueryEngine);
+    String query =
+        String.format("SELECT "
+            + "funnelMaxStep(timestampCol, '1000', "
+            + "ARRAY[ "
+            + "url = '/product/search', "
+            + "url = '/cart/add', "
+            + "url = '/checkout/start', "
+            + "url = '/checkout/confirmation' "
+            + "] ) "
+            + "FROM %s LIMIT %d", getTableName(), getCountStarResult());
+    JsonNode jsonNode = postQuery(query);
+    JsonNode rows = jsonNode.get("resultTable").get("rows");
+    assertEquals(rows.size(), 1);
+    JsonNode row = rows.get(0);
+    assertEquals(row.size(), 1);
+    assertEquals(row.get(0).intValue(), 4);
+  }
+
+  @Test(dataProvider = "useBothQueryEngines")
+  public void testFunnelMaxStepGroupByQueries(boolean useMultiStageQueryEngine)
+      throws Exception {
+    setUseMultiStageQueryEngine(useMultiStageQueryEngine);
+    String query =
+        String.format("SELECT "
+            + "userId, funnelMaxStep(timestampCol, '1000', "
+            + "ARRAY[ "
+            + "url = '/product/search', "
+            + "url = '/cart/add', "
+            + "url = '/checkout/start', "
+            + "url = '/checkout/confirmation' "
+            + "] ) "
+            + "FROM %s GROUP BY userId ORDER BY userId LIMIT %d", 
getTableName(), getCountStarResult());
+    JsonNode jsonNode = postQuery(query);
+    JsonNode rows = jsonNode.get("resultTable").get("rows");
+    assertEquals(rows.size(), 40);
+    for (int i = 0; i < 40; i++) {
+      JsonNode row = rows.get(i);
+      assertEquals(row.size(), 2);
+      assertEquals(row.get(0).textValue(), "user" + (i / 10) + (i % 10));
+      switch (i / 10) {
+        case 0:
+          assertEquals(row.get(1).intValue(), 4);
+          break;
+        case 1:
+          assertEquals(row.get(1).intValue(), 3);
+          break;
+        case 2:
+          assertEquals(row.get(1).intValue(), 3);
+          break;
+        case 3:
+          assertEquals(row.get(1).intValue(), 1);
+          break;
+        default:
+          throw new IllegalStateException();
+      }
+    }
+  }
+
+  @Test(dataProvider = "useBothQueryEngines")
+  public void testFunnelMaxStepGroupByQueriesWithMode(boolean 
useMultiStageQueryEngine)
+      throws Exception {
+    setUseMultiStageQueryEngine(useMultiStageQueryEngine);
+    String query =
+        String.format("SELECT "
+            + "userId, funnelMaxStep(timestampCol, '1000', "
+            + "ARRAY[ "
+            + "url = '/product/search', "
+            + "url = '/cart/add', "
+            + "url = '/checkout/start', "
+            + "url = '/checkout/confirmation' "
+            + "], 'strict_order' ) "
+            + "FROM %s GROUP BY userId ORDER BY userId LIMIT %d", 
getTableName(), getCountStarResult());
+    JsonNode jsonNode = postQuery(query);
+    JsonNode rows = jsonNode.get("resultTable").get("rows");
+    assertEquals(rows.size(), 40);
+    for (int i = 0; i < 40; i++) {
+      JsonNode row = rows.get(i);
+      assertEquals(row.size(), 2);
+      assertEquals(row.get(0).textValue(), "user" + (i / 10) + (i % 10));
+      switch (i / 10) {
+        case 0:
+          assertEquals(row.get(1).intValue(), 3);
+          break;
+        case 1:
+          assertEquals(row.get(1).intValue(), 3);
+          break;
+        case 2:
+          assertEquals(row.get(1).intValue(), 2);
+          break;
+        case 3:
+          assertEquals(row.get(1).intValue(), 1);
+          break;
+        default:
+          throw new IllegalStateException();
+      }
+    }
+
+    query =
+        String.format("SELECT "
+            + "userId, funnelMaxStep(timestampCol, '1000', "
+            + "ARRAY[ "
+            + "url = '/product/search', "
+            + "url = '/cart/add', "
+            + "url = '/checkout/start', "
+            + "url = '/checkout/confirmation' "
+            + "], 'strict_deduplication' ) "
+            + "FROM %s GROUP BY userId ORDER BY userId LIMIT %d", 
getTableName(), getCountStarResult());
+    jsonNode = postQuery(query);
+    rows = jsonNode.get("resultTable").get("rows");
+    assertEquals(rows.size(), 40);
+    for (int i = 0; i < 40; i++) {
+      JsonNode row = rows.get(i);
+      assertEquals(row.size(), 2);
+      assertEquals(row.get(0).textValue(), "user" + (i / 10) + (i % 10));
+      switch (i / 10) {
+        case 0:
+          assertEquals(row.get(1).intValue(), 4);
+          break;
+        case 1:
+          assertEquals(row.get(1).intValue(), 3);
+          break;
+        case 2:
+          assertEquals(row.get(1).intValue(), 2);
+          break;
+        case 3:
+          assertEquals(row.get(1).intValue(), 1);
+          break;
+        default:
+          throw new IllegalStateException();
+      }
+    }
+
+    query =
+        String.format("SELECT "
+            + "userId, funnelMaxStep(timestampCol, '1000', "
+            + "ARRAY[ "
+            + "url = '/product/search', "
+            + "url = '/cart/add', "
+            + "url = '/checkout/start', "
+            + "url = '/checkout/confirmation' "
+            + "], 'strict_increase' ) "
+            + "FROM %s GROUP BY userId ORDER BY userId LIMIT %d", 
getTableName(), getCountStarResult());
+    jsonNode = postQuery(query);
+    rows = jsonNode.get("resultTable").get("rows");
+    assertEquals(rows.size(), 40);
+    for (int i = 0; i < 40; i++) {
+      JsonNode row = rows.get(i);
+      assertEquals(row.size(), 2);
+      assertEquals(row.get(0).textValue(), "user" + (i / 10) + (i % 10));
+      switch (i / 10) {
+        case 0:
+          assertEquals(row.get(1).intValue(), 4);
+          break;
+        case 1:
+          assertEquals(row.get(1).intValue(), 2);
+          break;
+        case 2:
+          assertEquals(row.get(1).intValue(), 3);
+          break;
+        case 3:
+          assertEquals(row.get(1).intValue(), 1);
+          break;
+        default:
+          throw new IllegalStateException();
+      }
+    }
+  }
+
+  @Override
+  public String getTableName() {
+    return DEFAULT_TABLE_NAME;
+  }
+
+  @Override
+  public Schema createSchema() {
+    return new Schema.SchemaBuilder().setSchemaName(getTableName())
+        .addSingleValueDimension(URL_COLUMN, FieldSpec.DataType.STRING)
+        .addSingleValueDimension(TIMESTAMP_COLUMN, 
FieldSpec.DataType.TIMESTAMP)
+        .addSingleValueDimension(USER_ID_COLUMN, FieldSpec.DataType.STRING)
+        .build();
+  }
+
+  @Override
+  public File createAvroFile()
+      throws Exception {
+    // create avro schema
+    org.apache.avro.Schema avroSchema = 
org.apache.avro.Schema.createRecord("myRecord", null, null, false);
+    avroSchema.setFields(ImmutableList.of(
+        new org.apache.avro.Schema.Field(URL_COLUMN,
+            org.apache.avro.Schema.create(org.apache.avro.Schema.Type.STRING),
+            null, null),
+        new org.apache.avro.Schema.Field(TIMESTAMP_COLUMN,
+            org.apache.avro.Schema.create(org.apache.avro.Schema.Type.LONG),
+            null, null),
+        new org.apache.avro.Schema.Field(USER_ID_COLUMN,
+            org.apache.avro.Schema.create(org.apache.avro.Schema.Type.STRING),
+            null, null)
+    ));
+
+    long[][] userTimestampValues = new long[][]{
+        new long[]{1000, 1010, 1020, 1025, 1030},
+        new long[]{2010, 2010, 2000},
+        new long[]{1000, 1010, 1015, 1020, 11030},
+        new long[]{2020, 12010, 12050},
+    };
+    String[][] userUrlValues = new String[][]{
+        new String[]{"/product/search", "/cart/add", "/checkout/start", 
"/cart/add", "/checkout/confirmation"},
+        new String[]{"/checkout/start", "/cart/add", "/product/search"},
+        new String[]{"/product/search", "/cart/add", "/cart/add", 
"/checkout/start", "/checkout/confirmation"},
+        new String[]{"/checkout/start", "/cart/add", "/product/search"},
+    };
+    int repeats = 10;
+    long totalRows = 0;
+    for (String[] userUrlValue : userUrlValues) {
+      totalRows += userUrlValue.length;
+    }
+    _countStarResult = totalRows * repeats;
+    // create avro file
+    File avroFile = new File(_tempDir, "data.avro");
+    try (DataFileWriter<GenericData.Record> fileWriter = new 
DataFileWriter<>(new GenericDatumWriter<>(avroSchema))) {
+      fileWriter.create(avroSchema, avroFile);
+      for (int repeat = 0; repeat < repeats; repeat++) {
+        for (int i = 0; i < userUrlValues.length; i++) {
+          for (int j = 0; j < userUrlValues[i].length; j++) {
+            GenericData.Record record = new GenericData.Record(avroSchema);
+            record.put(TIMESTAMP_COLUMN, userTimestampValues[i][j]);
+            record.put(URL_COLUMN, userUrlValues[i][j]);
+            record.put(USER_ID_COLUMN, "user" + i + repeat);
+            fileWriter.append(record);
+          }
+        }
+      }
+    }
+    return avroFile;
+  }
+}
diff --git 
a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
 
b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
index 929275fc17..a6c468d8fe 100644
--- 
a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
+++ 
b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
@@ -327,6 +327,8 @@ public enum AggregationFunctionType {
           ordinal -> ordinal > 1), ReturnTypes.VARCHAR, 
ReturnTypes.explicit(SqlTypeName.OTHER)),
 
   // funnel aggregate functions
+  FUNNELMAXSTEP("funnelMaxStep", null, SqlKind.OTHER_FUNCTION, 
SqlFunctionCategory.USER_DEFINED_FUNCTION,
+      OperandTypes.VARIADIC, ReturnTypes.BIGINT, 
ReturnTypes.explicit(SqlTypeName.OTHER)),
   // TODO: revisit support for funnel count in V2
   FUNNELCOUNT("funnelCount");
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to