This is an automated email from the ASF dual-hosted git repository.

gian pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/master by this push:
     new 67de7d1889c fix: GroupByFrameCombiner selectors must track cursor 
changes. (#19238)
67de7d1889c is described below

commit 67de7d1889c7e0511d5d4547df70dd960a571d77
Author: Gian Merlino <[email protected]>
AuthorDate: Tue Mar 31 13:05:48 2026 -0700

    fix: GroupByFrameCombiner selectors must track cursor changes. (#19238)
    
    The groupBy combiner introduced in #19193 created selectors incorrectly:
    they tracked a single underlying cursor rather than following changes
    in the cursor. This patch addresses it by adding indirection through
    TrackingDimensionSelector and TrackingColumnValueSelector.
---
 .../msq/querykit/groupby/GroupByFrameCombiner.java |  34 +--
 .../querykit/groupby/GroupByFrameCombinerTest.java | 248 +++++++++++++++++++++
 .../processor/TrackingColumnValueSelector.java     | 100 +++++++++
 .../frame/processor/TrackingDimensionSelector.java | 140 ++++++++++++
 .../frame/processor/SummingFrameCombiner.java      | 199 ++++++++++-------
 .../druid/frame/processor/SuperSorterTest.java     | 105 ++++++++-
 6 files changed, 706 insertions(+), 120 deletions(-)

diff --git 
a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByFrameCombiner.java
 
b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByFrameCombiner.java
index 6f51b7569f2..d87ab13cc9f 100644
--- 
a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByFrameCombiner.java
+++ 
b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByFrameCombiner.java
@@ -22,6 +22,8 @@ package org.apache.druid.msq.querykit.groupby;
 import org.apache.druid.frame.Frame;
 import org.apache.druid.frame.processor.FrameCombiner;
 import org.apache.druid.frame.processor.FrameProcessors;
+import org.apache.druid.frame.processor.TrackingColumnValueSelector;
+import org.apache.druid.frame.processor.TrackingDimensionSelector;
 import org.apache.druid.frame.read.FrameReader;
 import org.apache.druid.frame.segment.FrameCursor;
 import org.apache.druid.query.aggregation.AggregatorFactory;
@@ -134,9 +136,6 @@ public class GroupByFrameCombiner implements FrameCombiner
       cachedFrame = frame;
       cachedCursor = FrameProcessors.makeCursor(frame, frameReader);
 
-      // Reset dimension selectors, they need to be recreated for the new 
cursor.
-      combinedColumnSelectorFactory.resetSelectorCache();
-
       // Rebuild aggregator selectors for the new cursor.
       final ColumnSelectorFactory columnSelectorFactory = 
cachedCursor.getColumnSelectorFactory();
       cachedAggregatorSelectors = new 
ColumnValueSelector<?>[aggregatorFactories.size()];
@@ -156,15 +155,8 @@ public class GroupByFrameCombiner implements FrameCombiner
    */
   private class CombinedColumnSelectorFactory implements ColumnSelectorFactory
   {
-    /**
-     * Cached dimension value selectors from {@link #cachedCursor}.
-     */
-    private final Map<String, ColumnValueSelector<?>> 
valueDimensionSelectorCache = new HashMap<>();
-
-    /**
-     * Cached dimension string selectors from {@link #cachedCursor}.
-     */
-    private final Map<DimensionSpec, DimensionSelector> 
stringDimensionSelectorCache = new HashMap<>();
+    private final Map<String, TrackingColumnValueSelector> 
columnValueSelectorCache = new HashMap<>();
+    private final Map<DimensionSpec, TrackingDimensionSelector> 
dimensionSelectorCache = new HashMap<>();
 
     @Override
     public DimensionSelector makeDimensionSelector(final DimensionSpec 
dimensionSpec)
@@ -193,10 +185,10 @@ public class GroupByFrameCombiner implements FrameCombiner
           }
         };
       } else {
-        // Dimension: delegate to cached dimension selector.
-        return stringDimensionSelectorCache.computeIfAbsent(
+        // Dimension: delegate to a cursor-tracking selector that refreshes 
when cachedCursor changes.
+        return dimensionSelectorCache.computeIfAbsent(
             dimensionSpec,
-            spec -> 
cachedCursor.getColumnSelectorFactory().makeDimensionSelector(spec)
+            spec -> new TrackingDimensionSelector(spec, () -> 
cachedCursor.getColumnSelectorFactory())
         );
       }
     }
@@ -257,20 +249,14 @@ public class GroupByFrameCombiner implements FrameCombiner
           }
         };
       } else {
-        // Dimension: delegate to cached dimension value selector.
-        return valueDimensionSelectorCache.computeIfAbsent(
+        // Dimension: delegate to a cursor-tracking selector that refreshes 
when cachedCursor changes.
+        return columnValueSelectorCache.computeIfAbsent(
             columnName,
-            name -> 
cachedCursor.getColumnSelectorFactory().makeColumnValueSelector(name)
+            name -> new TrackingColumnValueSelector(name, () -> 
cachedCursor.getColumnSelectorFactory())
         );
       }
     }
 
