Skip to content

Commit 50e073c

Browse files
authored
refactor: Use BufferedBatchState enum for SMJ spilling (#17429)
* refactor: Use `BufferedBatchState` enum for SMJ spilling
1 parent f5bdc2d commit 50e073c

File tree

1 file changed

+36
-33
lines changed
  • datafusion/physical-plan/src/joins/sort_merge_join

1 file changed

+36
-33
lines changed

datafusion/physical-plan/src/joins/sort_merge_join/stream.rs

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -185,11 +185,12 @@ impl StreamedBatch {
185185
}
186186

187187
/// A buffered batch that contains contiguous rows with same join key
188+
///
189+
/// `BufferedBatch` can exist as either an in-memory `RecordBatch` or a `RefCountedTempFile` on disk.
188190
#[derive(Debug)]
189191
pub(super) struct BufferedBatch {
190-
/// The buffered record batch
191-
/// None if the batch spilled to disk th
192-
pub batch: Option<RecordBatch>,
192+
/// Represents in memory or spilled record batch
193+
pub batch: BufferedBatchState,
193194
/// The range in which the rows share the same join key
194195
pub range: Range<usize>,
195196
/// Array refs of the join key
@@ -207,10 +208,6 @@ pub(super) struct BufferedBatch {
207208
/// but if batch is spilled to disk this property is preferable
208209
/// and less expensive
209210
pub num_rows: usize,
210-
/// An optional temp spill file name on the disk if the batch spilled
211-
/// None by default
212-
/// Some(fileName) if the batch spilled to the disk
213-
pub spill_file: Option<RefCountedTempFile>,
214211
}
215212

216213
impl BufferedBatch {
@@ -238,18 +235,27 @@ impl BufferedBatch {
238235

239236
let num_rows = batch.num_rows();
240237
BufferedBatch {
241-
batch: Some(batch),
238+
batch: BufferedBatchState::InMemory(batch),
242239
range,
243240
join_arrays,
244241
null_joined: vec![],
245242
size_estimation,
246243
join_filter_not_matched_map: HashMap::new(),
247244
num_rows,
248-
spill_file: None,
249245
}
250246
}
251247
}
252248

249+
// TODO: Spill join arrays (https://github.com/apache/datafusion/pull/17429)
250+
// Used to represent whether the buffered data is currently in memory or written to disk
251+
#[derive(Debug)]
252+
pub(super) enum BufferedBatchState {
253+
// In memory record batch
254+
InMemory(RecordBatch),
255+
// Spilled temp file
256+
Spilled(RefCountedTempFile),
257+
}
258+
253259
/// Sort-Merge join stream that consumes streamed and buffered data streams
254260
/// and produces joined output stream.
255261
pub(super) struct SortMergeJoinStream {
@@ -849,11 +855,10 @@ impl SortMergeJoinStream {
849855

850856
fn free_reservation(&mut self, buffered_batch: BufferedBatch) -> Result<()> {
851857
// Shrink memory usage for in-memory batches only
852-
if buffered_batch.spill_file.is_none() && buffered_batch.batch.is_some() {
858+
if let BufferedBatchState::InMemory(_) = buffered_batch.batch {
853859
self.reservation
854860
.try_shrink(buffered_batch.size_estimation)?;
855861
}
856-
857862
Ok(())
858863
}
859864

@@ -867,21 +872,21 @@ impl SortMergeJoinStream {
867872
}
868873
Err(_) if self.runtime_env.disk_manager.tmp_files_enabled() => {
869874
// Spill buffered batch to disk
870-
if let Some(batch) = buffered_batch.batch {
871-
let spill_file = self
872-
.spill_manager
873-
.spill_record_batch_and_finish(
874-
&[batch],
875-
"sort_merge_join_buffered_spill",
876-
)?
877-
.unwrap(); // Operation only return None if no batches are spilled, here we ensure that at least one batch is spilled
878-
879-
buffered_batch.spill_file = Some(spill_file);
880-
buffered_batch.batch = None;
881-
882-
Ok(())
883-
} else {
884-
internal_err!("Buffered batch has empty body")
875+
876+
match buffered_batch.batch {
877+
BufferedBatchState::InMemory(batch) => {
878+
let spill_file = self
879+
.spill_manager
880+
.spill_record_batch_and_finish(
881+
&[batch],
882+
"sort_merge_join_buffered_spill",
883+
)?
884+
.unwrap(); // Operation only return None if no batches are spilled, here we ensure that at least one batch is spilled
885+
886+
buffered_batch.batch = BufferedBatchState::Spilled(spill_file);
887+
Ok(())
888+
}
889+
_ => internal_err!("Buffered batch has empty body"),
885890
}
886891
}
887892
Err(e) => exec_err!("{}. Disk spilling disabled.", e.message()),
@@ -1741,16 +1746,16 @@ fn fetch_right_columns_from_batch_by_idxs(
17411746
buffered_batch: &BufferedBatch,
17421747
buffered_indices: &UInt64Array,
17431748
) -> Result<Vec<ArrayRef>> {
1744-
match (&buffered_batch.spill_file, &buffered_batch.batch) {
1749+
match &buffered_batch.batch {
17451750
// In memory batch
1746-
(None, Some(batch)) => Ok(batch
1751+
BufferedBatchState::InMemory(batch) => Ok(batch
17471752
.columns()
17481753
.iter()
17491754
.map(|column| take(column, &buffered_indices, None))
17501755
.collect::<Result<Vec<_>, ArrowError>>()
17511756
.map_err(Into::<DataFusionError>::into)?),
17521757
// If the batch was spilled to disk, less likely
1753-
(Some(spill_file), None) => {
1758+
BufferedBatchState::Spilled(spill_file) => {
17541759
let mut buffered_cols: Vec<ArrayRef> =
17551760
Vec::with_capacity(buffered_indices.len());
17561761

@@ -1763,10 +1768,8 @@ fn fetch_right_columns_from_batch_by_idxs(
17631768
});
17641769
}
17651770

1766-
Ok(buffered_cols)
1767-
}
1768-
// Invalid combination
1769-
(spill, batch) => internal_err!("Unexpected buffered batch spill status. Spill exists: {}. In-memory exists: {}", spill.is_some(), batch.is_some()),
1771+
Ok(buffered_cols)
1772+
}
17701773
}
17711774
}
17721775

0 commit comments

Comments
 (0)