This is an automated email from the ASF dual-hosted git repository.

mbutrovich pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 79b83d873 chore: Refactor JVM shuffle: Move `SpillSorter` to top level 
class and add tests (#3081)
79b83d873 is described below

commit 79b83d8733d976310552517f042c498c63b97592
Author: Andy Grove <[email protected]>
AuthorDate: Wed Jan 14 13:03:10 2026 -0700

    chore: Refactor JVM shuffle: Move `SpillSorter` to top level class and add 
tests (#3081)
---
 .github/workflows/pr_build_linux.yml               |   1 +
 .github/workflows/pr_build_macos.yml               |   1 +
 dev/ensure-jars-have-correct-contents.sh           |   1 +
 .../shuffle/sort/CometShuffleExternalSorter.java   | 268 ++--------------
 .../org/apache/spark/shuffle/sort/SpillSorter.java | 352 +++++++++++++++++++++
 .../spark/shuffle/sort/SpillSorterSuite.scala      | 262 +++++++++++++++
 6 files changed, 638 insertions(+), 247 deletions(-)

diff --git a/.github/workflows/pr_build_linux.yml 
b/.github/workflows/pr_build_linux.yml
index 9f5324b26..8e4dc5124 100644
--- a/.github/workflows/pr_build_linux.yml
+++ b/.github/workflows/pr_build_linux.yml
@@ -122,6 +122,7 @@ jobs:
               org.apache.comet.exec.CometAsyncShuffleSuite
               org.apache.comet.exec.DisableAQECometShuffleSuite
               org.apache.comet.exec.DisableAQECometAsyncShuffleSuite
+              org.apache.spark.shuffle.sort.SpillSorterSuite
           - name: "parquet"
             value: |
               org.apache.comet.parquet.CometParquetWriterSuite
diff --git a/.github/workflows/pr_build_macos.yml 
b/.github/workflows/pr_build_macos.yml
index 58ba48134..f94071dbc 100644
--- a/.github/workflows/pr_build_macos.yml
+++ b/.github/workflows/pr_build_macos.yml
@@ -85,6 +85,7 @@ jobs:
               org.apache.comet.exec.CometAsyncShuffleSuite
               org.apache.comet.exec.DisableAQECometShuffleSuite
               org.apache.comet.exec.DisableAQECometAsyncShuffleSuite
+              org.apache.spark.shuffle.sort.SpillSorterSuite
           - name: "parquet"
             value: |
               org.apache.comet.parquet.CometParquetWriterSuite
diff --git a/dev/ensure-jars-have-correct-contents.sh 
b/dev/ensure-jars-have-correct-contents.sh
index f698fe78f..570aeabb2 100755
--- a/dev/ensure-jars-have-correct-contents.sh
+++ b/dev/ensure-jars-have-correct-contents.sh
@@ -86,6 +86,7 @@ allowed_expr+="|^org/apache/spark/shuffle/$"
 allowed_expr+="|^org/apache/spark/shuffle/sort/$"
 allowed_expr+="|^org/apache/spark/shuffle/sort/CometShuffleExternalSorter.*$"
 allowed_expr+="|^org/apache/spark/shuffle/sort/RowPartition.class$"
+allowed_expr+="|^org/apache/spark/shuffle/sort/SpillSorter.*$"
 allowed_expr+="|^org/apache/spark/shuffle/comet/.*$"
 allowed_expr+="|^org/apache/spark/sql/$"
 # allow ExplainPlanGenerator trait since it may not be available in older 
Spark versions
diff --git 
a/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java
 
b/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java
index 8bc22b342..b026c6bc4 100644
--- 
a/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java
+++ 
b/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java
@@ -23,7 +23,6 @@ import java.io.File;
 import java.io.IOException;
 import java.util.LinkedList;
 import java.util.concurrent.*;
-import javax.annotation.Nullable;
 
 import scala.Tuple2;
 
@@ -32,7 +31,6 @@ import org.slf4j.LoggerFactory;
 
 import org.apache.spark.SparkConf;
 import org.apache.spark.TaskContext;
-import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.memory.SparkOutOfMemoryError;
 import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
 import org.apache.spark.shuffle.comet.CometShuffleChecksumSupport;
@@ -41,17 +39,14 @@ import org.apache.spark.shuffle.comet.TooLargePageException;
 import org.apache.spark.sql.comet.execution.shuffle.CometUnsafeShuffleWriter;
 import org.apache.spark.sql.comet.execution.shuffle.ShuffleThreadPool;
 import org.apache.spark.sql.comet.execution.shuffle.SpillInfo;
-import org.apache.spark.sql.comet.execution.shuffle.SpillWriter;
 import org.apache.spark.sql.types.StructType;
 import org.apache.spark.storage.BlockManager;
 import org.apache.spark.storage.TempShuffleBlockId;
-import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.UnsafeAlignedOffset;
 import org.apache.spark.unsafe.array.LongArray;
 import org.apache.spark.util.Utils;
 
 import org.apache.comet.CometConf$;
-import org.apache.comet.Native;
 
 /**
  * An external sorter that is specialized for sort-based shuffle.
@@ -169,10 +164,28 @@ public final class CometShuffleExternalSorter implements 
CometShuffleChecksumSup
       this.threadPool = null;
     }
 
-    this.activeSpillSorter = new SpillSorter();
-
     this.preferDictionaryRatio =
         (double) 
CometConf$.MODULE$.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO().get();
+
+    this.activeSpillSorter = createSpillSorter();
+  }
+
+  /** Creates a new SpillSorter with all required dependencies. */
+  private SpillSorter createSpillSorter() {
+    return new SpillSorter(
+        allocator,
+        initialSize,
+        schema,
+        uaoSize,
+        preferDictionaryRatio,
+        compressionCodec,
+        compressionLevel,
+        checksumAlgorithm,
+        partitionChecksums,
+        writeMetrics,
+        taskContext,
+        spills,
+        this::spill);
   }
 
   public long[] getChecksums() {
@@ -237,7 +250,7 @@ public final class CometShuffleExternalSorter implements 
CometShuffleChecksumSup
         }
       }
 
-      activeSpillSorter = new SpillSorter();
+      activeSpillSorter = createSpillSorter();
     } else {
       activeSpillSorter.writeSortedFileNative(false, tracingEnabled);
       final long spillSize = activeSpillSorter.freeMemory();
@@ -410,243 +423,4 @@ public final class CometShuffleExternalSorter implements 
CometShuffleChecksumSup
 
     return spills.toArray(new SpillInfo[spills.size()]);
   }