-    private void resetSelectorCache()
-    {
-      valueDimensionSelectorCache.clear();
-      stringDimensionSelectorCache.clear();
-    }
-
     @Nullable
     @Override
     public ColumnCapabilities getColumnCapabilities(final String column)
diff --git 
a/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/groupby/GroupByFrameCombinerTest.java
 
b/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/groupby/GroupByFrameCombinerTest.java
new file mode 100644
index 00000000000..70dd858f0ed
--- /dev/null
+++ 
b/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/groupby/GroupByFrameCombinerTest.java
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.msq.querykit.groupby;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.util.concurrent.MoreExecutors;
+import org.apache.druid.data.input.MapBasedRow;
+import org.apache.druid.data.input.Row;
+import org.apache.druid.frame.Frame;
+import org.apache.druid.frame.FrameType;
+import org.apache.druid.frame.allocation.ArenaMemoryAllocator;
+import org.apache.druid.frame.allocation.ArenaMemoryAllocatorFactory;
+import org.apache.druid.frame.channel.BlockingQueueFrameChannel;
+import org.apache.druid.frame.channel.ReadableFrameChannel;
+import org.apache.druid.frame.key.KeyColumn;
+import org.apache.druid.frame.key.KeyOrder;
+import org.apache.druid.frame.processor.FrameChannelMerger;
+import org.apache.druid.frame.processor.FrameProcessorExecutor;
+import org.apache.druid.frame.read.FrameReader;
+import org.apache.druid.frame.testutil.FrameSequenceBuilder;
+import org.apache.druid.frame.testutil.FrameTestUtil;
+import org.apache.druid.frame.write.FrameWriters;
+import org.apache.druid.java.util.common.StringUtils;
+import org.apache.druid.java.util.common.guava.Sequence;
+import org.apache.druid.java.util.common.guava.Sequences;
+import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
+import org.apache.druid.segment.RowAdapters;
+import org.apache.druid.segment.RowBasedCursorFactory;
+import org.apache.druid.segment.column.ColumnType;
+import org.apache.druid.segment.column.RowSignature;
+import org.apache.druid.testing.InitializedNullHandlingTest;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+/**
+ * Tests for {@link GroupByFrameCombiner} used with {@link FrameChannelMerger}.
+ *
+ * The existing combiner tests in SuperSorterTest use {@link 
org.apache.druid.frame.processor.SummingFrameCombiner},
+ * which creates a new ColumnSelectorFactory each call to 
getCombinedColumnSelectorFactory(). That avoids the bug
+ * where dimension selectors become stale when the underlying frame cursor 
changes. This test uses the real
+ * {@link GroupByFrameCombiner} to verify correct behavior across frame 
boundaries.
+ */
+public class GroupByFrameCombinerTest extends InitializedNullHandlingTest
+{
+  private static final RowSignature SIGNATURE = RowSignature.builder()
+                                                            .add("key_str", 
ColumnType.STRING)
+                                                            .add("key_long", 
ColumnType.LONG)
+                                                            .add("value", 
ColumnType.LONG)
+                                                            .build();
+
+  private static final List<KeyColumn> SORT_KEY =
+      ImmutableList.of(
+          new KeyColumn("key_str", KeyOrder.ASCENDING),
+          new KeyColumn("key_long", KeyOrder.ASCENDING)
+      );
+
+  private static final RowSignature SORTABLE_SIGNATURE =
+      FrameWriters.sortableSignature(SIGNATURE, SORT_KEY);
+
+  /**
+   * Two channels, five distinct keys, one row per frame. The combiner must 
switch frames between
+   * every group, exercising the cursor-tracking logic for dimension selectors.
+   */
+  @Test
+  public void testTwoChannelsOneRowPerFrame() throws Exception
+  {
+    final List<List<Object>> rows = runMergerWithCombiner(
+        ImmutableList.of(
+            ImmutableList.of(row("a", 100L, 1L), row("b", 200L, 2L), row("c", 
300L, 3L), row("d", 400L, 4L), row("e", 500L, 5L)),
+            ImmutableList.of(row("a", 100L, 10L), row("b", 200L, 20L), 
row("c", 300L, 30L), row("d", 400L, 40L), row("e", 500L, 50L))
+        ),
+        1 // one row per frame
+    );
+
+    Assert.assertEquals(
+        ImmutableList.of(
+            ImmutableList.of("a", 100L, 11L),
+            ImmutableList.of("b", 200L, 22L),
+            ImmutableList.of("c", 300L, 33L),
+            ImmutableList.of("d", 400L, 44L),
+            ImmutableList.of("e", 500L, 55L)
+        ),
+        rows
+    );
+  }
+
+  /**
+   * Four channels, twenty distinct keys, one row per frame.
+   */
+  @Test
+  public void testFourChannelsOneRowPerFrame() throws Exception
+  {
+    final int numChannels = 4;
+    final int numKeys = 20;
+
+    final List<List<Row>> channelData =
+        IntStream.range(0, numChannels)
+                 .mapToObj(ch -> IntStream.range(0, numKeys)
+                                          .mapToObj(k -> 
row(StringUtils.format("key_%02d", k), k, 1L))
+                                          .collect(Collectors.toList()))
+                 .collect(Collectors.toList());
+
+    final List<List<Object>> rows = runMergerWithCombiner(channelData, 1);
+
+    Assert.assertEquals(numKeys, rows.size());
+    for (int k = 0; k < numKeys; k++) {
+      Assert.assertEquals(StringUtils.format("key_%02d", k), 
rows.get(k).get(0));
+      Assert.assertEquals((long) k, rows.get(k).get(1));
+      Assert.assertEquals((long) numChannels, rows.get(k).get(2));
+    }
+  }
+
+  /**
+   * Two channels, three distinct keys, multiple rows per frame. Some 
combining happens within a single frame
+   * (same key appears in consecutive rows of the same frame) and some across 
frame boundaries.
+   */
+  @Test
+  public void testTwoChannelsMultipleRowsPerFrame() throws Exception
+  {
+    final List<List<Object>> rows = runMergerWithCombiner(
+        ImmutableList.of(
+            ImmutableList.of(
+                row("a", 1L, 10L),
+                row("a", 1L, 20L),
+                row("b", 2L, 30L),
+                row("c", 3L, 40L),
+                row("c", 3L, 50L)
+            ),
+            ImmutableList.of(
+                row("a", 1L, 100L),
+                row("b", 2L, 200L),
+                row("b", 2L, 300L),
+                row("c", 3L, 400L)
+            )
+        ),
+        3 // multiple rows per frame
+    );
+
+    Assert.assertEquals(
+        ImmutableList.of(
+            ImmutableList.of("a", 1L, 130L),
+            ImmutableList.of("b", 2L, 530L),
+            ImmutableList.of("c", 3L, 490L)
+        ),
+        rows
+    );
+  }
+
+  private List<List<Object>> runMergerWithCombiner(
+      final List<List<Row>> channelData,
+      final int maxRowsPerFrame
+  ) throws Exception
+  {
+    final FrameReader frameReader = FrameReader.create(SORTABLE_SIGNATURE);
+    final BlockingQueueFrameChannel outputChannel = 
BlockingQueueFrameChannel.minimal();
+
+    final List<ReadableFrameChannel> channels = new ArrayList<>();
+    for (final List<Row> data : channelData) {
+      channels.add(makeFrameChannel(data, maxRowsPerFrame));
+    }
+
+    final FrameChannelMerger merger = new FrameChannelMerger(
+        channels,
+        frameReader,
+        outputChannel.writable(),
+        FrameWriters.makeFrameWriterFactory(
+            FrameType.latestRowBased(),
+            new ArenaMemoryAllocatorFactory(1_000_000),
+            SORTABLE_SIGNATURE,
+            Collections.emptyList(),
+            false
+        ),
+        SORT_KEY,
+        new GroupByFrameCombiner(
+            SORTABLE_SIGNATURE,
+            ImmutableList.of(new LongSumAggregatorFactory("value", "value")),
+            2 // aggregatorStart: columns 0-1 are keys, column 2 is the 
aggregate
+        ),
+        null,
+        -1
+    );
+
+    new FrameProcessorExecutor(MoreExecutors.newDirectExecutorService())
+        .runFully(merger, null);
+
+    final List<List<Object>> rows = new ArrayList<>();
+    FrameTestUtil.readRowsFromFrameChannel(outputChannel.readable(), 
frameReader)
+                 .forEach(rows::add);
+    return rows;
+  }
+
+  private static Row row(final String key, final long keyLong, final long 
value)
+  {
+    return new MapBasedRow(0L, Map.of("key_str", key, "key_long", keyLong, 
"value", value));
+  }
+
+  private static ReadableFrameChannel makeFrameChannel(
+      final List<Row> rows,
+      final int maxRowsPerFrame
+  ) throws IOException
+  {
+    final Sequence<Frame> frames = FrameSequenceBuilder
+        .fromCursorFactory(new RowBasedCursorFactory<>(Sequences.simple(rows), 
RowAdapters.standardRow(), SIGNATURE))
+        .maxRowsPerFrame(maxRowsPerFrame)
+        .sortBy(SORT_KEY)
+        .allocator(ArenaMemoryAllocator.create(ByteBuffer.allocate(1_000_000)))
+        .frameType(FrameType.latestRowBased())
+        .frames();
+
+    final BlockingQueueFrameChannel channel = new 
BlockingQueueFrameChannel(100);
+    frames.forEach(frame -> {
+      try {
+        channel.writable().write(frame);
+      }
+      catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    });
+    channel.writable().close();
+    return channel.readable();
+  }
+}
diff --git 
a/processing/src/main/java/org/apache/druid/frame/processor/TrackingColumnValueSelector.java
 
