Skip to content

Commit

Permalink
Databricks pandas types casting (dyvenia#701)
Browse files Browse the repository at this point in the history
* ✨ Added columns cast in `Databricks` source

* ✨ Added object cast to str in `_cast_df_cols()`

* ✅ Added new unit test to databricks

* ✅ Added test `_cast_df_cols()` in utils

* 📝 Added entry to `CHANGELOG.md`

* ✨ Added ability to select types to convert

* 📝 Updated docstings and added entry to CHANGELOG.md
  • Loading branch information
djagoda881 authored May 30, 2023
1 parent cc05c79 commit b565ee4
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed a bug in `Databricks.create_table_from_pandas()`. The function that converts column names to snake_case was not used in every case. (#672)
- Added `howto_migrate_sources_tasks_and_flows.md` document. This document will assist the DEs with the viadot 1 -> viadot 2 migration process.
- `RedshiftSpectrum.from_df()` now automatically creates a folder for the table if not specified in `to_path`
- Fixed a bug in `Databricks.create_table_from_pandas()`. The function now automatically casts DataFrame types. (#681)


### Changed
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
]
ADDITIONAL_DATA_NEW_FIELD_DF = pd.DataFrame(ADDITIONAL_TEST_DATA)
ADDITIONAL_DATA_DF = ADDITIONAL_DATA_NEW_FIELD_DF.copy().drop("NewField", axis=1)
MIXED_TYPES_DATA = pd.DataFrame({"test": ["a", "b", 1.1, 1, True]})


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -272,6 +273,20 @@ def test_snakecase_column_names(databricks):
databricks.drop_schema(TEST_SCHEMA)


def test_create_table_from_pandas_handles_mixed_types(databricks):

assert not databricks._check_if_table_exists(schema=TEST_SCHEMA, table=TEST_TABLE)

databricks.create_schema(TEST_SCHEMA)
created = databricks.create_table_from_pandas(
schema=TEST_SCHEMA, table=TEST_TABLE, df=MIXED_TYPES_DATA
)
assert created

databricks.drop_table(schema=TEST_SCHEMA, table=TEST_TABLE)
databricks.drop_schema(TEST_SCHEMA)


# @pytest.mark.dependency(depends=["test_create_table", "test_drop_table", "test_to_df"])
# def test_insert_into_append(databricks):

Expand Down
26 changes: 26 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
gen_bulk_insert_query_from_df,
add_viadot_metadata_columns,
handle_api_request,
_cast_df_cols,
)
import pandas as pd
import json
Expand Down Expand Up @@ -108,3 +109,28 @@ def to_df(self):
testing_instance = TestingClass()
df = testing_instance.to_df()
assert "_viadot_source" in df.columns


def test___cast_df_cols():
TEST_DF = pd.DataFrame(
{
"bool_column": [True, False, True, False],
"datetime_column": [
"2023-05-25 10:30:00",
"2023-05-20 ",
"2023-05-15 10:30",
"2023-05-10 10:30:00+00:00 ",
],
"int_column": [5, 10, 15, 20],
"object_column": ["apple", "banana", "melon", "orange"],
}
)
TEST_DF["datetime_column"] = pd.to_datetime(TEST_DF["datetime_column"])
result_df = _cast_df_cols(
TEST_DF, types_to_convert=["datetime", "bool", "int", "object"]
)

assert result_df["bool_column"].dtype == pd.Int64Dtype()
assert result_df["datetime_column"].dtype == pd.StringDtype()
assert result_df["int_column"].dtype == pd.Int64Dtype()
assert result_df["object_column"].dtype == pd.StringDtype()
14 changes: 12 additions & 2 deletions viadot/sources/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@

from ..config import get_source_credentials
from ..exceptions import TableAlreadyExists, TableDoesNotExist
from ..utils import (add_viadot_metadata_columns, build_merge_query,
df_snakecase_column_names)
from ..utils import (
add_viadot_metadata_columns,
build_merge_query,
df_snakecase_column_names,
_cast_df_cols,
)
from .base import Source


Expand Down Expand Up @@ -266,6 +270,7 @@ def create_table_from_pandas(
if_empty: Literal["warn", "skip", "fail"] = "warn",
if_exists: Literal["replace", "skip", "fail"] = "fail",
snakecase_column_names: bool = True,
cast_df_columns: bool = True,
) -> bool:
"""
Create a table using a pandas `DataFrame`.
Expand All @@ -280,6 +285,8 @@ def create_table_from_pandas(
Defaults to 'fail'.
snakecase_column_names (bool, optional): Whether to convert column names to snake case.
Defaults to True.
cast_df_columns (bool, optional): Converts column types in DataFrame using utils._cast_df_cols().
This param exists because of possible errors with object cols. Defaults to True.
Example:
```python
Expand All @@ -306,6 +313,9 @@ def create_table_from_pandas(
if snakecase_column_names:
df = df_snakecase_column_names(df)

if cast_df_columns:
df = _cast_df_cols(df, types_to_convert=["object"])

fqn = f"{schema}.{table}"
success_message = f"Table {fqn} has been created successfully."

Expand Down
42 changes: 32 additions & 10 deletions viadot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,22 +200,44 @@ def get_sql_server_table_dtypes(
return dtypes


def _cast_df_cols(df):
def _cast_df_cols(
df: pd.DataFrame,
types_to_convert: List[Literal["datetime", "bool", "int", "object"]] = [
"datetime",
"bool",
"int",
],
) -> pd.DataFrame:
"""
Cast the data types of columns in a DataFrame.
Args:
df (pd.DataFrame): The input DataFrame.
types_to_convert (Literal[datetime, bool, int, object], optional): List of types to be converted.
Defaults to ["datetime", "bool", "int"].
Returns:
pd.DataFrame: A DataFrame with modified data types.
"""
df = df.replace({"False": False, "True": True})

datetime_cols = (col for col, dtype in df.dtypes.items() if dtype.kind == "M")
bool_cols = (col for col, dtype in df.dtypes.items() if dtype.kind == "b")
int_cols = (col for col, dtype in df.dtypes.items() if dtype.kind == "i")

for col in datetime_cols:
df[col] = df[col].dt.strftime("%Y-%m-%d %H:%M:%S+00:00")

for col in bool_cols:
df[col] = df[col].astype(pd.Int64Dtype())

for col in int_cols:
df[col] = df[col].astype(pd.Int64Dtype())
object_cols = (col for col, dtype in df.dtypes.items() if dtype.kind == "O")

if "datetime" in types_to_convert:
for col in datetime_cols:
df[col] = df[col].dt.strftime("%Y-%m-%d %H:%M:%S+00:00")
if "bool" in types_to_convert:
for col in bool_cols:
df[col] = df[col].astype(pd.Int64Dtype())
if "int" in types_to_convert:
for col in int_cols:
df[col] = df[col].astype(pd.Int64Dtype())
if "object" in types_to_convert:
for col in object_cols:
df[col] = df[col].astype(pd.StringDtype())

return df

Expand Down

0 comments on commit b565ee4

Please sign in to comment.