-
-  class SpillSorter extends SpillWriter {
-    private boolean freed = false;
-
-    private SpillInfo spillInfo;
-
-    // These variables are reset after spilling:
-    @Nullable private ShuffleInMemorySorter inMemSorter;
-
-    // This external sorter can call native code to sort partition ids and 
record pointers of rows.
-    // In order to do that, we need pass the address of the internal array in 
the sorter to native.
-    // But we cannot access it as it is private member in the Spark sorter. 
Instead, we allocate
-    // the array and assign the pointer array in the sorter.
-    private LongArray sorterArray;
-
-    SpillSorter() {
-      this.spillInfo = null;
-
-      this.allocator = CometShuffleExternalSorter.this.allocator;
-
-      // Allocate array for in-memory sorter.
-      // As we cannot access the address of the internal array in the sorter, 
so we need to
-      // allocate the array manually and expand the pointer array in the 
sorter.
-      // We don't want in-memory sorter to allocate memory but the initial 
size cannot be zero.
-      try {
-        this.inMemSorter = new ShuffleInMemorySorter(allocator, 1, true);
-      } catch (java.lang.IllegalAccessError e) {
-        throw new java.lang.RuntimeException(
-            "Error loading in-memory sorter check class path -- see "
-                + 
"https://github.com/apache/arrow-datafusion-comet?tab=readme-ov-file#enable-comet-shuffle";,
-            e);
-      }
-      sorterArray = allocator.allocateArray(initialSize);
-      this.inMemSorter.expandPointerArray(sorterArray);
-
-      this.allocatedPages = new LinkedList<>();
-
-      this.nativeLib = new Native();
-      this.dataTypes = serializeSchema(schema);
-    }
-
-    /** Frees allocated memory pages of this writer */
-    @Override
-    public long freeMemory() {
-      // We need to synchronize here because we may get the memory usage by 
calling
-      // this method in the task thread.
-      synchronized (this) {
-        return super.freeMemory();
-      }
-    }
-
-    @Override
-    public long getMemoryUsage() {
-      // We need to synchronize here because we may free the memory pages in 
another thread,
-      // i.e. when spilling, but this method may be called in the task thread.
-      synchronized (this) {
-        long totalPageSize = super.getMemoryUsage();
-
-        if (freed) {
-          return totalPageSize;
-        } else {
-          return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + 
totalPageSize;
-        }
-      }
-    }
-
-    @Override
-    protected void spill(int required) throws IOException {
-      CometShuffleExternalSorter.this.spill();
-    }
-
-    /** Free the pointer array held by this sorter. */
-    public void freeArray() {
-      synchronized (this) {
-        inMemSorter.free();
-        freed = true;
-      }
-    }
-
-    /**
-     * Reset the in-memory sorter's pointer array only after freeing up the 
memory pages holding the
-     * records.
-     */
-    public void reset() {
-      // We allocate pointer array outside the sorter.
-      // So we can get array address which can be used by native code.
-      inMemSorter.reset();
-      sorterArray = allocator.allocateArray(initialSize);
-      inMemSorter.expandPointerArray(sorterArray);
-    }
-
-    void setSpillInfo(SpillInfo spillInfo) {
-      this.spillInfo = spillInfo;
-    }
-
-    public int numRecords() {
-      return this.inMemSorter.numRecords();
-    }
-
-    public void writeSortedFileNative(boolean isLastFile, boolean 
tracingEnabled)
-        throws IOException {
-      // This call performs the actual sort.
-      long arrayAddr = this.sorterArray.getBaseOffset();
-      int pos = inMemSorter.numRecords();
-      nativeLib.sortRowPartitionsNative(arrayAddr, pos, tracingEnabled);
-      ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords =
-          new ShuffleInMemorySorter.ShuffleSorterIterator(pos, 
this.sorterArray, 0);
-
-      // If there are no sorted records, so we don't need to create an empty 
spill file.
-      if (!sortedRecords.hasNext()) {
-        return;
-      }
-
-      final ShuffleWriteMetricsReporter writeMetricsToUse;
-
-      if (isLastFile) {
-        // We're writing the final non-spill file, so we _do_ want to count 
this as shuffle bytes.
-        writeMetricsToUse = writeMetrics;
-      } else {
-        // We're spilling, so bytes written should be counted towards spill 
rather than write.
-        // Create a dummy WriteMetrics object to absorb these metrics, since 
we don't want to count
-        // them towards shuffle bytes written.
-        writeMetricsToUse = new ShuffleWriteMetrics();
-      }
-
-      int currentPartition = -1;
-
-      final RowPartition rowPartition = new RowPartition(initialSize);
-
-      while (sortedRecords.hasNext()) {
-        sortedRecords.loadNext();
-        final int partition = 
sortedRecords.packedRecordPointer.getPartitionId();
-        assert (partition >= currentPartition);
-        if (partition != currentPartition) {
-          // Switch to the new partition
-          if (currentPartition != -1) {
-
-            if (partitionChecksums.length > 0) {
-              // If checksum is enabled, we need to update the checksum for 
the current partition.
-              setChecksum(partitionChecksums[currentPartition]);
-              setChecksumAlgo(checksumAlgorithm);
-            }
-
-            long written =
-                doSpilling(
-                    dataTypes,
-                    spillInfo.file,
-                    rowPartition,
-                    writeMetricsToUse,
-                    preferDictionaryRatio,
-                    compressionCodec,
-                    compressionLevel,
-                    tracingEnabled);
-            spillInfo.partitionLengths[currentPartition] = written;
-
-            // Store the checksum for the current partition.
-            partitionChecksums[currentPartition] = getChecksum();
-          }
-          currentPartition = partition;
-        }
-
-        final long recordPointer = 
sortedRecords.packedRecordPointer.getRecordPointer();
-        final long recordOffsetInPage = 
allocator.getOffsetInPage(recordPointer);
-        // Note that we need to skip over record key (partition id)
-        // Note that we already use off-heap memory for serialized rows, so 
recordPage is always
-        // null.
-        int recordSizeInBytes = UnsafeAlignedOffset.getSize(null, 
recordOffsetInPage) - 4;
-        long recordReadPosition = recordOffsetInPage + uaoSize + 4; // skip 
over record length too
-        rowPartition.addRow(recordReadPosition, recordSizeInBytes);
-      }
-
-      if (currentPartition != -1) {
-        long written =
-            doSpilling(
-                dataTypes,
-                spillInfo.file,
-                rowPartition,
-                writeMetricsToUse,
-                preferDictionaryRatio,
-                compressionCodec,
-                compressionLevel,
-                tracingEnabled);
-        spillInfo.partitionLengths[currentPartition] = written;
-
-        synchronized (spills) {
-          spills.add(spillInfo);
-        }
-      }
-
-      if (!isLastFile) { // i.e. this is a spill file
-        // The current semantics of `shuffleRecordsWritten` seem to be that 
it's updated when
-        // records
-        // are written to disk, not when they enter the shuffle sorting code. 
DiskBlockObjectWriter
-        // relies on its `recordWritten()` method being called in order to 
trigger periodic updates
-        // to
-        // `shuffleBytesWritten`. If we were to remove the `recordWritten()` 
call and increment that
-        // counter at a higher-level, then the in-progress metrics for records 
written and bytes
-        // written would get out of sync.
-        //
-        // When writing the last file, we pass `writeMetrics` directly to the 
DiskBlockObjectWriter;
-        // in all other cases, we pass in a dummy write metrics to capture 
metrics, then copy those
-        // metrics to the true write metrics here. The reason for performing 
this copying is so that
-        // we can avoid reporting spilled bytes as shuffle write bytes.
-        //
-        // Note that we intentionally ignore the value of 
`writeMetricsToUse.shuffleWriteTime()`.
-        // Consistent with ExternalSorter, we do not count this IO towards 
shuffle write time.
-        // SPARK-3577 tracks the spill time separately.
-
-        // This is guaranteed to be a ShuffleWriteMetrics based on the if 
check in the beginning
-        // of this method.
-        synchronized (writeMetrics) {
-          writeMetrics.incRecordsWritten(
-              ((ShuffleWriteMetrics) writeMetricsToUse).recordsWritten());
-          taskContext
-              .taskMetrics()
-              .incDiskBytesSpilled(((ShuffleWriteMetrics) 
writeMetricsToUse).bytesWritten());
-        }
-      }
-    }
-
-    public boolean hasSpaceForAnotherRecord() {
-      return inMemSorter.hasSpaceForAnotherRecord();
-    }
-
-    public void expandPointerArray(LongArray newArray) {
-      inMemSorter.expandPointerArray(newArray);
-      this.sorterArray = newArray;
-    }
-
-    public void insertRecord(Object recordBase, long recordOffset, int length, 
int partitionId) {
-      final Object base = currentPage.getBaseObject();
-      final long recordAddress = 
allocator.encodePageNumberAndOffset(currentPage, pageCursor);
-      UnsafeAlignedOffset.putSize(base, pageCursor, length);
-      pageCursor += uaoSize;
-      Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
-      pageCursor += length;
-      inMemSorter.insertRecord(recordAddress, partitionId);
-    }
-  }
 }
