Skip to content

Commit c6a91b6

Browse files
committed
Add tests
1 parent 43d9fb2 commit c6a91b6

File tree

3 files changed

+128
-2
lines changed

3 files changed

+128
-2
lines changed

datafusion/physical-expr/src/expressions/dynamic_filters.rs

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ struct Inner {
7878
/// This is used for [`PhysicalExpr::snapshot_generation`] to have a cheap check for changes.
7979
generation: u64,
8080
expr: Arc<dyn PhysicalExpr>,
81+
/// Flag for quick synchronous check if filter is complete.
82+
/// This is redundant with the watch channel state, but allows us to return immediately
83+
/// from `wait_complete()` without subscribing if already complete.
84+
is_complete: bool,
8185
}
8286

8387
impl Inner {
@@ -87,6 +91,7 @@ impl Inner {
8791
// This is not currently used anywhere but it seems useful to have this simple distinction.
8892
generation: 1,
8993
expr,
94+
is_complete: false,
9095
}
9196
}
9297

@@ -231,6 +236,7 @@ impl DynamicFilterPhysicalExpr {
231236
*current = Inner {
232237
generation: new_generation,
233238
expr: new_expr,
239+
is_complete: current.is_complete,
234240
};
235241
drop(current); // Release the lock before broadcasting
236242

@@ -246,7 +252,11 @@ impl DynamicFilterPhysicalExpr {
246252
/// This signals that all expected updates have been received.
247253
/// Waiters using [`Self::wait_complete`] will be notified.
248254
pub fn mark_complete(&self) {
249-
let current_generation = self.inner.read().generation;
255+
let mut current = self.inner.write();
256+
let current_generation = current.generation;
257+
current.is_complete = true;
258+
drop(current);
259+
250260
// Broadcast completion to all waiters
251261
let _ = self.state_watch.send(FilterState::Complete {
252262
generation: current_generation,
@@ -274,8 +284,11 @@ impl DynamicFilterPhysicalExpr {
274284
/// Unlike [`Self::wait_update`], this method guarantees that when it returns,
275285
/// the filter is fully complete with no more updates expected.
276286
pub async fn wait_complete(&self) {
287+
if self.inner.read().is_complete {
288+
return;
289+
}
290+
277291
let mut rx = self.state_watch.subscribe();
278-
// Wait until the state becomes Complete
279292
let _ = rx
280293
.wait_for(|state| matches!(state, FilterState::Complete { .. }))
281294
.await;
@@ -580,4 +593,18 @@ mod test {
580593
"Expected err when evaluate is called after changing the expression."
581594
);
582595
}
596+
597+
#[tokio::test]
598+
async fn test_wait_complete_already_complete() {
599+
let dynamic_filter = Arc::new(DynamicFilterPhysicalExpr::new(
600+
vec![],
601+
lit(42) as Arc<dyn PhysicalExpr>,
602+
));
603+
604+
// Mark as complete immediately
605+
dynamic_filter.mark_complete();
606+
607+
// wait_complete should return immediately
608+
dynamic_filter.wait_complete().await;
609+
}
583610
}

datafusion/physical-plan/src/joins/hash_join/exec.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4494,4 +4494,55 @@ mod tests {
44944494
fn columns(schema: &Schema) -> Vec<String> {
44954495
schema.fields().iter().map(|f| f.name().clone()).collect()
44964496
}
4497+
4498+
/// This test verifies that the dynamic filter is marked as complete after HashJoinExec finishes building the hash table.
4499+
#[tokio::test]
4500+
async fn test_hash_join_marks_filter_complete() -> Result<()> {
4501+
let task_ctx = Arc::new(TaskContext::default());
4502+
let left = build_table(
4503+
("a1", &vec![1, 2, 3]),
4504+
("b1", &vec![4, 5, 6]),
4505+
("c1", &vec![7, 8, 9]),
4506+
);
4507+
let right = build_table(
4508+
("a2", &vec![10, 20, 30]),
4509+
("b1", &vec![4, 5, 6]),
4510+
("c2", &vec![70, 80, 90]),
4511+
);
4512+
4513+
let on = vec![(
4514+
Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
4515+
Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
4516+
)];
4517+
4518+
// Create a dynamic filter manually
4519+
let dynamic_filter = HashJoinExec::create_dynamic_filter(&on);
4520+
let dynamic_filter_clone = Arc::clone(&dynamic_filter);
4521+
4522+
// Create HashJoinExec with the dynamic filter
4523+
let mut join = HashJoinExec::try_new(
4524+
left,
4525+
right,
4526+
on,
4527+
None,
4528+
&JoinType::Inner,
4529+
None,
4530+
PartitionMode::CollectLeft,
4531+
NullEquality::NullEqualsNothing,
4532+
)?;
4533+
join.dynamic_filter = Some(HashJoinExecDynamicFilter {
4534+
filter: dynamic_filter,
4535+
bounds_accumulator: OnceLock::new(),
4536+
});
4537+
4538+
// Execute the join
4539+
let stream = join.execute(0, task_ctx)?;
4540+
let _batches = common::collect(stream).await?;
4541+
4542+
// After the join completes, the dynamic filter should be marked as complete
4543+
// wait_complete() should return immediately
4544+
dynamic_filter_clone.wait_complete().await;
4545+
4546+
Ok(())
4547+
}
44974548
}

datafusion/physical-plan/src/topk/mod.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,4 +1201,52 @@ mod tests {
12011201

12021202
Ok(())
12031203
}
1204+
1205+
/// This test verifies that the dynamic filter is marked as complete after TopK processing finishes.
1206+
#[tokio::test]
1207+
async fn test_topk_marks_filter_complete() -> Result<()> {
1208+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1209+
1210+
let sort_expr = PhysicalSortExpr {
1211+
expr: col("a", schema.as_ref())?,
1212+
options: SortOptions::default(),
1213+
};
1214+
1215+
let full_expr = LexOrdering::from([sort_expr.clone()]);
1216+
let prefix = vec![sort_expr];
1217+
1218+
// Create a dummy runtime environment and metrics
1219+
let runtime = Arc::new(RuntimeEnv::default());
1220+
let metrics = ExecutionPlanMetricsSet::new();
1221+
1222+
// Create a dynamic filter that we'll check for completion
1223+
let dynamic_filter = Arc::new(DynamicFilterPhysicalExpr::new(vec![], lit(true)));
1224+
let dynamic_filter_clone = Arc::clone(&dynamic_filter);
1225+
1226+
// Create a TopK instance
1227+
let mut topk = TopK::try_new(
1228+
0,
1229+
Arc::clone(&schema),
1230+
prefix,
1231+
full_expr,
1232+
2,
1233+
10,
1234+
runtime,
1235+
&metrics,
1236+
Arc::new(RwLock::new(TopKDynamicFilters::new(dynamic_filter))),
1237+
)?;
1238+
1239+
let array: ArrayRef = Arc::new(Int32Array::from(vec![Some(3), Some(1), Some(2)]));
1240+
let batch = RecordBatch::try_new(Arc::clone(&schema), vec![array])?;
1241+
topk.insert_batch(batch)?;
1242+
1243+
// Call emit to finish TopK processing
1244+
let _results: Vec<_> = topk.emit()?.try_collect().await?;
1245+
1246+
// After emit is called, the dynamic filter should be marked as complete
1247+
// wait_complete() should return immediately
1248+
dynamic_filter_clone.wait_complete().await;
1249+
1250+
Ok(())
1251+
}
12041252
}

0 commit comments

Comments
 (0)