Skip to content

Commit 5d609a7

Browse files
committed
refactor: Implement CompliantNamespace.(all|col|exclude|nth)
All backends (besides `polars`) now share the same implementation #2202 (comment)
1 parent 554ef8d commit 5d609a7

File tree

10 files changed

+66
-212
lines changed

10 files changed

+66
-212
lines changed

narwhals/_arrow/namespace.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
from __future__ import annotations
22

33
import operator
4-
from functools import partial
54
from functools import reduce
65
from itertools import chain
76
from typing import TYPE_CHECKING
87
from typing import Any
98
from typing import Callable
10-
from typing import Container
119
from typing import Iterable
1210
from typing import Literal
1311
from typing import Sequence
@@ -28,10 +26,7 @@
2826
from narwhals._expression_parsing import combine_alias_output_names
2927
from narwhals._expression_parsing import combine_evaluate_output_names
3028
from narwhals.utils import Implementation
31-
from narwhals.utils import exclude_column_names
32-
from narwhals.utils import get_column_names
3329
from narwhals.utils import import_dtypes_module
34-
from narwhals.utils import passthrough_column_names
3530

3631
if TYPE_CHECKING:
3732
from typing import Callable
@@ -67,20 +62,6 @@ def __init__(
6762
self._version = version
6863

6964
# --- selection ---
70-
def col(self: Self, *column_names: str) -> ArrowExpr:
71-
return self._expr.from_column_names(
72-
passthrough_column_names(column_names), function_name="col", context=self
73-
)
74-
75-
def exclude(self: Self, excluded_names: Container[str]) -> ArrowExpr:
76-
return self._expr.from_column_names(
77-
partial(exclude_column_names, names=excluded_names),
78-
function_name="exclude",
79-
context=self,
80-
)
81-
82-
def nth(self: Self, *column_indices: int) -> ArrowExpr:
83-
return self._expr.from_column_indices(*column_indices, context=self)
8465

8566
def len(self: Self) -> ArrowExpr:
8667
# coverage bug? this is definitely hit
@@ -98,11 +79,6 @@ def len(self: Self) -> ArrowExpr:
9879
version=self._version,
9980
)
10081

101-
def all(self: Self) -> ArrowExpr:
102-
return self._expr.from_column_names(
103-
get_column_names, function_name="all", context=self
104-
)
105-
10682
def lit(self: Self, value: Any, dtype: DType | None) -> ArrowExpr:
10783
def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries:
10884
arrow_series = ArrowSeries._from_iterable(

narwhals/_compliant/expr.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,18 @@ def __narwhals_expr__(self) -> None: ...
9090
def __narwhals_namespace__(
9191
self,
9292
) -> CompliantNamespace[CompliantFrameT, Self]: ...
93+
@classmethod
94+
def from_column_names(
95+
cls,
96+
evaluate_column_names: Callable[[CompliantFrameT], Sequence[str]],
97+
/,
98+
*,
99+
function_name: str,
100+
context: _FullContext,
101+
) -> Self: ...
102+
@classmethod
103+
def from_column_indices(cls, *column_indices: int, context: _FullContext) -> Self: ...
104+
93105
def is_null(self) -> Self: ...
94106
def abs(self) -> Self: ...
95107
def all(self) -> Self: ...
@@ -330,22 +342,6 @@ def _from_series(cls, series: EagerSeriesT) -> Self:
330342
version=series._version,
331343
)
332344

