Skip to content

Commit 85b26d8

Browse files
committed
add filter
1 parent 092d2bf commit 85b26d8

File tree

4 files changed

+118
-23
lines changed

4 files changed

+118
-23
lines changed

datafusion/core/src/physical_planner.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,14 +1143,12 @@ impl DefaultPhysicalPlanner {
11431143
.iter()
11441144
.map(|dynamic_column| {
11451145
let column = dynamic_column.column();
1146-
let index = join.schema().index_of(column.name())?;
1147-
let physical_column = Arc::new(
1148-
datafusion_physical_expr::expressions::Column::new(
1149-
&column.name,
1150-
index,
1151-
),
1152-
);
1153-
let build_side_name = dynamic_column.build_name().to_owned();
1146+
let index =
1147+
join.schema().index_of(column.name())?;
1148+
let physical_column =
1149+
Arc::new(Column::new(&column.name, index));
1150+
let build_side_name =
1151+
dynamic_column.build_name().to_owned();
11541152
Ok((physical_column, build_side_name))
11551153
})
11561154
.collect::<Result<_, DataFusionError>>()?;

datafusion/physical-plan/src/joins/dynamic_filters.rs

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
// KIND, either express or implied. See the License for the
1515
// specific language governing permissions and limitations
1616
// under the License.
17-
18-
use arrow::array::AsArray;
17+
use arrow::array::{AsArray, BooleanBuilder};
1918
use arrow::array::{
2019
Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array,
2120
UInt16Array, UInt32Array, UInt64Array, UInt8Array,
@@ -24,10 +23,10 @@ use arrow::compute::filter_record_batch;
2423
use arrow::compute::kernels::aggregate::{max, max_string, min, min_string};
2524
use arrow::datatypes::DataType;
2625
use arrow::record_batch::RecordBatch;
27-
use arrow_array::Array;
2826
use arrow_array::ArrowNativeTypeOp;
2927
use arrow_array::StringArray;
30-
use datafusion_common::{exec_err, DataFusionError, ScalarValue};
28+
use arrow_array::{Array, ArrayRef};
29+
use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, ScalarValue};
3130
use datafusion_expr::Operator;
3231
use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal};
3332
use datafusion_physical_expr::PhysicalExpr;
@@ -36,6 +35,8 @@ use parking_lot::Mutex;
3635
use std::fmt;
3736
use std::sync::Arc;
3837

