Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions narwhals/_dask/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,9 @@ def unique(
raise error
if keep == "none":
subset = subset or self.columns
token = generate_temporary_column_name(n_bytes=8, columns=subset)
token = generate_temporary_column_name(
n_bytes=8, columns=subset, prefix="count_"
)
ser = self.native.groupby(subset).size().rename(token)
ser = ser[ser == 1]
unique = ser.reset_index().drop(columns=token)
Expand Down Expand Up @@ -335,7 +337,7 @@ def _join_full(

def _join_cross(self, other: Self, *, suffix: str) -> dd.DataFrame:
key_token = generate_temporary_column_name(
n_bytes=8, columns=(*self.columns, *other.columns)
n_bytes=8, columns=(*self.columns, *other.columns), prefix="cross_join_key_"
)
return (
self.native.assign(**{key_token: 0})
Expand Down Expand Up @@ -365,7 +367,7 @@ def _join_anti(
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str]
) -> dd.DataFrame:
indicator_token = generate_temporary_column_name(
n_bytes=8, columns=(*self.columns, *other.columns)
n_bytes=8, columns=(*self.columns, *other.columns), prefix="join_indicator_"
)
other_native = self._join_filter_rename(
other=other,
Expand Down Expand Up @@ -477,7 +479,9 @@ def tail(self, n: int) -> Self: # pragma: no cover
raise NotImplementedError(msg)

def gather_every(self, n: int, offset: int) -> Self:
row_index_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
row_index_token = generate_temporary_column_name(
n_bytes=8, columns=self.columns, prefix="row_index_"
)
plx = self.__narwhals_namespace__()
return (
self.with_row_index(row_index_token, order_by=None)
Expand Down
8 changes: 6 additions & 2 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,9 @@ def func(expr: dx.Series, quantile: float) -> dx.Series:
def is_first_distinct(self) -> Self:
def func(expr: dx.Series) -> dx.Series:
_name = expr.name
col_token = generate_temporary_column_name(n_bytes=8, columns=[_name])
col_token = generate_temporary_column_name(
n_bytes=8, columns=[_name], prefix="row_index_"
)
frame = add_row_index(expr.to_frame(), col_token)
first_distinct_index = frame.groupby(_name).agg({col_token: "min"})[col_token]
return frame[col_token].isin(first_distinct_index)
Expand All @@ -539,7 +541,9 @@ def func(expr: dx.Series) -> dx.Series:
def is_last_distinct(self) -> Self:
def func(expr: dx.Series) -> dx.Series:
_name = expr.name
col_token = generate_temporary_column_name(n_bytes=8, columns=[_name])
col_token = generate_temporary_column_name(
n_bytes=8, columns=[_name], prefix="row_index_"
)
frame = add_row_index(expr.to_frame(), col_token)
last_distinct_index = frame.groupby(_name).agg({col_token: "max"})[col_token]
return frame[col_token].isin(last_distinct_index)
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_duckdb/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def unique(
subset_ = subset or self.columns
if error := self._check_columns_exist(subset_):
raise error
tmp_name = generate_temporary_column_name(8, self.columns)
tmp_name = generate_temporary_column_name(8, self.columns, prefix="row_index_")
if order_by and keep == "last":
descending = [True] * len(order_by)
nulls_last = [True] * len(order_by)
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_ibis/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def unique(
subset_ = subset or self.columns
if error := self._check_columns_exist(subset_):
raise error
tmp_name = generate_temporary_column_name(8, self.columns)
tmp_name = generate_temporary_column_name(8, self.columns, prefix="row_index_")
if order_by and keep == "last":
order_by_ = IbisExpr._sort(*order_by, descending=True, nulls_last=True)
elif order_by:
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_polars/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def unique(
order_by: Sequence[str] | None = None,
) -> Self:
if order_by and maintain_order:
token = generate_temporary_column_name(8, self.columns)
token = generate_temporary_column_name(8, self.columns, prefix="row_index_")
res = (
self.native.with_row_index(token)
.sort(order_by, nulls_last=False)
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_spark_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def unique(
subset_ = subset or self.columns
if error := self._check_columns_exist(subset_):
raise error
tmp_name = generate_temporary_column_name(8, self.columns)
tmp_name = generate_temporary_column_name(8, self.columns, prefix="row_index_")
window = self._Window.partitionBy(subset_)
if order_by and keep == "last":
window = window.orderBy(*[self._F.desc_nulls_last(x) for x in order_by])
Expand Down
18 changes: 12 additions & 6 deletions narwhals/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,17 +1239,19 @@ def is_ordered_categorical(series: Series[Any]) -> bool:


def generate_unique_token(
n_bytes: int, columns: Container[str]
n_bytes: int, columns: Container[str], prefix: str = "nw"
) -> str: # pragma: no cover
msg = (
"Use `generate_temporary_column_name` instead. `generate_unique_token` is "
"deprecated and it will be removed in future versions"
)
issue_deprecation_warning(msg, _version="1.13.0")
return generate_temporary_column_name(n_bytes=n_bytes, columns=columns)
return generate_temporary_column_name(n_bytes=n_bytes, columns=columns, prefix=prefix)


def generate_temporary_column_name(n_bytes: int, columns: Container[str]) -> str:
def generate_temporary_column_name(
n_bytes: int, columns: Container[str], prefix: str = "nw"
) -> str:
"""Generates a unique column name that is not present in the given list of columns.

It relies on [python secrets token_hex](https://docs.python.org/3/library/secrets.html#secrets.token_hex)
Expand All @@ -1258,6 +1260,7 @@ def generate_temporary_column_name(n_bytes: int, columns: Container[str]) -> str
Arguments:
n_bytes: The number of bytes to generate for the token.
columns: The list of columns to check for uniqueness.
prefix: prefix with which the temporary column name should start with.

Returns:
A unique token that is not present in the given list of columns.
Expand All @@ -1270,12 +1273,15 @@ def generate_temporary_column_name(n_bytes: int, columns: Container[str]) -> str
>>> columns = ["abc", "xyz"]
>>> nw.generate_temporary_column_name(n_bytes=8, columns=columns) not in columns
True
>>> temp_name = nw.generate_temporary_column_name(
... n_bytes=8, columns=columns, prefix="foo"
... )
>>> temp_name not in columns and temp_name.startswith("foo")
True
"""
counter = 0
while True:
# Prepend `'nw'` to ensure it always starts with a character
# https://github.com/narwhals-dev/narwhals/issues/2510
token = f"nw{token_hex(n_bytes - 1)}"
token = f"{prefix}{token_hex(n_bytes - 1)}"
Comment on lines -1276 to +1284
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we ensure that prefix is not an empty string?

if token not in columns:
return token

Expand Down
27 changes: 27 additions & 0 deletions tests/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,17 @@ def test_generate_temporary_column_name(n_bytes: int) -> None:
assert temp_col_name not in columns


@pytest.mark.parametrize("_idx", [1, 2])
def test_generate_temporary_column_name_prefix(_idx: int) -> None:
columns = ["abc", "XYZ"]
prefix = columns[0][:_idx]

temp_col_name = nw.generate_temporary_column_name(
n_bytes=2, columns=columns, prefix=prefix
)
assert temp_col_name not in columns


def test_generate_temporary_column_name_raise() -> None:
from itertools import product

Expand All @@ -299,6 +310,22 @@ def test_generate_temporary_column_name_raise() -> None:
nw.generate_temporary_column_name(n_bytes=1, columns=columns)


def test_generate_temporary_column_name_pr_3118_example() -> None:
from tests.utils import DUCKDB_VERSION

if DUCKDB_VERSION < (1, 3, 0):
pytest.skip()

import duckdb

conn = duckdb.connect()
conn.sql("""CREATE TABLE df (a int64, b int64);""")

df = nw.from_native(conn.table("df"))
sql = df.unique("b", keep="any").select("a").to_native().sql_query()
assert "AS row_index_" in sql


@pytest.mark.parametrize(
("version", "expected"),
[
Expand Down
Loading