333-
@classmethod
334-
def from_column_names(
335-
cls,
336-
evaluate_column_names: Callable[[EagerDataFrameT], Sequence[str]],
337-
/,
338-
*,
339-
function_name: str,
340-
context: _FullContext,
341-
) -> Self: ...
342-
@classmethod
343-
def from_column_indices(
344-
cls,
345-
*column_indices: int,
346-
context: _FullContext,
347-
) -> Self: ...
348-
349345
def _reuse_series(
350346
self: Self,
351347
method_name: str,

narwhals/_compliant/namespace.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from functools import partial
34
from typing import TYPE_CHECKING
45
from typing import Any
56
from typing import Container
@@ -13,21 +14,46 @@
1314
from narwhals._compliant.typing import EagerExprT
1415
from narwhals._compliant.typing import EagerSeriesT_co
1516
from narwhals.utils import deprecated
17+
from narwhals.utils import exclude_column_names
18+
from narwhals.utils import get_column_names
19+
from narwhals.utils import passthrough_column_names
1620

1721
if TYPE_CHECKING:
1822
from narwhals._compliant.selectors import CompliantSelectorNamespace
1923
from narwhals.dtypes import DType
24+
from narwhals.utils import Implementation
25+
from narwhals.utils import Version
2026

2127
__all__ = ["CompliantNamespace", "EagerNamespace"]
2228

2329

2430
class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]):
25-
def col(self, *column_names: str) -> CompliantExprT: ...
26-
def lit(self, value: Any, dtype: DType | None) -> CompliantExprT: ...
27-
def exclude(self, excluded_names: Container[str]) -> CompliantExprT: ...
28-
def nth(self, *column_indices: int) -> CompliantExprT: ...
31+
_implementation: Implementation
32+
_backend_version: tuple[int, ...]
33+
_version: Version
34+
35+
def all(self) -> CompliantExprT:
36+
return self._expr.from_column_names(
37+
get_column_names, function_name="all", context=self
38+
)
39+
40+
def col(self, *column_names: str) -> CompliantExprT:
41+
return self._expr.from_column_names(
42+
passthrough_column_names(column_names), function_name="col", context=self
43+
)
44+
45+
def exclude(self, excluded_names: Container[str]) -> CompliantExprT:
46+
return self._expr.from_column_names(
47+
partial(exclude_column_names, names=excluded_names),
48+
function_name="exclude",
49+
context=self,
50+
)
51+
52+
def nth(self, *column_indices: int) -> CompliantExprT:
53+
return self._expr.from_column_indices(*column_indices, context=self)
54+
2955
def len(self) -> CompliantExprT: ...
30-
def all(self) -> CompliantExprT: ...
56+
def lit(self, value: Any, dtype: DType | None) -> CompliantExprT: ...
3157
def all_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
3258
def any_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
3359
def sum_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...

narwhals/_dask/expr.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from narwhals._dask.namespace import DaskNamespace
4040
from narwhals.dtypes import DType
4141
from narwhals.utils import Version
42+
from narwhals.utils import _FullContext
4243

4344

4445
class DaskExpr(LazyExpr["DaskLazyFrame", "dx.Series"]):
@@ -100,8 +101,7 @@ def from_column_names(
100101
/,
101102
*,
102103
function_name: str,
103-
backend_version: tuple[int, ...],
104-
version: Version,
104+
context: _FullContext,
105105
) -> Self:
106106
def func(df: DaskLazyFrame) -> list[dx.Series]:
107107
try:
@@ -124,16 +124,13 @@ def func(df: DaskLazyFrame) -> list[dx.Series]:
124124
function_name=function_name,
125125
evaluate_output_names=evaluate_column_names,
126126
alias_output_names=None,
127-
backend_version=backend_version,
128-
version=version,
127+
backend_version=context._backend_version,
128+
version=context._version,
129129
)
130130

131131
@classmethod
132132
def from_column_indices(
133-
cls: type[Self],
134-
*column_indices: int,
135-
backend_version: tuple[int, ...],
136-
version: Version,
133+
cls: type[Self], *column_indices: int, context: _FullContext
137134
) -> Self:
138135
def func(df: DaskLazyFrame) -> list[dx.Series]:
139136
return [
@@ -146,8 +143,8 @@ def func(df: DaskLazyFrame) -> list[dx.Series]:
146143
function_name="nth",
147144
evaluate_output_names=lambda df: [df.columns[i] for i in column_indices],
148145
alias_output_names=None,
149-
backend_version=backend_version,
150-
version=version,
146+
backend_version=context._backend_version,
147+
version=context._version,
151148
)
152149

