Skip to content

Commit

Permalink
fix: set columns numeric datatypes when exporting to excel (#27229)
Browse files Browse the repository at this point in the history
Co-authored-by: Elizabeth Thompson <eschutho@gmail.com>
  • Loading branch information
squalou and eschutho authored Aug 23, 2024
1 parent 9d5268a commit ce72a0a
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 8 deletions.
2 changes: 1 addition & 1 deletion superset/common/query_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _get_full(
payload["colnames"] = list(df.columns)
payload["indexnames"] = list(df.index)
payload["coltypes"] = extract_dataframe_dtypes(df, datasource)
payload["data"] = query_context.get_data(df)
payload["data"] = query_context.get_data(df, payload["coltypes"])
payload["result_format"] = query_context.result_format
del payload["df"]

Expand Down
4 changes: 3 additions & 1 deletion superset/common/query_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from superset.common.query_object import QueryObject
from superset.models.slice import Slice
from superset.utils.core import GenericDataType

if TYPE_CHECKING:
from superset.connectors.sqla.models import BaseDatasource
Expand Down Expand Up @@ -88,8 +89,9 @@ def __init__( # pylint: disable=too-many-arguments
def get_data(
self,
df: pd.DataFrame,
coltypes: list[GenericDataType],
) -> str | list[dict[str, Any]]:
return self._processor.get_data(df)
return self._processor.get_data(df, coltypes)

def get_payload(
self,
Expand Down
6 changes: 5 additions & 1 deletion superset/common/query_context_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
DTTM_ALIAS,
error_msg_from_exception,
FilterOperator,
GenericDataType,
get_base_axis_labels,
get_column_names_from_columns,
get_column_names_from_metrics,
Expand Down Expand Up @@ -641,7 +642,9 @@ def generate_join_column(

return str(value)

def get_data(self, df: pd.DataFrame) -> str | list[dict[str, Any]]:
def get_data(
self, df: pd.DataFrame, coltypes: list[GenericDataType]
) -> str | list[dict[str, Any]]:
if self._query_context.result_format in ChartDataResultFormat.table_like():
include_index = not isinstance(df.index, pd.RangeIndex)
columns = list(df.columns)
Expand All @@ -655,6 +658,7 @@ def get_data(self, df: pd.DataFrame) -> str | list[dict[str, Any]]:
df, index=include_index, **config["CSV_EXPORT"]
)
elif self._query_context.result_format == ChartDataResultFormat.XLSX:
excel.apply_column_types(df, coltypes)
result = excel.df_to_excel(df, **config["EXCEL_EXPORT"])
return result or ""

Expand Down
21 changes: 17 additions & 4 deletions superset/utils/excel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,29 @@

import pandas as pd

from superset.utils.core import GenericDataType


def df_to_excel(df: pd.DataFrame, **kwargs: Any) -> Any:
output = io.BytesIO()

# timezones are not supported
for column in df.select_dtypes(include=["datetimetz"]).columns:
df[column] = df[column].astype(str)

# pylint: disable=abstract-class-instantiated
with pd.ExcelWriter(output, engine="xlsxwriter") as writer:
df.to_excel(writer, **kwargs)

return output.getvalue()


def apply_column_types(
df: pd.DataFrame, column_types: list[GenericDataType]
) -> pd.DataFrame:
for column, column_type in zip(df.columns, column_types):
if column_type == GenericDataType.NUMERIC:
try:
df[column] = pd.to_numeric(df[column])
except ValueError:
df[column] = df[column].astype(str)
elif pd.api.types.is_datetime64tz_dtype(df[column]):
# timezones are not supported
df[column] = df[column].astype(str)
return df
65 changes: 64 additions & 1 deletion tests/unit_tests/utils/excel_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,77 @@
from datetime import datetime, timezone

import pandas as pd
from pandas.api.types import is_numeric_dtype

from superset.utils.excel import df_to_excel
from superset.utils.core import GenericDataType
from superset.utils.excel import apply_column_types, df_to_excel


def test_timezone_conversion() -> None:
"""
Test that columns with timezones are converted to a string.
"""
df = pd.DataFrame({"dt": [datetime(2023, 1, 1, 0, 0, tzinfo=timezone.utc)]})
apply_column_types(df, [GenericDataType.TEMPORAL])
contents = df_to_excel(df)
assert pd.read_excel(contents)["dt"][0] == "2023-01-01 00:00:00+00:00"


def test_column_data_types_with_one_numeric_column():
df = pd.DataFrame(
{
"col0": ["123", "1", "2", "3"],
"col1": ["456", "5.67", "0", ".45"],
"col2": [
datetime(2023, 1, 1, 0, 0, tzinfo=timezone.utc),
datetime(2023, 1, 2, 0, 0, tzinfo=timezone.utc),
datetime(2023, 1, 3, 0, 0, tzinfo=timezone.utc),
datetime(2023, 1, 4, 0, 0, tzinfo=timezone.utc),
],
"col3": ["True", "False", "True", "False"],
}
)
coltypes: list[GenericDataType] = [
GenericDataType.STRING,
GenericDataType.NUMERIC,
GenericDataType.TEMPORAL,
GenericDataType.BOOLEAN,
]

# only col1 should be converted to numeric, according to coltypes definition
assert not is_numeric_dtype(df["col1"])
apply_column_types(df, coltypes)
assert not is_numeric_dtype(df["col0"])
assert is_numeric_dtype(df["col1"])
assert not is_numeric_dtype(df["col2"])
assert not is_numeric_dtype(df["col3"])


def test_column_data_types_with_failing_conversion():
df = pd.DataFrame(
{
"col0": ["123", "1", "2", "3"],
"col1": ["456", "non_numeric_value", "0", ".45"],
"col2": [
datetime(2023, 1, 1, 0, 0, tzinfo=timezone.utc),
datetime(2023, 1, 2, 0, 0, tzinfo=timezone.utc),
datetime(2023, 1, 3, 0, 0, tzinfo=timezone.utc),
datetime(2023, 1, 4, 0, 0, tzinfo=timezone.utc),
],
"col3": ["True", "False", "True", "False"],
}
)
coltypes: list[GenericDataType] = [
GenericDataType.STRING,
GenericDataType.NUMERIC,
GenericDataType.TEMPORAL,
GenericDataType.BOOLEAN,
]

# should not fail neither convert
assert not is_numeric_dtype(df["col1"])
apply_column_types(df, coltypes)
assert not is_numeric_dtype(df["col0"])
assert not is_numeric_dtype(df["col1"])
assert not is_numeric_dtype(df["col2"])
assert not is_numeric_dtype(df["col3"])

0 comments on commit ce72a0a

Please sign in to comment.