b/processing/src/main/java/org/apache/druid/frame/processor/TrackingColumnValueSelector.java
new file mode 100644
index 00000000000..f1e46b338c2
--- /dev/null
+++ 
b/processing/src/main/java/org/apache/druid/frame/processor/TrackingColumnValueSelector.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.frame.processor;
+
+import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
+import org.apache.druid.segment.ColumnSelectorFactory;
+import org.apache.druid.segment.ColumnValueSelector;
+
+import javax.annotation.Nullable;
+import java.util.function.Supplier;
+
+/**
+ * {@link ColumnValueSelector} that delegates to a selector obtained from a 
{@link ColumnSelectorFactory} supplier.
+ * The delegate is refreshed when the supplier returns a different factory (by 
identity). This allows the selector
+ * to remain valid across changes to the underlying data source, such as frame 
changes during merging or combining.
+ */
+public class TrackingColumnValueSelector implements ColumnValueSelector<Object>
+{
+  private final String columnName;
+  private final Supplier<ColumnSelectorFactory> factorySupplier;
+  private ColumnSelectorFactory delegateFactory;
+  private ColumnValueSelector<?> delegate;
+
+  public TrackingColumnValueSelector(final String columnName, final 
Supplier<ColumnSelectorFactory> factorySupplier)
+  {
+    this.columnName = columnName;
+    this.factorySupplier = factorySupplier;
+  }
+
+  private ColumnValueSelector<?> delegate()
+  {
+    final ColumnSelectorFactory currentFactory = factorySupplier.get();
+    //noinspection ObjectEquality
+    if (currentFactory != delegateFactory) {
+      delegateFactory = currentFactory;
+      delegate = currentFactory.makeColumnValueSelector(columnName);
+    }
+    return delegate;
+  }
+
+  @Override
+  public double getDouble()
+  {
+    return delegate().getDouble();
+  }
+
+  @Override
+  public float getFloat()
+  {
+    return delegate().getFloat();
+  }
+
+  @Override
+  public long getLong()
+  {
+    return delegate().getLong();
+  }
+
+  @Override
+  public boolean isNull()
+  {
+    return delegate().isNull();
+  }
+
+  @Nullable
+  @Override
+  public Object getObject()
+  {
+    return delegate().getObject();
+  }
+
+  @Override
+  public Class<?> classOfObject()
+  {
+    return Object.class;
+  }
+
+  @Override
+  public void inspectRuntimeShape(final RuntimeShapeInspector inspector)
+  {
+    // Do nothing.
+  }
+}
diff --git 
a/processing/src/main/java/org/apache/druid/frame/processor/TrackingDimensionSelector.java
 