153150
def _from_call(

narwhals/_dask/namespace.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
from __future__ import annotations
22

33
import operator
4-
from functools import partial
54
from functools import reduce
65
from typing import TYPE_CHECKING
76
from typing import Any
87
from typing import Callable
9-
from typing import Container
108
from typing import Iterable
119
from typing import Literal
1210
from typing import Sequence
@@ -26,9 +24,6 @@
2624
from narwhals._expression_parsing import combine_alias_output_names
2725
from narwhals._expression_parsing import combine_evaluate_output_names
2826
from narwhals.utils import Implementation
29-
from narwhals.utils import exclude_column_names
30-
from narwhals.utils import get_column_names
31-
from narwhals.utils import passthrough_column_names
3227

3328
if TYPE_CHECKING:
3429
from typing_extensions import Self
@@ -59,35 +54,6 @@ def __init__(
5954
self._backend_version = backend_version
6055
self._version = version
6156

62-
def all(self: Self) -> DaskExpr:
63-
return self._expr.from_column_names(
64-
get_column_names,
65-
function_name="all",
66-
backend_version=self._backend_version,
67-
version=self._version,
68-
)
69-
70-
def col(self: Self, *column_names: str) -> DaskExpr:
71-
return self._expr.from_column_names(
72-
passthrough_column_names(column_names),
73-
function_name="col",
74-
backend_version=self._backend_version,
75-
version=self._version,
76-
)
77-
78-
def exclude(self: Self, excluded_names: Container[str]) -> DaskExpr:
79-
return self._expr.from_column_names(
80-
partial(exclude_column_names, names=excluded_names),
81-
function_name="exclude",
82-
backend_version=self._backend_version,
83-
version=self._version,
84-
)
85-
86-
def nth(self: Self, *column_indices: int) -> DaskExpr:
87-
return self._expr.from_column_indices(
88-
*column_indices, backend_version=self._backend_version, version=self._version
89-
)
90-
9157
def lit(self: Self, value: Any, dtype: DType | None) -> DaskExpr:
9258
def func(df: DaskLazyFrame) -> list[dx.Series]:
9359
if dtype is not None:

narwhals/_duckdb/expr.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from narwhals._duckdb.namespace import DuckDBNamespace
3636
from narwhals.dtypes import DType
3737
from narwhals.utils import Version
38+
from narwhals.utils import _FullContext
3839

3940

4041
class DuckDBExpr(LazyExpr["DuckDBLazyFrame", "duckdb.Expression"]):
@@ -85,8 +86,7 @@ def from_column_names(
8586
/,
8687
*,
8788
function_name: str,
88-
backend_version: tuple[int, ...],
89-
version: Version,
89+
context: _FullContext,
9090
) -> Self:
9191
def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
9292
return [ColumnExpression(col_name) for col_name in evaluate_column_names(df)]
@@ -96,16 +96,13 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
9696
function_name=function_name,
9797
evaluate_output_names=evaluate_column_names,
9898
alias_output_names=None,
99-
backend_version=backend_version,
100-
version=version,
99+
backend_version=context._backend_version,
100+
version=context._version,
101101
)
102102

103103
@classmethod
104104
def from_column_indices(
105-
cls: type[Self],
106-
*column_indices: int,
107-
backend_version: tuple[int, ...],
108-
version: Version,
105+
cls: type[Self], *column_indices: int, context: _FullContext
109106
) -> Self:
110107
def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
111108
columns = df.columns
@@ -117,8 +114,8 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
117114
function_name="nth",
118115
evaluate_output_names=lambda df: [df.columns[i] for i in column_indices],
119116
alias_output_names=None,
120-
backend_version=backend_version,
121-
version=version,
117+
backend_version=context._backend_version,
118+
version=context._version,
122119
)
123120

124121
def _from_call(

narwhals/_duckdb/namespace.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
from __future__ import annotations
22

33
import operator
4-
from functools import partial
54
from functools import reduce
65
from typing import TYPE_CHECKING
76
from typing import Any
87
from typing import Callable
9-
from typing import Container
108
from typing import Iterable
119
from typing import Literal
1210
from typing import Sequence
@@ -26,9 +24,6 @@
2624
from narwhals._expression_parsing import combine_alias_output_names
2725
from narwhals._expression_parsing import combine_evaluate_output_names
2826
from narwhals.utils import Implementation
29-
from narwhals.utils import exclude_column_names
30-
from narwhals.utils import get_column_names
31-
from narwhals.utils import passthrough_column_names
3227

3328
if TYPE_CHECKING:
3429
import duckdb
@@ -56,14 +51,6 @@ def selectors(self: Self) -> DuckDBSelectorNamespace:
5651
def _expr(self) -> type[DuckDBExpr]:
5752
return DuckDBExpr
5853

59-
def all(self: Self) -> DuckDBExpr:
60-
return self._expr.from_column_names(
61-
get_column_names,
62-
function_name="all",
63-
backend_version=self._backend_version,
64-
version=self._version,
65-
)
66-
6754
def concat(
6855
self: Self,
6956
items: Iterable[DuckDBLazyFrame],
@@ -236,27 +223,6 @@ def when(self: Self, predicate: DuckDBExpr) -> DuckDBWhen:
236223
version=self._version,
237224
)
238225

239-
def col(self: Self, *column_names: str) -> DuckDBExpr:
240-
return self._expr.from_column_names(
241-
passthrough_column_names(column_names),
242-
function_name="col",
243-
backend_version=self._backend_version,
244-
version=self._version,
245-
)
246-
247-
def exclude(self: Self, excluded_names: Container[str]) -> DuckDBExpr:
248-
return self._expr.from_column_names(
249-
partial(exclude_column_names, names=excluded_names),
250-
function_name="exclude",
251-
backend_version=self._backend_version,
252-
version=self._version,
253-
)
254-
255-
def nth(self: Self, *column_indices: int) -> DuckDBExpr:
256-
return self._expr.from_column_indices(
257-
*column_indices, backend_version=self._backend_version, version=self._version
258-
)
259-
260226
def lit(self: Self, value: Any, dtype: DType | None) -> DuckDBExpr:
261227
def func(_df: DuckDBLazyFrame) -> list[duckdb.Expression]:
262228
if dtype is not None:

0 commit comments

Comments
 (0)