Skip to content

Commit

Permalink
feat(python)!: Constrain access to globals from df.sql in favour of…
Browse files Browse the repository at this point in the history
… top-level `pl.sql` (pola-rs#16598)
  • Loading branch information
alexander-beedie authored and Wouittone committed Jun 22, 2024
1 parent f99f790 commit d981b85
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 14 deletions.
10 changes: 3 additions & 7 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4728,13 +4728,9 @@ def sql(self, query: str, *, table_name: str = "self") -> Self:
issue_unstable_warning(
"`sql` is considered **unstable** (although it is close to being considered stable)."
)
with SQLContext(
register_globals=True,
eager=True,
) as ctx:
frames = {table_name: self} if table_name else {}
frames["self"] = self
ctx.register_many(frames)
with SQLContext(register_globals=False, eager=True) as ctx:
name = table_name if table_name else "self"
ctx.register(name=name, frame=self)
return ctx.execute(query) # type: ignore[return-value]

def top_k(
Expand Down
10 changes: 3 additions & 7 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,13 +1367,9 @@ def sql(self, query: str, *, table_name: str = "self") -> Self:
issue_unstable_warning(
"`sql` is considered **unstable** (although it is close to being considered stable)."
)
with SQLContext(
register_globals=True,
eager=False,
) as ctx:
frames = {table_name: self} if table_name else {}
frames["self"] = self
ctx.register_many(frames)
with SQLContext(register_globals=False, eager=False) as ctx:
name = table_name if table_name else "self"
ctx.register(name=name, frame=self)
return ctx.execute(query) # type: ignore[return-value]

def top_k(
Expand Down
16 changes: 16 additions & 0 deletions py-polars/tests/unit/sql/test_miscellaneous.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,22 @@ def test_distinct() -> None:
ctx.execute("SELECT * FROM df")


def test_frame_sql_globals_error() -> None:
df1 = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
df2 = pl.DataFrame({"a": [2, 3, 4], "b": [7, 6, 5]}) # noqa: F841

query = """
SELECT df1.a, df2.b
FROM df2 JOIN df1 ON df1.a = df2.a
ORDER BY b DESC
"""
with pytest.raises(ComputeError, match=".*not found.*"):
df1.sql(query=query)

res = pl.sql(query=query, eager=True)
assert res.to_dict(as_series=False) == {"a": [2, 3], "b": [7, 6]}


def test_in_no_ops_11946() -> None:
lf = pl.LazyFrame(
[
Expand Down

0 comments on commit d981b85

Please sign in to comment.