Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(python): Improve join argument checks #18847

Merged
merged 2 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4502,6 +4502,17 @@ def join(
msg = f"expected `other` join table to be a LazyFrame, not a {type(other).__name__!r}"
raise TypeError(msg)

uses_on = on is not None
uses_left_on = left_on is not None
uses_right_on = right_on is not None
uses_lr_on = uses_left_on or uses_right_on
if uses_on and uses_lr_on:
msg = "cannot use 'on' in conjunction with 'left_on' or 'right_on'"
raise ValueError(msg)
elif uses_left_on != uses_right_on:
msg = "'left_on' requires corresponding 'right_on'"
raise ValueError(msg)

if how == "outer":
how = "full"
issue_deprecation_warning(
Expand All @@ -4515,9 +4526,8 @@ def join(
"Use of `how='outer_coalesce'` should be replaced with `how='full', coalesce=True`.",
version="0.20.29",
)

elif how == "cross":
if left_on is not None or right_on is not None:
if uses_on or uses_lr_on:
msg = "cross join should not pass join keys"
raise ValueError(msg)
return self._from_pyldf(
Expand All @@ -4534,11 +4544,11 @@ def join(
)
)

if on is not None:
if uses_on:
pyexprs = parse_into_list_of_expressions(on)
pyexprs_left = pyexprs
pyexprs_right = pyexprs
elif left_on is not None and right_on is not None:
elif uses_lr_on:
pyexprs_left = parse_into_list_of_expressions(left_on)
pyexprs_right = parse_into_list_of_expressions(right_on)
else:
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -2145,7 +2145,7 @@ def test_join_suffixes() -> None:
join_strategies: list[JoinStrategy] = ["left", "inner", "full", "cross"]
for how in join_strategies:
# no need for an assert, we error if wrong
df_a.join(df_b, on="A", suffix="_y", how=how)["B_y"]
df_a.join(df_b, on="A" if how != "cross" else None, suffix="_y", how=how)["B_y"]

df_a.join_asof(df_b, on=pl.col("A").set_sorted(), suffix="_y")["B_y"]

Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/operations/test_cross_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_cross_join_predicate_pushdown_block_16956() -> None:
).cast(pl.Datetime("ms", "Europe/Amsterdam"))

assert (
lf.join(lf, on="start_datetime", how="cross")
lf.join(lf, how="cross")
.filter(
pl.col.end_datetime_right.is_between(
pl.col.start_datetime, pl.col.start_datetime.dt.offset_by("132h")
Expand Down
46 changes: 46 additions & 0 deletions py-polars/tests/unit/operations/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,3 +1036,49 @@ def test_join_coalesce_not_supported_warning() -> None:
)

assert_frame_equal(expect, got, check_row_order=False)


@pytest.mark.parametrize(
("on_args"),
[
{"on": "a", "left_on": "a"},
{"on": "a", "right_on": "a"},
{"on": "a", "left_on": "a", "right_on": "a"},
],
)
def test_join_on_and_left_right_on(on_args: dict[str, str]) -> None:
df1 = pl.DataFrame({"a": [1], "b": [2]})
df2 = pl.DataFrame({"a": [1], "c": [3]})
msg = "cannot use 'on' in conjunction with 'left_on' or 'right_on'"
with pytest.raises(ValueError, match=msg):
df1.join(df2, **on_args) # type: ignore[arg-type]


@pytest.mark.parametrize(
("on_args"),
[
{"left_on": "a"},
{"right_on": "a"},
],
)
def test_join_only_left_or_right_on(on_args: dict[str, str]) -> None:
df1 = pl.DataFrame({"a": [1]})
df2 = pl.DataFrame({"a": [1]})
msg = "'left_on' requires corresponding 'right_on'"
with pytest.raises(ValueError, match=msg):
df1.join(df2, **on_args) # type: ignore[arg-type]


@pytest.mark.parametrize(
("on_args"),
[
{"on": "a"},
{"left_on": "a", "right_on": "a"},
],
)
def test_cross_join_no_on_keys(on_args: dict[str, str]) -> None:
df1 = pl.DataFrame({"a": [1, 2]})
df2 = pl.DataFrame({"b": [3, 4]})
msg = "cross join should not pass join keys"
with pytest.raises(ValueError, match=msg):
df1.join(df2, how="cross", **on_args) # type: ignore[arg-type]
20 changes: 4 additions & 16 deletions py-polars/tests/unit/streaming/test_streaming_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,26 +103,14 @@ def test_streaming_joins() -> None:


def test_streaming_cross_join_empty() -> None:
df1 = pl.LazyFrame(
data={
"col1": ["a"],
}
)
df1 = pl.LazyFrame(data={"col1": ["a"]})

df2 = pl.LazyFrame(
data={
"col1": [],
},
schema={
"col1": str,
},
data={"col1": []},
schema={"col1": str},
)

out = df1.join(
df2,
how="cross",
on="col1",
).collect(streaming=True)
out = df1.join(df2, how="cross").collect(streaming=True)
assert out.shape == (0, 2)
assert out.columns == ["col1", "col1_right"]

Expand Down