Skip to content
7 changes: 6 additions & 1 deletion dask_expr/_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,12 @@ def combine_similar(

if changed_dependency:
expr = type(expr)(*new_operands)
changed = True
if isinstance(expr, Projection):
# We might introduce stacked Projections (merge for example).
# So get rid of them here again
expr_simplify_down = expr._simplify_down()
if expr_simplify_down is not None:
expr = expr_simplify_down
if update_root:
root = expr
continue
Expand Down
147 changes: 117 additions & 30 deletions dask_expr/_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from dask_expr._util import _convert_to_list

_HASH_COLUMN_NAME = "__hash_partition"
_PARTITION_COLUMN = "_partitions"


class Merge(Expr):
Expand Down Expand Up @@ -61,6 +62,10 @@ class Merge(Expr):
"shuffle_backend": None,
}

# combine similar variables
_skip_ops = (Filter, AssignPartitioningIndex, Shuffle)
_remove_ops = (Projection,)

def __str__(self):
return f"Merge({self._name[-7:]})"

Expand Down Expand Up @@ -254,52 +259,134 @@ def _simplify_up(self, parent):
return type(parent)(result)
return result[parent_columns]

def _validate_same_operations(self, common, op, remove_ops, skip_ops):
def _validate_same_operations(self, common, op, remove="both"):
# Travers left and right to check if we can find the same operation
# more than once. We have to account for potential projections on both sides
name = common._name
if name == op._name:
return True
op_left, _ = self._remove_operations(op.left, remove_ops, skip_ops)
op_right, _ = self._remove_operations(op.right, remove_ops, skip_ops)
return type(op)(op_left, op_right, *op.operands[2:])._name == name
return True, op.left.columns, op.right.columns

columns_left, columns_right = None, None
op_left, op_right = op.left, op.right
if remove in ("both", "left"):
op_left, columns_left = self._remove_operations(
op.left, self._remove_ops, self._skip_ops
)
if remove in ("both", "right"):
op_right, columns_right = self._remove_operations(
op.right, self._remove_ops, self._skip_ops
)

return (
type(op)(op_left, op_right, *op.operands[2:])._name == name,
columns_left,
columns_right,
)

@staticmethod
def _flatten_columns(expr, columns, side):
if len(columns) == 0:
return getattr(expr, side).columns
else:
return list(set(flatten(columns)))

def _combine_similar(self, root: Expr):
# Push projections back up to avoid performing the same merge multiple times
skip_ops = (Filter, AssignPartitioningIndex, Shuffle)
remove_ops = (Projection,)

def _flatten_columns(columns, side):
if len(columns) == 0:
return getattr(self, side).columns
else:
return list(set(flatten(columns)))

left, columns_left = self._remove_operations(self.left, remove_ops, skip_ops)
columns_left = _flatten_columns(columns_left, "left")
right, columns_right = self._remove_operations(self.right, remove_ops, skip_ops)
columns_right = _flatten_columns(columns_right, "right")
left, columns_left = self._remove_operations(
self.left, self._remove_ops, self._skip_ops
)
columns_left = self._flatten_columns(self, columns_left, "left")
right, columns_right = self._remove_operations(
self.right, self._remove_ops, self._skip_ops
)
columns_right = self._flatten_columns(self, columns_right, "right")

if left._name == self.left._name and right._name == self.right._name:
# There aren't any ops we can remove, so bail
return

common = type(self)(left, right, *self.operands[2:])
# We can not remove Projections on both sides at once, because only
# one side might need the push back up step. So try if removing Projections
# on either side works before removing them on both sides at once.

common_left = type(self)(self.left, right, *self.operands[2:])
common_right = type(self)(left, self.right, *self.operands[2:])
common_both = type(self)(left, right, *self.operands[2:])

push_up_op = False
for op in self._find_similar_operations(root, ignore=self._parameters):
if self._validate_same_operations(common, op, remove_ops, skip_ops):
push_up_op = True
columns, left_sub, right_sub = None, None, None

for op in self._find_similar_operations(root, ignore=["left", "right"]):
if op._name in (common_right._name, common_left._name, common_both._name):
continue

validation = self._validate_same_operations(common_right, op, "left")
if validation[0]:
left_sub = self._flatten_columns(op, validation[1], side="left")
columns = self.right.columns.copy()
columns += [col for col in self.left.columns if col not in columns]
break

