Skip to content

Make combine_similar less greedy for merge #334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Oct 16, 2023
7 changes: 6 additions & 1 deletion dask_expr/_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,12 @@ def combine_similar(

if changed_dependency:
expr = type(expr)(*new_operands)
changed = True
if isinstance(expr, Projection):
# We might introduce stacked Projections (merge for example).
# So get rid of them here again
expr_simplify_down = expr._simplify_down()
if expr_simplify_down is not None:
expr = expr_simplify_down
if update_root:
root = expr
continue
Expand Down
147 changes: 117 additions & 30 deletions dask_expr/_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from dask_expr._util import _convert_to_list

_HASH_COLUMN_NAME = "__hash_partition"
_PARTITION_COLUMN = "_partitions"


class Merge(Expr):
Expand Down Expand Up @@ -61,6 +62,10 @@ class Merge(Expr):
"shuffle_backend": None,
}

# combine similar variables
_skip_ops = (Filter, AssignPartitioningIndex, Shuffle)
_remove_ops = (Projection,)

def __str__(self):
return f"Merge({self._name[-7:]})"

Expand Down Expand Up @@ -254,52 +259,134 @@ def _simplify_up(self, parent):
return type(parent)(result)
return result[parent_columns]

def _validate_same_operations(self, common, op, remove_ops, skip_ops):
def _validate_same_operations(self, common, op, remove="both"):
# Travers left and right to check if we can find the same operation
# more than once. We have to account for potential projections on both sides
name = common._name
if name == op._name:
return True
op_left, _ = self._remove_operations(op.left, remove_ops, skip_ops)
op_right, _ = self._remove_operations(op.right, remove_ops, skip_ops)
return type(op)(op_left, op_right, *op.operands[2:])._name == name
return True, op.left.columns, op.right.columns

columns_left, columns_right = None, None
op_left, op_right = op.left, op.right
if remove in ("both", "left"):
op_left, columns_left = self._remove_operations(
op.left, self._remove_ops, self._skip_ops
)
if remove in ("both", "right"):
op_right, columns_right = self._remove_operations(
op.right, self._remove_ops, self._skip_ops
)

return (
type(op)(op_left, op_right, *op.operands[2:])._name == name,
columns_left,
columns_right,
)

@staticmethod
def _flatten_columns(expr, columns, side):
if len(columns) == 0:
return getattr(expr, side).columns
else:
return list(set(flatten(columns)))

def _combine_similar(self, root: Expr):
# Push projections back up to avoid performing the same merge multiple times
skip_ops = (Filter, AssignPartitioningIndex, Shuffle)
remove_ops = (Projection,)

def _flatten_columns(columns, side):
if len(columns) == 0:
return getattr(self, side).columns
else:
return list(set(flatten(columns)))

left, columns_left = self._remove_operations(self.left, remove_ops, skip_ops)
columns_left = _flatten_columns(columns_left, "left")
right, columns_right = self._remove_operations(self.right, remove_ops, skip_ops)
columns_right = _flatten_columns(columns_right, "right")
left, columns_left = self._remove_operations(
self.left, self._remove_ops, self._skip_ops
)
columns_left = self._flatten_columns(self, columns_left, "left")
right, columns_right = self._remove_operations(
self.right, self._remove_ops, self._skip_ops
)
columns_right = self._flatten_columns(self, columns_right, "right")

if left._name == self.left._name and right._name == self.right._name:
# There aren't any ops we can remove, so bail
return

common = type(self)(left, right, *self.operands[2:])
# We can not remove Projections on both sides at once, because only
# one side might need the push back up step. So try if removing Projections
# on either side works before removing them on both sides at once.

common_left = type(self)(self.left, right, *self.operands[2:])
common_right = type(self)(left, self.right, *self.operands[2:])
common_both = type(self)(left, right, *self.operands[2:])

push_up_op = False
for op in self._find_similar_operations(root, ignore=self._parameters):
if self._validate_same_operations(common, op, remove_ops, skip_ops):
push_up_op = True
columns, left_sub, right_sub = None, None, None

for op in self._find_similar_operations(root, ignore=["left", "right"]):
if op._name in (common_right._name, common_left._name, common_both._name):
continue

validation = self._validate_same_operations(common_right, op, "left")
if validation[0]:
left_sub = self._flatten_columns(op, validation[1], side="left")
columns = self.right.columns.copy()
columns += [col for col in self.left.columns if col not in columns]
break

