This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-ballista.git


The following commit(s) were added to refs/heads/main by this push:
     new 6bb5eb4e7 Improve sort-based shuffle: single spill file per partition 
and batch coalescing (#1431)
6bb5eb4e7 is described below

commit 6bb5eb4e7c5f9bc8ab073aab742903c1f3076d9c
Author: Andy Grove <[email protected]>
AuthorDate: Sun Feb 1 11:01:13 2026 -0700

    Improve sort-based shuffle: single spill file per partition and batch 
coalescing (#1431)
---
 Cargo.lock                                         | 251 +++++++++++++---
 ballista/core/proto/ballista.proto                 |   1 +
 ballista/core/src/config.rs                        |  14 +-
 .../src/execution_plans/sort_shuffle/buffer.rs     |  56 ++++
 .../src/execution_plans/sort_shuffle/config.rs     |   5 +
 .../core/src/execution_plans/sort_shuffle/spill.rs | 153 +++++-----
 .../src/execution_plans/sort_shuffle/writer.rs     |  27 +-
 ballista/core/src/serde/generated/ballista.rs      |   2 +
 ballista/core/src/serde/mod.rs                     |   7 +
 ballista/scheduler/src/planner.rs                  |   1 +
 benchmarks/Cargo.toml                              |   5 +
 benchmarks/benches/sort_shuffle.rs                 | 315 +++++++++++++++++++++
 benchmarks/src/bin/shuffle_bench.rs                |   1 +
 13 files changed, 728 insertions(+), 110 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 487dc4b29..981c8de66 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -61,6 +61,12 @@ dependencies = [
  "libc",
 ]
 
+[[package]]
+name = "anes"
+version = "0.1.6"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299"
+
 [[package]]
 name = "anstream"
 version = "0.6.21"
@@ -97,7 +103,7 @@ version = "1.1.5"
 source = "registry+https://github.com/rust-lang/crates.io-index";
 checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc"
 dependencies = [
- "windows-sys 0.60.2",
+ "windows-sys 0.61.2",
 ]
 
 [[package]]
@@ -108,7 +114,7 @@ checksum = 
"291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d"
 dependencies = [
  "anstyle",
  "once_cell_polyfill",
- "windows-sys 0.60.2",
+ "windows-sys 0.61.2",
 ]
 
 [[package]]
@@ -940,6 +946,7 @@ version = "52.0.0"
 dependencies = [
  "ballista",
  "ballista-core",
+ "criterion",
  "datafusion",
  "datafusion-proto",
  "env_logger",
@@ -982,7 +989,7 @@ dependencies = [
  "datafusion-proto",
  "datafusion-proto-common",
  "futures",
- "itertools",
+ "itertools 0.14.0",
  "log",
  "md-5",
  "object_store",
@@ -1334,6 +1341,12 @@ dependencies = [
  "libbz2-rs-sys",
 ]
 
+[[package]]
+name = "cast"
+version = "0.3.0"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
+
 [[package]]
 name = "cc"
 version = "1.2.54"
@@ -1380,6 +1393,33 @@ dependencies = [
  "phf",
 ]
 
+[[package]]
+name = "ciborium"
+version = "0.2.2"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e"
+dependencies = [
+ "ciborium-io",
+ "ciborium-ll",
+ "serde",
+]
+
+[[package]]
+name = "ciborium-io"
+version = "0.2.2"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757"
+
+[[package]]
+name = "ciborium-ll"
+version = "0.2.2"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9"
+dependencies = [
+ "ciborium-io",
+ "half",
+]
+
 [[package]]
 name = "clap"
 version = "2.34.0"
@@ -1546,6 +1586,42 @@ dependencies = [
  "cfg-if",
 ]
 
+[[package]]
+name = "criterion"
+version = "0.5.1"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f"
+dependencies = [
+ "anes",
+ "cast",
+ "ciborium",
+ "clap 4.5.56",
+ "criterion-plot",
+ "is-terminal",
+ "itertools 0.10.5",
+ "num-traits",
+ "once_cell",
+ "oorandom",
+ "plotters",
+ "rayon",
+ "regex",
+ "serde",
+ "serde_derive",
+ "serde_json",
+ "tinytemplate",
+ "walkdir",
+]
+
+[[package]]
+name = "criterion-plot"
+version = "0.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1"
+dependencies = [
+ "cast",
+ "itertools 0.10.5",
+]
+
 [[package]]
 name = "crossbeam-channel"
 version = "0.5.15"
@@ -1555,6 +1631,25 @@ dependencies = [
  "crossbeam-utils",
 ]
 
+[[package]]
+name = "crossbeam-deque"
+version = "0.8.6"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
+dependencies = [
+ "crossbeam-epoch",
+ "crossbeam-utils",
+]
+
+[[package]]
+name = "crossbeam-epoch"
+version = "0.9.18"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
+dependencies = [
+ "crossbeam-utils",
+]
+
 [[package]]
 name = "crossbeam-utils"
 version = "0.8.21"
@@ -1737,7 +1832,7 @@ dependencies = [
  "datafusion-sql",
  "flate2",
  "futures",
- "itertools",
+ "itertools 0.14.0",
  "liblzma",
  "log",
  "object_store",
@@ -1771,7 +1866,7 @@ dependencies = [
  "datafusion-physical-plan",
  "datafusion-session",
  "futures",
- "itertools",
+ "itertools 0.14.0",
  "log",
  "object_store",
  "parking_lot",
@@ -1796,7 +1891,7 @@ dependencies = [
  "datafusion-physical-expr-common",
  "datafusion-physical-plan",
  "futures",
- "itertools",
+ "itertools 0.14.0",
  "log",
  "object_store",
 ]
@@ -1890,7 +1985,7 @@ dependencies = [
  "flate2",
  "futures",
  "glob",
- "itertools",
+ "itertools 0.14.0",
  "liblzma",
  "log",
  "object_store",
@@ -1920,7 +2015,7 @@ dependencies = [
  "datafusion-physical-plan",
  "datafusion-session",
  "futures",
- "itertools",
+ "itertools 0.14.0",
  "object_store",
  "tokio",
 ]
@@ -2012,7 +2107,7 @@ dependencies = [
  "datafusion-pruning",
  "datafusion-session",
  "futures",
- "itertools",
+ "itertools 0.14.0",
  "log",
  "object_store",
  "parking_lot",
@@ -2064,7 +2159,7 @@ dependencies = [
  "datafusion-functions-window-common",
  "datafusion-physical-expr-common",
  "indexmap 2.13.0",
- "itertools",
+ "itertools 0.14.0",
  "paste",
  "recursive",
  "serde_json",
@@ -2080,7 +2175,7 @@ dependencies = [
  "arrow",
  "datafusion-common",
  "indexmap 2.13.0",
- "itertools",
+ "itertools 0.14.0",
  "paste",
 ]
 
@@ -2104,7 +2199,7 @@ dependencies = [
  "datafusion-expr-common",
  "datafusion-macros",
  "hex",
- "itertools",
+ "itertools 0.14.0",
  "log",
  "md-5",
  "num-traits",
@@ -2167,7 +2262,7 @@ dependencies = [
  "datafusion-functions-aggregate-common",
  "datafusion-macros",
  "datafusion-physical-expr-common",
- "itertools",
+ "itertools 0.14.0",
  "log",
  "paste",
 ]
@@ -2240,7 +2335,7 @@ dependencies = [
  "datafusion-expr-common",
  "datafusion-physical-expr",
  "indexmap 2.13.0",
- "itertools",
+ "itertools 0.14.0",
  "log",
  "recursive",
  "regex",
@@ -2263,7 +2358,7 @@ dependencies = [
  "half",
  "hashbrown 0.16.1",
  "indexmap 2.13.0",
- "itertools",
+ "itertools 0.14.0",
  "parking_lot",
  "paste",
  "petgraph",
@@ -2283,7 +2378,7 @@ dependencies = [
  "datafusion-functions",
  "datafusion-physical-expr",
  "datafusion-physical-expr-common",
- "itertools",
+ "itertools 0.14.0",
 ]
 
 [[package]]
@@ -2299,7 +2394,7 @@ dependencies = [
  "datafusion-expr-common",
  "hashbrown 0.16.1",
  "indexmap 2.13.0",
- "itertools",
+ "itertools 0.14.0",
  "parking_lot",
 ]
 
@@ -2318,7 +2413,7 @@ dependencies = [
  "datafusion-physical-expr-common",
  "datafusion-physical-plan",
  "datafusion-pruning",
- "itertools",
+ "itertools 0.14.0",
  "recursive",
 ]
 
@@ -2346,7 +2441,7 @@ dependencies = [
  "half",
  "hashbrown 0.16.1",
  "indexmap 2.13.0",
- "itertools",
+ "itertools 0.14.0",
  "log",
  "parking_lot",
  "pin-project-lite",
@@ -2404,7 +2499,7 @@ dependencies = [
  "datafusion-physical-expr",
  "datafusion-physical-expr-common",
  "datafusion-physical-plan",
- "itertools",
+ "itertools 0.14.0",
  "log",
 ]
 
@@ -2451,7 +2546,7 @@ dependencies = [
  "chrono",
  "datafusion",
  "half",
- "itertools",
+ "itertools 0.14.0",
  "object_store",
  "pbjson-types",
  "prost",
@@ -2500,7 +2595,7 @@ dependencies = [
  "libc",
  "option-ext",
  "redox_users",
- "windows-sys 0.59.0",
+ "windows-sys 0.61.2",
 ]
 
 [[package]]
@@ -2615,7 +2710,7 @@ source = 
"registry+https://github.com/rust-lang/crates.io-index";
 checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
 dependencies = [
  "libc",
- "windows-sys 0.59.0",
+ "windows-sys 0.61.2",
 ]
 
 [[package]]
@@ -2971,6 +3066,12 @@ version = "0.5.0"
 source = "registry+https://github.com/rust-lang/crates.io-index";
 checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
 
+[[package]]
+name = "hermit-abi"
+version = "0.5.2"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c"
+
 [[package]]
 name = "hex"
 version = "0.4.3"
@@ -3374,12 +3475,32 @@ dependencies = [
  "serde",
 ]
 
+[[package]]
+name = "is-terminal"
+version = "0.4.17"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46"
+dependencies = [
+ "hermit-abi",
+ "libc",
+ "windows-sys 0.61.2",
+]
+
 [[package]]
 name = "is_terminal_polyfill"
 version = "1.70.2"
 source = "registry+https://github.com/rust-lang/crates.io-index";
 checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695"
 
+[[package]]
+name = "itertools"
+version = "0.10.5"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473"
+dependencies = [
+ "either",
+]
+
 [[package]]
 name = "itertools"
 version = "0.14.0"
@@ -3709,7 +3830,7 @@ version = "0.50.3"
 source = "registry+https://github.com/rust-lang/crates.io-index";
 checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
 dependencies = [
- "windows-sys 0.59.0",
+ "windows-sys 0.61.2",
 ]
 
 [[package]]
@@ -3818,7 +3939,7 @@ dependencies = [
  "http-body-util",
  "humantime",
  "hyper",
- "itertools",
+ "itertools 0.14.0",
  "md-5",
  "parking_lot",
  "percent-encoding",
@@ -3851,6 +3972,12 @@ version = "1.70.2"
 source = "registry+https://github.com/rust-lang/crates.io-index";
 checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe"
 
+[[package]]
+name = "oorandom"
+version = "11.1.5"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e"
+
 [[package]]
 name = "openssl-probe"
 version = "0.2.1"
@@ -3987,7 +4114,7 @@ source = 
"registry+https://github.com/rust-lang/crates.io-index";
 checksum = "af22d08a625a2213a78dbb0ffa253318c5c79ce3133d32d296655a7bdfb02095"
 dependencies = [
  "heck 0.5.0",
- "itertools",
+ "itertools 0.14.0",
  "prost",
  "prost-types",
 ]
@@ -4124,6 +4251,34 @@ version = "0.3.32"
 source = "registry+https://github.com/rust-lang/crates.io-index";
 checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
 
+[[package]]
+name = "plotters"
+version = "0.3.7"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747"
+dependencies = [
+ "num-traits",
+ "plotters-backend",
+ "plotters-svg",
+ "wasm-bindgen",
+ "web-sys",
+]
+
+[[package]]
+name = "plotters-backend"
+version = "0.3.7"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a"
+
+[[package]]
+name = "plotters-svg"
+version = "0.3.7"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670"
+dependencies = [
+ "plotters-backend",
+]
+
 [[package]]
 name = "portable-atomic"
 version = "1.13.0"
@@ -4271,7 +4426,7 @@ source = 
"registry+https://github.com/rust-lang/crates.io-index";
 checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7"
 dependencies = [
  "heck 0.5.0",
- "itertools",
+ "itertools 0.14.0",
  "log",
  "multimap",
  "petgraph",
@@ -4292,7 +4447,7 @@ source = 
"registry+https://github.com/rust-lang/crates.io-index";
 checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b"
 dependencies = [
  "anyhow",
- "itertools",
+ "itertools 0.14.0",
  "proc-macro2",
  "quote",
  "syn 2.0.114",
@@ -4425,7 +4580,7 @@ dependencies = [
  "once_cell",
  "socket2",
  "tracing",
- "windows-sys 0.59.0",
+ "windows-sys 0.60.2",
 ]
 
 [[package]]
@@ -4482,6 +4637,26 @@ dependencies = [
  "getrandom 0.3.4",
 ]
 
+[[package]]
+name = "rayon"
+version = "1.11.0"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f"
+dependencies = [
+ "either",
+ "rayon-core",
+]
+
+[[package]]
+name = "rayon-core"
+version = "1.13.0"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91"
+dependencies = [
+ "crossbeam-deque",
+ "crossbeam-utils",
+]
+
 [[package]]
 name = "recursive"
 version = "0.1.1"
@@ -4725,7 +4900,7 @@ dependencies = [
  "errno",
  "libc",
  "linux-raw-sys 0.11.0",
- "windows-sys 0.59.0",
+ "windows-sys 0.61.2",
 ]
 
 [[package]]
@@ -5363,7 +5538,7 @@ dependencies = [
  "getrandom 0.3.4",
  "once_cell",
  "rustix 1.1.3",
- "windows-sys 0.59.0",
+ "windows-sys 0.61.2",
 ]
 
 [[package]]
@@ -5381,7 +5556,7 @@ dependencies = [
  "etcetera",
  "ferroid",
  "futures",
- "itertools",
+ "itertools 0.14.0",
  "log",
  "memchr",
  "parse-display",
@@ -5524,6 +5699,16 @@ dependencies = [
  "zerovec",
 ]
 
+[[package]]
+name = "tinytemplate"
+version = "1.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc"
+dependencies = [
+ "serde",
+ "serde_json",
+]
+
 [[package]]
 name = "tinyvec"
 version = "1.10.0"
@@ -6188,7 +6373,7 @@ version = "0.1.11"
 source = "registry+https://github.com/rust-lang/crates.io-index";
 checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
 dependencies = [
- "windows-sys 0.59.0",
+ "windows-sys 0.61.2",
 ]
 
 [[package]]
diff --git a/ballista/core/proto/ballista.proto 
b/ballista/core/proto/ballista.proto
index eb263fd2f..8660b6c72 100644
--- a/ballista/core/proto/ballista.proto
+++ b/ballista/core/proto/ballista.proto
@@ -58,6 +58,7 @@ message SortShuffleWriterExecNode {
   uint64 buffer_size = 5;
   uint64 memory_limit = 6;
   double spill_threshold = 7;
+  uint64 batch_size = 8;
 }
 
 message UnresolvedShuffleExecNode {
diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs
index 86f486262..b9a78647f 100644
--- a/ballista/core/src/config.rs
+++ b/ballista/core/src/config.rs
@@ -69,6 +69,9 @@ pub const BALLISTA_SHUFFLE_SORT_BASED_MEMORY_LIMIT: &str =
 /// Configuration key for sort shuffle spill threshold (0.0-1.0).
 pub const BALLISTA_SHUFFLE_SORT_BASED_SPILL_THRESHOLD: &str =
     "ballista.shuffle.sort_based.spill_threshold";
+/// Configuration key for sort shuffle target batch size in rows.
+pub const BALLISTA_SHUFFLE_SORT_BASED_BATCH_SIZE: &str =
+    "ballista.shuffle.sort_based.batch_size";
 
 /// Result type for configuration parsing operations.
 pub type ParseResult<T> = result::Result<T, String>;
@@ -129,7 +132,11 @@ static CONFIG_ENTRIES: LazyLock<HashMap<String, 
ConfigEntry>> = LazyLock::new(||
         
ConfigEntry::new(BALLISTA_SHUFFLE_SORT_BASED_SPILL_THRESHOLD.to_string(),
                          "Spill threshold as decimal fraction (0.0-1.0) of 
memory limit".to_string(),
                          DataType::Utf8,
-                         Some("0.8".to_string()))
+                         Some("0.8".to_string())),
+        ConfigEntry::new(BALLISTA_SHUFFLE_SORT_BASED_BATCH_SIZE.to_string(),
+                         "Target batch size in rows for coalescing small 
batches in sort shuffle".to_string(),
+                         DataType::UInt64,
+                         Some((8192).to_string()))
     ];
     entries
         .into_iter()
@@ -316,6 +323,11 @@ impl BallistaConfig {
         self.get_f64_setting(BALLISTA_SHUFFLE_SORT_BASED_SPILL_THRESHOLD)
     }
 
+    /// Returns the target batch size for sort-based shuffle.
+    pub fn shuffle_sort_based_batch_size(&self) -> usize {
+        self.get_usize_setting(BALLISTA_SHUFFLE_SORT_BASED_BATCH_SIZE)
+    }
+
     fn get_usize_setting(&self, key: &str) -> usize {
         if let Some(v) = self.settings.get(key) {
             // infallible because we validate all configs in the constructor
diff --git a/ballista/core/src/execution_plans/sort_shuffle/buffer.rs 
b/ballista/core/src/execution_plans/sort_shuffle/buffer.rs
index 61292af2c..c2aa5e85f 100644
--- a/ballista/core/src/execution_plans/sort_shuffle/buffer.rs
+++ b/ballista/core/src/execution_plans/sort_shuffle/buffer.rs
@@ -22,6 +22,7 @@
 
 use datafusion::arrow::datatypes::SchemaRef;
 use datafusion::arrow::record_batch::RecordBatch;
+use datafusion::physical_plan::coalesce::LimitedBatchCoalescer;
 
 /// Buffer for accumulating record batches for a single output partition.
 ///
@@ -110,6 +111,61 @@ impl PartitionBuffer {
     pub fn take_batches(&mut self) -> Vec<RecordBatch> {
         std::mem::take(&mut self.batches)
     }
+
+    /// Drains batches from the buffer, coalescing small batches into
+    /// larger ones up to `target_batch_size` rows each.
+    pub fn drain_coalesced(&mut self, target_batch_size: usize) -> 
Vec<RecordBatch> {
+        self.memory_used = 0;
+        self.num_rows = 0;
+        let batches = std::mem::take(&mut self.batches);
+        coalesce_batches(batches, &self.schema, target_batch_size)
+    }
+
+    /// Takes all batches, coalescing small batches into larger ones
+    /// up to `target_batch_size` rows each.
+    pub fn take_batches_coalesced(
+        &mut self,
+        target_batch_size: usize,
+    ) -> Vec<RecordBatch> {
+        let batches = std::mem::take(&mut self.batches);
+        coalesce_batches(batches, &self.schema, target_batch_size)
+    }
+}
+
+/// Coalesces small batches into larger ones up to `target_batch_size`
+/// rows each using DataFusion's `LimitedBatchCoalescer`.
+fn coalesce_batches(
+    batches: Vec<RecordBatch>,
+    schema: &SchemaRef,
+    target_batch_size: usize,
+) -> Vec<RecordBatch> {
+    if batches.len() <= 1 {
+        return batches;
+    }
+
+    let mut coalescer =
+        LimitedBatchCoalescer::new(schema.clone(), target_batch_size, None);
+    let mut result = Vec::new();
+
+    for batch in batches {
+        if batch.num_rows() == 0 {
+            continue;
+        }
+        // push_batch can only fail on schema mismatch, which won't
+        // happen here since all batches share the same schema
+        let _ = coalescer.push_batch(batch);
+        while let Some(completed) = coalescer.next_completed_batch() {
+            result.push(completed);
+        }
+    }
+
+    // Flush remaining buffered rows
+    let _ = coalescer.finish();
+    while let Some(completed) = coalescer.next_completed_batch() {
+        result.push(completed);
+    }
+
+    result
 }
 
 #[cfg(test)]
diff --git a/ballista/core/src/execution_plans/sort_shuffle/config.rs 
b/ballista/core/src/execution_plans/sort_shuffle/config.rs
index 7c323f78d..adb306234 100644
--- a/ballista/core/src/execution_plans/sort_shuffle/config.rs
+++ b/ballista/core/src/execution_plans/sort_shuffle/config.rs
@@ -37,6 +37,8 @@ pub struct SortShuffleConfig {
     pub spill_threshold: f64,
     /// Compression codec for shuffle data (default: LZ4_FRAME)
     pub compression: CompressionType,
+    /// Target batch size in rows when coalescing small batches (default: 8192)
+    pub batch_size: usize,
 }
 
 impl Default for SortShuffleConfig {
@@ -47,6 +49,7 @@ impl Default for SortShuffleConfig {
             memory_limit: 256 * 1024 * 1024, // 256 MB
             spill_threshold: 0.8,
             compression: CompressionType::LZ4_FRAME,
+            batch_size: 8192,
         }
     }
 }
@@ -59,6 +62,7 @@ impl SortShuffleConfig {
         memory_limit: usize,
         spill_threshold: f64,
         compression: CompressionType,
+        batch_size: usize,
     ) -> Self {
         Self {
             enabled,
@@ -66,6 +70,7 @@ impl SortShuffleConfig {
             memory_limit,
             spill_threshold,
             compression,
+            batch_size,
         }
     }
 
diff --git a/ballista/core/src/execution_plans/sort_shuffle/spill.rs 
b/ballista/core/src/execution_plans/sort_shuffle/spill.rs
index 254775a1f..e832644b1 100644
--- a/ballista/core/src/execution_plans/sort_shuffle/spill.rs
+++ b/ballista/core/src/execution_plans/sort_shuffle/spill.rs
@@ -35,16 +35,17 @@ use std::path::PathBuf;
 /// Manages spill files for sort-based shuffle.
 ///
 /// When partition buffers exceed memory limits, they are spilled to disk
-/// as Arrow IPC files. During finalization, these spill files are read
-/// back and merged into the consolidated output file.
-#[derive(Debug)]
+/// as Arrow IPC files. Each output partition has at most one spill file
+/// that is appended to across multiple spill calls. During finalization,
+/// these spill files are read back and merged into the consolidated
+/// output file.
 pub struct SpillManager {
     /// Base directory for spill files
     spill_dir: PathBuf,
-    /// Spill files per output partition: partition_id -> Vec<spill_file_path>
-    spill_files: HashMap<usize, Vec<PathBuf>>,
-    /// Counter for generating unique spill file names
-    spill_counter: usize,
+    /// Spill file path per output partition: partition_id -> spill_file_path
+    spill_files: HashMap<usize, PathBuf>,
+    /// Active writers per partition, kept open for appending
+    active_writers: HashMap<usize, StreamWriter<BufWriter<File>>>,
     /// Compression codec for spill files
     compression: CompressionType,
     /// Total number of spills performed
@@ -81,7 +82,7 @@ impl SpillManager {
         Ok(Self {
             spill_dir,
             spill_files: HashMap::new(),
-            spill_counter: 0,
+            active_writers: HashMap::new(),
             compression,
             total_spills: 0,
             total_bytes_spilled: 0,
@@ -90,7 +91,9 @@ impl SpillManager {
 
     /// Spills batches for a partition to disk.
     ///
-    /// Returns the number of bytes written.
+    /// If a spill file already exists for this partition, batches are
+    /// appended to it. Otherwise a new spill file is created.
+    /// Returns the number of bytes written (estimated from batch sizes).
     pub fn spill(
         &mut self,
         partition_id: usize,
@@ -101,37 +104,35 @@ impl SpillManager {
             return Ok(0);
         }
 
-        let spill_path = self.next_spill_path(partition_id);
         debug!(
             "Spilling {} batches for partition {} to {:?}",
             batches.len(),
             partition_id,
-            spill_path
+            self.spill_path(partition_id)
         );
 
-        let file = File::create(&spill_path).map_err(BallistaError::IoError)?;
-        let buffered = BufWriter::new(file);
+        // Get or create the writer for this partition
+        if !self.active_writers.contains_key(&partition_id) {
+            let spill_path = self.spill_path(partition_id);
+            let file = 
File::create(&spill_path).map_err(BallistaError::IoError)?;
+            let buffered = BufWriter::new(file);
 
-        let options =
-            
IpcWriteOptions::default().try_with_compression(Some(self.compression))?;
+            let options = IpcWriteOptions::default()
+                .try_with_compression(Some(self.compression))?;
 
-        let mut writer = StreamWriter::try_new_with_options(buffered, schema, 
options)?;
+            let writer = StreamWriter::try_new_with_options(buffered, schema, 
options)?;
 
-        for batch in &batches {
-            writer.write(batch)?;
+            self.active_writers.insert(partition_id, writer);
+            self.spill_files.insert(partition_id, spill_path);
         }
 
-        writer.finish()?;
-
-        let bytes_written = std::fs::metadata(&spill_path)
-            .map_err(BallistaError::IoError)?
-            .len();
+        let writer = self.active_writers.get_mut(&partition_id).unwrap();
 
-        // Track the spill file
-        self.spill_files
-            .entry(partition_id)
-            .or_default()
-            .push(spill_path);
+        let mut bytes_written: u64 = 0;
+        for batch in &batches {
+            bytes_written += batch.get_array_memory_size() as u64;
+            writer.write(batch)?;
+        }
 
         self.total_spills += 1;
         self.total_bytes_spilled += bytes_written;
@@ -139,39 +140,42 @@ impl SpillManager {
         Ok(bytes_written)
     }
 
-    /// Returns the spill files for a partition.
-    pub fn get_spill_files(&self, partition_id: usize) -> &[PathBuf] {
-        self.spill_files
-            .get(&partition_id)
-            .map(|v| v.as_slice())
-            .unwrap_or(&[])
-    }
-
-    /// Returns true if the partition has spill files.
+    /// Returns true if the partition has a spill file.
     pub fn has_spill_files(&self, partition_id: usize) -> bool {
-        self.spill_files
-            .get(&partition_id)
-            .is_some_and(|v| !v.is_empty())
+        self.spill_files.contains_key(&partition_id)
     }
 
-    /// Reads all spill files for a partition and returns the batches.
-    pub fn read_spill_files(&self, partition_id: usize) -> 
Result<Vec<RecordBatch>> {
-        let mut all_batches = Vec::new();
-
-        for spill_path in self.get_spill_files(partition_id) {
-            let file = File::open(spill_path).map_err(BallistaError::IoError)?;
-            let reader = StreamReader::try_new(file, None)?;
+    /// Finishes all active writers so spill files can be read.
+    /// Must be called before `open_spill_reader`.
+    pub fn finish_writers(&mut self) -> Result<()> {
+        for (_, mut writer) in self.active_writers.drain() {
+            writer.finish()?;
+        }
+        Ok(())
+    }
 
-            for batch_result in reader {
-                all_batches.push(batch_result?);
+    /// Opens the spill file for a partition and returns a streaming
+    /// reader. `finish_writers` must be called before this method.
+    pub fn open_spill_reader(
+        &self,
+        partition_id: usize,
+    ) -> Result<Option<StreamReader<File>>> {
+        match self.spill_files.get(&partition_id) {
+            Some(spill_path) => {
+                let file = 
File::open(spill_path).map_err(BallistaError::IoError)?;
+                let reader = StreamReader::try_new(file, None)?;
+                Ok(Some(reader))
             }
+            None => Ok(None),
         }
-
-        Ok(all_batches)
     }
 
     /// Cleans up all spill files.
-    pub fn cleanup(&self) -> Result<()> {
+    pub fn cleanup(&mut self) -> Result<()> {
+        // Finish any active writers first
+        for (_, mut writer) in self.active_writers.drain() {
+            let _ = writer.finish();
+        }
         if self.spill_dir.exists() {
             
std::fs::remove_dir_all(&self.spill_dir).map_err(BallistaError::IoError)?;
         }
@@ -188,14 +192,21 @@ impl SpillManager {
         self.total_bytes_spilled
     }
 
-    /// Generates the next spill file path for a partition.
-    fn next_spill_path(&mut self, partition_id: usize) -> PathBuf {
-        let path = self.spill_dir.join(format!(
-            "part-{partition_id}-spill-{}.arrow",
-            self.spill_counter
-        ));
-        self.spill_counter += 1;
-        path
+    /// Returns the spill file path for a partition.
+    fn spill_path(&self, partition_id: usize) -> PathBuf {
+        self.spill_dir.join(format!("part-{partition_id}.arrow"))
+    }
+}
+
+impl std::fmt::Debug for SpillManager {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.debug_struct("SpillManager")
+            .field("spill_dir", &self.spill_dir)
+            .field("spill_files", &self.spill_files)
+            .field("compression", &self.compression)
+            .field("total_spills", &self.total_spills)
+            .field("total_bytes_spilled", &self.total_bytes_spilled)
+            .finish()
     }
 }
 
@@ -251,8 +262,13 @@ mod tests {
         assert!(!manager.has_spill_files(1));
         assert_eq!(manager.total_spills(), 1);
 
-        // Read back
-        let read_batches = manager.read_spill_files(0)?;
+        // Finish writers before reading
+        manager.finish_writers()?;
+
+        // Read back via streaming reader
+        let reader = manager.open_spill_reader(0)?.unwrap();
+        let read_batches: Vec<_> =
+            reader.into_iter().collect::<std::result::Result<_, _>>()?;
         assert_eq!(read_batches.len(), 2);
         assert_eq!(read_batches[0].num_rows(), 3);
         assert_eq!(read_batches[1].num_rows(), 2);
@@ -273,15 +289,20 @@ mod tests {
             CompressionType::LZ4_FRAME,
         )?;
 
-        // Multiple spills for same partition
+        // Multiple spills for same partition append to same file
         manager.spill(0, vec![create_test_batch(&schema, vec![1, 2])], 
&schema)?;
         manager.spill(0, vec![create_test_batch(&schema, vec![3, 4])], 
&schema)?;
 
-        assert_eq!(manager.get_spill_files(0).len(), 2);
+        assert!(manager.has_spill_files(0));
         assert_eq!(manager.total_spills(), 2);
 
-        // Read all back
-        let batches = manager.read_spill_files(0)?;
+        // Finish writers before reading
+        manager.finish_writers()?;
+
+        // Read all back - both batches from single file
+        let reader = manager.open_spill_reader(0)?.unwrap();
+        let batches: Vec<_> =
+            reader.into_iter().collect::<std::result::Result<_, _>>()?;
         assert_eq!(batches.len(), 2);
 
         Ok(())
diff --git a/ballista/core/src/execution_plans/sort_shuffle/writer.rs 
b/ballista/core/src/execution_plans/sort_shuffle/writer.rs
index dfef86086..84c7f61ec 100644
--- a/ballista/core/src/execution_plans/sort_shuffle/writer.rs
+++ b/ballista/core/src/execution_plans/sort_shuffle/writer.rs
@@ -256,11 +256,17 @@ impl SortShuffleWriterExec {
                         &mut spill_manager,
                         &schema,
                         config.spill_memory_threshold() / 2,
+                        config.batch_size,
                     )?;
                     timer.done();
                 }
             }
 
+            // Finish spill writers before reading them back
+            spill_manager
+                .finish_writers()
+                .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
+
             // Finalize: write consolidated output file
             let timer = metrics.write_time.timer();
             let (data_path, index_path, partition_stats) = finalize_output(
@@ -325,6 +331,7 @@ fn spill_largest_buffers(
     spill_manager: &mut SpillManager,
     schema: &SchemaRef,
     target_memory: usize,
+    batch_size: usize,
 ) -> Result<()> {
     loop {
         let total_memory: usize = buffers.iter().map(|b| 
b.memory_used()).sum();
@@ -342,7 +349,7 @@ fn spill_largest_buffers(
         match largest_idx {
             Some(idx) if buffers[idx].memory_used() > 0 => {
                 let partition_id = buffers[idx].partition_id();
-                let batches = buffers[idx].drain();
+                let batches = buffers[idx].drain_coalesced(batch_size);
                 spill_manager
                     .spill(partition_id, batches, schema)
                     .map_err(|e| 
DataFusionError::Execution(format!("{e:?}")))?;
@@ -405,13 +412,13 @@ fn finalize_output(
         let mut partition_batches: u64 = 0;
         let mut partition_bytes: u64 = 0;
 
-        // First, write any spill files for this partition
-        if spill_manager.has_spill_files(partition_id) {
-            let spill_batches = spill_manager
-                .read_spill_files(partition_id)
-                .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
-
-            for batch in spill_batches {
+        // First, stream any spill file for this partition
+        if let Some(reader) = spill_manager
+            .open_spill_reader(partition_id)
+            .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?
+        {
+            for batch_result in reader {
+                let batch = batch_result?;
                 partition_rows += batch.num_rows() as u64;
                 partition_bytes += batch.get_array_memory_size() as u64;
                 partition_batches += 1;
@@ -419,8 +426,8 @@ fn finalize_output(
             }
         }
 
-        // Then write remaining buffered data
-        let buffered_batches = buffer.take_batches();
+        // Then write remaining buffered data (coalesced)
+        let buffered_batches = 
buffer.take_batches_coalesced(config.batch_size);
         for batch in buffered_batches {
             partition_rows += batch.num_rows() as u64;
             partition_bytes += batch.get_array_memory_size() as u64;
diff --git a/ballista/core/src/serde/generated/ballista.rs 
b/ballista/core/src/serde/generated/ballista.rs
index 6a5d9b896..d0608d024 100644
--- a/ballista/core/src/serde/generated/ballista.rs
+++ b/ballista/core/src/serde/generated/ballista.rs
@@ -61,6 +61,8 @@ pub struct SortShuffleWriterExecNode {
     pub memory_limit: u64,
     #[prost(double, tag = "7")]
     pub spill_threshold: f64,
+    #[prost(uint64, tag = "8")]
+    pub batch_size: u64,
 }
 #[derive(Clone, PartialEq, ::prost::Message)]
 pub struct UnresolvedShuffleExecNode {
diff --git a/ballista/core/src/serde/mod.rs b/ballista/core/src/serde/mod.rs
index 13c4a53f4..7eb4ecefd 100644
--- a/ballista/core/src/serde/mod.rs
+++ b/ballista/core/src/serde/mod.rs
@@ -327,12 +327,18 @@ impl PhysicalExtensionCodec for 
BallistaPhysicalExtensionCodec {
                     )
                 })?;
 
+                let batch_size = if sort_shuffle_writer.batch_size > 0 {
+                    sort_shuffle_writer.batch_size as usize
+                } else {
+                    8192 // default for backwards compatibility
+                };
                 let config = SortShuffleConfig::new(
                     true,
                     sort_shuffle_writer.buffer_size as usize,
                     sort_shuffle_writer.memory_limit as usize,
                     sort_shuffle_writer.spill_threshold,
                     datafusion::arrow::ipc::CompressionType::LZ4_FRAME,
+                    batch_size,
                 );
 
                 Ok(Arc::new(SortShuffleWriterExec::try_new(
@@ -478,6 +484,7 @@ impl PhysicalExtensionCodec for 
BallistaPhysicalExtensionCodec {
                         buffer_size: config.buffer_size as u64,
                         memory_limit: config.memory_limit as u64,
                         spill_threshold: config.spill_threshold,
+                        batch_size: config.batch_size as u64,
                     },
                 )),
             };
diff --git a/ballista/scheduler/src/planner.rs 
b/ballista/scheduler/src/planner.rs
index fd5321ecb..d00f0bb67 100644
--- a/ballista/scheduler/src/planner.rs
+++ b/ballista/scheduler/src/planner.rs
@@ -359,6 +359,7 @@ fn create_shuffle_writer_with_config(
                 ballista_config.shuffle_sort_based_memory_limit(),
                 ballista_config.shuffle_sort_based_spill_threshold(),
                 datafusion::arrow::ipc::CompressionType::LZ4_FRAME,
+                ballista_config.shuffle_sort_based_batch_size(),
             );
 
             return Ok(Arc::new(SortShuffleWriterExec::try_new(
diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml
index ac6e262ef..97e4ece91 100644
--- a/benchmarks/Cargo.toml
+++ b/benchmarks/Cargo.toml
@@ -51,3 +51,8 @@ tokio = { version = "^1.44", features = [
 ] }
 
 [dev-dependencies]
+criterion = { version = "0.5", features = ["html_reports"] }
+
+[[bench]]
+harness = false
+name = "sort_shuffle"
diff --git a/benchmarks/benches/sort_shuffle.rs 
b/benchmarks/benches/sort_shuffle.rs
new file mode 100644
index 000000000..30a7423ac
--- /dev/null
+++ b/benchmarks/benches/sort_shuffle.rs
@@ -0,0 +1,315 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Criterion benchmarks for sort-based shuffle writer.
+
+use std::sync::Arc;
+
+use ballista_core::execution_plans::SortShuffleWriterExec;
+use ballista_core::execution_plans::sort_shuffle::SortShuffleConfig;
+use criterion::{Criterion, criterion_group, criterion_main};
+use datafusion::arrow::array::{
+    BooleanArray, Float32Array, Float64Array, Int8Array, Int16Array, 
Int32Array,
+    Int64Array, StringArray, UInt8Array, UInt16Array, UInt32Array, UInt64Array,
+};
+use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
+use datafusion::arrow::ipc::CompressionType;
+use datafusion::arrow::record_batch::RecordBatch;
+use datafusion::datasource::memory::MemorySourceConfig;
+use datafusion::datasource::source::DataSourceExec;
+use datafusion::physical_plan::ExecutionPlan;
+use datafusion::physical_plan::Partitioning;
+use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
+use datafusion::physical_plan::expressions::Column;
+use datafusion::prelude::SessionContext;
+use futures::TryStreamExt;
+use rand::Rng;
+use tempfile::TempDir;
+
+const BATCH_SIZE: usize = 8192;
+const NUM_OUTPUT_PARTITIONS: usize = 200;
+// 100 columns: mix of primitive types
+const NUM_INT_COLS: usize = 12; // i8, i16, i32, i64 x3 each
+const NUM_UINT_COLS: usize = 12; // u8, u16, u32, u64 x3 each
+const NUM_FLOAT_COLS: usize = 6; // f32, f64 x3 each
+const NUM_BOOL_COLS: usize = 5;
+const NUM_STRING_COLS: usize = 5;
+// Remaining to reach 100
+const NUM_EXTRA_INT32_COLS: usize =
+    100 - NUM_INT_COLS - NUM_UINT_COLS - NUM_FLOAT_COLS - NUM_BOOL_COLS - 
NUM_STRING_COLS;
+
+fn build_schema() -> SchemaRef {
+    let mut fields = Vec::with_capacity(100);
+    let mut idx = 0;
+
+    for _ in 0..3 {
+        fields.push(Field::new(format!("c{idx}"), DataType::Int8, true));
+        idx += 1;
+    }
+    for _ in 0..3 {
+        fields.push(Field::new(format!("c{idx}"), DataType::Int16, true));
+        idx += 1;
+    }
+    for _ in 0..3 {
+        fields.push(Field::new(format!("c{idx}"), DataType::Int32, true));
+        idx += 1;
+    }
+    for _ in 0..3 {
+        fields.push(Field::new(format!("c{idx}"), DataType::Int64, true));
+        idx += 1;
+    }
+
+    for _ in 0..3 {
+        fields.push(Field::new(format!("c{idx}"), DataType::UInt8, true));
+        idx += 1;
+    }
+    for _ in 0..3 {
+        fields.push(Field::new(format!("c{idx}"), DataType::UInt16, true));
+        idx += 1;
+    }
+    for _ in 0..3 {
+        fields.push(Field::new(format!("c{idx}"), DataType::UInt32, true));
+        idx += 1;
+    }
+    for _ in 0..3 {
+        fields.push(Field::new(format!("c{idx}"), DataType::UInt64, true));
+        idx += 1;
+    }
+
+    for _ in 0..3 {
+        fields.push(Field::new(format!("c{idx}"), DataType::Float32, true));
+        idx += 1;
+    }
+    for _ in 0..3 {
+        fields.push(Field::new(format!("c{idx}"), DataType::Float64, true));
+        idx += 1;
+    }
+
+    for _ in 0..NUM_BOOL_COLS {
+        fields.push(Field::new(format!("c{idx}"), DataType::Boolean, true));
+        idx += 1;
+    }
+
+    for _ in 0..NUM_STRING_COLS {
+        fields.push(Field::new(format!("c{idx}"), DataType::Utf8, true));
+        idx += 1;
+    }
+
+    for _ in 0..NUM_EXTRA_INT32_COLS {
+        fields.push(Field::new(format!("c{idx}"), DataType::Int32, true));
+        idx += 1;
+    }
+
+    Arc::new(Schema::new(fields))
+}
+
+fn build_batch(schema: &SchemaRef) -> RecordBatch {
+    let mut rng = rand::rng();
+    let mut columns: Vec<Arc<dyn datafusion::arrow::array::Array>> =
+        Vec::with_capacity(100);
+
+    // i8 x3
+    for _ in 0..3 {
+        let vals: Vec<i8> = (0..BATCH_SIZE).map(|_| rng.random()).collect();
+        columns.push(Arc::new(Int8Array::from(vals)));
+    }
+    // i16 x3
+    for _ in 0..3 {
+        let vals: Vec<i16> = (0..BATCH_SIZE).map(|_| rng.random()).collect();
+        columns.push(Arc::new(Int16Array::from(vals)));
+    }
+    // i32 x3
+    for _ in 0..3 {
+        let vals: Vec<i32> = (0..BATCH_SIZE).map(|_| rng.random()).collect();
+        columns.push(Arc::new(Int32Array::from(vals)));
+    }
+    // i64 x3
+    for _ in 0..3 {
+        let vals: Vec<i64> = (0..BATCH_SIZE).map(|_| rng.random()).collect();
+        columns.push(Arc::new(Int64Array::from(vals)));
+    }
+
+    // u8 x3
+    for _ in 0..3 {
+        let vals: Vec<u8> = (0..BATCH_SIZE).map(|_| rng.random()).collect();
+        columns.push(Arc::new(UInt8Array::from(vals)));
+    }
+    // u16 x3
+    for _ in 0..3 {
+        let vals: Vec<u16> = (0..BATCH_SIZE).map(|_| rng.random()).collect();
+        columns.push(Arc::new(UInt16Array::from(vals)));
+    }
+    // u32 x3
+    for _ in 0..3 {
+        let vals: Vec<u32> = (0..BATCH_SIZE).map(|_| rng.random()).collect();
+        columns.push(Arc::new(UInt32Array::from(vals)));
+    }
+    // u64 x3
+    for _ in 0..3 {
+        let vals: Vec<u64> = (0..BATCH_SIZE).map(|_| rng.random()).collect();
+        columns.push(Arc::new(UInt64Array::from(vals)));
+    }
+
+    // f32 x3
+    for _ in 0..3 {
+        let vals: Vec<f32> = (0..BATCH_SIZE).map(|_| rng.random()).collect();
+        columns.push(Arc::new(Float32Array::from(vals)));
+    }
+    // f64 x3
+    for _ in 0..3 {
+        let vals: Vec<f64> = (0..BATCH_SIZE).map(|_| rng.random()).collect();
+        columns.push(Arc::new(Float64Array::from(vals)));
+    }
+
+    // bool
+    for _ in 0..NUM_BOOL_COLS {
+        let vals: Vec<bool> = (0..BATCH_SIZE).map(|_| rng.random()).collect();
+        columns.push(Arc::new(BooleanArray::from(vals)));
+    }
+
+    // string (short random strings)
+    for _ in 0..NUM_STRING_COLS {
+        let vals: Vec<String> = (0..BATCH_SIZE)
+            .map(|_| format!("s{}", rng.random_range(0..100_000)))
+            .collect();
+        columns.push(Arc::new(StringArray::from(vals)));
+    }
+
+    // remaining i32 columns
+    for _ in 0..NUM_EXTRA_INT32_COLS {
+        let vals: Vec<i32> = (0..BATCH_SIZE).map(|_| rng.random()).collect();
+        columns.push(Arc::new(Int32Array::from(vals)));
+    }
+
+    RecordBatch::try_new(schema.clone(), columns).unwrap()
+}
+
+fn create_input(num_batches: usize) -> Arc<dyn 
datafusion::physical_plan::ExecutionPlan> {
+    let schema = build_schema();
+    let batch = build_batch(&schema);
+    let partition: Vec<RecordBatch> = (0..num_batches).map(|_| 
batch.clone()).collect();
+    let partitions = vec![partition];
+
+    let memory_source =
+        Arc::new(MemorySourceConfig::try_new(&partitions, schema, 
None).unwrap());
+    Arc::new(CoalescePartitionsExec::new(Arc::new(DataSourceExec::new(
+        memory_source,
+    ))))
+}
+
+fn run_sort_shuffle(
+    input: Arc<dyn datafusion::physical_plan::ExecutionPlan>,
+    work_dir: &str,
+    memory_limit: usize,
+) {
+    let rt = tokio::runtime::Runtime::new().unwrap();
+    let session_ctx = SessionContext::new();
+    let task_ctx = session_ctx.task_ctx();
+
+    let config = SortShuffleConfig::new(
+        true,
+        1024 * 1024, // 1MB buffer
+        memory_limit,
+        0.8,
+        CompressionType::LZ4_FRAME,
+        8192,
+    );
+
+    let writer = SortShuffleWriterExec::try_new(
+        "bench_job".to_string(),
+        1,
+        input,
+        work_dir.to_string(),
+        Partitioning::Hash(vec![Arc::new(Column::new("c0", 0))], 
NUM_OUTPUT_PARTITIONS),
+        config,
+    )
+    .unwrap();
+
+    rt.block_on(async {
+        let mut stream = writer.execute(0, task_ctx).unwrap();
+        while let Some(_batch) = stream.try_next().await.unwrap() {}
+    });
+}
+
+fn bench_no_spill(c: &mut Criterion) {
+    let mut group = c.benchmark_group("sort_shuffle_no_spill");
+    group.sample_size(10);
+
+    // 10 batches of 8192 rows = 81920 rows, 256MB limit
+    let input = create_input(10);
+    let work_dir = TempDir::new().unwrap();
+
+    group.bench_function("10_batches_200_partitions", |b| {
+        b.iter(|| {
+            run_sort_shuffle(
+                input.clone(),
+                work_dir.path().to_str().unwrap(),
+                256 * 1024 * 1024,
+            );
+        });
+    });
+
+    // 50 batches
+    let input = create_input(50);
+    group.bench_function("50_batches_200_partitions", |b| {
+        b.iter(|| {
+            run_sort_shuffle(
+                input.clone(),
+                work_dir.path().to_str().unwrap(),
+                256 * 1024 * 1024,
+            );
+        });
+    });
+
+    group.finish();
+}
+
+fn bench_with_spill(c: &mut Criterion) {
+    let mut group = c.benchmark_group("sort_shuffle_with_spill");
+    group.sample_size(10);
+
+    let work_dir = TempDir::new().unwrap();
+
+    // 50 batches with 8MB memory limit to force spilling
+    let input = create_input(50);
+    group.bench_function("50_batches_200_partitions_8mb_limit", |b| {
+        b.iter(|| {
+            run_sort_shuffle(
+                input.clone(),
+                work_dir.path().to_str().unwrap(),
+                8 * 1024 * 1024,
+            );
+        });
+    });
+
+    // 50 batches with 2MB memory limit to force heavy spilling
+    let input = create_input(50);
+    group.bench_function("50_batches_200_partitions_2mb_limit", |b| {
+        b.iter(|| {
+            run_sort_shuffle(
+                input.clone(),
+                work_dir.path().to_str().unwrap(),
+                2 * 1024 * 1024,
+            );
+        });
+    });
+
+    group.finish();
+}
+
+criterion_group!(benches, bench_no_spill, bench_with_spill);
+criterion_main!(benches);
diff --git a/benchmarks/src/bin/shuffle_bench.rs 
b/benchmarks/src/bin/shuffle_bench.rs
index 202c123c8..5ced6ee49 100644
--- a/benchmarks/src/bin/shuffle_bench.rs
+++ b/benchmarks/src/bin/shuffle_bench.rs
@@ -226,6 +226,7 @@ async fn benchmark_sort_shuffle(
         memory_limit,
         0.8,
         CompressionType::LZ4_FRAME,
+        8192,
     );
 
     // Create sort shuffle writer


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to