38+
use super::utils::JoinHashMap;
39+
3940
pub struct DynamicFilterInfo {
4041
columns: Vec<Arc<Column>>,
4142
build_side_names: Vec<String>,
@@ -184,7 +185,7 @@ impl DynamicFilterInfo {
184185
records: &RecordBatch,
185186
) -> Result<RecordBatch, DataFusionError> {
186187
let filter_expr = match self.inner.lock().final_expr.as_ref() {
187-
Some(expr) => Arc::<dyn datafusion_physical_expr::PhysicalExpr>::clone(expr),
188+
Some(expr) => Arc::<dyn PhysicalExpr>::clone(expr),
188189
None => {
189190
return exec_err!(
190191
"Filter expression should have been created before calling filter_batch"
@@ -354,4 +355,39 @@ impl PartitionedDynamicFilterInfo {
354355
self.dynamic_filter_info
355356
.merge_batch_and_check_finalized(records, self.partition)
356357
}
358+
359+
pub fn filter_probe_batch(
360+
&self,
361+
batch: &RecordBatch,
362+
hashes: &[u64],
363+
hash_map: &JoinHashMap,
364+
) -> Result<(RecordBatch, Vec<u64>), DataFusionError> {
365+
let left_hash_set = hash_map.extract_unique_keys();
366+
367+
let mut mask_builder = BooleanBuilder::new();
368+
for hash in hashes.iter() {
369+
mask_builder.append_value(left_hash_set.contains(hash));
370+
}
371+
let mask = mask_builder.finish();
372+
373+
let filtered_columns = batch
374+
.columns()
375+
.iter()
376+
.map(|col| {
377+
arrow::compute::filter(col, &mask).map_err(|e| arrow_datafusion_err!(e))
378+
})
379+
.collect::<Result<Vec<ArrayRef>, DataFusionError>>()?;
380+
381+
let filtered_batch = RecordBatch::try_new(batch.schema(), filtered_columns)?;
382+
383+
let filtered_hashes = hashes
384+
.iter()
385+
.zip(mask.iter())
386+
.filter_map(|(hash, keep)| {
387+
keep.and_then(|k| if k { Some(*hash) } else { None })
388+
})
389+
.collect();
390+
391+
Ok((filtered_batch, filtered_hashes))
392+
}
357393
}

datafusion/physical-plan/src/joins/hash_join.rs

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,6 @@ use crate::{
4949
Partitioning, PlanProperties, RecordBatchStream, SendableRecordBatchStream,
5050
Statistics,
5151
};
52-
use std::fmt;
53-
use std::sync::atomic::{AtomicUsize, Ordering};
54-
use std::sync::Arc;
55-
use std::task::Poll;
56-
use std::{any::Any, vec};
5752

5853
use arrow::array::{
5954
Array, ArrayRef, BooleanArray, BooleanBufferBuilder, UInt32Array, UInt64Array,
@@ -1395,7 +1390,9 @@ impl HashJoinStream {
13951390
self.state = HashJoinStreamState::ExhaustedProbeSide;
13961391
}
13971392
Some(Ok(batch)) => {
1398-
// Precalculate hash values for fetched batch
1393+
let left_data = Arc::<JoinLeftData>::clone(
1394+
&self.build_side.try_as_ready()?.left_data,
1395+
);
13991396
let keys_values = self
14001397
.on_right
14011398
.iter()
@@ -1406,12 +1403,24 @@ impl HashJoinStream {
14061403
self.hashes_buffer.resize(batch.num_rows(), 0);
14071404
create_hashes(&keys_values, &self.random_state, &mut self.hashes_buffer)?;
14081405

1406+
let (filtered_batch, filtered_hashes) =
1407+
if let Some(dynamic_filter) = &self.dynamic_filter_info {
1408+
dynamic_filter.filter_probe_batch(
1409+
&batch,
1410+
&self.hashes_buffer,
1411+
&left_data.hash_map,
1412+
)?
1413+
} else {
1414+
(batch, std::mem::take(&mut self.hashes_buffer))
1415+
};
1416+
14091417
self.join_metrics.input_batches.add(1);
1410-
self.join_metrics.input_rows.add(batch.num_rows());
1418+
self.join_metrics.input_rows.add(filtered_batch.num_rows());
14111419

1420+
self.hashes_buffer = filtered_hashes;
14121421
self.state =
14131422
HashJoinStreamState::ProcessProbeBatch(ProcessProbeBatchState {
1414-
batch,
1423+
batch: filtered_batch,
14151424
offset: (0, None),
14161425
joined_probe_idx: None,
14171426
});

datafusion/physical-plan/src/joins/utils.rs

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,18 @@ impl JoinHashMap {
140140
next: vec![0; capacity],
141141
}
142142
}
143+
144+
/// extract all unique keys of this join hash map
145+
pub fn extract_unique_keys(&self) -> HashSet<u64> {
146+
let mut unique_keys = HashSet::new();
147+
unsafe {
148+
self.map.iter().for_each(|entry| {
149+
let (hash, _) = entry.as_ref();
150+
unique_keys.insert(hash.to_owned());
151+
})
152+
};
153+
unique_keys
154+
}
143155
}
144156

145157
// Type of offsets for obtaining indices from JoinHashMap.
@@ -371,8 +383,48 @@ impl JoinHashMapType for JoinHashMap {
371383
}
372384

373385
impl Debug for JoinHashMap {
374-
fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result {
375-
Ok(())
386+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
387+
writeln!(f, "JoinHashMap {{")?;
388+
writeln!(f, " map:")?;
389+
writeln!(f, " ----------")?;
390+
391+
let mut entries: Vec<_> = unsafe { self.map.iter().collect() };
392+
entries.sort_by_key(|bucket| unsafe { bucket.as_ref().0 });
393+
394+
for bucket in entries {
395+
let mut indices = Vec::new();
396+
let mut curr_idx = unsafe { bucket.as_ref().1 };
397+
398+
while curr_idx > 0 {
399+
indices.push(curr_idx - 1);
400+
curr_idx = self.next[(curr_idx - 1) as usize];
401+
}
402+
403+
indices.reverse();
404+
405+
writeln!(
406+
f,
407+
" | {:3} | {} | -> {:?}",
408+
unsafe { bucket.as_ref().0 },
409+
unsafe { bucket.as_ref().1 },
410+
indices
411+
)?;
412+
}
413+
414+
writeln!(f, " ----------")?;
415+
writeln!(f, "\n next:")?;
416+
writeln!(f, " ---------------------")?;
417+
write!(f, " |")?;
418+
for &next_idx in self.next.iter() {
419+
write!(f, " {:2} |", next_idx)?;
420+
}
421+
writeln!(f)?;
422+
write!(f, " |")?;
423+
for i in 0..self.next.len() {
424+
write!(f, " {:2} |", i)?;
425+
}
426+
writeln!(f, "\n ---------------------")?;
427+
writeln!(f, "}}")
376428
}
377429
}
378430

0 commit comments

Comments
 (0)