Skip to content

Commit 18c581c

Browse files
author
QP Hou
authored
fix join column handling logic for On and Using constraints (#605)
* fix join column handling logic for `On` and `Using` constraints * handling join column expansion during USING JOIN planning get rid of shared field and move column expansion logic into plan builder and optimizer. * add more comments & fix clippy * add more comment * reduce duplicate code in join predicate pushdown
1 parent 3664766 commit 18c581c

File tree

23 files changed

+836
-500
lines changed

23 files changed

+836
-500
lines changed

ballista/rust/core/proto/ballista.proto

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,12 +378,18 @@ enum JoinType {
378378
ANTI = 5;
379379
}
380380

381+
enum JoinConstraint {
382+
ON = 0;
383+
USING = 1;
384+
}
385+
381386
message JoinNode {
382387
LogicalPlanNode left = 1;
383388
LogicalPlanNode right = 2;
384389
JoinType join_type = 3;
385-
repeated Column left_join_column = 4;
386-
repeated Column right_join_column = 5;
390+
JoinConstraint join_constraint = 4;
391+
repeated Column left_join_column = 5;
392+
repeated Column right_join_column = 6;
387393
}
388394

389395
message LimitNode {

ballista/rust/core/src/serde/logical_plan/from_proto.rs

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ use datafusion::logical_plan::window_frames::{
2626
};
2727
use datafusion::logical_plan::{
2828
abs, acos, asin, atan, ceil, cos, exp, floor, ln, log10, log2, round, signum, sin,
29-
sqrt, tan, trunc, Column, DFField, DFSchema, Expr, JoinType, LogicalPlan,
30-
LogicalPlanBuilder, Operator,
29+
sqrt, tan, trunc, Column, DFField, DFSchema, Expr, JoinConstraint, JoinType,
30+
LogicalPlan, LogicalPlanBuilder, Operator,
3131
};
3232
use datafusion::physical_plan::aggregates::AggregateFunction;
3333
use datafusion::physical_plan::csv::CsvReadOptions;
@@ -257,23 +257,32 @@ impl TryInto<LogicalPlan> for &protobuf::LogicalPlanNode {
257257
join.join_type
258258
))
259259
})?;
260-
let join_type = match join_type {
261-
protobuf::JoinType::Inner => JoinType::Inner,
262-
protobuf::JoinType::Left => JoinType::Left,
263-
protobuf::JoinType::Right => JoinType::Right,
264-
protobuf::JoinType::Full => JoinType::Full,
265-
protobuf::JoinType::Semi => JoinType::Semi,
266-
protobuf::JoinType::Anti => JoinType::Anti,
267-
};
268-
LogicalPlanBuilder::from(convert_box_required!(join.left)?)
269-
.join(
260+
let join_constraint = protobuf::JoinConstraint::from_i32(
261+
join.join_constraint,
262+
)
263+
.ok_or_else(|| {
264+
proto_error(format!(
265+
"Received a JoinNode message with unknown JoinConstraint {}",
266+
join.join_constraint
267+
))
268+
})?;
269+
270+
let builder = LogicalPlanBuilder::from(convert_box_required!(join.left)?);
271+
let builder = match join_constraint.into() {
272+
JoinConstraint::On => builder.join(
270273
&convert_box_required!(join.right)?,
271-
join_type,
274+
join_type.into(),
272275
left_keys,
273276
right_keys,
274-
)?
275-
.build()
276-
.map_err(|e| e.into())
277+
)?,
278+
JoinConstraint::Using => builder.join_using(
279+
&convert_box_required!(join.right)?,
280+
join_type.into(),
281+
left_keys,
282+
)?,
283+
};
284+
285+
builder.build().map_err(|e| e.into())
277286
}
278287
}
279288
}