if push_up_op:
columns = columns_left.copy()
columns += [col for col in columns_right if col not in columns_left]
if sorted(common.columns) != sorted(columns):
common = common[columns]
c = common._simplify_down()
common = c if c is not None else common
return common
validation = self._validate_same_operations(common_left, op, "right")
if validation[0]:
right_sub = self._flatten_columns(op, validation[2], side="right")
columns = self.left.columns.copy()
columns += [col for col in self.right.columns if col not in columns]
break

validation = self._validate_same_operations(common_both, op)
if validation[0]:
left_sub = self._flatten_columns(op, validation[1], side="left")
right_sub = self._flatten_columns(op, validation[2], side="right")
columns = columns_left.copy()
columns += [col for col in columns_right if col not in columns_left]
break

if columns is not None:
expr = self
if _PARTITION_COLUMN in columns:
columns.remove(_PARTITION_COLUMN)

if left_sub is not None:
left_sub.extend([col for col in columns_left if col not in left_sub])
left = self._replace_projections(self.left, left_sub)
expr = expr.substitute(self.left, left)

if right_sub is not None:
right_sub.extend([col for col in columns_right if col not in right_sub])
right = self._replace_projections(self.right, right_sub)
expr = expr.substitute(self.right, right)

if sorted(expr.columns) != sorted(columns):
expr = expr[columns]
if expr._name == self._name:
return None
return expr

def _replace_projections(self, frame, new_columns):
# This branch might have a number of Projections that differ from our
# new columns. We replace those projections appropriately

operations = []
while isinstance(frame, self._remove_ops + self._skip_ops):
if isinstance(frame, self._remove_ops):
# TODO: Shuffle and AssignPartitioningIndex being 2 different ops
# causes all kinds of pain
if isinstance(frame.frame, AssignPartitioningIndex):
new_cols = new_columns
else:
new_cols = [col for col in new_columns if col != _PARTITION_COLUMN]

# Ignore Projection if new_columns = frame.frame.columns
if sorted(new_cols) != sorted(frame.frame.columns):
operations.append((type(frame), [new_cols]))
else:
operations.append((type(frame), frame.operands[1:]))
frame = frame.frame

for op_type, operands in reversed(operations):
frame = op_type(frame, *operands)
return frame


class HashJoinP2P(Merge, PartitionsFiltered):
Expand Down
35 changes: 34 additions & 1 deletion dask_expr/tests/test_merge.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pytest
from dask.dataframe.utils import assert_eq

from dask_expr import from_pandas
from dask_expr import Merge, from_pandas
from dask_expr._expr import Projection
from dask_expr._shuffle import Shuffle
from dask_expr.tests._util import _backend_library

# Set DataFrame backend for this module
Expand Down Expand Up @@ -206,3 +208,34 @@ def test_merge_combine_similar(npartitions_left, npartitions_right):
expected["new"] = expected.b + expected.c
expected = expected.groupby(["a", "e", "x"]).new.sum()
assert_eq(query, expected)


def test_merge_combine_similar_intermediate_projections():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reproducer for the behavior being targeted by this PR? This test seems to pass on main for me.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the bug was blocked by the additional layer in #333 but shows up now

pdf = lib.DataFrame(
{
"a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"b": 1,
"c": 1,
}
)
pdf2 = lib.DataFrame({"a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], "x": 1})
pdf3 = lib.DataFrame({"d": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], "e": 1, "y": 1})

df = from_pandas(pdf, npartitions=2)
df2 = from_pandas(pdf2, npartitions=3)
df3 = from_pandas(pdf3, npartitions=3)

q = df.merge(df2).merge(df3, left_on="b", right_on="d")[["b", "x", "y"]]
q["new"] = q.b + q.x
result = q.optimize(fuse=False)
# Check that we have intermediate projections dropping unnecessary columns
assert isinstance(result.expr.frame, Projection)
assert isinstance(result.expr.frame.frame, Merge)
assert isinstance(result.expr.frame.frame.left, Projection)
assert isinstance(result.expr.frame.frame.left.frame, Shuffle)

pd_result = pdf.merge(pdf2).merge(pdf3, left_on="b", right_on="d")[["b", "x", "y"]]
pd_result["new"] = pd_result.b + pd_result.x

assert sorted(result.expr.frame.frame.left.operand("columns")) == ["b", "x"]
assert_eq(result, pd_result, check_index=False)