Skip to content

Commit

Permalink
chore: add backend_version param to polars dtype translation utility (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
AlessandroMiola authored Nov 24, 2024
1 parent 8dc9f0a commit 43cd8bf
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 15 deletions.
12 changes: 6 additions & 6 deletions narwhals/_polars/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __getattr__(self, attr: str) -> Any:
if attr == "schema":
schema = self._native_frame.schema
return {
name: native_to_narwhals_dtype(dtype, self._dtypes)
name: native_to_narwhals_dtype(dtype, self._dtypes, self._backend_version)
for name, dtype in schema.items()
}

Expand Down Expand Up @@ -113,12 +113,12 @@ def __array__(self, dtype: Any | None = None, copy: bool | None = None) -> np.nd
def collect_schema(self) -> dict[str, DType]:
if self._backend_version < (1,):
return {
name: native_to_narwhals_dtype(dtype, self._dtypes)
name: native_to_narwhals_dtype(dtype, self._dtypes, self._backend_version)
for name, dtype in self._native_frame.schema.items()
}
else:
return {
name: native_to_narwhals_dtype(dtype, self._dtypes)
name: native_to_narwhals_dtype(dtype, self._dtypes, self._backend_version)
for name, dtype in self._native_frame.collect_schema().items()
}

Expand Down Expand Up @@ -351,19 +351,19 @@ def columns(self) -> list[str]:
def schema(self) -> dict[str, Any]:
schema = self._native_frame.schema
return {
name: native_to_narwhals_dtype(dtype, self._dtypes)
name: native_to_narwhals_dtype(dtype, self._dtypes, self._backend_version)
for name, dtype in schema.items()
}

def collect_schema(self) -> dict[str, DType]:
if self._backend_version < (1,):
return {
name: native_to_narwhals_dtype(dtype, self._dtypes)
name: native_to_narwhals_dtype(dtype, self._dtypes, self._backend_version)
for name, dtype in self._native_frame.schema.items()
}
else:
return {
name: native_to_narwhals_dtype(dtype, self._dtypes)
name: native_to_narwhals_dtype(dtype, self._dtypes, self._backend_version)
for name, dtype in self._native_frame.collect_schema().items()
}

Expand Down
4 changes: 3 additions & 1 deletion narwhals/_polars/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def name(self) -> str:

@property
def dtype(self: Self) -> DType:
return native_to_narwhals_dtype(self._native_series.dtype, self._dtypes)
return native_to_narwhals_dtype(
self._native_series.dtype, self._dtypes, self._backend_version
)

@overload
def __getitem__(self, item: int) -> Any: ...
Expand Down
24 changes: 16 additions & 8 deletions narwhals/_polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from narwhals.dtypes import DType
from narwhals.typing import DTypes

from narwhals.utils import parse_version


def extract_native(obj: Any) -> Any:
from narwhals._polars.dataframe import PolarsDataFrame
Expand All @@ -34,7 +32,11 @@ def extract_args_kwargs(args: Any, kwargs: Any) -> tuple[list[Any], dict[str, An
return args, kwargs


def native_to_narwhals_dtype(dtype: pl.DataType, dtypes: DTypes) -> DType:
def native_to_narwhals_dtype(
dtype: pl.DataType,
dtypes: DTypes,
backend_version: tuple[int, ...],
) -> DType:
import polars as pl # ignore-banned-import()

if dtype == pl.Float64:
Expand Down Expand Up @@ -79,20 +81,26 @@ def native_to_narwhals_dtype(dtype: pl.DataType, dtypes: DTypes) -> DType:
if dtype == pl.Struct:
return dtypes.Struct(
[
dtypes.Field(field_name, native_to_narwhals_dtype(field_type, dtypes))
dtypes.Field(
field_name,
native_to_narwhals_dtype(field_type, dtypes, backend_version),
)
for field_name, field_type in dtype # type: ignore[attr-defined]
]
)
if dtype == pl.List:
return dtypes.List(native_to_narwhals_dtype(dtype.inner, dtypes)) # type: ignore[attr-defined]
return dtypes.List(native_to_narwhals_dtype(dtype.inner, dtypes, backend_version)) # type: ignore[attr-defined]
if dtype == pl.Array:
if parse_version(pl.__version__) < (0, 20, 30): # pragma: no cover
if backend_version < (0, 20, 30): # pragma: no cover
return dtypes.Array(
native_to_narwhals_dtype(dtype.inner, dtypes), # type: ignore[attr-defined]
native_to_narwhals_dtype(dtype.inner, dtypes, backend_version), # type: ignore[attr-defined]
dtype.width, # type: ignore[attr-defined]
)
else:
return dtypes.Array(native_to_narwhals_dtype(dtype.inner, dtypes), dtype.size) # type: ignore[attr-defined]
return dtypes.Array(
native_to_narwhals_dtype(dtype.inner, dtypes, backend_version), # type: ignore[attr-defined]
dtype.size, # type: ignore[attr-defined]
)
return dtypes.Unknown()


Expand Down

0 comments on commit 43cd8bf

Please sign in to comment.