ballista/rust/core/src/serde/logical_plan/to_proto.rs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUn
2626
use datafusion::datasource::CsvFile;
2727
use datafusion::logical_plan::{
2828
window_frames::{WindowFrame, WindowFrameBound, WindowFrameUnits},
29-
Column, Expr, JoinType, LogicalPlan,
29+
Column, Expr, JoinConstraint, JoinType, LogicalPlan,
3030
};
3131
use datafusion::physical_plan::aggregates::AggregateFunction;
3232
use datafusion::physical_plan::functions::BuiltinScalarFunction;
@@ -804,26 +804,23 @@ impl TryInto<protobuf::LogicalPlanNode> for &LogicalPlan {
804804
right,
805805
on,
806806
join_type,
807+
join_constraint,
807808
..
808809
} => {
809810
let left: protobuf::LogicalPlanNode = left.as_ref().try_into()?;
810811
let right: protobuf::LogicalPlanNode = right.as_ref().try_into()?;
811-
let join_type = match join_type {
812-
JoinType::Inner => protobuf::JoinType::Inner,
813-
JoinType::Left => protobuf::JoinType::Left,
814-
JoinType::Right => protobuf::JoinType::Right,
815-
JoinType::Full => protobuf::JoinType::Full,
816-
JoinType::Semi => protobuf::JoinType::Semi,
817-
JoinType::Anti => protobuf::JoinType::Anti,
818-
};
819812
let (left_join_column, right_join_column) =
820813
on.iter().map(|(l, r)| (l.into(), r.into())).unzip();
814+
let join_type: protobuf::JoinType = join_type.to_owned().into();
815+
let join_constraint: protobuf::JoinConstraint =
816+
join_constraint.to_owned().into();
821817
Ok(protobuf::LogicalPlanNode {
822818
logical_plan_type: Some(LogicalPlanType::Join(Box::new(
823819
protobuf::JoinNode {
824820
left: Some(Box::new(left)),
825821
right: Some(Box::new(right)),
826822
join_type: join_type.into(),
823+
join_constraint: join_constraint.into(),
827824
left_join_column,
828825
right_join_column,
829826
},

ballista/rust/core/src/serde/mod.rs

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
2121
use std::{convert::TryInto, io::Cursor};
2222

23-
use datafusion::logical_plan::Operator;
23+
use datafusion::logical_plan::{JoinConstraint, JoinType, Operator};
2424
use datafusion::physical_plan::aggregates::AggregateFunction;
2525
use datafusion::physical_plan::window_functions::BuiltInWindowFunction;
2626

@@ -291,3 +291,47 @@ impl Into<datafusion::arrow::datatypes::DataType> for protobuf::PrimitiveScalarT
291291
}
292292
}
293293
}
294+
295+
impl From<protobuf::JoinType> for JoinType {
296+
fn from(t: protobuf::JoinType) -> Self {
297+
match t {
298+
protobuf::JoinType::Inner => JoinType::Inner,
299+
protobuf::JoinType::Left => JoinType::Left,
300+
protobuf::JoinType::Right => JoinType::Right,
301+
protobuf::JoinType::Full => JoinType::Full,
302+
protobuf::JoinType::Semi => JoinType::Semi,
303+
protobuf::JoinType::Anti => JoinType::Anti,
304+
}
305+
}
306+
}
307+
308+
impl From<JoinType> for protobuf::JoinType {
309+
fn from(t: JoinType) -> Self {
310+
match t {
311+
JoinType::Inner => protobuf::JoinType::Inner,
312+
JoinType::Left => protobuf::JoinType::Left,
313+
JoinType::Right => protobuf::JoinType::Right,
314+
JoinType::Full => protobuf::JoinType::Full,
315+
JoinType::Semi => protobuf::JoinType::Semi,
316+
JoinType::Anti => protobuf::JoinType::Anti,
317+
}
318+
}
319+
}
320+
321+
impl From<protobuf::JoinConstraint> for JoinConstraint {
322+
fn from(t: protobuf::JoinConstraint) -> Self {
323+
match t {
324+
protobuf::JoinConstraint::On => JoinConstraint::On,
325+
protobuf::JoinConstraint::Using => JoinConstraint::Using,
326+
}
327+
}
328+
}
329+
330+
impl From<JoinConstraint> for protobuf::JoinConstraint {
331+
fn from(t: JoinConstraint) -> Self {
332+
match t {
333+
JoinConstraint::On => protobuf::JoinConstraint::On,
334+
JoinConstraint::Using => protobuf::JoinConstraint::Using,
335+
}
336+
}
337+
}

ballista/rust/core/src/serde/physical_plan/from_proto.rs

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ use datafusion::catalog::catalog::{
3535
use datafusion::execution::context::{
3636
ExecutionConfig, ExecutionContextState, ExecutionProps,
3737
};
38-
use datafusion::logical_plan::{window_frames::WindowFrame, DFSchema, Expr};
38+
use datafusion::logical_plan::{
39+
window_frames::WindowFrame, DFSchema, Expr, JoinConstraint, JoinType,
40+
};
3941
use datafusion::physical_plan::aggregates::{create_aggregate_expr, AggregateFunction};
4042
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
4143
use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec};
@@ -57,7 +59,6 @@ use datafusion::physical_plan::{
5759
filter::FilterExec,
5860
functions::{self, BuiltinScalarFunction, ScalarFunctionExpr},
5961
hash_join::HashJoinExec,
60-
hash_utils::JoinType,
6162
limit::{GlobalLimitExec, LocalLimitExec},
6263
parquet::ParquetExec,
6364
projection::ProjectionExec,
@@ -348,14 +349,7 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
348349
hashjoin.join_type
349350
))
350351
})?;
351-
let join_type = match join_type {
352-
protobuf::JoinType::Inner => JoinType::Inner,
353-
protobuf::JoinType::Left => JoinType::Left,
354-
protobuf::JoinType::Right => JoinType::Right,
355-
protobuf::JoinType::Full => JoinType::Full,
356-
protobuf::JoinType::Semi => JoinType::Semi,
357-
protobuf::JoinType::Anti => JoinType::Anti,
358-
};
352+
359353
let partition_mode =
360354
protobuf::PartitionMode::from_i32(hashjoin.partition_mode)
361355
.ok_or_else(|| {
@@ -372,7 +366,7 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
372366
left,
373367
right,
374368
on,
375-
&join_type,
369+
&join_type.into(),
376370
partition_mode,
377371
)?))
378372
}

