-
Notifications
You must be signed in to change notification settings - Fork 793
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: make pandas and NumPy optional dependencies, don't require PyArrow for plotting with Polars/Modin/cuDF #3452
Changes from 23 commits
ef2a10e
6da3e3b
e91ed4d
f6d639e
7c052c0
81da742
c72fb9b
1078d38
bb84f22
b111fe9
b2118f9
84bda85
110f848
6be087a
795b464
3063fdf
1c8c5a3
b0ca54d
d0417df
5d54bc4
b52eaca
8b4b3db
cd81385
bec9bc2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,12 +27,11 @@ | |
from operator import itemgetter | ||
|
||
import jsonschema | ||
import pandas as pd | ||
import numpy as np | ||
from pandas.api.types import infer_dtype | ||
import narwhals.stable.v1 as nw | ||
from narwhals.dependencies import is_pandas_dataframe, get_polars | ||
from narwhals.typing import IntoDataFrame | ||
|
||
from altair.utils.schemapi import SchemaBase, Undefined | ||
from altair.utils._dfi_types import Column, DtypeKind, DataFrame as DfiDataFrame | ||
|
||
if sys.version_info >= (3, 10): | ||
from typing import ParamSpec | ||
|
@@ -43,11 +42,14 @@ | |
if TYPE_CHECKING: | ||
from types import ModuleType | ||
import typing as t | ||
from pandas.core.interchange.dataframe_protocol import Column as PandasColumn | ||
import pyarrow as pa | ||
from altair.vegalite.v5.schema._typing import StandardType_T as InferredVegaLiteType | ||
from altair.utils._dfi_types import DataFrame as DfiDataFrame | ||
from narwhals.typing import IntoExpr | ||
import pandas as pd | ||
|
||
V = TypeVar("V") | ||
P = ParamSpec("P") | ||
TIntoDataFrame = TypeVar("TIntoDataFrame", bound=IntoDataFrame) | ||
|
||
|
||
@runtime_checkable | ||
|
@@ -198,10 +200,7 @@ def __dataframe__( | |
] | ||
|
||
|
||
InferredVegaLiteType = Literal["ordinal", "nominal", "quantitative", "temporal"] | ||
|
||
|
||
def infer_vegalite_type( | ||
def infer_vegalite_type_for_pandas( | ||
data: object, | ||
) -> InferredVegaLiteType | tuple[InferredVegaLiteType, list[Any]]: | ||
""" | ||
|
@@ -212,6 +211,9 @@ def infer_vegalite_type( | |
---------- | ||
data: object | ||
""" | ||
# This is safe to import here, as this function is only called on pandas input. | ||
from pandas.api.types import infer_dtype | ||
|
||
typ = infer_dtype(data, skipna=False) | ||
|
||
if typ in { | ||
|
@@ -297,13 +299,16 @@ def sanitize_geo_interface(geo: t.MutableMapping[Any, Any]) -> dict[str, Any]: | |
|
||
|
||
def numpy_is_subtype(dtype: Any, subtype: Any) -> bool: | ||
# This is only called on `numpy` inputs, so it's safe to import it here. | ||
import numpy as np | ||
|
||
try: | ||
return np.issubdtype(dtype, subtype) | ||
except (NotImplementedError, TypeError): | ||
return False | ||
|
||
|
||
def sanitize_dataframe(df: pd.DataFrame) -> pd.DataFrame: | ||
def sanitize_pandas_dataframe(df: pd.DataFrame) -> pd.DataFrame: | ||
"""Sanitize a DataFrame to prepare it for serialization. | ||
|
||
* Make a copy | ||
|
@@ -320,6 +325,11 @@ def sanitize_dataframe(df: pd.DataFrame) -> pd.DataFrame: | |
* convert dedicated string column to objects and replace NaN with None | ||
* Raise a ValueError for TimeDelta dtypes | ||
""" | ||
# This is safe to import here, as this function is only called on pandas input. | ||
# NumPy is a required dependency of pandas so is also safe to import. | ||
import pandas as pd | ||
import numpy as np | ||
|
||
df = df.copy() | ||
|
||
if isinstance(df.columns, pd.RangeIndex): | ||
|
@@ -429,30 +439,54 @@ def to_list_if_array(val): | |
return df | ||
|
||
|
||
def sanitize_arrow_table(pa_table: pa.Table) -> pa.Table: | ||
"""Sanitize arrow table for JSON serialization""" | ||
import pyarrow as pa | ||
import pyarrow.compute as pc | ||
|
||
arrays = [] | ||
schema = pa_table.schema | ||
for name in schema.names: | ||
array = pa_table[name] | ||
dtype_name = str(schema.field(name).type) | ||
if dtype_name.startswith(("timestamp", "date")): | ||
arrays.append(pc.strftime(array)) | ||
elif dtype_name.startswith("duration"): | ||
def sanitize_narwhals_dataframe( | ||
data: nw.DataFrame[TIntoDataFrame], | ||
) -> nw.DataFrame[TIntoDataFrame]: | ||
"""Sanitize narwhals.DataFrame for JSON serialization""" | ||
schema = data.schema | ||
columns: list[IntoExpr] = [] | ||
# See https://github.com/vega/altair/issues/1027 for why this is necessary. | ||
local_iso_fmt_string = "%Y-%m-%dT%H:%M:%S" | ||
for name, dtype in schema.items(): | ||
if dtype == nw.Date and nw.get_native_namespace(data) is get_polars(): | ||
# Polars doesn't allow formatting `Date` with time directives. | ||
# The date -> datetime cast is extremely fast compared with `to_string` | ||
columns.append( | ||
nw.col(name).cast(nw.Datetime).dt.to_string(local_iso_fmt_string) | ||
) | ||
elif dtype == nw.Date: | ||
columns.append(nw.col(name).dt.to_string(local_iso_fmt_string)) | ||
elif dtype == nw.Datetime: | ||
columns.append(nw.col(name).dt.to_string(f"{local_iso_fmt_string}%.f")) | ||
elif dtype == nw.Duration: | ||
msg = ( | ||
f'Field "{name}" has type "{dtype_name}" which is ' | ||
f'Field "{name}" has type "{dtype}" which is ' | ||
"not supported by Altair. Please convert to " | ||
"either a timestamp or a numerical value." | ||
"" | ||
) | ||
raise ValueError(msg) | ||
else: | ||
arrays.append(array) | ||
columns.append(name) | ||
return data.select(columns) | ||
|
||
|
||
return pa.Table.from_arrays(arrays, names=schema.names) | ||
def to_eager_narwhals_dataframe(data: IntoDataFrame) -> nw.DataFrame[Any]: | ||
"""Wrap `data` in `narwhals.DataFrame`. | ||
|
||
If `data` is not supported by Narwhals, but it is convertible | ||
to a PyArrow table, then first convert to a PyArrow Table, | ||
and then wrap in `narwhals.DataFrame`. | ||
""" | ||
data_nw = nw.from_native(data, eager_or_interchange_only=True) | ||
if nw.get_level(data_nw) == "interchange": | ||
# If Narwhals' support for `data`'s class is only metadata-level, then we | ||
# use the interchange protocol to convert to a PyArrow Table. | ||
from altair.utils.data import arrow_table_from_dfi_dataframe | ||
|
||
pa_table = arrow_table_from_dfi_dataframe(data) # type: ignore[arg-type] | ||
data_nw = nw.from_native(pa_table, eager_only=True) | ||
return data_nw | ||
|
||
|
||
def parse_shorthand( | ||
|
@@ -498,6 +532,7 @@ def parse_shorthand( | |
|
||
Examples | ||
-------- | ||
>>> import pandas as pd | ||
>>> data = pd.DataFrame({'foo': ['A', 'B', 'A', 'B'], | ||
... 'bar': [1, 2, 3, 4]}) | ||
|
||
|
@@ -537,7 +572,7 @@ def parse_shorthand( | |
>>> parse_shorthand('count()', data) == {'aggregate': 'count', 'type': 'quantitative'} | ||
True | ||
""" | ||
from altair.utils._importers import pyarrow_available | ||
from altair.utils.data import is_data_type | ||
|
||
if not shorthand: | ||
return {} | ||
|
@@ -597,39 +632,22 @@ def parse_shorthand( | |
attrs["type"] = "temporal" | ||
|
||
# if data is specified and type is not, infer type from data | ||
if "type" not in attrs: | ||
if pyarrow_available() and data is not None and isinstance(data, DataFrameLike): | ||
dfi = data.__dataframe__() | ||
if "field" in attrs: | ||
unescaped_field = attrs["field"].replace("\\", "") | ||
if unescaped_field in dfi.column_names(): | ||
column = dfi.get_column_by_name(unescaped_field) | ||
try: | ||
attrs["type"] = infer_vegalite_type_for_dfi_column(column) | ||
except (NotImplementedError, AttributeError, ValueError): | ||
# Fall back to pandas-based inference. | ||
# Note: The AttributeError catch is a workaround for | ||
# https://github.com/pandas-dev/pandas/issues/55332 | ||
if isinstance(data, pd.DataFrame): | ||
attrs["type"] = infer_vegalite_type(data[unescaped_field]) | ||
else: | ||
raise | ||
|
||
if isinstance(attrs["type"], tuple): | ||
attrs["sort"] = attrs["type"][1] | ||
attrs["type"] = attrs["type"][0] | ||
elif isinstance(data, pd.DataFrame): | ||
# Fallback if pyarrow is not installed or if pandas is older than 1.5 | ||
# | ||
# Remove escape sequences so that types can be inferred for columns with special characters | ||
if "field" in attrs and attrs["field"].replace("\\", "") in data.columns: | ||
attrs["type"] = infer_vegalite_type( | ||
data[attrs["field"].replace("\\", "")] | ||
) | ||
# ordered categorical dataframe columns return the type and sort order as a tuple | ||
if isinstance(attrs["type"], tuple): | ||
attrs["sort"] = attrs["type"][1] | ||
attrs["type"] = attrs["type"][0] | ||
if "type" not in attrs and is_data_type(data): | ||
unescaped_field = attrs["field"].replace("\\", "") | ||
data_nw = nw.from_native(data, eager_or_interchange_only=True) | ||
schema = data_nw.schema | ||
if unescaped_field in schema: | ||
column = data_nw[unescaped_field] | ||
if schema[unescaped_field] in { | ||
nw.Object, | ||
nw.Unknown, | ||
} and is_pandas_dataframe(nw.to_native(data_nw)): | ||
attrs["type"] = infer_vegalite_type_for_pandas(nw.to_native(column)) | ||
else: | ||
attrs["type"] = infer_vegalite_type_for_narwhals(column) | ||
if isinstance(attrs["type"], tuple): | ||
attrs["sort"] = attrs["type"][1] | ||
attrs["type"] = attrs["type"][0] | ||
|
||
# If an unescaped colon is still present, it's often due to an incorrect data type specification | ||
# but could also be due to using a column name with ":" in it. | ||
|
@@ -650,41 +668,23 @@ def parse_shorthand( | |
return attrs | ||
|
||
|
||
def infer_vegalite_type_for_dfi_column( | ||
column: Column | PandasColumn, | ||
) -> InferredVegaLiteType | tuple[InferredVegaLiteType, list[Any]]: | ||
from pyarrow.interchange.from_dataframe import column_to_array | ||
|
||
try: | ||
kind = column.dtype[0] | ||
except NotImplementedError as e: | ||
# Edge case hack: | ||
# dtype access fails for pandas column with datetime64[ns, UTC] type, | ||
# but all we need to know is that its temporal, so check the | ||
# error message for the presence of datetime64. | ||
# | ||
# See https://github.com/pandas-dev/pandas/issues/54239 | ||
if "datetime64" in e.args[0] or "timestamp" in e.args[0]: | ||
return "temporal" | ||
raise e | ||
|
||
def infer_vegalite_type_for_narwhals( | ||
column: nw.Series, | ||
) -> InferredVegaLiteType | tuple[InferredVegaLiteType, list]: | ||
dtype = column.dtype | ||
if ( | ||
kind == DtypeKind.CATEGORICAL | ||
and column.describe_categorical["is_ordered"] | ||
and column.describe_categorical["categories"] is not None | ||
nw.is_ordered_categorical(column) | ||
and not (categories := column.cat.get_categories()).is_empty() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, nice. I had missed that dropping Python 3.7 means we can use the walrus operator! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 😄 there's a few others if you're brave enough to try out auto-walrus |
||
): | ||
# Treat ordered categorical column as Vega-Lite ordinal | ||
categories_column = column.describe_categorical["categories"] | ||
categories_array = column_to_array(categories_column) | ||
return "ordinal", categories_array.to_pylist() | ||
if kind in {DtypeKind.STRING, DtypeKind.CATEGORICAL, DtypeKind.BOOL}: | ||
return "ordinal", categories.to_list() | ||
if dtype in {nw.String, nw.Categorical, nw.Boolean}: | ||
return "nominal" | ||
elif kind in {DtypeKind.INT, DtypeKind.UINT, DtypeKind.FLOAT}: | ||
elif dtype.is_numeric(): | ||
return "quantitative" | ||
elif kind == DtypeKind.DATETIME: | ||
elif dtype in {nw.Datetime, nw.Date}: | ||
return "temporal" | ||
else: | ||
msg = f"Unexpected DtypeKind: {kind}" | ||
msg = f"Unexpected DtypeKind: {dtype}" | ||
raise ValueError(msg) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really nice cleanup!