Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Support null safe equal in joins #3161

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1549,6 +1549,7 @@ class PyMicroPartition:
right: PyMicroPartition,
left_on: list[PyExpr],
right_on: list[PyExpr],
null_equals_nulls: list[bool] | None,
how: JoinType,
) -> PyMicroPartition: ...
def pivot(
Expand Down
2 changes: 2 additions & 0 deletions daft/execution/execution_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,7 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata])
class HashJoin(SingleOutputInstruction):
left_on: ExpressionsProjection
right_on: ExpressionsProjection
null_equals_nulls: list[bool] | None
how: JoinType
is_swapped: bool

Expand All @@ -810,6 +811,7 @@ def _hash_join(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
right,
left_on=self.left_on,
right_on=self.right_on,
null_equals_nulls=self.null_equals_nulls,
how=self.how,
)
return [result]
Expand Down
15 changes: 14 additions & 1 deletion daft/execution/physical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ def hash_join(
right_plan: InProgressPhysicalPlan[PartitionT],
left_on: ExpressionsProjection,
right_on: ExpressionsProjection,
null_equals_nulls: None | list[bool],
how: JoinType,
) -> InProgressPhysicalPlan[PartitionT]:
"""Hash-based pairwise join the partitions from `left_child_plan` and `right_child_plan` together."""
Expand Down Expand Up @@ -387,6 +388,7 @@ def hash_join(
instruction=execution_step.HashJoin(
left_on=left_on,
right_on=right_on,
null_equals_nulls=null_equals_nulls,
how=how,
is_swapped=False,
)
Expand Down Expand Up @@ -432,6 +434,7 @@ def _create_broadcast_join_step(
receiver_part: SingleOutputPartitionTask[PartitionT],
left_on: ExpressionsProjection,
right_on: ExpressionsProjection,
null_equals_nulls: None | list[bool],
how: JoinType,
is_swapped: bool,
) -> PartitionTaskBuilder[PartitionT]:
Expand Down Expand Up @@ -477,6 +480,7 @@ def _create_broadcast_join_step(
instruction=execution_step.BroadcastJoin(
left_on=left_on,
right_on=right_on,
null_equals_nulls=null_equals_nulls,
how=how,
is_swapped=is_swapped,
)
Expand All @@ -488,6 +492,7 @@ def broadcast_join(
receiver_plan: InProgressPhysicalPlan[PartitionT],
left_on: ExpressionsProjection,
right_on: ExpressionsProjection,
null_equals_nulls: None | list[bool],
how: JoinType,
is_swapped: bool,
) -> InProgressPhysicalPlan[PartitionT]:
Expand Down Expand Up @@ -530,7 +535,15 @@ def broadcast_join(
# Broadcast all broadcaster partitions to each new receiver partition that was materialized on this dispatch loop.
while receiver_requests and receiver_requests[0].done():
receiver_part = receiver_requests.popleft()
yield _create_broadcast_join_step(broadcaster_parts, receiver_part, left_on, right_on, how, is_swapped)
yield _create_broadcast_join_step(
broadcaster_parts,
receiver_part,
left_on,
right_on,
null_equals_nulls,
how,
is_swapped,
)

# Execute single child step to pull in more input partitions.
try:
Expand Down
4 changes: 4 additions & 0 deletions daft/execution/rust_physical_plan_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def hash_join(
right: physical_plan.InProgressPhysicalPlan[PartitionT],
left_on: list[PyExpr],
right_on: list[PyExpr],
null_equals_nulls: list[bool] | None,
join_type: JoinType,
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
left_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in left_on])
Expand All @@ -253,6 +254,7 @@ def hash_join(
left_on=left_on_expr_proj,
right_on=right_on_expr_proj,
how=join_type,
null_equals_nulls=null_equals_nulls,
)


Expand Down Expand Up @@ -303,6 +305,7 @@ def broadcast_join(
receiver: physical_plan.InProgressPhysicalPlan[PartitionT],
left_on: list[PyExpr],
right_on: list[PyExpr],
null_equals_nulls: list[bool] | None,
join_type: JoinType,
is_swapped: bool,
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
Expand All @@ -315,6 +318,7 @@ def broadcast_join(
right_on=right_on_expr_proj,
how=join_type,
is_swapped=is_swapped,
null_equals_nulls=null_equals_nulls,
)


Expand Down
9 changes: 8 additions & 1 deletion daft/table/micropartition.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def hash_join(
right: MicroPartition,
left_on: ExpressionsProjection,
right_on: ExpressionsProjection,
null_equals_nulls: list[bool] | None = None,
how: JoinType = JoinType.Inner,
) -> MicroPartition:
if len(left_on) != len(right_on):
Expand All @@ -262,7 +263,13 @@ def hash_join(
right_exprs = [e._expr for e in right_on]

return MicroPartition._from_pymicropartition(
self._micropartition.hash_join(right._micropartition, left_on=left_exprs, right_on=right_exprs, how=how)
self._micropartition.hash_join(
right._micropartition,
left_on=left_exprs,
right_on=right_exprs,
null_equals_nulls=null_equals_nulls,
how=how,
)
)

def sort_merge_join(
Expand Down
10 changes: 5 additions & 5 deletions src/daft-core/src/array/ops/arrow2/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,17 @@ pub fn build_is_equal(
pub fn build_multi_array_is_equal(
left: &[Series],
right: &[Series],
nulls_equal: bool,
nan_equal: bool,
nulls_equal: &[bool],
nans_equal: &[bool],
) -> DaftResult<Box<dyn Fn(usize, usize) -> bool + Send + Sync>> {
let mut fn_list = Vec::with_capacity(left.len());

for (l, r) in left.iter().zip(right.iter()) {
for (idx, (l, r)) in left.iter().zip(right.iter()).enumerate() {
fn_list.push(build_is_equal(
l.to_arrow().as_ref(),
r.to_arrow().as_ref(),
nulls_equal,
nan_equal,
nulls_equal[idx],
nans_equal[idx],
)?);
}

Expand Down
8 changes: 7 additions & 1 deletion src/daft-micropartition/src/ops/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,17 @@ impl MicroPartition {
right: &Self,
left_on: &[ExprRef],
right_on: &[ExprRef],
null_equals_nulls: Option<Vec<bool>>,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure whether this is the right type definition or not, do we need to define it as something like Option<&[bool]?

how: JoinType,
) -> DaftResult<Self> {
let io_stats = IOStatsContext::new("MicroPartition::hash_join");
let null_equals_nulls = null_equals_nulls.unwrap_or_else(|| vec![false; left_on.len()]);
let table_join =
|lt: &Table, rt: &Table, lo: &[ExprRef], ro: &[ExprRef], _how: JoinType| {
Table::hash_join(lt, rt, lo, ro, null_equals_nulls.as_slice(), _how)
};

self.join(right, io_stats, left_on, right_on, how, Table::hash_join)
self.join(right, io_stats, left_on, right_on, how, table_join)
}

pub fn sort_merge_join(
Expand Down
2 changes: 2 additions & 0 deletions src/daft-micropartition/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ impl PyMicroPartition {
left_on: Vec<PyExpr>,
right_on: Vec<PyExpr>,
how: JoinType,
null_equals_nulls: Option<Vec<bool>>,
) -> PyResult<Self> {
let left_exprs: Vec<daft_dsl::ExprRef> =
left_on.into_iter().map(std::convert::Into::into).collect();
Expand All @@ -272,6 +273,7 @@ impl PyMicroPartition {
&right.inner,
left_exprs.as_slice(),
right_exprs.as_slice(),
null_equals_nulls,
how,
)?
.into())
Expand Down
25 changes: 25 additions & 0 deletions src/daft-plan/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -460,12 +460,37 @@ impl LogicalPlanBuilder {
join_strategy: Option<JoinStrategy>,
join_suffix: Option<&str>,
join_prefix: Option<&str>,
) -> DaftResult<Self> {
self.join_with_null_safe_equal(
right,
left_on,
right_on,
None,
join_type,
join_strategy,
join_suffix,
join_prefix,
)
}

#[allow(clippy::too_many_arguments)]
pub fn join_with_null_safe_equal<Right: Into<LogicalPlanRef>>(
&self,
right: Right,
left_on: Vec<ExprRef>,
right_on: Vec<ExprRef>,
null_equals_nulls: Option<Vec<bool>>,
join_type: JoinType,
join_strategy: Option<JoinStrategy>,
join_suffix: Option<&str>,
join_prefix: Option<&str>,
) -> DaftResult<Self> {
let logical_plan: LogicalPlan = logical_ops::Join::try_new(
self.plan.clone(),
right.into(),
left_on,
right_on,
null_equals_nulls,
join_type,
join_strategy,
join_suffix,
Expand Down
3 changes: 2 additions & 1 deletion src/daft-plan/src/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,11 @@ Project1 --> Limit0
.build();

let plan = LogicalPlanBuilder::new(subplan, None)
.join(
.join_with_null_safe_equal(
subplan2,
vec![col("id")],
vec![col("id")],
Some(vec![true]),
JoinType::Inner,
None,
None,
Expand Down
21 changes: 21 additions & 0 deletions src/daft-plan/src/logical_ops/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub struct Join {

pub left_on: Vec<ExprRef>,
pub right_on: Vec<ExprRef>,
pub null_equals_nulls: Option<Vec<bool>>,
pub join_type: JoinType,
pub join_strategy: Option<JoinStrategy>,
pub output_schema: SchemaRef,
Expand All @@ -40,6 +41,7 @@ impl std::hash::Hash for Join {
std::hash::Hash::hash(&self.right, state);
std::hash::Hash::hash(&self.left_on, state);
std::hash::Hash::hash(&self.right_on, state);
std::hash::Hash::hash(&self.null_equals_nulls, state);
std::hash::Hash::hash(&self.join_type, state);
std::hash::Hash::hash(&self.join_strategy, state);
std::hash::Hash::hash(&self.output_schema, state);
Expand All @@ -53,6 +55,7 @@ impl Join {
right: Arc<LogicalPlan>,
left_on: Vec<ExprRef>,
right_on: Vec<ExprRef>,
null_equals_nulls: Option<Vec<bool>>,
join_type: JoinType,
join_strategy: Option<JoinStrategy>,
join_suffix: Option<&str>,
Expand Down Expand Up @@ -92,6 +95,16 @@ impl Join {
}
}

if let Some(null_equals_null) = &null_equals_nulls {
if null_equals_null.len() != left_on.len() {
return Err(DaftError::ValueError(
"null_equals_nulls must have the same length as left_on or right_on"
.to_string(),
))
.context(CreationSnafu);
}
}

if matches!(join_type, JoinType::Anti | JoinType::Semi) {
// The output schema is the same as the left input schema for anti and semi joins.

Expand All @@ -102,6 +115,7 @@ impl Join {
right,
left_on,
right_on,
null_equals_nulls,
join_type,
join_strategy,
output_schema,
Expand Down Expand Up @@ -188,6 +202,7 @@ impl Join {
right,
left_on,
right_on,
null_equals_nulls,
join_type,
join_strategy,
output_schema,
Expand Down Expand Up @@ -276,6 +291,12 @@ impl Join {
));
}
}
if let Some(null_equals_nulls) = &self.null_equals_nulls {
res.push(format!(
"Null equals Nulls = [{}]",
null_equals_nulls.iter().map(|b| b.to_string()).join(", ")
));
}
res.push(format!(
"Output schema = {}",
self.output_schema.short_string()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ fn find_inner_join(
left: left_input,
right: right_input,
left_on: left_keys,
null_equals_nulls: None,
right_on: right_keys,
join_type: JoinType::Inner,
join_strategy: None,
Expand All @@ -327,6 +328,7 @@ fn find_inner_join(
right,
left_on: vec![],
right_on: vec![],
null_equals_nulls: None,
join_type: JoinType::Inner,
join_strategy: None,
output_schema: Arc::new(join_schema),
Expand Down
Loading
Loading