diff --git a/spark/src/main/java/org/apache/spark/shuffle/sort/SpillSorter.java 
b/spark/src/main/java/org/apache/spark/shuffle/sort/SpillSorter.java
new file mode 100644
index 000000000..36b50e620
--- /dev/null
+++ b/spark/src/main/java/org/apache/spark/shuffle/sort/SpillSorter.java
@@ -0,0 +1,352 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import java.io.IOException;
+import java.util.LinkedList;
+import javax.annotation.Nullable;
+
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
+import org.apache.spark.shuffle.comet.CometShuffleMemoryAllocatorTrait;
+import org.apache.spark.sql.comet.execution.shuffle.SpillInfo;
+import org.apache.spark.sql.comet.execution.shuffle.SpillWriter;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.UnsafeAlignedOffset;
+import org.apache.spark.unsafe.array.LongArray;
+
+import org.apache.comet.Native;
+
+/**
+ * A spill sorter that buffers records in memory, sorts them by partition ID, 
and writes them to
+ * disk. This class is used by CometShuffleExternalSorter to manage individual 
spill operations.
+ *
+ * <p>Each SpillSorter instance manages its own memory pages and pointer 
array. When spilling is
+ * triggered, the records are sorted by partition ID using native code and 
written to a spill file.
+ */
+public class SpillSorter extends SpillWriter {
+
+  /** Callback interface for triggering spill operations in the parent sorter. 
*/
+  @FunctionalInterface
+  public interface SpillCallback {
+    void onSpillRequired() throws IOException;
+  }
+
+  // Configuration fields (immutable after construction)
+  private final int initialSize;
+  private final int uaoSize;
+  private final double preferDictionaryRatio;
+  private final String compressionCodec;
+  private final int compressionLevel;
+  private final String checksumAlgorithm;
+
+  // Shared state (mutable, passed by reference from parent)
+  private final long[] partitionChecksums;
+  private final ShuffleWriteMetricsReporter writeMetrics;
+  private final TaskContext taskContext;
+  private final LinkedList<SpillInfo> spills;
+  private final SpillCallback spillCallback;
+
+  // Internal state
+  private boolean freed = false;
+  private SpillInfo spillInfo;
+  @Nullable private ShuffleInMemorySorter inMemSorter;
+  private LongArray sorterArray;
+
+  /**
+   * Creates a new SpillSorter with explicit dependencies.
+   *
+   * @param allocator Memory allocator for pages and arrays
+   * @param initialSize Initial size for the pointer array
+   * @param schema Schema of the records being sorted
+   * @param uaoSize Size of UnsafeAlignedOffset (4 or 8 bytes)
+   * @param preferDictionaryRatio Dictionary encoding preference ratio
+   * @param compressionCodec Compression codec for spill files
+   * @param compressionLevel Compression level
+   * @param checksumAlgorithm Checksum algorithm (e.g., "crc32", "adler32")
+   * @param partitionChecksums Array to store partition checksums (shared with 
parent)
+   * @param writeMetrics Metrics reporter for shuffle writes
+   * @param taskContext Task context for metrics updates
+   * @param spills List to accumulate spill info (shared with parent)
+   * @param spillCallback Callback to trigger spill in parent sorter
+   */
+  public SpillSorter(
+      CometShuffleMemoryAllocatorTrait allocator,
+      int initialSize,
+      StructType schema,
+      int uaoSize,
+      double preferDictionaryRatio,
+      String compressionCodec,
+      int compressionLevel,
+      String checksumAlgorithm,
+      long[] partitionChecksums,
+      ShuffleWriteMetricsReporter writeMetrics,
+      TaskContext taskContext,
+      LinkedList<SpillInfo> spills,
+      SpillCallback spillCallback) {
+
+    this.initialSize = initialSize;
+    this.uaoSize = uaoSize;
+    this.preferDictionaryRatio = preferDictionaryRatio;
+    this.compressionCodec = compressionCodec;
+    this.compressionLevel = compressionLevel;
+    this.checksumAlgorithm = checksumAlgorithm;
+    this.partitionChecksums = partitionChecksums;
+    this.writeMetrics = writeMetrics;
+    this.taskContext = taskContext;
+    this.spills = spills;
+    this.spillCallback = spillCallback;
+
+    this.spillInfo = null;
+    this.allocator = allocator;
+
+    // Allocate array for in-memory sorter.
+    // As we cannot access the address of the internal array in the sorter, so 
we need to
+    // allocate the array manually and expand the pointer array in the sorter.
+    // We don't want in-memory sorter to allocate memory but the initial size 
cannot be zero.
+    try {
+      this.inMemSorter = new ShuffleInMemorySorter(allocator, 1, true);
+    } catch (java.lang.IllegalAccessError e) {
+      throw new java.lang.RuntimeException(
+          "Error loading in-memory sorter check class path -- see "
+              + 
"https://github.com/apache/arrow-datafusion-comet?tab=readme-ov-file#enable-comet-shuffle";,
+          e);
+    }
+    sorterArray = allocator.allocateArray(initialSize);
+    this.inMemSorter.expandPointerArray(sorterArray);
+
+    this.allocatedPages = new LinkedList<>();
+
+    this.nativeLib = new Native();
+    this.dataTypes = serializeSchema(schema);
+  }
+
+  /** Frees allocated memory pages of this writer */
+  @Override
+  public long freeMemory() {
+    // We need to synchronize here because we may get the memory usage by 
calling
+    // this method in the task thread.
+    synchronized (this) {
+      return super.freeMemory();
+    }
+  }
+
+  @Override
+  public long getMemoryUsage() {
+    // We need to synchronize here because we may free the memory pages in 
another thread,
+    // i.e. when spilling, but this method may be called in the task thread.
+    synchronized (this) {
+      long totalPageSize = super.getMemoryUsage();
+
+      if (freed) {
+        return totalPageSize;
+      } else {
+        return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + 
totalPageSize;
+      }
+    }
+  }
+
+  @Override
+  protected void spill(int required) throws IOException {
+    spillCallback.onSpillRequired();
+  }
+
+  /** Free the pointer array held by this sorter. */
+  public void freeArray() {
+    synchronized (this) {
+      inMemSorter.free();
+      freed = true;
+    }
+  }
+
+  /**
+   * Reset the in-memory sorter's pointer array only after freeing up the 
memory pages holding the
+   * records.
+   */
+  public void reset() {
+    synchronized (this) {
+      // We allocate pointer array outside the sorter.
+      // So we can get array address which can be used by native code.
+      inMemSorter.reset();
+      sorterArray = allocator.allocateArray(initialSize);
+      inMemSorter.expandPointerArray(sorterArray);
+      freed = false;
+    }
+  }
+
+  void setSpillInfo(SpillInfo spillInfo) {
+    this.spillInfo = spillInfo;
+  }
+
+  public int numRecords() {
+    return this.inMemSorter.numRecords();
+  }
+
+  public void writeSortedFileNative(boolean isLastFile, boolean 
tracingEnabled) throws IOException {
+    // This call performs the actual sort.
+    long arrayAddr = this.sorterArray.getBaseOffset();
+    int pos = inMemSorter.numRecords();
+    nativeLib.sortRowPartitionsNative(arrayAddr, pos, tracingEnabled);
+    ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords =
+        new ShuffleInMemorySorter.ShuffleSorterIterator(pos, this.sorterArray, 
0);
+
+    // If there are no sorted records, so we don't need to create an empty 
spill file.
+    if (!sortedRecords.hasNext()) {
+      return;
+    }
+
+    final ShuffleWriteMetricsReporter writeMetricsToUse;
+
+    if (isLastFile) {
+      // We're writing the final non-spill file, so we _do_ want to count this 
as shuffle bytes.
+      writeMetricsToUse = writeMetrics;
+    } else {
+      // We're spilling, so bytes written should be counted towards spill 
rather than write.
+      // Create a dummy WriteMetrics object to absorb these metrics, since we 
don't want to count
+      // them towards shuffle bytes written.
+      writeMetricsToUse = new ShuffleWriteMetrics();
+    }
+
+    int currentPartition = -1;
+
+    final RowPartition rowPartition = new RowPartition(initialSize);
+
+    while (sortedRecords.hasNext()) {
+      sortedRecords.loadNext();
+      final int partition = sortedRecords.packedRecordPointer.getPartitionId();
+      assert (partition >= currentPartition);
+      if (partition != currentPartition) {
+        // Switch to the new partition
+        if (currentPartition != -1) {
+
+          if (partitionChecksums.length > 0) {
+            // If checksum is enabled, we need to update the checksum for the 
current partition.
+            setChecksum(partitionChecksums[currentPartition]);
+            setChecksumAlgo(checksumAlgorithm);
+          }
+
+          long written =
+              doSpilling(
+                  dataTypes,
+                  spillInfo.file,
+                  rowPartition,
+                  writeMetricsToUse,
+                  preferDictionaryRatio,
+                  compressionCodec,
+                  compressionLevel,
+                  tracingEnabled);
+          spillInfo.partitionLengths[currentPartition] = written;
+
+          // Store the checksum for the current partition.
+          partitionChecksums[currentPartition] = getChecksum();
+        }
+        currentPartition = partition;
+      }
+
+      final long recordPointer = 
sortedRecords.packedRecordPointer.getRecordPointer();
+      final long recordOffsetInPage = allocator.getOffsetInPage(recordPointer);
+      // Note that we need to skip over record key (partition id)
+      // Note that we already use off-heap memory for serialized rows, so 
recordPage is always
+      // null.
+      int recordSizeInBytes = UnsafeAlignedOffset.getSize(null, 
recordOffsetInPage) - 4;
+      long recordReadPosition = recordOffsetInPage + uaoSize + 4; // skip over 
record length too
+      rowPartition.addRow(recordReadPosition, recordSizeInBytes);
+    }
+
+    if (currentPartition != -1) {
+      if (partitionChecksums.length > 0) {
+        // If checksum is enabled, we need to update the checksum for the last 
partition.
+        setChecksum(partitionChecksums[currentPartition]);
+        setChecksumAlgo(checksumAlgorithm);
+      }
+
+      long written =
+          doSpilling(
+              dataTypes,
+              spillInfo.file,
+              rowPartition,
+              writeMetricsToUse,
+              preferDictionaryRatio,
+              compressionCodec,
+              compressionLevel,
+              tracingEnabled);
+      spillInfo.partitionLengths[currentPartition] = written;
+
+      // Store the checksum for the last partition.
+      if (partitionChecksums.length > 0) {
+        partitionChecksums[currentPartition] = getChecksum();
+      }
+
+      synchronized (spills) {
+        spills.add(spillInfo);
+      }
+    }
+
+    if (!isLastFile) { // i.e. this is a spill file
+      // The current semantics of `shuffleRecordsWritten` seem to be that it's 
updated when
+      // records
+      // are written to disk, not when they enter the shuffle sorting code. 
DiskBlockObjectWriter
+      // relies on its `recordWritten()` method being called in order to 
trigger periodic updates
+      // to
+      // `shuffleBytesWritten`. If we were to remove the `recordWritten()` 
call and increment that
+      // counter at a higher-level, then the in-progress metrics for records 
written and bytes
+      // written would get out of sync.
+      //
+      // When writing the last file, we pass `writeMetrics` directly to the 
DiskBlockObjectWriter;
+      // in all other cases, we pass in a dummy write metrics to capture 
metrics, then copy those
+      // metrics to the true write metrics here. The reason for performing 
this copying is so that
+      // we can avoid reporting spilled bytes as shuffle write bytes.
+      //
+      // Note that we intentionally ignore the value of 
`writeMetricsToUse.shuffleWriteTime()`.
+      // Consistent with ExternalSorter, we do not count this IO towards 
shuffle write time.
+      // SPARK-3577 tracks the spill time separately.
+
+      // This is guaranteed to be a ShuffleWriteMetrics based on the if check 
in the beginning
+      // of this method.
+      synchronized (writeMetrics) {
+        writeMetrics.incRecordsWritten(((ShuffleWriteMetrics) 
writeMetricsToUse).recordsWritten());
+        taskContext
+            .taskMetrics()
+            .incDiskBytesSpilled(((ShuffleWriteMetrics) 
writeMetricsToUse).bytesWritten());
+      }
+    }
+  }
+
+  public boolean hasSpaceForAnotherRecord() {
+    return inMemSorter.hasSpaceForAnotherRecord();
+  }
+
+  public void expandPointerArray(LongArray newArray) {
+    inMemSorter.expandPointerArray(newArray);
+    this.sorterArray = newArray;
+  }
+
+  public void insertRecord(Object recordBase, long recordOffset, int length, 
int partitionId) {
+    final Object base = currentPage.getBaseObject();
+    final long recordAddress = 
allocator.encodePageNumberAndOffset(currentPage, pageCursor);
+    UnsafeAlignedOffset.putSize(base, pageCursor, length);
+    pageCursor += uaoSize;
+    Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
+    pageCursor += length;
+    inMemSorter.insertRecord(recordAddress, partitionId);
+  }
+}
diff --git 
a/spark/src/test/scala/org/apache/spark/shuffle/sort/SpillSorterSuite.scala 
b/spark/src/test/scala/org/apache/spark/shuffle/sort/SpillSorterSuite.scala
new file mode 100644
index 000000000..dfbe38b64
--- /dev/null
+++ b/spark/src/test/scala/org/apache/spark/shuffle/sort/SpillSorterSuite.scala
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.shuffle.sort
+
+import java.util.concurrent.atomic.AtomicInteger
+
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.funsuite.AnyFunSuite
+
+import org.apache.spark.{SparkConf, TaskContext}
+import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
+import org.apache.spark.shuffle.comet.CometShuffleMemoryAllocator
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.Platform
+import org.apache.spark.unsafe.UnsafeAlignedOffset
+
+/**
+ * Unit tests for [[SpillSorter]].
+ *
+ * These tests verify SpillSorter behavior using Spark's test memory 
management infrastructure,
+ * without needing a full SparkContext.
+ */
+class SpillSorterSuite extends AnyFunSuite with BeforeAndAfterEach {
+
+  private val INITIAL_SIZE = 1024
+  private val UAO_SIZE = UnsafeAlignedOffset.getUaoSize
+  private val PAGE_SIZE = 4 * 1024 * 1024 // 4MB
+
+  private var conf: SparkConf = _
+  private var memoryManager: TestMemoryManager = _
+  private var taskMemoryManager: TaskMemoryManager = _
+
+  override def beforeEach(): Unit = {
+    super.beforeEach()
+    conf = new SparkConf()
+      .set("spark.memory.offHeap.enabled", "false")
+    memoryManager = new TestMemoryManager(conf)
+    memoryManager.limit(100 * 1024 * 1024) // 100MB
+    taskMemoryManager = new TaskMemoryManager(memoryManager, 0)
+  }
+
+  override def afterEach(): Unit = {
+    if (taskMemoryManager != null) {
+      taskMemoryManager.cleanUpAllAllocatedMemory()
+      taskMemoryManager = null
+    }
+    memoryManager = null
+    super.afterEach()
+  }
+
+  private def createTestSchema(): StructType = {
+    new StructType().add("id", IntegerType)
+  }
+
+  private def createSpillSorter(
+      spillCallback: SpillSorter.SpillCallback = () => {},
+      spills: 
java.util.LinkedList[org.apache.spark.sql.comet.execution.shuffle.SpillInfo] =
+        new 
java.util.LinkedList[org.apache.spark.sql.comet.execution.shuffle.SpillInfo](),
+      partitionChecksums: Array[Long] = new Array[Long](10)): SpillSorter = {
+    val allocator = CometShuffleMemoryAllocator.getInstance(conf, 
taskMemoryManager, PAGE_SIZE)
+    val schema = createTestSchema()
+    val writeMetrics = new ShuffleWriteMetrics()
+    val taskContext = TaskContext.empty()
+
+    new SpillSorter(
+      allocator,
+      INITIAL_SIZE,
+      schema,
+      UAO_SIZE,
+      0.5, // preferDictionaryRatio
+      "zstd", // compressionCodec
+      1, // compressionLevel
+      "adler32", // checksumAlgorithm
+      partitionChecksums,
+      writeMetrics,
+      taskContext,
+      spills,
+      spillCallback)
+  }
+
+  test("initial state") {
+    val sorter = createSpillSorter()
+    try {
+      assert(sorter.numRecords() === 0)
+      assert(sorter.hasSpaceForAnotherRecord())
+      assert(sorter.getMemoryUsage() > 0)
+    } finally {
+      sorter.freeMemory()
+      sorter.freeArray()
+    }
+  }
+
+  test("insert single record") {
+    val sorter = createSpillSorter()
+    try {
+      val recordData = Array[Byte](1, 2, 3, 4)
+      val partitionId = 0
+
+      sorter.initialCurrentPage(recordData.length + UAO_SIZE)
+      sorter.insertRecord(recordData, Platform.BYTE_ARRAY_OFFSET, 
recordData.length, partitionId)
+
+      assert(sorter.numRecords() === 1)
+    } finally {
+      sorter.freeMemory()
+      sorter.freeArray()
+    }
+  }
+
+  test("insert multiple records") {
+    val sorter = createSpillSorter()
+    try {
+      val recordData = Array[Byte](1, 2, 3, 4)
+      val numRecords = 100
+
+      sorter.initialCurrentPage(numRecords * (recordData.length + UAO_SIZE))
+
+      for (i <- 0 until numRecords) {
+        val partitionId = i % 10
+        sorter.insertRecord(
+          recordData,
+          Platform.BYTE_ARRAY_OFFSET,
+          recordData.length,
+          partitionId)
+      }
+
+      assert(sorter.numRecords() === numRecords)
+    } finally {
+      sorter.freeMemory()
+      sorter.freeArray()
+    }
+  }
+
+  test("reset after free") {
+    val sorter = createSpillSorter()
+    try {
+      val recordData = Array[Byte](1, 2, 3, 4)
+      sorter.initialCurrentPage(recordData.length + UAO_SIZE)
+      sorter.insertRecord(recordData, Platform.BYTE_ARRAY_OFFSET, 
recordData.length, 0)
+
+      assert(sorter.numRecords() === 1)
+
+      sorter.freeMemory()
+      sorter.reset()
+
+      assert(sorter.numRecords() === 0)
+      assert(sorter.hasSpaceForAnotherRecord())
+    } finally {
+      sorter.freeMemory()
+      sorter.freeArray()
+    }
+  }
+
+  test("free memory returns correct value") {
+    val sorter = createSpillSorter()
+    try {
+      sorter.initialCurrentPage(1024)
+      val memoryBefore = sorter.getMemoryUsage()
+      assert(memoryBefore > 0)
+
+      val freed = sorter.freeMemory()
+      assert(freed > 0)
+    } finally {
+      sorter.freeArray()
+    }
+  }
+
+  test("spill callback not triggered during normal operations") {
+    val spillCount = new AtomicInteger(0)
+    val callback: SpillSorter.SpillCallback = () => 
spillCount.incrementAndGet()
+
+    val sorter = createSpillSorter(spillCallback = callback)
+    try {
+      sorter.initialCurrentPage(1024)
+      val recordData = Array[Byte](1, 2, 3, 4)
+      sorter.insertRecord(recordData, Platform.BYTE_ARRAY_OFFSET, 
recordData.length, 0)
+
+      assert(spillCount.get() === 0, "Spill callback should not be triggered 
during normal ops")
+    } finally {
+      sorter.freeMemory()
+      sorter.freeArray()
+    }
+  }
+
+  test("getMemoryUsage is thread-safe") {
+    val sorter = createSpillSorter()
+    try {
+      sorter.initialCurrentPage(1024)
+
+      val threads = (0 until 10).map { _ =>
+        new Thread(() => {
+          for (_ <- 0 until 100) {
+            sorter.getMemoryUsage()
+          }
+        })
+      }
+
+      threads.foreach(_.start())
+      threads.foreach(_.join())
+      // Test passes if no exceptions thrown
+    } finally {
+      sorter.freeMemory()
+      sorter.freeArray()
+    }
+  }
+
+  test("expand pointer array") {
+    val sorter = createSpillSorter()
+    try {
+      val initialMemory = sorter.getMemoryUsage()
+      val allocator = CometShuffleMemoryAllocator.getInstance(conf, 
taskMemoryManager, PAGE_SIZE)
+      val newArray = allocator.allocateArray(INITIAL_SIZE * 2)
+      sorter.expandPointerArray(newArray)
+
+      assert(sorter.getMemoryUsage() >= initialMemory)
+    } finally {
+      sorter.freeMemory()
+      sorter.freeArray()
+    }
+  }
+
+  test("records distributed across partitions") {
+    val sorter = createSpillSorter()
+    try {
+      val recordData = Array[Byte](1, 2, 3, 4)
+      val numPartitions = 5
+      val recordsPerPartition = 20
+
+      sorter.initialCurrentPage(
+        numPartitions * recordsPerPartition * (recordData.length + UAO_SIZE))
+
+      for (p <- 0 until numPartitions) {
+        for (_ <- 0 until recordsPerPartition) {
+          sorter.insertRecord(recordData, Platform.BYTE_ARRAY_OFFSET, 
recordData.length, p)
+        }
+      }
+
+      assert(sorter.numRecords() === numPartitions * recordsPerPartition)
+    } finally {
+      sorter.freeMemory()
+      sorter.freeArray()
+    }
+  }
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to