ballista/rust/core/src/serde/physical_plan/mod.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,14 @@ mod roundtrip_tests {
2727
compute::kernels::sort::SortOptions,
2828
datatypes::{DataType, Field, Schema},
2929
},
30-
logical_plan::Operator,
30+
logical_plan::{JoinType, Operator},
3131
physical_plan::{
3232
empty::EmptyExec,
3333
expressions::{binary, col, lit, InListExpr, NotExpr},
3434
expressions::{Avg, Column, PhysicalSortExpr},
3535
filter::FilterExec,
3636
hash_aggregate::{AggregateMode, HashAggregateExec},
3737
hash_join::{HashJoinExec, PartitionMode},
38-
hash_utils::JoinType,
3938
limit::{GlobalLimitExec, LocalLimitExec},
4039
sort::SortExec,
4140
AggregateExpr, ColumnarValue, Distribution, ExecutionPlan, Partitioning,

ballista/rust/core/src/serde/physical_plan/to_proto.rs

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ use std::{
2626
sync::Arc,
2727
};
2828

29+
use datafusion::logical_plan::JoinType;
2930
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
3031
use datafusion::physical_plan::csv::CsvExec;
3132
use datafusion::physical_plan::expressions::{
@@ -35,7 +36,6 @@ use datafusion::physical_plan::expressions::{CastExpr, TryCastExpr};
3536
use datafusion::physical_plan::filter::FilterExec;
3637
use datafusion::physical_plan::hash_aggregate::AggregateMode;
3738
use datafusion::physical_plan::hash_join::{HashJoinExec, PartitionMode};
38-
use datafusion::physical_plan::hash_utils::JoinType;
3939
use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
4040
use datafusion::physical_plan::parquet::ParquetExec;
4141
use datafusion::physical_plan::projection::ProjectionExec;
@@ -135,18 +135,13 @@ impl TryInto<protobuf::PhysicalPlanNode> for Arc<dyn ExecutionPlan> {
135135
}),
136136
})
137137
.collect();
138-
let join_type = match exec.join_type() {
139-
JoinType::Inner => protobuf::JoinType::Inner,
140-
JoinType::Left => protobuf::JoinType::Left,
141-
JoinType::Right => protobuf::JoinType::Right,
142-
JoinType::Full => protobuf::JoinType::Full,
143-
JoinType::Semi => protobuf::JoinType::Semi,
144-
JoinType::Anti => protobuf::JoinType::Anti,
145-
};
138+
let join_type: protobuf::JoinType = exec.join_type().to_owned().into();
139+
146140
let partition_mode = match exec.partition_mode() {
147141
PartitionMode::CollectLeft => protobuf::PartitionMode::CollectLeft,
148142
PartitionMode::Partitioned => protobuf::PartitionMode::Partitioned,
149143
};
144+
150145
Ok(protobuf::PhysicalPlanNode {
151146
physical_plan_type: Some(PhysicalPlanType::HashJoin(Box::new(
152147
protobuf::HashJoinExecNode {

benchmarks/queries/q7.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,4 @@ group by
3636
order by
3737
supp_nation,
3838
cust_nation,
39-
l_year;
39+
l_year;

datafusion/src/execution/context.rs

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,6 +1278,96 @@ mod tests {
12781278
Ok(())
12791279
}
12801280

1281+
#[tokio::test]
1282+
async fn left_join_using() -> Result<()> {
1283+
let results = execute(
1284+
"SELECT t1.c1, t2.c2 FROM test t1 JOIN test t2 USING (c2) ORDER BY t2.c2",
1285+
1,
1286+
)
1287+
.await?;
1288+
assert_eq!(results.len(), 1);
1289+
1290+
let expected = vec![
1291+
"+----+----+",
1292+
"| c1 | c2 |",
1293+
"+----+----+",
1294+
"| 0 | 1 |",
1295+
"| 0 | 2 |",
1296+
"| 0 | 3 |",
1297+
"| 0 | 4 |",
1298+
"| 0 | 5 |",
1299+
"| 0 | 6 |",
1300+
"| 0 | 7 |",
1301+
"| 0 | 8 |",
1302+
"| 0 | 9 |",
1303+
"| 0 | 10 |",
1304+
"+----+----+",
1305+
];
1306+
1307+
assert_batches_eq!(expected, &results);
1308+
Ok(())
1309+
}
1310+
1311+
#[tokio::test]
1312+
async fn left_join_using_join_key_projection() -> Result<()> {
1313+
let results = execute(
1314+
"SELECT t1.c1, t1.c2, t2.c2 FROM test t1 JOIN test t2 USING (c2) ORDER BY t2.c2",
1315+
1,
1316+
)
1317+
.await?;
1318+
assert_eq!(results.len(), 1);
1319+
1320+
let expected = vec![
1321+
"+----+----+----+",
1322+
"| c1 | c2 | c2 |",
1323+
"+----+----+----+",
1324+
"| 0 | 1 | 1 |",
1325+
"| 0 | 2 | 2 |",
1326+
"| 0 | 3 | 3 |",
1327+
"| 0 | 4 | 4 |",
1328+
"| 0 | 5 | 5 |",
1329+
"| 0 | 6 | 6 |",
1330+
"| 0 | 7 | 7 |",
1331+
"| 0 | 8 | 8 |",
1332+
"| 0 | 9 | 9 |",
1333+
"| 0 | 10 | 10 |",
1334+
"+----+----+----+",
1335+
];
1336+
1337+
assert_batches_eq!(expected, &results);
1338+
Ok(())
1339+
}
1340+
1341+
#[tokio::test]
1342+
async fn left_join() -> Result<()> {
1343+
let results = execute(
1344+
"SELECT t1.c1, t1.c2, t2.c2 FROM test t1 JOIN test t2 ON t1.c2 = t2.c2 ORDER BY t1.c2",
1345+
1,
1346+
)
1347+
.await?;
1348+
assert_eq!(results.len(), 1);
1349+
1350+
let expected = vec![
1351+
"+----+----+----+",
1352+
"| c1 | c2 | c2 |",
1353+
"+----+----+----+",
1354+
"| 0 | 1 | 1 |",
1355+
"| 0 | 2 | 2 |",
1356+
"| 0 | 3 | 3 |",
1357+
"| 0 | 4 | 4 |",
1358+
"| 0 | 5 | 5 |",
1359+
"| 0 | 6 | 6 |",
1360+
"| 0 | 7 | 7 |",
1361+
"| 0 | 8 | 8 |",
1362+
"| 0 | 9 | 9 |",
1363+
"| 0 | 10 | 10 |",
1364+
"+----+----+----+",
1365+
];
1366+
1367+
assert_batches_eq!(expected, &results);
1368+
Ok(())
1369+
}
1370+
12811371
#[tokio::test]
12821372
async fn window() -> Result<()> {
12831373
let results = execute(

0 commit comments

Comments
 (0)