Skip to content

Commit 99cedf6

Browse files
Refactor HashJoinExec to progressively accumulate dynamic filter bounds instead of computing them after data is accumulated (apache#17444) (#46)
(cherry picked from commit 5b833b9) Co-authored-by: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>
1 parent 1a31b79 commit 99cedf6

File tree

3 files changed

+160
-36
lines changed

3 files changed

+160
-36
lines changed

Cargo.lock

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/physical-plan/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ datafusion-common = { workspace = true, default-features = true }
5353
datafusion-common-runtime = { workspace = true, default-features = true }
5454
datafusion-execution = { workspace = true }
5555
datafusion-expr = { workspace = true }
56-
datafusion-functions-aggregate-common = { workspace = true }
56+
datafusion-functions-aggregate = { workspace = true }
5757
datafusion-functions-window-common = { workspace = true }
5858
datafusion-physical-expr = { workspace = true, default-features = true }
5959
datafusion-physical-expr-common = { workspace = true }

datafusion/physical-plan/src/joins/hash_join/exec.rs

Lines changed: 159 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -54,20 +54,21 @@ use crate::{
5454
PlanProperties, SendableRecordBatchStream, Statistics,
5555
};
5656

57-
use arrow::array::{Array, ArrayRef, BooleanBufferBuilder};
57+
use arrow::array::{ArrayRef, BooleanBufferBuilder};
5858
use arrow::compute::concat_batches;
5959
use arrow::datatypes::SchemaRef;
6060
use arrow::record_batch::RecordBatch;
6161
use arrow::util::bit_util;
62+
use arrow_schema::DataType;
6263
use datafusion_common::config::ConfigOptions;
6364
use datafusion_common::utils::memory::estimate_memory_size;
6465
use datafusion_common::{
6566
internal_err, plan_err, project_schema, JoinSide, JoinType, NullEquality, Result,
66-
ScalarValue,
6767
};
6868
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
6969
use datafusion_execution::TaskContext;
70-
use datafusion_functions_aggregate_common::min_max::{max_batch, min_batch};
70+
use datafusion_expr::Accumulator;
71+
use datafusion_functions_aggregate::min_max::{MaxAccumulator, MinAccumulator};
7172
use datafusion_physical_expr::equivalence::{
7273
join_equivalence_properties, ProjectionMapping,
7374
};
@@ -1188,29 +1189,123 @@ impl ExecutionPlan for HashJoinExec {
11881189
}
11891190
}
11901191

