Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions native-engine/datafusion-ext-plans/src/joins/smj/full_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ impl<const L_OUTER: bool, const R_OUTER: bool> FullJoiner<L_OUTER, R_OUTER> {
self.lindices.len() >= self.join_params.batch_size
}

fn has_enough_room(&self, new_size: usize) -> bool {
self.lindices.len() + new_size <= self.join_params.batch_size
}

async fn flush(
mut self: Pin<&mut Self>,
cur1: &mut StreamCursor,
Expand Down Expand Up @@ -160,9 +164,26 @@ impl<const L_OUTER: bool, const R_OUTER: bool> Joiner for FullJoiner<L_OUTER, R_
continue;
}

for (&lidx, &ridx) in equal_lindices.iter().cartesian_product(&equal_rindices) {
self.lindices.push(lidx);
self.rindices.push(ridx);
let new_size = equal_lindices.len() * equal_rindices.len();
if self.has_enough_room(new_size) {
// old cartesian_product way
for (&lidx, &ridx) in
equal_lindices.iter().cartesian_product(&equal_rindices)
{
self.lindices.push(lidx);
self.rindices.push(ridx);
}
} else {
// do more aggressive flush
for &lidx in &equal_lindices {
for &ridx in &equal_rindices {
self.lindices.push(lidx);
self.rindices.push(ridx);
if self.should_flush() {
self.as_mut().flush(cur1, cur2).await?;
}
}
}
}

if r_equal {
Expand Down
199 changes: 198 additions & 1 deletion native-engine/datafusion-ext-plans/src/joins/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ mod tests {
error::Result,
physical_expr::expressions::Column,
physical_plan::{ExecutionPlan, common, joins::utils::*, test::TestMemoryExec},
prelude::SessionContext,
prelude::{SessionConfig, SessionContext},
};

use crate::{
Expand Down Expand Up @@ -264,6 +264,91 @@ mod tests {
Ok((columns, batches))
}

async fn join_collect_with_batch_size(
test_type: TestType,
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
on: JoinOn,
join_type: JoinType,
batch_size: usize,
) -> Result<(Vec<String>, Vec<RecordBatch>)> {
MemManager::init(1000000);
let session_config = SessionConfig::new().with_batch_size(batch_size);
let session_ctx = SessionContext::new_with_config(session_config);
let task_ctx = session_ctx.task_ctx();
let schema = build_join_schema_for_test(&left.schema(), &right.schema(), join_type)?;

let join: Arc<dyn ExecutionPlan> = match test_type {
SMJ => {
let sort_options = vec![SortOptions::default(); on.len()];
Arc::new(SortMergeJoinExec::try_new(
schema,
left,
right,
on,
join_type,
sort_options,
)?)
}
BHJLeftProbed => {
let right = Arc::new(BroadcastJoinBuildHashMapExec::new(
right,
on.iter().map(|(_, right_key)| right_key.clone()).collect(),
));
Arc::new(BroadcastJoinExec::try_new(
schema,
left,
right,
on,
join_type,
JoinSide::Right,
true,
None,
)?)
}
BHJRightProbed => {
let left = Arc::new(BroadcastJoinBuildHashMapExec::new(
left,
on.iter().map(|(left_key, _)| left_key.clone()).collect(),
));
Arc::new(BroadcastJoinExec::try_new(
schema,
left,
right,
on,
join_type,
JoinSide::Left,
true,
None,
)?)
}
SHJLeftProbed => Arc::new(BroadcastJoinExec::try_new(
schema,
left,
right,
on,
join_type,
JoinSide::Right,
false,
None,
)?),
SHJRightProbed => Arc::new(BroadcastJoinExec::try_new(
schema,
left,
right,
on,
join_type,
JoinSide::Left,
false,
None,
)?),
};
let columns = columns(&join.schema());
let stream = join.execute(0, task_ctx)?;
let batches = common::collect(stream).await?;
Ok((columns, batches))
}

const ALL_TEST_TYPE: [TestType; 5] = [
SMJ,
BHJLeftProbed,
Expand Down Expand Up @@ -428,6 +513,118 @@ mod tests {
Ok(())
}

#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn join_inner_batchsize() -> Result<()> {
for test_type in ALL_TEST_TYPE {
let left = build_table(
("a1", &vec![1, 1, 1, 1, 1]),
("b1", &vec![1, 2, 3, 4, 5]),
("c1", &vec![1, 2, 3, 4, 5]),
);
let right = build_table(
("a2", &vec![1, 1, 1, 1, 1, 1, 1]),
("b2", &vec![1, 2, 3, 4, 5, 6, 7]),
("c2", &vec![1, 2, 3, 4, 5, 6, 7]),
);
let on: JoinOn = vec![(
Arc::new(Column::new_with_schema("a1", &left.schema())?),
Arc::new(Column::new_with_schema("a2", &right.schema())?),
)];
let expected = vec![
"+----+----+----+----+----+----+",
"| a1 | b1 | c1 | a2 | b2 | c2 |",
"+----+----+----+----+----+----+",
"| 1 | 1 | 1 | 1 | 1 | 1 |",
"| 1 | 1 | 1 | 1 | 2 | 2 |",
"| 1 | 1 | 1 | 1 | 3 | 3 |",
"| 1 | 1 | 1 | 1 | 4 | 4 |",
"| 1 | 1 | 1 | 1 | 5 | 5 |",
"| 1 | 1 | 1 | 1 | 6 | 6 |",
"| 1 | 1 | 1 | 1 | 7 | 7 |",
"| 1 | 2 | 2 | 1 | 1 | 1 |",
"| 1 | 2 | 2 | 1 | 2 | 2 |",
"| 1 | 2 | 2 | 1 | 3 | 3 |",
"| 1 | 2 | 2 | 1 | 4 | 4 |",
"| 1 | 2 | 2 | 1 | 5 | 5 |",
"| 1 | 2 | 2 | 1 | 6 | 6 |",
"| 1 | 2 | 2 | 1 | 7 | 7 |",
"| 1 | 3 | 3 | 1 | 1 | 1 |",
"| 1 | 3 | 3 | 1 | 2 | 2 |",
"| 1 | 3 | 3 | 1 | 3 | 3 |",
"| 1 | 3 | 3 | 1 | 4 | 4 |",
"| 1 | 3 | 3 | 1 | 5 | 5 |",
"| 1 | 3 | 3 | 1 | 6 | 6 |",
"| 1 | 3 | 3 | 1 | 7 | 7 |",
"| 1 | 4 | 4 | 1 | 1 | 1 |",
"| 1 | 4 | 4 | 1 | 2 | 2 |",
"| 1 | 4 | 4 | 1 | 3 | 3 |",
"| 1 | 4 | 4 | 1 | 4 | 4 |",
"| 1 | 4 | 4 | 1 | 5 | 5 |",
"| 1 | 4 | 4 | 1 | 6 | 6 |",
"| 1 | 4 | 4 | 1 | 7 | 7 |",
"| 1 | 5 | 5 | 1 | 1 | 1 |",
"| 1 | 5 | 5 | 1 | 2 | 2 |",
"| 1 | 5 | 5 | 1 | 3 | 3 |",
"| 1 | 5 | 5 | 1 | 4 | 4 |",
"| 1 | 5 | 5 | 1 | 5 | 5 |",
"| 1 | 5 | 5 | 1 | 6 | 6 |",
"| 1 | 5 | 5 | 1 | 7 | 7 |",
"+----+----+----+----+----+----+",
];
let (_, batches) = join_collect_with_batch_size(
test_type,
left.clone(),
right.clone(),
on.clone(),
Inner,
2,
)
.await?;
assert_batches_sorted_eq!(expected, &batches);
let (_, batches) = join_collect_with_batch_size(
test_type,
left.clone(),
right.clone(),
on.clone(),
Inner,
3,
)
.await?;
assert_batches_sorted_eq!(expected, &batches);
let (_, batches) = join_collect_with_batch_size(
test_type,
left.clone(),
right.clone(),
on.clone(),
Inner,
4,
)
.await?;
assert_batches_sorted_eq!(expected, &batches);
let (_, batches) = join_collect_with_batch_size(
test_type,
left.clone(),
right.clone(),
on.clone(),
Inner,
5,
)
.await?;
assert_batches_sorted_eq!(expected, &batches);
let (_, batches) = join_collect_with_batch_size(
test_type,
left.clone(),
right.clone(),
on.clone(),
Inner,
7,
)
.await?;
assert_batches_sorted_eq!(expected, &batches);
}
Ok(())
}

#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn join_left_one() -> Result<()> {
for test_type in ALL_TEST_TYPE {
Expand Down