From ce72a0ac27d10335c8a95bdb409b342ed9ff4f80 Mon Sep 17 00:00:00 2001 From: squalou Date: Fri, 23 Aug 2024 09:39:47 +0200 Subject: [PATCH] fix: set columns numeric datatypes when exporting to excel (#27229) Co-authored-by: Elizabeth Thompson --- superset/common/query_actions.py | 2 +- superset/common/query_context.py | 4 +- superset/common/query_context_processor.py | 6 +- superset/utils/excel.py | 21 +++++-- tests/unit_tests/utils/excel_tests.py | 65 +++++++++++++++++++++- 5 files changed, 90 insertions(+), 8 deletions(-) diff --git a/superset/common/query_actions.py b/superset/common/query_actions.py index bdbccc78dbe2c..9e61de6e1aaa2 100644 --- a/superset/common/query_actions.py +++ b/superset/common/query_actions.py @@ -107,7 +107,7 @@ def _get_full( payload["colnames"] = list(df.columns) payload["indexnames"] = list(df.index) payload["coltypes"] = extract_dataframe_dtypes(df, datasource) - payload["data"] = query_context.get_data(df) + payload["data"] = query_context.get_data(df, payload["coltypes"]) payload["result_format"] = query_context.result_format del payload["df"] diff --git a/superset/common/query_context.py b/superset/common/query_context.py index 48b5abfbecdea..a04e3944603fb 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -28,6 +28,7 @@ ) from superset.common.query_object import QueryObject from superset.models.slice import Slice +from superset.utils.core import GenericDataType if TYPE_CHECKING: from superset.connectors.sqla.models import BaseDatasource @@ -88,8 +89,9 @@ def __init__( # pylint: disable=too-many-arguments def get_data( self, df: pd.DataFrame, + coltypes: list[GenericDataType], ) -> str | list[dict[str, Any]]: - return self._processor.get_data(df) + return self._processor.get_data(df, coltypes) def get_payload( self, diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index 26935a4d96783..762ed30997194 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -56,6 +56,7 @@ DTTM_ALIAS, error_msg_from_exception, FilterOperator, + GenericDataType, get_base_axis_labels, get_column_names_from_columns, get_column_names_from_metrics, @@ -641,7 +642,9 @@ def generate_join_column( return str(value) - def get_data(self, df: pd.DataFrame) -> str | list[dict[str, Any]]: + def get_data( + self, df: pd.DataFrame, coltypes: list[GenericDataType] + ) -> str | list[dict[str, Any]]: if self._query_context.result_format in ChartDataResultFormat.table_like(): include_index = not isinstance(df.index, pd.RangeIndex) columns = list(df.columns) @@ -655,6 +658,7 @@ def get_data(self, df: pd.DataFrame) -> str | list[dict[str, Any]]: df, index=include_index, **config["CSV_EXPORT"] ) elif self._query_context.result_format == ChartDataResultFormat.XLSX: + excel.apply_column_types(df, coltypes) result = excel.df_to_excel(df, **config["EXCEL_EXPORT"]) return result or "" diff --git a/superset/utils/excel.py b/superset/utils/excel.py index ccbeadee5ecec..8609be5b43e6b 100644 --- a/superset/utils/excel.py +++ b/superset/utils/excel.py @@ -19,16 +19,29 @@ import pandas as pd +from superset.utils.core import GenericDataType + def df_to_excel(df: pd.DataFrame, **kwargs: Any) -> Any: output = io.BytesIO() - # timezones are not supported - for column in df.select_dtypes(include=["datetimetz"]).columns: - df[column] = df[column].astype(str) - # pylint: disable=abstract-class-instantiated with pd.ExcelWriter(output, engine="xlsxwriter") as writer: df.to_excel(writer, **kwargs) return output.getvalue() + + +def apply_column_types( + df: pd.DataFrame, column_types: list[GenericDataType] +) -> pd.DataFrame: + for column, column_type in zip(df.columns, column_types): + if column_type == GenericDataType.NUMERIC: + try: + df[column] = pd.to_numeric(df[column]) + except ValueError: + df[column] = df[column].astype(str) + elif pd.api.types.is_datetime64tz_dtype(df[column]): + # timezones are not supported + df[column] = df[column].astype(str) + return df diff --git a/tests/unit_tests/utils/excel_tests.py b/tests/unit_tests/utils/excel_tests.py index c15f69a0c62a3..745beff5052af 100644 --- a/tests/unit_tests/utils/excel_tests.py +++ b/tests/unit_tests/utils/excel_tests.py @@ -18,8 +18,10 @@ from datetime import datetime, timezone import pandas as pd +from pandas.api.types import is_numeric_dtype -from superset.utils.excel import df_to_excel +from superset.utils.core import GenericDataType +from superset.utils.excel import apply_column_types, df_to_excel def test_timezone_conversion() -> None: @@ -27,5 +29,66 @@ def test_timezone_conversion() -> None: Test that columns with timezones are converted to a string. """ df = pd.DataFrame({"dt": [datetime(2023, 1, 1, 0, 0, tzinfo=timezone.utc)]}) + apply_column_types(df, [GenericDataType.TEMPORAL]) contents = df_to_excel(df) assert pd.read_excel(contents)["dt"][0] == "2023-01-01 00:00:00+00:00" + + +def test_column_data_types_with_one_numeric_column(): + df = pd.DataFrame( + { + "col0": ["123", "1", "2", "3"], + "col1": ["456", "5.67", "0", ".45"], + "col2": [ + datetime(2023, 1, 1, 0, 0, tzinfo=timezone.utc), + datetime(2023, 1, 2, 0, 0, tzinfo=timezone.utc), + datetime(2023, 1, 3, 0, 0, tzinfo=timezone.utc), + datetime(2023, 1, 4, 0, 0, tzinfo=timezone.utc), + ], + "col3": ["True", "False", "True", "False"], + } + ) + coltypes: list[GenericDataType] = [ + GenericDataType.STRING, + GenericDataType.NUMERIC, + GenericDataType.TEMPORAL, + GenericDataType.BOOLEAN, + ] + + # only col1 should be converted to numeric, according to coltypes definition + assert not is_numeric_dtype(df["col1"]) + apply_column_types(df, coltypes) + assert not is_numeric_dtype(df["col0"]) + assert is_numeric_dtype(df["col1"]) + assert not is_numeric_dtype(df["col2"]) + assert not is_numeric_dtype(df["col3"]) + + +def test_column_data_types_with_failing_conversion(): + df = pd.DataFrame( + { + "col0": ["123", "1", "2", "3"], + "col1": ["456", "non_numeric_value", "0", ".45"], + "col2": [ + datetime(2023, 1, 1, 0, 0, tzinfo=timezone.utc), + datetime(2023, 1, 2, 0, 0, tzinfo=timezone.utc), + datetime(2023, 1, 3, 0, 0, tzinfo=timezone.utc), + datetime(2023, 1, 4, 0, 0, tzinfo=timezone.utc), + ], + "col3": ["True", "False", "True", "False"], + } + ) + coltypes: list[GenericDataType] = [ + GenericDataType.STRING, + GenericDataType.NUMERIC, + GenericDataType.TEMPORAL, + GenericDataType.BOOLEAN, + ] + + # should not fail neither convert + assert not is_numeric_dtype(df["col1"]) + apply_column_types(df, coltypes) + assert not is_numeric_dtype(df["col0"]) + assert not is_numeric_dtype(df["col1"]) + assert not is_numeric_dtype(df["col2"]) + assert not is_numeric_dtype(df["col3"])