This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-ballista.git
The following commit(s) were added to refs/heads/main by this push:
new 6bb5eb4e7 Improve sort-based shuffle: single spill file per partition
and batch coalescing (#1431)
6bb5eb4e7 is described below
commit 6bb5eb4e7c5f9bc8ab073aab742903c1f3076d9c
Author: Andy Grove <[email protected]>
AuthorDate: Sun Feb 1 11:01:13 2026 -0700
Improve sort-based shuffle: single spill file per partition and batch
coalescing (#1431)
---
Cargo.lock | 251 +++++++++++++---
ballista/core/proto/ballista.proto | 1 +
ballista/core/src/config.rs | 14 +-
.../src/execution_plans/sort_shuffle/buffer.rs | 56 ++++
.../src/execution_plans/sort_shuffle/config.rs | 5 +
.../core/src/execution_plans/sort_shuffle/spill.rs | 153 +++++-----
.../src/execution_plans/sort_shuffle/writer.rs | 27 +-
ballista/core/src/serde/generated/ballista.rs | 2 +
ballista/core/src/serde/mod.rs | 7 +
ballista/scheduler/src/planner.rs | 1 +
benchmarks/Cargo.toml | 5 +
benchmarks/benches/sort_shuffle.rs | 315 +++++++++++++++++++++
benchmarks/src/bin/shuffle_bench.rs | 1 +
13 files changed, 728 insertions(+), 110 deletions(-)
diff --git a/Cargo.lock b/Cargo.lock
index 487dc4b29..981c8de66 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -61,6 +61,12 @@ dependencies = [
"libc",
]
+[[package]]
+name = "anes"
+version = "0.1.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299"
+
[[package]]
name = "anstream"
version = "0.6.21"
@@ -97,7 +103,7 @@ version = "1.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc"
dependencies = [
- "windows-sys 0.60.2",
+ "windows-sys 0.61.2",
]
[[package]]
@@ -108,7 +114,7 @@ checksum =
"291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d"
dependencies = [
"anstyle",
"once_cell_polyfill",
- "windows-sys 0.60.2",
+ "windows-sys 0.61.2",
]
[[package]]
@@ -940,6 +946,7 @@ version = "52.0.0"
dependencies = [
"ballista",
"ballista-core",
+ "criterion",
"datafusion",
"datafusion-proto",
"env_logger",
@@ -982,7 +989,7 @@ dependencies = [
"datafusion-proto",
"datafusion-proto-common",
"futures",
- "itertools",
+ "itertools 0.14.0",
"log",
"md-5",
"object_store",
@@ -1334,6 +1341,12 @@ dependencies = [
"libbz2-rs-sys",
]
+[[package]]
+name = "cast"
+version = "0.3.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
+
[[package]]
name = "cc"
version = "1.2.54"
@@ -1380,6 +1393,33 @@ dependencies = [
"phf",
]
+[[package]]
+name = "ciborium"
+version = "0.2.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e"
+dependencies = [
+ "ciborium-io",
+ "ciborium-ll",
+ "serde",
+]
+
+[[package]]
+name = "ciborium-io"
+version = "0.2.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757"
+
+[[package]]
+name = "ciborium-ll"
+version = "0.2.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9"
+dependencies = [
+ "ciborium-io",
+ "half",
+]
+
[[package]]
name = "clap"
version = "2.34.0"
@@ -1546,6 +1586,42 @@ dependencies = [
"cfg-if",
]
+[[package]]
+name = "criterion"
+version = "0.5.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f"
+dependencies = [
+ "anes",
+ "cast",
+ "ciborium",
+ "clap 4.5.56",
+ "criterion-plot",
+ "is-terminal",
+ "itertools 0.10.5",
+ "num-traits",
+ "once_cell",
+ "oorandom",
+ "plotters",
+ "rayon",
+ "regex",
+ "serde",
+ "serde_derive",
+ "serde_json",
+ "tinytemplate",
+ "walkdir",
+]
+
+[[package]]
+name = "criterion-plot"
+version = "0.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1"
+dependencies = [
+ "cast",
+ "itertools 0.10.5",
+]
+
[[package]]
name = "crossbeam-channel"
version = "0.5.15"
@@ -1555,6 +1631,25 @@ dependencies = [
"crossbeam-utils",
]
+[[package]]
+name = "crossbeam-deque"
+version = "0.8.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
+dependencies = [
+ "crossbeam-epoch",
+ "crossbeam-utils",
+]
+
+[[package]]
+name = "crossbeam-epoch"
+version = "0.9.18"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
+dependencies = [
+ "crossbeam-utils",
+]
+
[[package]]
name = "crossbeam-utils"
version = "0.8.21"
@@ -1737,7 +1832,7 @@ dependencies = [
"datafusion-sql",
"flate2",
"futures",
- "itertools",
+ "itertools 0.14.0",
"liblzma",
"log",
"object_store",
@@ -1771,7 +1866,7 @@ dependencies = [
"datafusion-physical-plan",
"datafusion-session",
"futures",
- "itertools",
+ "itertools 0.14.0",
"log",
"object_store",
"parking_lot",
@@ -1796,7 +1891,7 @@ dependencies = [
"datafusion-physical-expr-common",
"datafusion-physical-plan",
"futures",
- "itertools",
+ "itertools 0.14.0",
"log",
"object_store",
]
@@ -1890,7 +1985,7 @@ dependencies = [
"flate2",
"futures",
"glob",
- "itertools",
+ "itertools 0.14.0",
"liblzma",
"log",
"object_store",
@@ -1920,7 +2015,7 @@ dependencies = [
"datafusion-physical-plan",
"datafusion-session",
"futures",
- "itertools",
+ "itertools 0.14.0",
"object_store",
"tokio",
]
@@ -2012,7 +2107,7 @@ dependencies = [
"datafusion-pruning",
"datafusion-session",
"futures",
- "itertools",
+ "itertools 0.14.0",
"log",
"object_store",
"parking_lot",
@@ -2064,7 +2159,7 @@ dependencies = [
"datafusion-functions-window-common",
"datafusion-physical-expr-common",
"indexmap 2.13.0",
- "itertools",
+ "itertools 0.14.0",
"paste",
"recursive",
"serde_json",
@@ -2080,7 +2175,7 @@ dependencies = [
"arrow",
"datafusion-common",
"indexmap 2.13.0",
- "itertools",
+ "itertools 0.14.0",
"paste",
]
@@ -2104,7 +2199,7 @@ dependencies = [
"datafusion-expr-common",
"datafusion-macros",
"hex",
- "itertools",
+ "itertools 0.14.0",
"log",
"md-5",
"num-traits",
@@ -2167,7 +2262,7 @@ dependencies = [
"datafusion-functions-aggregate-common",
"datafusion-macros",
"datafusion-physical-expr-common",
- "itertools",
+ "itertools 0.14.0",
"log",
"paste",
]
@@ -2240,7 +2335,7 @@ dependencies = [
"datafusion-expr-common",
"datafusion-physical-expr",
"indexmap 2.13.0",
- "itertools",
+ "itertools 0.14.0",
"log",
"recursive",
"regex",
@@ -2263,7 +2358,7 @@ dependencies = [
"half",
"hashbrown 0.16.1",
"indexmap 2.13.0",
- "itertools",
+ "itertools 0.14.0",
"parking_lot",
"paste",
"petgraph",
@@ -2283,7 +2378,7 @@ dependencies = [
"datafusion-functions",
"datafusion-physical-expr",
"datafusion-physical-expr-common",
- "itertools",
+ "itertools 0.14.0",
]
[[package]]
@@ -2299,7 +2394,7 @@ dependencies = [
"datafusion-expr-common",
"hashbrown 0.16.1",
"indexmap 2.13.0",
- "itertools",
+ "itertools 0.14.0",
"parking_lot",
]
@@ -2318,7 +2413,7 @@ dependencies = [
"datafusion-physical-expr-common",
"datafusion-physical-plan",
"datafusion-pruning",
- "itertools",
+ "itertools 0.14.0",
"recursive",
]
@@ -2346,7 +2441,7 @@ dependencies = [
"half",
"hashbrown 0.16.1",
"indexmap 2.13.0",
- "itertools",
+ "itertools 0.14.0",
"log",
"parking_lot",
"pin-project-lite",
@@ -2404,7 +2499,7 @@ dependencies = [
"datafusion-physical-expr",
"datafusion-physical-expr-common",
"datafusion-physical-plan",
- "itertools",
+ "itertools 0.14.0",
"log",
]
@@ -2451,7 +2546,7 @@ dependencies = [
"chrono",
"datafusion",
"half",
- "itertools",
+ "itertools 0.14.0",
"object_store",
"pbjson-types",
"prost",
@@ -2500,7 +2595,7 @@ dependencies = [
"libc",
"option-ext",
"redox_users",
- "windows-sys 0.59.0",
+ "windows-sys 0.61.2",
]
[[package]]
@@ -2615,7 +2710,7 @@ source =
"registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [
"libc",
- "windows-sys 0.59.0",
+ "windows-sys 0.61.2",
]
[[package]]
@@ -2971,6 +3066,12 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
+[[package]]
+name = "hermit-abi"
+version = "0.5.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c"
+
[[package]]
name = "hex"
version = "0.4.3"
@@ -3374,12 +3475,32 @@ dependencies = [
"serde",
]
+[[package]]
+name = "is-terminal"
+version = "0.4.17"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46"
+dependencies = [
+ "hermit-abi",
+ "libc",
+ "windows-sys 0.61.2",
+]
+
[[package]]
name = "is_terminal_polyfill"
version = "1.70.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695"
+[[package]]
+name = "itertools"
+version = "0.10.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473"
+dependencies = [
+ "either",
+]
+
[[package]]
name = "itertools"
version = "0.14.0"
@@ -3709,7 +3830,7 @@ version = "0.50.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
dependencies = [
- "windows-sys 0.59.0",
+ "windows-sys 0.61.2",
]
[[package]]
@@ -3818,7 +3939,7 @@ dependencies = [
"http-body-util",
"humantime",
"hyper",
- "itertools",
+ "itertools 0.14.0",
"md-5",
"parking_lot",
"percent-encoding",
@@ -3851,6 +3972,12 @@ version = "1.70.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe"
+[[package]]
+name = "oorandom"
+version = "11.1.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e"
+
[[package]]
name = "openssl-probe"
version = "0.2.1"
@@ -3987,7 +4114,7 @@ source =
"registry+https://github.com/rust-lang/crates.io-index"
checksum = "af22d08a625a2213a78dbb0ffa253318c5c79ce3133d32d296655a7bdfb02095"
dependencies = [
"heck 0.5.0",
- "itertools",
+ "itertools 0.14.0",
"prost",
"prost-types",
]
@@ -4124,6 +4251,34 @@ version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
+[[package]]
+name = "plotters"
+version = "0.3.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747"
+dependencies = [
+ "num-traits",
+ "plotters-backend",
+ "plotters-svg",
+ "wasm-bindgen",
+ "web-sys",
+]
+
+[[package]]
+name = "plotters-backend"
+version = "0.3.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a"
+
+[[package]]
+name = "plotters-svg"
+version = "0.3.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670"
+dependencies = [
+ "plotters-backend",
+]
+
[[package]]
name = "portable-atomic"
version = "1.13.0"
@@ -4271,7 +4426,7 @@ source =
"registry+https://github.com/rust-lang/crates.io-index"
checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7"
dependencies = [
"heck 0.5.0",
- "itertools",
+ "itertools 0.14.0",
"log",
"multimap",
"petgraph",
@@ -4292,7 +4447,7 @@ source =
"registry+https://github.com/rust-lang/crates.io-index"
checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b"
dependencies = [
"anyhow",
- "itertools",
+ "itertools 0.14.0",
"proc-macro2",
"quote",
"syn 2.0.114",
@@ -4425,7 +4580,7 @@ dependencies = [
"once_cell",
"socket2",
"tracing",
- "windows-sys 0.59.0",
+ "windows-sys 0.60.2",
]
[[package]]
@@ -4482,6 +4637,26 @@ dependencies = [
"getrandom 0.3.4",
]
+[[package]]
+name = "rayon"
+version = "1.11.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f"
+dependencies = [
+ "either",
+ "rayon-core",
+]
+
+[[package]]
+name = "rayon-core"
+version = "1.13.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91"
+dependencies = [
+ "crossbeam-deque",
+ "crossbeam-utils",
+]
+
[[package]]
name = "recursive"
version = "0.1.1"
@@ -4725,7 +4900,7 @@ dependencies = [
"errno",
"libc",
"linux-raw-sys 0.11.0",
- "windows-sys 0.59.0",
+ "windows-sys 0.61.2",
]
[[package]]
@@ -5363,7 +5538,7 @@ dependencies = [
"getrandom 0.3.4",
"once_cell",
"rustix 1.1.3",
- "windows-sys 0.59.0",
+ "windows-sys 0.61.2",
]
[[package]]
@@ -5381,7 +5556,7 @@ dependencies = [
"etcetera",
"ferroid",
"futures",
- "itertools",
+ "itertools 0.14.0",
"log",
"memchr",
"parse-display",
@@ -5524,6 +5699,16 @@ dependencies = [
"zerovec",
]
+[[package]]
+name = "tinytemplate"
+version = "1.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc"
+dependencies = [
+ "serde",
+ "serde_json",
+]
+
[[package]]
name = "tinyvec"
version = "1.10.0"
@@ -6188,7 +6373,7 @@ version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
dependencies = [
- "windows-sys 0.59.0",
+ "windows-sys 0.61.2",
]
[[package]]
diff --git a/ballista/core/proto/ballista.proto
b/ballista/core/proto/ballista.proto
index eb263fd2f..8660b6c72 100644
--- a/ballista/core/proto/ballista.proto
+++ b/ballista/core/proto/ballista.proto
@@ -58,6 +58,7 @@ message SortShuffleWriterExecNode {
uint64 buffer_size = 5;
uint64 memory_limit = 6;
double spill_threshold = 7;
+ uint64 batch_size = 8;
}
message UnresolvedShuffleExecNode {
diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs
index 86f486262..b9a78647f 100644
--- a/ballista/core/src/config.rs
+++ b/ballista/core/src/config.rs
@@ -69,6 +69,9 @@ pub const BALLISTA_SHUFFLE_SORT_BASED_MEMORY_LIMIT: &str =
/// Configuration key for sort shuffle spill threshold (0.0-1.0).
pub const BALLISTA_SHUFFLE_SORT_BASED_SPILL_THRESHOLD: &str =
"ballista.shuffle.sort_based.spill_threshold";
+/// Configuration key for sort shuffle target batch size in rows.
+pub const BALLISTA_SHUFFLE_SORT_BASED_BATCH_SIZE: &str =
+ "ballista.shuffle.sort_based.batch_size";
/// Result type for configuration parsing operations.
pub type ParseResult<T> = result::Result<T, String>;
@@ -129,7 +132,11 @@ static CONFIG_ENTRIES: LazyLock<HashMap<String,
ConfigEntry>> = LazyLock::new(||
ConfigEntry::new(BALLISTA_SHUFFLE_SORT_BASED_SPILL_THRESHOLD.to_string(),
"Spill threshold as decimal fraction (0.0-1.0) of
memory limit".to_string(),
DataType::Utf8,
- Some("0.8".to_string()))
+ Some("0.8".to_string())),
+ ConfigEntry::new(BALLISTA_SHUFFLE_SORT_BASED_BATCH_SIZE.to_string(),
+ "Target batch size in rows for coalescing small
batches in sort shuffle".to_string(),
+ DataType::UInt64,
+ Some((8192).to_string()))
];
entries
.into_iter()
@@ -316,6 +323,11 @@ impl BallistaConfig {
self.get_f64_setting(BALLISTA_SHUFFLE_SORT_BASED_SPILL_THRESHOLD)
}
+ /// Returns the target batch size for sort-based shuffle.
+ pub fn shuffle_sort_based_batch_size(&self) -> usize {
+ self.get_usize_setting(BALLISTA_SHUFFLE_SORT_BASED_BATCH_SIZE)
+ }
+
fn get_usize_setting(&self, key: &str) -> usize {
if let Some(v) = self.settings.get(key) {
// infallible because we validate all configs in the constructor
diff --git a/ballista/core/src/execution_plans/sort_shuffle/buffer.rs
b/ballista/core/src/execution_plans/sort_shuffle/buffer.rs
index 61292af2c..c2aa5e85f 100644
--- a/ballista/core/src/execution_plans/sort_shuffle/buffer.rs
+++ b/ballista/core/src/execution_plans/sort_shuffle/buffer.rs
@@ -22,6 +22,7 @@
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::arrow::record_batch::RecordBatch;
+use datafusion::physical_plan::coalesce::LimitedBatchCoalescer;
/// Buffer for accumulating record batches for a single output partition.
///
@@ -110,6 +111,61 @@ impl PartitionBuffer {
pub fn take_batches(&mut self) -> Vec<RecordBatch> {
std::mem::take(&mut self.batches)
}
+
+ /// Drains batches from the buffer, coalescing small batches into
+ /// larger ones up to `target_batch_size` rows each.
+ pub fn drain_coalesced(&mut self, target_batch_size: usize) ->
Vec<RecordBatch> {
+ self.memory_used = 0;
+ self.num_rows = 0;
+ let batches = std::mem::take(&mut self.batches);
+ coalesce_batches(batches, &self.schema, target_batch_size)
+ }
+
+ /// Takes all batches, coalescing small batches into larger ones
+ /// up to `target_batch_size` rows each.
+ pub fn take_batches_coalesced(
+ &mut self,
+ target_batch_size: usize,
+ ) -> Vec<RecordBatch> {
+ let batches = std::mem::take(&mut self.batches);
+ coalesce_batches(batches, &self.schema, target_batch_size)
+ }
+}
+
+/// Coalesces small batches into larger ones up to `target_batch_size`
+/// rows each using DataFusion's `LimitedBatchCoalescer`.
+fn coalesce_batches(
+ batches: Vec<RecordBatch>,
+ schema: &SchemaRef,
+ target_batch_size: usize,
+) -> Vec<RecordBatch> {
+ if batches.len() <= 1 {
+ return batches;
+ }
+
+ let mut coalescer =
+ LimitedBatchCoalescer::new(schema.clone(), target_batch_size, None);
+ let mut result = Vec::new();
+
+ for batch in batches {
+ if batch.num_rows() == 0 {
+ continue;
+ }
+ // push_batch can only fail on schema mismatch, which won't
+ // happen here since all batches share the same schema
+ let _ = coalescer.push_batch(batch);
+ while let Some(completed) = coalescer.next_completed_batch() {
+ result.push(completed);
+ }
+ }
+
+ // Flush remaining buffered rows
+ let _ = coalescer.finish();
+ while let Some(completed) = coalescer.next_completed_batch() {
+ result.push(completed);
+ }
+
+ result
}
#[cfg(test)]
diff --git a/ballista/core/src/execution_plans/sort_shuffle/config.rs
b/ballista/core/src/execution_plans/sort_shuffle/config.rs
index 7c323f78d..adb306234 100644
--- a/ballista/core/src/execution_plans/sort_shuffle/config.rs
+++ b/ballista/core/src/execution_plans/sort_shuffle/config.rs
@@ -37,6 +37,8 @@ pub struct SortShuffleConfig {
pub spill_threshold: f64,
/// Compression codec for shuffle data (default: LZ4_FRAME)
pub compression: CompressionType,
+ /// Target batch size in rows when coalescing small batches (default: 8192)
+ pub batch_size: usize,
}
impl Default for SortShuffleConfig {
@@ -47,6 +49,7 @@ impl Default for SortShuffleConfig {
memory_limit: 256 * 1024 * 1024, // 256 MB
spill_threshold: 0.8,
compression: CompressionType::LZ4_FRAME,
+ batch_size: 8192,
}
}
}
@@ -59,6 +62,7 @@ impl SortShuffleConfig {
memory_limit: usize,
spill_threshold: f64,
compression: CompressionType,
+ batch_size: usize,
) -> Self {
Self {
enabled,
@@ -66,6 +70,7 @@ impl SortShuffleConfig {
memory_limit,
spill_threshold,
compression,
+ batch_size,
}
}
diff --git a/ballista/core/src/execution_plans/sort_shuffle/spill.rs
b/ballista/core/src/execution_plans/sort_shuffle/spill.rs
index 254775a1f..e832644b1 100644
--- a/ballista/core/src/execution_plans/sort_shuffle/spill.rs
+++ b/ballista/core/src/execution_plans/sort_shuffle/spill.rs
@@ -35,16 +35,17 @@ use std::path::PathBuf;
/// Manages spill files for sort-based shuffle.
///
/// When partition buffers exceed memory limits, they are spilled to disk
-/// as Arrow IPC files. During finalization, these spill files are read
-/// back and merged into the consolidated output file.
-#[derive(Debug)]
+/// as Arrow IPC files. Each output partition has at most one spill file
+/// that is appended to across multiple spill calls. During finalization,
+/// these spill files are read back and merged into the consolidated
+/// output file.
pub struct SpillManager {
/// Base directory for spill files
spill_dir: PathBuf,
- /// Spill files per output partition: partition_id -> Vec<spill_file_path>
- spill_files: HashMap<usize, Vec<PathBuf>>,
- /// Counter for generating unique spill file names
- spill_counter: usize,
+ /// Spill file path per output partition: partition_id -> spill_file_path
+ spill_files: HashMap<usize, PathBuf>,
+ /// Active writers per partition, kept open for appending
+ active_writers: HashMap<usize, StreamWriter<BufWriter<File>>>,
/// Compression codec for spill files
compression: CompressionType,
/// Total number of spills performed
@@ -81,7 +82,7 @@ impl SpillManager {
Ok(Self {
spill_dir,
spill_files: HashMap::new(),
- spill_counter: 0,
+ active_writers: HashMap::new(),
compression,
total_spills: 0,
total_bytes_spilled: 0,
@@ -90,7 +91,9 @@ impl SpillManager {
/// Spills batches for a partition to disk.
///
- /// Returns the number of bytes written.
+ /// If a spill file already exists for this partition, batches are
+ /// appended to it. Otherwise a new spill file is created.
+ /// Returns the number of bytes written (estimated from batch sizes).
pub fn spill(
&mut self,
partition_id: usize,
@@ -101,37 +104,35 @@ impl SpillManager {
return Ok(0);
}
- let spill_path = self.next_spill_path(partition_id);
debug!(
"Spilling {} batches for partition {} to {:?}",
batches.len(),
partition_id,
- spill_path
+ self.spill_path(partition_id)
);
- let file = File::create(&spill_path).map_err(BallistaError::IoError)?;
- let buffered = BufWriter::new(file);
+ // Get or create the writer for this partition
+ if !self.active_writers.contains_key(&partition_id) {
+ let spill_path = self.spill_path(partition_id);
+ let file =
File::create(&spill_path).map_err(BallistaError::IoError)?;
+ let buffered = BufWriter::new(file);
- let options =
-
IpcWriteOptions::default().try_with_compression(Some(self.compression))?;
+ let options = IpcWriteOptions::default()
+ .try_with_compression(Some(self.compression))?;
- let mut writer = StreamWriter::try_new_with_options(buffered, schema,
options)?;
+ let writer = StreamWriter::try_new_with_options(buffered, schema,
options)?;
- for batch in &batches {
- writer.write(batch)?;
+ self.active_writers.insert(partition_id, writer);
+ self.spill_files.insert(partition_id, spill_path);
}
- writer.finish()?;
-
- let bytes_written = std::fs::metadata(&spill_path)
- .map_err(BallistaError::IoError)?
- .len();
+ let writer = self.active_writers.get_mut(&partition_id).unwrap();
- // Track the spill file
- self.spill_files
- .entry(partition_id)
- .or_default()
- .push(spill_path);
+ let mut bytes_written: u64 = 0;
+ for batch in &batches {
+ bytes_written += batch.get_array_memory_size() as u64;
+ writer.write(batch)?;
+ }
self.total_spills += 1;
self.total_bytes_spilled += bytes_written;
@@ -139,39 +140,42 @@ impl SpillManager {
Ok(bytes_written)
}
- /// Returns the spill files for a partition.
- pub fn get_spill_files(&self, partition_id: usize) -> &[PathBuf] {
- self.spill_files
- .get(&partition_id)
- .map(|v| v.as_slice())
- .unwrap_or(&[])
- }
-
- /// Returns true if the partition has spill files.
+ /// Returns true if the partition has a spill file.
pub fn has_spill_files(&self, partition_id: usize) -> bool {
- self.spill_files
- .get(&partition_id)
- .is_some_and(|v| !v.is_empty())
+ self.spill_files.contains_key(&partition_id)
}
- /// Reads all spill files for a partition and returns the batches.
- pub fn read_spill_files(&self, partition_id: usize) ->
Result<Vec<RecordBatch>> {
- let mut all_batches = Vec::new();
-
- for spill_path in self.get_spill_files(partition_id) {
- let file = File::open(spill_path).map_err(BallistaError::IoError)?;
- let reader = StreamReader::try_new(file, None)?;
+ /// Finishes all active writers so spill files can be read.
+ /// Must be called before `open_spill_reader`.
+ pub fn finish_writers(&mut self) -> Result<()> {
+ for (_, mut writer) in self.active_writers.drain() {
+ writer.finish()?;
+ }
+ Ok(())
+ }
- for batch_result in reader {
- all_batches.push(batch_result?);
+ /// Opens the spill file for a partition and returns a streaming
+ /// reader. `finish_writers` must be called before this method.
+ pub fn open_spill_reader(
+ &self,
+ partition_id: usize,
+ ) -> Result<Option<StreamReader<File>>> {
+ match self.spill_files.get(&partition_id) {
+ Some(spill_path) => {
+ let file =
File::open(spill_path).map_err(BallistaError::IoError)?;
+ let reader = StreamReader::try_new(file, None)?;
+ Ok(Some(reader))
}
+ None => Ok(None),
}
-
- Ok(all_batches)
}
/// Cleans up all spill files.
- pub fn cleanup(&self) -> Result<()> {
+ pub fn cleanup(&mut self) -> Result<()> {
+ // Finish any active writers first
+ for (_, mut writer) in self.active_writers.drain() {
+ let _ = writer.finish();
+ }
if self.spill_dir.exists() {
std::fs::remove_dir_all(&self.spill_dir).map_err(BallistaError::IoError)?;
}
@@ -188,14 +192,21 @@ impl SpillManager {
self.total_bytes_spilled
}
- /// Generates the next spill file path for a partition.
- fn next_spill_path(&mut self, partition_id: usize) -> PathBuf {
- let path = self.spill_dir.join(format!(
- "part-{partition_id}-spill-{}.arrow",
- self.spill_counter
- ));
- self.spill_counter += 1;
- path
+ /// Returns the spill file path for a partition.
+ fn spill_path(&self, partition_id: usize) -> PathBuf {
+ self.spill_dir.join(format!("part-{partition_id}.arrow"))
+ }
+}
+
+impl std::fmt::Debug for SpillManager {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("SpillManager")
+ .field("spill_dir", &self.spill_dir)
+ .field("spill_files", &self.spill_files)
+ .field("compression", &self.compression)
+ .field("total_spills", &self.total_spills)
+ .field("total_bytes_spilled", &self.total_bytes_spilled)
+ .finish()
}
}
@@ -251,8 +262,13 @@ mod tests {
assert!(!manager.has_spill_files(1));
assert_eq!(manager.total_spills(), 1);
- // Read back
- let read_batches = manager.read_spill_files(0)?;
+ // Finish writers before reading
+ manager.finish_writers()?;
+
+ // Read back via streaming reader
+ let reader = manager.open_spill_reader(0)?.unwrap();
+ let read_batches: Vec<_> =
+ reader.into_iter().collect::<std::result::Result<_, _>>()?;
assert_eq!(read_batches.len(), 2);
assert_eq!(read_batches[0].num_rows(), 3);
assert_eq!(read_batches[1].num_rows(), 2);
@@ -273,15 +289,20 @@ mod tests {
CompressionType::LZ4_FRAME,
)?;
- // Multiple spills for same partition
+ // Multiple spills for same partition append to same file
manager.spill(0, vec![create_test_batch(&schema, vec![1, 2])],
&schema)?;
manager.spill(0, vec![create_test_batch(&schema, vec![3, 4])],
&schema)?;
- assert_eq!(manager.get_spill_files(0).len(), 2);
+ assert!(manager.has_spill_files(0));
assert_eq!(manager.total_spills(), 2);
- // Read all back
- let batches = manager.read_spill_files(0)?;
+ // Finish writers before reading
+ manager.finish_writers()?;
+
+ // Read all back - both batches from single file
+ let reader = manager.open_spill_reader(0)?.unwrap();
+ let batches: Vec<_> =
+ reader.into_iter().collect::<std::result::Result<_, _>>()?;
assert_eq!(batches.len(), 2);
Ok(())
diff --git a/ballista/core/src/execution_plans/sort_shuffle/writer.rs
b/ballista/core/src/execution_plans/sort_shuffle/writer.rs
index dfef86086..84c7f61ec 100644
--- a/ballista/core/src/execution_plans/sort_shuffle/writer.rs
+++ b/ballista/core/src/execution_plans/sort_shuffle/writer.rs
@@ -256,11 +256,17 @@ impl SortShuffleWriterExec {
&mut spill_manager,
&schema,
config.spill_memory_threshold() / 2,
+ config.batch_size,
)?;
timer.done();
}
}
+ // Finish spill writers before reading them back
+ spill_manager
+ .finish_writers()
+ .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
+
// Finalize: write consolidated output file
let timer = metrics.write_time.timer();
let (data_path, index_path, partition_stats) = finalize_output(
@@ -325,6 +331,7 @@ fn spill_largest_buffers(
spill_manager: &mut SpillManager,
schema: &SchemaRef,
target_memory: usize,
+ batch_size: usize,
) -> Result<()> {
loop {
let total_memory: usize = buffers.iter().map(|b|
b.memory_used()).sum();
@@ -342,7 +349,7 @@ fn spill_largest_buffers(
match largest_idx {
Some(idx) if buffers[idx].memory_used() > 0 => {
let partition_id = buffers[idx].partition_id();
- let batches = buffers[idx].drain();
+ let batches = buffers[idx].drain_coalesced(batch_size);
spill_manager
.spill(partition_id, batches, schema)
.map_err(|e|
DataFusionError::Execution(format!("{e:?}")))?;
@@ -405,13 +412,13 @@ fn finalize_output(
let mut partition_batches: u64 = 0;
let mut partition_bytes: u64 = 0;
- // First, write any spill files for this partition
- if spill_manager.has_spill_files(partition_id) {
- let spill_batches = spill_manager
- .read_spill_files(partition_id)
- .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
-
- for batch in spill_batches {
+ // First, stream any spill file for this partition
+ if let Some(reader) = spill_manager
+ .open_spill_reader(partition_id)
+ .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?
+ {
+ for batch_result in reader {
+ let batch = batch_result?;
partition_rows += batch.num_rows() as u64;
partition_bytes += batch.get_array_memory_size() as u64;
partition_batches += 1;
@@ -419,8 +426,8 @@ fn finalize_output(
}
}
- // Then write remaining buffered data
- let buffered_batches = buffer.take_batches();
+ // Then write remaining buffered data (coalesced)
+ let buffered_batches =
buffer.take_batches_coalesced(config.batch_size);
for batch in buffered_batches {
partition_rows += batch.num_rows() as u64;
partition_bytes += batch.get_array_memory_size() as u64;
diff --git a/ballista/core/src/serde/generated/ballista.rs
b/ballista/core/src/serde/generated/ballista.rs
index 6a5d9b896..d0608d024 100644
--- a/ballista/core/src/serde/generated/ballista.rs
+++ b/ballista/core/src/serde/generated/ballista.rs
@@ -61,6 +61,8 @@ pub struct SortShuffleWriterExecNode {
pub memory_limit: u64,
#[prost(double, tag = "7")]
pub spill_threshold: f64,
+ #[prost(uint64, tag = "8")]
+ pub batch_size: u64,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct UnresolvedShuffleExecNode {
diff --git a/ballista/core/src/serde/mod.rs b/ballista/core/src/serde/mod.rs
index 13c4a53f4..7eb4ecefd 100644
--- a/ballista/core/src/serde/mod.rs
+++ b/ballista/core/src/serde/mod.rs
@@ -327,12 +327,18 @@ impl PhysicalExtensionCodec for
BallistaPhysicalExtensionCodec {
)
})?;
+ let batch_size = if sort_shuffle_writer.batch_size > 0 {
+ sort_shuffle_writer.batch_size as usize
+ } else {
+ 8192 // default for backwards compatibility
+ };
let config = SortShuffleConfig::new(
true,
sort_shuffle_writer.buffer_size as usize,
sort_shuffle_writer.memory_limit as usize,
sort_shuffle_writer.spill_threshold,
datafusion::arrow::ipc::CompressionType::LZ4_FRAME,
+ batch_size,
);
Ok(Arc::new(SortShuffleWriterExec::try_new(
@@ -478,6 +484,7 @@ impl PhysicalExtensionCodec for
BallistaPhysicalExtensionCodec {
buffer_size: config.buffer_size as u64,
memory_limit: config.memory_limit as u64,
spill_threshold: config.spill_threshold,
+ batch_size: config.batch_size as u64,
},
)),
};
diff --git a/ballista/scheduler/src/planner.rs
b/ballista/scheduler/src/planner.rs
index fd5321ecb..d00f0bb67 100644
--- a/ballista/scheduler/src/planner.rs
+++ b/ballista/scheduler/src/planner.rs
@@ -359,6 +359,7 @@ fn create_shuffle_writer_with_config(
ballista_config.shuffle_sort_based_memory_limit(),
ballista_config.shuffle_sort_based_spill_threshold(),
datafusion::arrow::ipc::CompressionType::LZ4_FRAME,
+ ballista_config.shuffle_sort_based_batch_size(),
);
return Ok(Arc::new(SortShuffleWriterExec::try_new(
diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml
index ac6e262ef..97e4ece91 100644
--- a/benchmarks/Cargo.toml
+++ b/benchmarks/Cargo.toml
@@ -51,3 +51,8 @@ tokio = { version = "^1.44", features = [
] }
[dev-dependencies]
+criterion = { version = "0.5", features = ["html_reports"] }
+
+[[bench]]
+harness = false
+name = "sort_shuffle"
diff --git a/benchmarks/benches/sort_shuffle.rs
b/benchmarks/benches/sort_shuffle.rs
new file mode 100644
index 000000000..30a7423ac
--- /dev/null
+++ b/benchmarks/benches/sort_shuffle.rs
@@ -0,0 +1,315 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Criterion benchmarks for sort-based shuffle writer.
+
+use std::sync::Arc;
+
+use ballista_core::execution_plans::SortShuffleWriterExec;
+use ballista_core::execution_plans::sort_shuffle::SortShuffleConfig;
+use criterion::{Criterion, criterion_group, criterion_main};
+use datafusion::arrow::array::{
+ BooleanArray, Float32Array, Float64Array, Int8Array, Int16Array,
Int32Array,
+ Int64Array, StringArray, UInt8Array, UInt16Array, UInt32Array, UInt64Array,
+};
+use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
+use datafusion::arrow::ipc::CompressionType;
+use datafusion::arrow::record_batch::RecordBatch;
+use datafusion::datasource::memory::MemorySourceConfig;
+use datafusion::datasource::source::DataSourceExec;
+use datafusion::physical_plan::ExecutionPlan;
+use datafusion::physical_plan::Partitioning;
+use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
+use datafusion::physical_plan::expressions::Column;
+use datafusion::prelude::SessionContext;
+use futures::TryStreamExt;
+use rand::Rng;
+use tempfile::TempDir;
+
+const BATCH_SIZE: usize = 8192;
+const NUM_OUTPUT_PARTITIONS: usize = 200;
+// 100 columns: mix of primitive types
+const NUM_INT_COLS: usize = 12; // i8, i16, i32, i64 x3 each
+const NUM_UINT_COLS: usize = 12; // u8, u16, u32, u64 x3 each
+const NUM_FLOAT_COLS: usize = 6; // f32, f64 x3 each
+const NUM_BOOL_COLS: usize = 5;
+const NUM_STRING_COLS: usize = 5;
+// Remaining to reach 100
+const NUM_EXTRA_INT32_COLS: usize =
+ 100 - NUM_INT_COLS - NUM_UINT_COLS - NUM_FLOAT_COLS - NUM_BOOL_COLS -
NUM_STRING_COLS;
+
+fn build_schema() -> SchemaRef {
+ let mut fields = Vec::with_capacity(100);
+ let mut idx = 0;
+
+ for _ in 0..3 {
+ fields.push(Field::new(format!("c{idx}"), DataType::Int8, true));
+ idx += 1;
+ }
+ for _ in 0..3 {
+ fields.push(Field::new(format!("c{idx}"), DataType::Int16, true));
+ idx += 1;
+ }
+ for _ in 0..3 {
+ fields.push(Field::new(format!("c{idx}"), DataType::Int32, true));
+ idx += 1;
+ }
+ for _ in 0..3 {
+ fields.push(Field::new(format!("c{idx}"), DataType::Int64, true));
+ idx += 1;
+ }
+
+ for _ in 0..3 {
+ fields.push(Field::new(format!("c{idx}"), DataType::UInt8, true));
+ idx += 1;
+ }
+ for _ in 0..3 {
+ fields.push(Field::new(format!("c{idx}"), DataType::UInt16, true));
+ idx += 1;
+ }
+ for _ in 0..3 {
+ fields.push(Field::new(format!("c{idx}"), DataType::UInt32, true));
+ idx += 1;
+ }
+ for _ in 0..3 {
+ fields.push(Field::new(format!("c{idx}"), DataType::UInt64, true));
+ idx += 1;
+ }
+
+ for _ in 0..3 {
+ fields.push(Field::new(format!("c{idx}"), DataType::Float32, true));
+ idx += 1;
+ }
+ for _ in 0..3 {
+ fields.push(Field::new(format!("c{idx}"), DataType::Float64, true));
+ idx += 1;
+ }
+
+ for _ in 0..NUM_BOOL_COLS {
+ fields.push(Field::new(format!("c{idx}"), DataType::Boolean, true));
+ idx += 1;
+ }
+
+ for _ in 0..NUM_STRING_COLS {
+ fields.push(Field::new(format!("c{idx}"), DataType::Utf8, true));
+ idx += 1;
+ }
+
+ for _ in 0..NUM_EXTRA_INT32_COLS {
+ fields.push(Field::new(format!("c{idx}"), DataType::Int32, true));
+ idx += 1;
+ }
+
+ Arc::new(Schema::new(fields))
+}
+
+fn build_batch(schema: &SchemaRef) -> RecordBatch {
+ let mut rng = rand::rng();
+ let mut columns: Vec<Arc<dyn datafusion::arrow::array::Array>> =
+ Vec::with_capacity(100);
+
+ // i8 x3
+ for _ in 0..3 {
+ let vals: Vec<i8> = (0..BATCH_SIZE).map(|_| rng.random()).collect();
+ columns.push(Arc::new(Int8Array::from(vals)));
+ }
+ // i16 x3
+ for _ in 0..3 {
+ let vals: Vec<i16> = (0..BATCH_SIZE).map(|_| rng.random()).collect();
+ columns.push(Arc::new(Int16Array::from(vals)));
+ }
+ // i32 x3
+ for _ in 0..3 {
+ let vals: Vec<i32> = (0..BATCH_SIZE).map(|_| rng.random()).collect();
+ columns.push(Arc::new(Int32Array::from(vals)));
+ }
+ // i64 x3
+ for _ in 0..3 {
+ let vals: Vec<i64> = (0..BATCH_SIZE).map(|_| rng.random()).collect();
+ columns.push(Arc::new(Int64Array::from(vals)));
+ }
+
+ // u8 x3
+ for _ in 0..3 {
+ let vals: Vec<u8> = (0..BATCH_SIZE).map(|_| rng.random()).collect();
+ columns.push(Arc::new(UInt8Array::from(vals)));
+ }
+ // u16 x3
+ for _ in 0..3 {
+ let vals: Vec<u16> = (0..BATCH_SIZE).map(|_| rng.random()).collect();
+ columns.push(Arc::new(UInt16Array::from(vals)));
+ }
+ // u32 x3
+ for _ in 0..3 {
+ let vals: Vec<u32> = (0..BATCH_SIZE).map(|_| rng.random()).collect();
+ columns.push(Arc::new(UInt32Array::from(vals)));
+ }
+ // u64 x3
+ for _ in 0..3 {
+ let vals: Vec<u64> = (0..BATCH_SIZE).map(|_| rng.random()).collect();
+ columns.push(Arc::new(UInt64Array::from(vals)));
+ }
+
+ // f32 x3
+ for _ in 0..3 {
+ let vals: Vec<f32> = (0..BATCH_SIZE).map(|_| rng.random()).collect();
+ columns.push(Arc::new(Float32Array::from(vals)));
+ }
+ // f64 x3
+ for _ in 0..3 {
+ let vals: Vec<f64> = (0..BATCH_SIZE).map(|_| rng.random()).collect();
+ columns.push(Arc::new(Float64Array::from(vals)));
+ }
+
+ // bool
+ for _ in 0..NUM_BOOL_COLS {
+ let vals: Vec<bool> = (0..BATCH_SIZE).map(|_| rng.random()).collect();
+ columns.push(Arc::new(BooleanArray::from(vals)));
+ }
+
+ // string (short random strings)
+ for _ in 0..NUM_STRING_COLS {
+ let vals: Vec<String> = (0..BATCH_SIZE)
+ .map(|_| format!("s{}", rng.random_range(0..100_000)))
+ .collect();
+ columns.push(Arc::new(StringArray::from(vals)));
+ }
+
+ // remaining i32 columns
+ for _ in 0..NUM_EXTRA_INT32_COLS {
+ let vals: Vec<i32> = (0..BATCH_SIZE).map(|_| rng.random()).collect();
+ columns.push(Arc::new(Int32Array::from(vals)));
+ }
+
+ RecordBatch::try_new(schema.clone(), columns).unwrap()
+}
+
+fn create_input(num_batches: usize) -> Arc<dyn
datafusion::physical_plan::ExecutionPlan> {
+ let schema = build_schema();
+ let batch = build_batch(&schema);
+ let partition: Vec<RecordBatch> = (0..num_batches).map(|_|
batch.clone()).collect();
+ let partitions = vec![partition];
+
+ let memory_source =
+ Arc::new(MemorySourceConfig::try_new(&partitions, schema,
None).unwrap());
+ Arc::new(CoalescePartitionsExec::new(Arc::new(DataSourceExec::new(
+ memory_source,
+ ))))
+}
+
+fn run_sort_shuffle(
+ input: Arc<dyn datafusion::physical_plan::ExecutionPlan>,
+ work_dir: &str,
+ memory_limit: usize,
+) {
+ let rt = tokio::runtime::Runtime::new().unwrap();
+ let session_ctx = SessionContext::new();
+ let task_ctx = session_ctx.task_ctx();
+
+ let config = SortShuffleConfig::new(
+ true,
+ 1024 * 1024, // 1MB buffer
+ memory_limit,
+ 0.8,
+ CompressionType::LZ4_FRAME,
+ 8192,
+ );
+
+ let writer = SortShuffleWriterExec::try_new(
+ "bench_job".to_string(),
+ 1,
+ input,
+ work_dir.to_string(),
+ Partitioning::Hash(vec![Arc::new(Column::new("c0", 0))],
NUM_OUTPUT_PARTITIONS),
+ config,
+ )
+ .unwrap();
+
+ rt.block_on(async {
+ let mut stream = writer.execute(0, task_ctx).unwrap();
+ while let Some(_batch) = stream.try_next().await.unwrap() {}
+ });
+}
+
+fn bench_no_spill(c: &mut Criterion) {
+ let mut group = c.benchmark_group("sort_shuffle_no_spill");
+ group.sample_size(10);
+
+ // 10 batches of 8192 rows = 81920 rows, 256MB limit
+ let input = create_input(10);
+ let work_dir = TempDir::new().unwrap();
+
+ group.bench_function("10_batches_200_partitions", |b| {
+ b.iter(|| {
+ run_sort_shuffle(
+ input.clone(),
+ work_dir.path().to_str().unwrap(),
+ 256 * 1024 * 1024,
+ );
+ });
+ });
+
+ // 50 batches
+ let input = create_input(50);
+ group.bench_function("50_batches_200_partitions", |b| {
+ b.iter(|| {
+ run_sort_shuffle(
+ input.clone(),
+ work_dir.path().to_str().unwrap(),
+ 256 * 1024 * 1024,
+ );
+ });
+ });
+
+ group.finish();
+}
+
+fn bench_with_spill(c: &mut Criterion) {
+ let mut group = c.benchmark_group("sort_shuffle_with_spill");
+ group.sample_size(10);
+
+ let work_dir = TempDir::new().unwrap();
+
+ // 50 batches with 8MB memory limit to force spilling
+ let input = create_input(50);
+ group.bench_function("50_batches_200_partitions_8mb_limit", |b| {
+ b.iter(|| {
+ run_sort_shuffle(
+ input.clone(),
+ work_dir.path().to_str().unwrap(),
+ 8 * 1024 * 1024,
+ );
+ });
+ });
+
+ // 50 batches with 2MB memory limit to force heavy spilling
+ let input = create_input(50);
+ group.bench_function("50_batches_200_partitions_2mb_limit", |b| {
+ b.iter(|| {
+ run_sort_shuffle(
+ input.clone(),
+ work_dir.path().to_str().unwrap(),
+ 2 * 1024 * 1024,
+ );
+ });
+ });
+
+ group.finish();
+}
+
+criterion_group!(benches, bench_no_spill, bench_with_spill);
+criterion_main!(benches);
diff --git a/benchmarks/src/bin/shuffle_bench.rs
b/benchmarks/src/bin/shuffle_bench.rs
index 202c123c8..5ced6ee49 100644
--- a/benchmarks/src/bin/shuffle_bench.rs
+++ b/benchmarks/src/bin/shuffle_bench.rs
@@ -226,6 +226,7 @@ async fn benchmark_sort_shuffle(
memory_limit,
0.8,
CompressionType::LZ4_FRAME,
+ 8192,
);
// Create sort shuffle writer
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]