Skip to content

Commit

Permalink
[FEAT] Support intersect as a DataFrame API (#3134)
Browse files Browse the repository at this point in the history
This commit leverages null safe equal support in joins(see #3069 and
#3161) to support intersect API.

Partially fixes #3122.
  • Loading branch information
advancedxy authored Nov 12, 2024
1 parent e37c6d3 commit a547bd3
Show file tree
Hide file tree
Showing 10 changed files with 221 additions and 0 deletions.
1 change: 1 addition & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1752,6 +1752,7 @@ class LogicalPlanBuilder:
join_suffix: str | None = None,
) -> LogicalPlanBuilder: ...
def concat(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: ...
def intersect(self, other: LogicalPlanBuilder, is_all: bool) -> LogicalPlanBuilder: ...
def add_monotonically_increasing_id(self, column_name: str | None) -> LogicalPlanBuilder: ...
def table_write(
self,
Expand Down
30 changes: 30 additions & 0 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2474,6 +2474,36 @@ def pivot(
builder = self._builder.pivot(group_by_expr, pivot_col_expr, value_col_expr, agg_expr, names)
return DataFrame(builder)

@DataframePublicAPI
def intersect(self, other: "DataFrame") -> "DataFrame":
"""Returns the intersection of two DataFrames.
Example:
>>> import daft
>>> df1 = daft.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]})
>>> df2 = daft.from_pydict({"a": [1, 2, 3], "b": [4, 8, 6]})
>>> df1.intersect(df2).collect()
╭───────┬───────╮
│ a ┆ b │
│ --- ┆ --- │
│ Int64 ┆ Int64 │
╞═══════╪═══════╡
│ 1 ┆ 4 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 3 ┆ 6 │
╰───────┴───────╯
<BLANKLINE>
(Showing first 2 of 2 rows)
Args:
other (DataFrame): DataFrame to intersect with
Returns:
DataFrame: DataFrame with the intersection of the two DataFrames
"""
builder = self._builder.intersect(other._builder)
return DataFrame(builder)

def _materialize_results(self) -> None:
"""Materializes the results of for this DataFrame and hold a pointer to the results."""
context = get_context()
Expand Down
4 changes: 4 additions & 0 deletions daft/logical/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,10 @@ def concat(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: # type: igno
builder = self._builder.concat(other._builder)
return LogicalPlanBuilder(builder)

def intersect(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder:
builder = self._builder.intersect(other._builder, False)
return LogicalPlanBuilder(builder)

def add_monotonically_increasing_id(self, column_name: str | None) -> LogicalPlanBuilder:
builder = self._builder.add_monotonically_increasing_id(column_name)
return LogicalPlanBuilder(builder)
Expand Down
11 changes: 11 additions & 0 deletions src/daft-logical-plan/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,13 @@ impl LogicalPlanBuilder {
Ok(self.with_new_plan(logical_plan))
}

pub fn intersect(&self, other: &Self, is_all: bool) -> DaftResult<Self> {
let logical_plan: LogicalPlan =
ops::Intersect::try_new(self.plan.clone(), other.plan.clone(), is_all)?
.to_optimized_join()?;
Ok(self.with_new_plan(logical_plan))
}

pub fn add_monotonically_increasing_id(&self, column_name: Option<&str>) -> DaftResult<Self> {
let logical_plan: LogicalPlan =
ops::MonotonicallyIncreasingId::new(self.plan.clone(), column_name).into();
Expand Down Expand Up @@ -785,6 +792,10 @@ impl PyLogicalPlanBuilder {
Ok(self.builder.concat(&other.builder)?.into())
}

pub fn intersect(&self, other: &Self, is_all: bool) -> DaftResult<Self> {
Ok(self.builder.intersect(&other.builder, is_all)?.into())
}

pub fn add_monotonically_increasing_id(&self, column_name: Option<&str>) -> PyResult<Self> {
Ok(self
.builder
Expand Down
9 changes: 9 additions & 0 deletions src/daft-logical-plan/src/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub enum LogicalPlan {
Aggregate(Aggregate),
Pivot(Pivot),
Concat(Concat),
Intersect(Intersect),
Join(Join),
Sink(Sink),
Sample(Sample),
Expand Down Expand Up @@ -58,6 +59,7 @@ impl LogicalPlan {
Self::Aggregate(Aggregate { output_schema, .. }) => output_schema.clone(),
Self::Pivot(Pivot { output_schema, .. }) => output_schema.clone(),
Self::Concat(Concat { input, .. }) => input.schema(),
Self::Intersect(Intersect { lhs, .. }) => lhs.schema(),
Self::Join(Join { output_schema, .. }) => output_schema.clone(),
Self::Sink(Sink { schema, .. }) => schema.clone(),
Self::Sample(Sample { input, .. }) => input.schema(),
Expand Down Expand Up @@ -162,6 +164,7 @@ impl LogicalPlan {
.collect();
vec![left, right]
}
Self::Intersect(_) => vec![IndexSet::new(), IndexSet::new()],
Self::Source(_) => todo!(),
Self::Sink(_) => todo!(),
}
Expand All @@ -183,6 +186,7 @@ impl LogicalPlan {
Self::Pivot(..) => "Pivot",
Self::Concat(..) => "Concat",
Self::Join(..) => "Join",
Self::Intersect(..) => "Intersect",
Self::Sink(..) => "Sink",
Self::Sample(..) => "Sample",
Self::MonotonicallyIncreasingId(..) => "MonotonicallyIncreasingId",
Expand All @@ -205,6 +209,7 @@ impl LogicalPlan {
Self::Aggregate(aggregate) => aggregate.multiline_display(),
Self::Pivot(pivot) => pivot.multiline_display(),
Self::Concat(_) => vec!["Concat".to_string()],
Self::Intersect(inner) => inner.multiline_display(),
Self::Join(join) => join.multiline_display(),
Self::Sink(sink) => sink.multiline_display(),
Self::Sample(sample) => {
Expand All @@ -231,6 +236,7 @@ impl LogicalPlan {
Self::Concat(Concat { input, other }) => vec![input, other],
Self::Join(Join { left, right, .. }) => vec![left, right],
Self::Sink(Sink { input, .. }) => vec![input],
Self::Intersect(Intersect { lhs, rhs, .. }) => vec![lhs, rhs],
Self::Sample(Sample { input, .. }) => vec![input],
Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId { input, .. }) => {
vec![input]
Expand Down Expand Up @@ -259,11 +265,13 @@ impl LogicalPlan {
Self::Unpivot(Unpivot {ids, values, variable_name, value_name, output_schema, ..}) => Self::Unpivot(Unpivot { input: input.clone(), ids: ids.clone(), values: values.clone(), variable_name: variable_name.clone(), value_name: value_name.clone(), output_schema: output_schema.clone() }),
Self::Sample(Sample {fraction, with_replacement, seed, ..}) => Self::Sample(Sample::new(input.clone(), *fraction, *with_replacement, *seed)),
Self::Concat(_) => panic!("Concat ops should never have only one input, but got one"),
Self::Intersect(_) => panic!("Intersect ops should never have only one input, but got one"),
Self::Join(_) => panic!("Join ops should never have only one input, but got one"),
},
[input1, input2] => match self {
Self::Source(_) => panic!("Source nodes don't have children, with_new_children() should never be called for Source ops"),
Self::Concat(_) => Self::Concat(Concat::try_new(input1.clone(), input2.clone()).unwrap()),
Self::Intersect(inner) => Self::Intersect(Intersect::try_new(input1.clone(), input2.clone(), inner.is_all).unwrap()),
Self::Join(Join { left_on, right_on, null_equals_nulls, join_type, join_strategy, .. }) => Self::Join(Join::try_new(
input1.clone(),
input2.clone(),
Expand Down Expand Up @@ -361,6 +369,7 @@ impl_from_data_struct_for_logical_plan!(Distinct);
impl_from_data_struct_for_logical_plan!(Aggregate);
impl_from_data_struct_for_logical_plan!(Pivot);
impl_from_data_struct_for_logical_plan!(Concat);
impl_from_data_struct_for_logical_plan!(Intersect);
impl_from_data_struct_for_logical_plan!(Join);
impl_from_data_struct_for_logical_plan!(Sink);
impl_from_data_struct_for_logical_plan!(Sample);
Expand Down
2 changes: 2 additions & 0 deletions src/daft-logical-plan/src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ mod pivot;
mod project;
mod repartition;
mod sample;
mod set_operations;
mod sink;
mod sort;
mod source;
Expand All @@ -29,6 +30,7 @@ pub use pivot::Pivot;
pub use project::Project;
pub use repartition::Repartition;
pub use sample::Sample;
pub use set_operations::Intersect;
pub use sink::Sink;
pub use sort::Sort;
pub use source::Source;
Expand Down
109 changes: 109 additions & 0 deletions src/daft-logical-plan/src/ops/set_operations.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
use std::sync::Arc;

use common_error::DaftError;
use daft_core::join::JoinType;
use daft_dsl::col;
use snafu::ResultExt;

use crate::{logical_plan, logical_plan::CreationSnafu, LogicalPlan};

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Intersect {
// Upstream nodes.
pub lhs: Arc<LogicalPlan>,
pub rhs: Arc<LogicalPlan>,
pub is_all: bool,
}

impl Intersect {
pub(crate) fn try_new(
lhs: Arc<LogicalPlan>,
rhs: Arc<LogicalPlan>,
is_all: bool,
) -> logical_plan::Result<Self> {
let lhs_schema = lhs.schema();
let rhs_schema = rhs.schema();
if lhs_schema.len() != rhs_schema.len() {
return Err(DaftError::SchemaMismatch(format!(
"Both plans must have the same num of fields to intersect, \
but got[lhs: {} v.s rhs: {}], lhs schema: {}, rhs schema: {}",
lhs_schema.len(),
rhs_schema.len(),
lhs_schema,
rhs_schema
)))
.context(CreationSnafu);
}
// lhs and rhs should have the same type for each field to intersect
if lhs_schema
.fields
.values()
.zip(rhs_schema.fields.values())
.any(|(l, r)| l.dtype != r.dtype)
{
return Err(DaftError::SchemaMismatch(format!(
"Both plans' schemas should have the same type for each field to intersect, \
but got lhs schema: {}, rhs schema: {}",
lhs_schema, rhs_schema
)))
.context(CreationSnafu);
}
Ok(Self { lhs, rhs, is_all })
}

/// intersect distinct could be represented as a semi join + distinct
/// the following intersect operator:
/// ```sql
/// select a1, a2 from t1 intersect select b1, b2 from t2
/// ```
/// is the same as:
/// ```sql
/// select distinct a1, a2 from t1 left semi join t2
/// on t1.a1 <> t2.b1 and t1.a2 <> t2.b2
/// ```
/// TODO: Move this logical to logical optimization rules
pub(crate) fn to_optimized_join(&self) -> logical_plan::Result<LogicalPlan> {
if self.is_all {
Err(logical_plan::Error::CreationError {
source: DaftError::InternalError("intersect all is not supported yet".to_string()),
})
} else {
let left_on = self
.lhs
.schema()
.fields
.keys()
.map(|k| col(k.clone()))
.collect();
let right_on = self
.rhs
.schema()
.fields
.keys()
.map(|k| col(k.clone()))
.collect();
let join = logical_plan::Join::try_new(
self.lhs.clone(),
self.rhs.clone(),
left_on,
right_on,
Some(vec![true; self.lhs.schema().fields.len()]),
JoinType::Semi,
None,
None,
None,
);
join.map(|j| logical_plan::Distinct::new(j.into()).into())
}
}

pub fn multiline_display(&self) -> Vec<String> {
let mut res = vec![];
if self.is_all {
res.push("Intersect All:".to_string());
} else {
res.push("Intersect:".to_string());
}
res
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,11 @@ impl PushDownProjection {
// since Distinct implicitly requires all parent columns.
Ok(Transformed::no(plan))
}
LogicalPlan::Intersect(_) => {
// Cannot push down past an Intersect,
// since Intersect implicitly requires all parent columns.
Ok(Transformed::no(plan))
}
LogicalPlan::Pivot(_) | LogicalPlan::MonotonicallyIncreasingId(_) => {
// Cannot push down past a Pivot/MonotonicallyIncreasingId because it changes the schema.
Ok(Transformed::no(plan))
Expand Down
3 changes: 3 additions & 0 deletions src/daft-physical-plan/src/physical_planner/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,9 @@ pub(super) fn translate_single_logical_node(
.arced(),
)
}
LogicalPlan::Intersect(_) => Err(DaftError::InternalError(
"Intersect should already be optimized away".to_string(),
)),
}
}

Expand Down
47 changes: 47 additions & 0 deletions tests/dataframe/test_intersect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

import daft
from daft import col


def test_simple_intersect(make_df):
df1 = make_df({"foo": [1, 2, 3]})
df2 = make_df({"bar": [2, 3, 4]})
result = df1.intersect(df2)
assert result.to_pydict() == {"foo": [2, 3]}


def test_intersect_with_duplicate(make_df):
df1 = make_df({"foo": [1, 2, 2, 3]})
df2 = make_df({"bar": [2, 3, 3]})
result = df1.intersect(df2)
assert result.to_pydict() == {"foo": [2, 3]}


def test_self_intersect(make_df):
df = make_df({"foo": [1, 2, 3]})
result = df.intersect(df).sort(by="foo")
assert result.to_pydict() == {"foo": [1, 2, 3]}


def test_intersect_empty(make_df):
df1 = make_df({"foo": [1, 2, 3]})
df2 = make_df({"bar": []}).select(col("bar").cast(daft.DataType.int64()))
result = df1.intersect(df2)
assert result.to_pydict() == {"foo": []}


def test_intersect_with_nulls(make_df):
df1 = make_df({"foo": [1, 2, None]})
df1_without_mull = make_df({"foo": [1, 2]})
df2 = make_df({"bar": [2, 3, None]})
df2_without_null = make_df({"bar": [2, 3]})

result = df1.intersect(df2)
assert result.to_pydict() == {"foo": [2, None]}

result = df1_without_mull.intersect(df2)
assert result.to_pydict() == {"foo": [2]}

result = df1.intersect(df2_without_null)
assert result.to_pydict() == {"foo": [2]}

0 comments on commit a547bd3

Please sign in to comment.