1191-
/// Compute min/max bounds for each column in the given arrays
1192-
fn compute_bounds(arrays: &[ArrayRef]) -> Result<Vec<ColumnBounds>> {
1193-
arrays
1194-
.iter()
1195-
.map(|array| {
1196-
if array.is_empty() {
1197-
// Return NULL values for empty arrays
1198-
return Ok(ColumnBounds::new(
1199-
ScalarValue::try_from(array.data_type())?,
1200-
ScalarValue::try_from(array.data_type())?,
1201-
));
1192+
/// Accumulator for collecting min/max bounds from build-side data during hash join.
1193+
///
1194+
/// This struct encapsulates the logic for progressively computing column bounds
1195+
/// (minimum and maximum values) for a specific join key expression as batches
1196+
/// are processed during the build phase of a hash join.
1197+
///
1198+
/// The bounds are used for dynamic filter pushdown optimization, where filters
1199+
/// based on the actual data ranges can be pushed down to the probe side to
1200+
/// eliminate unnecessary data early.
1201+
struct CollectLeftAccumulator {
1202+
/// The physical expression to evaluate for each batch
1203+
expr: Arc<dyn PhysicalExpr>,
1204+
/// Accumulator for tracking the minimum value across all batches
1205+
min: MinAccumulator,
1206+
/// Accumulator for tracking the maximum value across all batches
1207+
max: MaxAccumulator,
1208+
}
1209+
1210+
impl CollectLeftAccumulator {
1211+
/// Creates a new accumulator for tracking bounds of a join key expression.
1212+
///
1213+
/// # Arguments
1214+
/// * `expr` - The physical expression to track bounds for
1215+
/// * `schema` - The schema of the input data
1216+
///
1217+
/// # Returns
1218+
/// A new `CollectLeftAccumulator` instance configured for the expression's data type
1219+
fn try_new(expr: Arc<dyn PhysicalExpr>, schema: &SchemaRef) -> Result<Self> {
1220+
/// Recursively unwraps dictionary types to get the underlying value type.
1221+
fn dictionary_value_type(data_type: &DataType) -> DataType {
1222+
match data_type {
1223+
DataType::Dictionary(_, value_type) => {
1224+
dictionary_value_type(value_type.as_ref())
1225+
}
1226+
_ => data_type.clone(),
12021227
}
1228+
}
1229+
1230+
let data_type = expr
1231+
.data_type(schema)
1232+
// Min/Max can operate on dictionary data but expect to be initialized with the underlying value type
1233+
.map(|dt| dictionary_value_type(&dt))?;
1234+
Ok(Self {
1235+
expr,
1236+
min: MinAccumulator::try_new(&data_type)?,
1237+
max: MaxAccumulator::try_new(&data_type)?,
1238+
})
1239+
}
12031240

1204-
// Use Arrow kernels for efficient min/max computation
1205-
let min_val = min_batch(array)?;
1206-
let max_val = max_batch(array)?;
1241+
/// Updates the accumulators with values from a new batch.
1242+
///
1243+
/// Evaluates the expression on the batch and updates both min and max
1244+
/// accumulators with the resulting values.
1245+
///
1246+
/// # Arguments
1247+
/// * `batch` - The record batch to process
1248+
///
1249+
/// # Returns
1250+
/// Ok(()) if the update succeeds, or an error if expression evaluation fails
1251+
fn update_batch(&mut self, batch: &RecordBatch) -> Result<()> {
1252+
let array = self.expr.evaluate(batch)?.into_array(batch.num_rows())?;
1253+
self.min.update_batch(std::slice::from_ref(&array))?;
1254+
self.max.update_batch(std::slice::from_ref(&array))?;
1255+
Ok(())
1256+
}
12071257

1208-
Ok(ColumnBounds::new(min_val, max_val))
1258+
/// Finalizes the accumulation and returns the computed bounds.
1259+
///
1260+
/// Consumes self to extract the final min and max values from the accumulators.
1261+
///
1262+
/// # Returns
1263+
/// The `ColumnBounds` containing the minimum and maximum values observed
1264+
fn evaluate(mut self) -> Result<ColumnBounds> {
1265+
Ok(ColumnBounds::new(
1266+
self.min.evaluate()?,
1267+
self.max.evaluate()?,
1268+
))
1269+
}
1270+
}
1271+
1272+
/// State for collecting the build-side data during hash join
1273+
struct BuildSideState {
1274+
batches: Vec<RecordBatch>,
1275+
num_rows: usize,
1276+
metrics: BuildProbeJoinMetrics,
1277+
reservation: MemoryReservation,
1278+
bounds_accumulators: Option<Vec<CollectLeftAccumulator>>,
1279+
}
1280+
1281+
impl BuildSideState {
1282+
/// Create a new BuildSideState with optional accumulators for bounds computation
1283+
fn try_new(
1284+
metrics: BuildProbeJoinMetrics,
1285+
reservation: MemoryReservation,
1286+
on_left: Vec<Arc<dyn PhysicalExpr>>,
1287+
schema: &SchemaRef,
1288+
should_compute_bounds: bool,
1289+
) -> Result<Self> {
1290+
Ok(Self {
1291+
batches: Vec::new(),
1292+
num_rows: 0,
1293+
metrics,
1294+
reservation,
1295+
bounds_accumulators: should_compute_bounds
1296+
.then(|| {
1297+
on_left
1298+
.iter()
1299+
.map(|expr| {
1300+
CollectLeftAccumulator::try_new(Arc::clone(expr), schema)
1301+
})
1302+
.collect::<Result<Vec<_>>>()
1303+
})
1304+
.transpose()?,
12091305
})
1210-
.collect()
1306+
}
12111307
}
12121308

1213-
#[expect(clippy::too_many_arguments)]
12141309
/// Collects all batches from the left (build) side stream and creates a hash map for joining.
12151310
///
12161311
/// This function is responsible for:
@@ -1239,6 +1334,7 @@ fn compute_bounds(arrays: &[ArrayRef]) -> Result<Vec<ColumnBounds>> {
12391334
/// # Returns
12401335
/// `JoinLeftData` containing the hash map, consolidated batch, join key values,
12411336
/// visited indices bitmap, and computed bounds (if requested).
1337+
#[allow(clippy::too_many_arguments)]
12421338
async fn collect_left_input(
12431339
random_state: RandomState,
12441340
left_stream: SendableRecordBatchStream,
@@ -1254,24 +1350,48 @@ async fn collect_left_input(
12541350
// This operation performs 2 steps at once:
12551351
// 1. creates a [JoinHashMap] of all batches from the stream
12561352
// 2. stores the batches in a vector.
1257-
let initial = (Vec::new(), 0, metrics, reservation);
1258-
let (batches, num_rows, metrics, mut reservation) = left_stream
1259-
.try_fold(initial, |mut acc, batch| async {
1353+
let initial = BuildSideState::try_new(
1354+
metrics,
1355+
reservation,
1356+
on_left.clone(),
1357+
&schema,
1358+
should_compute_bounds,
1359+
)?;
1360+
1361+
let state = left_stream
1362+
.try_fold(initial, |mut state, batch| async move {
1363+
// Update accumulators if computing bounds
1364+
if let Some(ref mut accumulators) = state.bounds_accumulators {
1365+
for accumulator in accumulators {
1366+
accumulator.update_batch(&batch)?;
1367+
}
1368+
}
1369+
1370+
// Decide if we spill or not
12601371
let batch_size = get_record_batch_memory_size(&batch);
12611372
// Reserve memory for incoming batch
1262-
acc.3.try_grow(batch_size)?;
1373+
state.reservation.try_grow(batch_size)?;
12631374
// Update metrics
1264-
acc.2.build_mem_used.add(batch_size);
1265-
acc.2.build_input_batches.add(1);
1266-
acc.2.build_input_rows.add(batch.num_rows());
1375+
state.metrics.build_mem_used.add(batch_size);
1376+
state.metrics.build_input_batches.add(1);
1377+
state.metrics.build_input_rows.add(batch.num_rows());
12671378
// Update row count
1268-
acc.1 += batch.num_rows();
1379+
state.num_rows += batch.num_rows();
12691380
// Push batch to output
1270-
acc.0.push(batch);
1271-
Ok(acc)
1381+
state.batches.push(batch);
1382+
Ok(state)
12721383
})
12731384
.await?;
12741385

1386+
// Extract fields from state
1387+
let BuildSideState {
1388+
batches,
1389+
num_rows,
1390+
metrics,
1391+
mut reservation,
1392+
bounds_accumulators,
1393+
} = state;
1394+
12751395
// Estimation of memory size, required for hashtable, prior to allocation.
12761396
// Final result can be verified using `RawTable.allocation_info()`
12771397
let fixed_size_u32 = size_of::<JoinHashMapU32>();
@@ -1338,10 +1458,15 @@ async fn collect_left_input(
13381458
.collect::<Result<Vec<_>>>()?;
13391459

13401460
// Compute bounds for dynamic filter if enabled
1341-
let bounds = if should_compute_bounds && num_rows > 0 {
1342-
Some(compute_bounds(&left_values)?)
1343-
} else {
1344-
None
1461+
let bounds = match bounds_accumulators {
1462+
Some(accumulators) if num_rows > 0 => {
1463+
let bounds = accumulators
1464+
.into_iter()
1465+
.map(CollectLeftAccumulator::evaluate)
1466+
.collect::<Result<Vec<_>>>()?;
1467+
Some(bounds)
1468+
}
1469+
_ => None,
13451470
};
13461471

13471472
let data = JoinLeftData::new(

0 commit comments

Comments
 (0)