b/processing/src/main/java/org/apache/druid/frame/processor/TrackingDimensionSelector.java
new file mode 100644
index 00000000000..db7695e97f7
--- /dev/null
+++ 
b/processing/src/main/java/org/apache/druid/frame/processor/TrackingDimensionSelector.java
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.frame.processor;
+
+import org.apache.druid.query.dimension.DimensionSpec;
+import org.apache.druid.query.filter.DruidPredicateFactory;
+import org.apache.druid.query.filter.ValueMatcher;
+import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
+import org.apache.druid.segment.ColumnSelectorFactory;
+import org.apache.druid.segment.DimensionSelector;
+import org.apache.druid.segment.DimensionSelectorUtils;
+import org.apache.druid.segment.IdLookup;
+import org.apache.druid.segment.data.IndexedInts;
+
+import javax.annotation.Nullable;
+import java.nio.ByteBuffer;
+import java.util.function.Supplier;
+
+/**
+ * {@link DimensionSelector} that delegates to a selector obtained from a 
{@link ColumnSelectorFactory} supplier.
+ * The delegate is refreshed when the supplier returns a different factory (by 
identity). This allows the selector
+ * to remain valid across changes to the underlying data source, such as frame 
changes during merging or combining.
+ */
+public class TrackingDimensionSelector implements DimensionSelector
+{
+  private final DimensionSpec dimensionSpec;
+  private final Supplier<ColumnSelectorFactory> factorySupplier;
+  private ColumnSelectorFactory delegateFactory;
+  private DimensionSelector delegate;
+
+  public TrackingDimensionSelector(final DimensionSpec dimensionSpec, final 
Supplier<ColumnSelectorFactory> factorySupplier)
+  {
+    this.dimensionSpec = dimensionSpec;
+    this.factorySupplier = factorySupplier;
+  }
+
+  private DimensionSelector delegate()
+  {
+    final ColumnSelectorFactory currentFactory = factorySupplier.get();
+    //noinspection ObjectEquality
+    if (currentFactory != delegateFactory) {
+      delegateFactory = currentFactory;
+      delegate = currentFactory.makeDimensionSelector(dimensionSpec);
+    }
+    return delegate;
+  }
+
+  @Override
+  public IndexedInts getRow()
+  {
+    return delegate().getRow();
+  }
+
+  @Override
+  public ValueMatcher makeValueMatcher(@Nullable final String value)
+  {
+    return DimensionSelectorUtils.makeValueMatcherGeneric(this, value);
+  }
+
+  @Override
+  public ValueMatcher makeValueMatcher(final DruidPredicateFactory 
predicateFactory)
+  {
+    return DimensionSelectorUtils.makeValueMatcherGeneric(this, 
predicateFactory);
+  }
+
+  @Nullable
+  @Override
+  public Object getObject()
+  {
+    return delegate().getObject();
+  }
+
+  @Override
+  public Class<?> classOfObject()
+  {
+    return delegate().classOfObject();
+  }
+
+  @Override
+  public int getValueCardinality()
+  {
+    return CARDINALITY_UNKNOWN;
+  }
+
+  @Nullable
+  @Override
+  public String lookupName(final int id)
+  {
+    return delegate().lookupName(id);
+  }
+
+  @Nullable
+  @Override
+  public ByteBuffer lookupNameUtf8(final int id)
+  {
+    return delegate().lookupNameUtf8(id);
+  }
+
+  @Override
+  public boolean supportsLookupNameUtf8()
+  {
+    return delegate().supportsLookupNameUtf8();
+  }
+
+  @Override
+  public boolean nameLookupPossibleInAdvance()
+  {
+    return false;
+  }
+
+  @Nullable
+  @Override
+  public IdLookup idLookup()
+  {
+    return null;
+  }
+
+  @Override
+  public void inspectRuntimeShape(final RuntimeShapeInspector inspector)
+  {
+    // Do nothing.
+  }
+}
diff --git 
a/processing/src/test/java/org/apache/druid/frame/processor/SummingFrameCombiner.java
 