if push_up_op:
columns = columns_left.copy()
columns += [col for col in columns_right if col not in columns_left]
if sorted(common.columns) != sorted(columns):
common = common[columns]
c = common._simplify_down()
common = c if c is not None else common
return common
validation = self._validate_same_operations(common_left, op, "right")
if validation[0]:
right_sub = self._flatten_columns(op, validation[2], side="right")
columns = self.left.columns.copy()
columns += [col for col in self.right.columns if col not in columns]
break

validation = self._validate_same_operations(common_both, op)
if validation[0]:
left_sub = self._flatten_columns(op, validation[1], side="left")
right_sub = self._flatten_columns(op, validation[2], side="right")
columns = columns_left.copy()
columns += [col for col in columns_right if col not in columns_left]
break

if columns is not None:
expr = self
if _PARTITION_COLUMN in columns:
columns.remove(_PARTITION_COLUMN)

if left_sub is not None:
left_sub.extend([col for col in columns_left if col not in left_sub])
left = self._replace_projections(self.left, left_sub)
expr = expr.substitute(self.left, left)

if right_sub is not None:
right_sub.extend([col for col in columns_right if col not in right_sub])
right = self._replace_projections(self.right, right_sub)
expr = expr.substitute(self.right, right)

if sorted(expr.columns) != sorted(columns):
expr = expr[columns]
if expr._name == self._name:
return None
return expr

def _replace_projections(self, frame, new_columns):
# This branch might have a number of Projections that differ from our
# new columns. We replace those projections appropriately

operations = []
while isinstance(frame, self._remove_ops + self._skip_ops):
if isinstance(frame, self._remove_ops):
# TODO: Shuffle and AssignPartitioningIndex being 2 different ops
# causes all kinds of pain
if isinstance(frame.frame, AssignPartitioningIndex):
new_cols = new_columns
else:
new_cols = [col for col in new_columns if col != _PARTITION_COLUMN]

# Ignore Projection if new_columns = frame.frame.columns
if sorted(new_cols) != sorted(frame.frame.columns):
operations.append((type(frame), [new_cols]))
else:
operations.append((type(frame), frame.operands[1:]))
frame = frame.frame

for op_type, operands in reversed(operations):
frame = op_type(frame, *operands)
return frame


class HashJoinP2P(Merge, PartitionsFiltered):
Expand Down
35 changes: 34 additions & 1 deletion dask_expr/tests/test_merge.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pytest
from dask.dataframe.utils import assert_eq

from dask_expr import from_pandas
from dask_expr import Merge, from_pandas
from dask_expr._expr import Projection
from dask_expr._shuffle import Shuffle
from dask_expr.tests._util import _backend_library

# Set DataFrame backend for this module
Expand Down Expand Up @@ -206,3 +208,34 @@ def test_merge_combine_similar(npartitions_left, npartitions_right):
expected["new"] = expected.b + expected.c
expected = expected.groupby(["a", "e", "x"]).new.sum()
assert_eq(query, expected)


def test_merge_combine_similar_intermediate_projections():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reproducer for the behavior being targeted by this PR? This test seems to pass on main for me.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the bug was blocked by the additional layer in #333 but shows up now

pdf = lib.DataFrame(
{
"a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"b": 1,
"c": 1,
}
)
pdf2 = lib.DataFrame({"a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], "x": 1})
pdf3 = lib.DataFrame({"d": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], "e": 1, "y": 1})

df = from_pandas(pdf, npartitions=2)
df2 = from_pandas(pdf2, npartitions=3)
df3 = from_pandas(pdf3, npartitions=3)

q = df.merge(df2).merge(df3, left_on="b", right_on="d")[["b", "x", "y"]]
q["new"] = q.b + q.x
result = q.optimize(fuse=False)
# Check that we have intermediate projections dropping unnecessary columns
assert isinstance(result.expr.frame, Projection)
assert isinstance(result.expr.frame.frame, Merge)
assert isinstance(result.expr.frame.frame.left, Projection)
assert isinstance(result.expr.frame.frame.left.frame, Shuffle)

pd_result = pdf.merge(pdf2).merge(pdf3, left_on="b", right_on="d")[["b", "x", "y"]]
pd_result["new"] = pd_result.b + pd_result.x

assert sorted(result.expr.frame.frame.left.operand("columns")) == ["b", "x"]
assert_eq(result, pd_result, check_index=False)