b/processing/src/test/java/org/apache/druid/frame/processor/SummingFrameCombiner.java
index 26d4fa740d7..61022abf68c 100644
--- 
a/processing/src/test/java/org/apache/druid/frame/processor/SummingFrameCombiner.java
+++ 
b/processing/src/test/java/org/apache/druid/frame/processor/SummingFrameCombiner.java
@@ -32,6 +32,8 @@ import org.apache.druid.segment.column.ColumnCapabilities;
 import org.apache.druid.segment.column.RowSignature;
 
 import javax.annotation.Nullable;
+import java.util.HashMap;
+import java.util.Map;
 
 /**
  * Simple test combiner that sums a long column at {@link #sumColumnNumber}.
@@ -41,15 +43,26 @@ public class SummingFrameCombiner implements FrameCombiner
 {
   private final RowSignature signature;
   private final int sumColumnNumber;
+  private final CombinedColumnSelectorFactory combinedColumnSelectorFactory;
 
   private FrameReader frameReader;
-  private FrameCursor keyCursor;
+
+  @Nullable
+  private Frame cachedFrame;
+
+  @Nullable
+  private FrameCursor cachedCursor;
+
+  @Nullable
+  private ColumnValueSelector<?> cachedSumSelector;
+
   private long summedValue;
 
   public SummingFrameCombiner(final RowSignature signature, final int 
sumColumnNumber)
   {
     this.signature = signature;
     this.sumColumnNumber = sumColumnNumber;
+    this.combinedColumnSelectorFactory = new CombinedColumnSelectorFactory();
   }
 
   @Override
@@ -61,105 +74,123 @@ public class SummingFrameCombiner implements FrameCombiner
   @Override
   public void reset(final Frame frame, final int row)
   {
-    this.keyCursor = FrameProcessors.makeCursor(frame, frameReader);
-    this.keyCursor.setCurrentRow(row);
-    this.summedValue = readLongValue(frame, row);
+    final FrameCursor cursor = getCursor(frame);
+    cursor.setCurrentRow(row);
+    this.summedValue = cachedSumSelector.getLong();
   }
 
   @Override
   public void combine(final Frame frame, final int row)
   {
-    this.summedValue += readLongValue(frame, row);
+    final FrameCursor cursor = getCursor(frame);
+    cursor.setCurrentRow(row);
+    this.summedValue += cachedSumSelector.getLong();
   }
 
   @Override
   public ColumnSelectorFactory getCombinedColumnSelectorFactory()
   {
-    return new ColumnSelectorFactory()
+    return combinedColumnSelectorFactory;
+  }
+
+  private FrameCursor getCursor(final Frame frame)
+  {
+    //noinspection ObjectEquality
+    if (frame != cachedFrame) {
+      cachedFrame = frame;
+      cachedCursor = FrameProcessors.makeCursor(frame, frameReader);
+
+      final String sumColumnName = signature.getColumnName(sumColumnNumber);
+      cachedSumSelector = 
cachedCursor.getColumnSelectorFactory().makeColumnValueSelector(sumColumnName);
+    }
+    return cachedCursor;
+  }
+
+  private class CombinedColumnSelectorFactory implements ColumnSelectorFactory
+  {
+    private final Map<String, TrackingColumnValueSelector> 
columnValueSelectorCache = new HashMap<>();
+    private final Map<DimensionSpec, TrackingDimensionSelector> 
dimensionSelectorCache = new HashMap<>();
+
+    @Override
+    public DimensionSelector makeDimensionSelector(final DimensionSpec 
dimensionSpec)
     {
-      @Override
-      public DimensionSelector makeDimensionSelector(final DimensionSpec 
dimensionSpec)
-      {
-        final int columnNumber = 
signature.indexOf(dimensionSpec.getDimension());
-        if (columnNumber < 0) {
-          return DimensionSelector.constant(null, 
dimensionSpec.getExtractionFn());
-        } else if (columnNumber == sumColumnNumber) {
-          throw new UnsupportedOperationException();
-        } else {
-          return 
keyCursor.getColumnSelectorFactory().makeDimensionSelector(dimensionSpec);
-        }
+      final int columnNumber = signature.indexOf(dimensionSpec.getDimension());
+      if (columnNumber < 0) {
+        return DimensionSelector.constant(null, 
dimensionSpec.getExtractionFn());
+      } else if (columnNumber == sumColumnNumber) {
+        throw new UnsupportedOperationException();
+      } else {
+        return dimensionSelectorCache.computeIfAbsent(
+            dimensionSpec,
+            spec -> new TrackingDimensionSelector(spec, () -> 
cachedCursor.getColumnSelectorFactory())
+        );
       }
+    }
 
-      @Override
-      public ColumnValueSelector<?> makeColumnValueSelector(final String 
columnName)
-      {
-        final int columnNumber = signature.indexOf(columnName);
-        if (columnNumber < 0) {
-          return NilColumnValueSelector.instance();
-        } else if (columnNumber == sumColumnNumber) {
-          return new ColumnValueSelector<Long>()
+    @Override
+    public ColumnValueSelector<?> makeColumnValueSelector(final String 
columnName)
+    {
+      final int columnNumber = signature.indexOf(columnName);
+      if (columnNumber < 0) {
+        return NilColumnValueSelector.instance();
+      } else if (columnNumber == sumColumnNumber) {
+        return new ColumnValueSelector<Long>()
+        {
+          @Override
+          public double getDouble()
           {
-            @Override
-            public double getDouble()
-            {
-              return summedValue;
-            }
-
-            @Override
-            public float getFloat()
-            {
-              return summedValue;
-            }
-
-            @Override
-            public long getLong()
-            {
-              return summedValue;
-            }
-
-            @Override
-            public boolean isNull()
-            {
-              return false;
-            }
-
-            @Override
-            public Long getObject()
-            {
-              return summedValue;
-            }
-
-            @Override
-            public Class<Long> classOfObject()
-            {
-              return Long.class;
-            }
-
-            @Override
-            public void inspectRuntimeShape(RuntimeShapeInspector inspector)
-            {
-              // Nothing to do.
-            }
-          };
-        } else {
-          return 
keyCursor.getColumnSelectorFactory().makeColumnValueSelector(columnName);
-        }
-      }
+            return summedValue;
+          }
 
-      @Nullable
-      @Override
-      public ColumnCapabilities getColumnCapabilities(final String column)
-      {
-        return signature.getColumnCapabilities(column);
+          @Override
+          public float getFloat()
+          {
+            return summedValue;
+          }
+
+          @Override
+          public long getLong()
+          {
+            return summedValue;
+          }
+
+          @Override
+          public boolean isNull()
+          {
+            return false;
+          }
+
+          @Override
+          public Long getObject()
+          {
+            return summedValue;
+          }
+
+          @Override
+          public Class<Long> classOfObject()
+          {
+            return Long.class;
+          }
+
+          @Override
+          public void inspectRuntimeShape(final RuntimeShapeInspector 
inspector)
+          {
+            // Nothing to do.
+          }
+        };
+      } else {
+        return columnValueSelectorCache.computeIfAbsent(
+            columnName,
+            name -> new TrackingColumnValueSelector(name, () -> 
cachedCursor.getColumnSelectorFactory())
+        );
       }
-    };
-  }
+    }
 
-  private long readLongValue(final Frame frame, final int row)
-  {
-    final FrameCursor cursor = FrameProcessors.makeCursor(frame, frameReader);
-    cursor.setCurrentRow(row);
-    final String columnName = signature.getColumnName(sumColumnNumber);
-    return 
cursor.getColumnSelectorFactory().makeColumnValueSelector(columnName).getLong();
+    @Nullable
+    @Override
+    public ColumnCapabilities getColumnCapabilities(final String column)
+    {
+      return signature.getColumnCapabilities(column);
+    }
   }
 }
diff --git 
a/processing/src/test/java/org/apache/druid/frame/processor/SuperSorterTest.java
 
b/processing/src/test/java/org/apache/druid/frame/processor/SuperSorterTest.java
index bef24a3c846..9c8c5d19a1a 100644
--- 
a/processing/src/test/java/org/apache/druid/frame/processor/SuperSorterTest.java
+++ 
b/processing/src/test/java/org/apache/druid/frame/processor/SuperSorterTest.java
@@ -901,6 +901,9 @@ public class SuperSorterTest
       );
 
       final List<List<Object>> rows = runCombiningSuperSorter(
+          SIGNATURE,
+          CLUSTER_BY,
+          SORTABLE_SIGNATURE,
           channelData,
           ClusterByPartitions.oneUniversalPartition(),
           2,
@@ -929,6 +932,9 @@ public class SuperSorterTest
       }
 
       final List<List<Object>> rows = runCombiningSuperSorter(
+          SIGNATURE,
+          CLUSTER_BY,
+          SORTABLE_SIGNATURE,
           channelData,
           ClusterByPartitions.oneUniversalPartition(),
           1,
@@ -969,7 +975,15 @@ public class SuperSorterTest
           new Object[][]{{"a", 10L}, {"b", 20L}, {"c", 30L}, {"d", 40L}}
       );
 
-      final List<List<Object>> rows = runCombiningSuperSorter(channelData, 
partitions, 2, SuperSorter.UNLIMITED);
+      final List<List<Object>> rows = runCombiningSuperSorter(
+          SIGNATURE,
+          CLUSTER_BY,
+          SORTABLE_SIGNATURE,
+          channelData,
+          partitions,
+          2,
+          SuperSorter.UNLIMITED
+      );
 
       Assert.assertEquals(
           ImmutableList.of(
@@ -994,6 +1008,9 @@ public class SuperSorterTest
       }
 
       final List<List<Object>> rows = runCombiningSuperSorter(
+          SIGNATURE,
+          CLUSTER_BY,
+          SORTABLE_SIGNATURE,
           channelData,
           ClusterByPartitions.oneUniversalPartition(),
           2,
@@ -1020,6 +1037,9 @@ public class SuperSorterTest
 
       for (int limit = 1; limit <= 3; limit++) {
         final List<List<Object>> rows = runCombiningSuperSorter(
+            SIGNATURE,
+            CLUSTER_BY,
+            SORTABLE_SIGNATURE,
             channelData,
             ClusterByPartitions.oneUniversalPartition(),
             2,
@@ -1054,6 +1074,9 @@ public class SuperSorterTest
       );
 
       final List<List<Object>> rows = runCombiningSuperSorter(
+          SIGNATURE,
+          CLUSTER_BY,
+          SORTABLE_SIGNATURE,
           channelData,
           ClusterByPartitions.oneUniversalPartition(),
           2,
@@ -1078,6 +1101,9 @@ public class SuperSorterTest
       );
 
       final List<List<Object>> rows = runCombiningSuperSorter(
+          SIGNATURE,
+          CLUSTER_BY,
+          SORTABLE_SIGNATURE,
           channelData,
           ClusterByPartitions.oneUniversalPartition(),
           1,
@@ -1102,6 +1128,9 @@ public class SuperSorterTest
       final List<Object[][]> channelData = ImmutableList.of(new Object[][]{});
 
       final List<List<Object>> rows = runCombiningSuperSorter(
+          SIGNATURE,
+          CLUSTER_BY,
+          SORTABLE_SIGNATURE,
           channelData,
           ClusterByPartitions.oneUniversalPartition(),
           1,
@@ -1112,21 +1141,69 @@ public class SuperSorterTest
     }
 
     /**
-     * Helper that runs a combining SuperSorter with the given channel data, 
partitions, maxActiveProcessors,
-     * and rowLimit. Returns all output rows across all partitions.
+     * Test combining with a LONG key column, which exercises {@link 
TrackingColumnValueSelector}.
+     */
+    @Test
+    public void testCombineWithLongKey() throws Exception
+    {
+      final RowSignature longKeySignature =
+          RowSignature.builder()
+                      .add("key", ColumnType.LONG)
+                      .add("value", ColumnType.LONG)
+                      .build();
+
+      final ClusterBy longKeyClusterBy = new ClusterBy(
+          ImmutableList.of(new KeyColumn("key", KeyOrder.ASCENDING)),
+          0
+      );
+
+      final RowSignature longKeySortableSignature =
+          FrameWriters.sortableSignature(longKeySignature, 
longKeyClusterBy.getColumns());
+
+      final List<Object[][]> channelData = ImmutableList.of(
+          new Object[][]{{1L, 10L}, {2L, 20L}, {3L, 30L}},
+          new Object[][]{{1L, 100L}, {2L, 200L}, {3L, 300L}}
+      );
+
+      final List<List<Object>> rows = runCombiningSuperSorter(
+          longKeySignature,
+          longKeyClusterBy,
+          longKeySortableSignature,
+          channelData,
+          ClusterByPartitions.oneUniversalPartition(),
+          2,
+          SuperSorter.UNLIMITED
+      );
+
+      Assert.assertEquals(
+          ImmutableList.of(
+              ImmutableList.of(1L, 110L),
+              ImmutableList.of(2L, 220L),
+              ImmutableList.of(3L, 330L)
+          ),
+          rows
+      );
+    }
+
+    /**
+     * Helper that runs a combining SuperSorter with the given signature, 
channel data, partitions,
+     * maxActiveProcessors, and rowLimit. Returns all output rows across all 
partitions.
      */
     private List<List<Object>> runCombiningSuperSorter(
+        final RowSignature signature,
+        final ClusterBy clusterBy,
+        final RowSignature sortableSignature,
         final List<Object[][]> channelData,
         final ClusterByPartitions partitions,
         final int maxActiveProcessors,
         final long rowLimit
     ) throws Exception
     {
-      final FrameReader frameReader = FrameReader.create(SORTABLE_SIGNATURE);
+      final FrameReader frameReader = FrameReader.create(sortableSignature);
 
       final List<ReadableFrameChannel> channels = new ArrayList<>();
       for (final Object[][] data : channelData) {
-        channels.add(makeFrameChannel(data));
+        channels.add(makeFrameChannel(signature, clusterBy, data));
       }
 
       final File tempFolder = temporaryFolder.newFolder();
@@ -1134,7 +1211,7 @@ public class SuperSorterTest
       final SuperSorter superSorter = new SuperSorter(
           channels,
           frameReader,
-          CLUSTER_BY.getColumns(),
+          clusterBy.getColumns(),
           Futures.immediateFuture(partitions),
           exec,
           FrameProcessorDecorator.NONE,
@@ -1147,7 +1224,7 @@ public class SuperSorterTest
           null,
           new SuperSorterProgressTracker(),
           false,
-          () -> new SummingFrameCombiner(SORTABLE_SIGNATURE, 1)
+          () -> new SummingFrameCombiner(sortableSignature, 1)
       );
 
       final OutputChannels outputChannels = superSorter.run().get();
@@ -1161,13 +1238,17 @@ public class SuperSorterTest
       return rows;
     }
 
-    private ReadableFrameChannel makeFrameChannel(final Object[][] rows) 
throws IOException
+    private ReadableFrameChannel makeFrameChannel(
+        final RowSignature signature,
+        final ClusterBy clusterBy,
+        final Object[][] rows
+    ) throws IOException
     {
       final List<Row> rowList = new ArrayList<>();
       for (final Object[] row : rows) {
         final Map<String, Object> map = new HashMap<>();
-        for (int i = 0; i < SIGNATURE.size(); i++) {
-          map.put(SIGNATURE.getColumnName(i), row[i]);
+        for (int i = 0; i < signature.size(); i++) {
+          map.put(signature.getColumnName(i), row[i]);
         }
         rowList.add(new MapBasedRow(0L, map));
       }
@@ -1176,13 +1257,13 @@ public class SuperSorterTest
           new RowBasedCursorFactory<>(
               Sequences.simple(rowList),
               RowAdapters.standardRow(),
-              SIGNATURE
+              signature
           );
 
       final Sequence<Frame> frames =
           FrameSequenceBuilder.fromCursorFactory(cursorFactory)
                               .maxRowsPerFrame(maxRowsPerFrame)
-                              .sortBy(CLUSTER_BY.getColumns())
+                              .sortBy(clusterBy.getColumns())
                               
.allocator(ArenaMemoryAllocator.create(ByteBuffer.allocate(FRAME_SIZE)))
                               .frameType(FrameType.latestRowBased())
                               .frames();


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to