-
Notifications
You must be signed in to change notification settings - Fork 16.4k
Update SqlToS3Operator to support Polars and deprecate read_pd_kwargs #54195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,16 +20,18 @@ | |
| import enum | ||
| import gzip | ||
| import io | ||
| import warnings | ||
| from collections import namedtuple | ||
| from collections.abc import Iterable, Mapping, Sequence | ||
| from typing import TYPE_CHECKING, Any, Literal, cast | ||
|
|
||
| from airflow.exceptions import AirflowException | ||
| from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning | ||
| from airflow.providers.amazon.aws.hooks.s3 import S3Hook | ||
| from airflow.providers.amazon.version_compat import BaseHook, BaseOperator | ||
|
|
||
| if TYPE_CHECKING: | ||
| import pandas as pd | ||
| import polars as pl | ||
|
|
||
| from airflow.providers.common.sql.hooks.sql import DbApiHook | ||
| from airflow.utils.context import Context | ||
|
|
@@ -69,7 +71,8 @@ class SqlToS3Operator(BaseOperator): | |
| :param sql_hook_params: Extra config params to be passed to the underlying hook. | ||
| Should match the desired hook constructor params. | ||
| :param parameters: (optional) the parameters to render the SQL query with. | ||
| :param read_pd_kwargs: arguments to include in DataFrame when ``pd.read_sql()`` is called. | ||
| :param read_kwargs: arguments to include in DataFrame when reading from SQL (supports both pandas and polars). | ||
| :param df_type: the type of DataFrame to use ('pandas' or 'polars'). Defaults to 'pandas'. | ||
| :param aws_conn_id: reference to a specific S3 connection | ||
| :param verify: Whether or not to verify SSL certificates for S3 connection. | ||
| By default SSL certificates are verified. | ||
|
|
@@ -84,7 +87,7 @@ class SqlToS3Operator(BaseOperator): | |
| :param max_rows_per_file: (optional) argument to set destination file number of rows limit, if source data | ||
| is larger than that, it will be dispatched into multiple files. | ||
| Will be ignored if ``groupby_kwargs`` argument is specified. | ||
| :param pd_kwargs: arguments to include in DataFrame ``.to_parquet()``, ``.to_json()`` or ``.to_csv()``. | ||
| :param df_kwargs: arguments to include in DataFrame ``.to_parquet()``, ``.to_json()`` or ``.to_csv()``. | ||
| :param groupby_kwargs: argument to include in DataFrame ``groupby()``. | ||
| """ | ||
|
|
||
|
|
@@ -97,8 +100,9 @@ class SqlToS3Operator(BaseOperator): | |
| template_ext: Sequence[str] = (".sql",) | ||
| template_fields_renderers = { | ||
| "query": "sql", | ||
| "df_kwargs": "json", | ||
| "pd_kwargs": "json", | ||
| "read_pd_kwargs": "json", | ||
| "read_kwargs": "json", | ||
| } | ||
|
|
||
| def __init__( | ||
|
|
@@ -110,12 +114,15 @@ def __init__( | |
| sql_conn_id: str, | ||
| sql_hook_params: dict | None = None, | ||
| parameters: None | Mapping[str, Any] | list | tuple = None, | ||
| read_kwargs: dict | None = None, | ||
| read_pd_kwargs: dict | None = None, | ||
| df_type: Literal["pandas", "polars"] = "pandas", | ||
| replace: bool = False, | ||
| aws_conn_id: str | None = "aws_default", | ||
| verify: bool | str | None = None, | ||
| file_format: Literal["csv", "json", "parquet"] = "csv", | ||
| max_rows_per_file: int = 0, | ||
| df_kwargs: dict | None = None, | ||
| pd_kwargs: dict | None = None, | ||
| groupby_kwargs: dict | None = None, | ||
| **kwargs, | ||
|
|
@@ -128,14 +135,30 @@ def __init__( | |
| self.aws_conn_id = aws_conn_id | ||
| self.verify = verify | ||
| self.replace = replace | ||
| self.pd_kwargs = pd_kwargs or {} | ||
| self.parameters = parameters | ||
| self.read_pd_kwargs = read_pd_kwargs or {} | ||
| self.max_rows_per_file = max_rows_per_file | ||
| self.groupby_kwargs = groupby_kwargs or {} | ||
| self.sql_hook_params = sql_hook_params | ||
| self.df_type = df_type | ||
|
|
||
| if "path_or_buf" in self.pd_kwargs: | ||
| if read_pd_kwargs is not None: | ||
| warnings.warn( | ||
| "The 'read_pd_kwargs' parameter is deprecated. Use 'read_kwargs' instead.", | ||
| AirflowProviderDeprecationWarning, | ||
| stacklevel=2, | ||
| ) | ||
| self.read_kwargs = read_kwargs if read_kwargs is not None else read_pd_kwargs or {} | ||
|
|
||
| if pd_kwargs is not None: | ||
| warnings.warn( | ||
| "The 'pd_kwargs' parameter is deprecated. Use 'df_kwargs' instead.", | ||
| AirflowProviderDeprecationWarning, | ||
| stacklevel=2, | ||
| ) | ||
|
|
||
| self.df_kwargs = df_kwargs if df_kwargs is not None else pd_kwargs or {} | ||
|
|
||
| if "path_or_buf" in self.df_kwargs: | ||
| raise AirflowException("The argument path_or_buf is not allowed, please remove it") | ||
|
|
||
| if self.max_rows_per_file and self.groupby_kwargs: | ||
|
|
@@ -190,28 +213,29 @@ def execute(self, context: Context) -> None: | |
| sql_hook = self._get_hook() | ||
| s3_conn = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) | ||
| data_df = sql_hook.get_df( | ||
| sql=self.query, parameters=self.parameters, df_type="pandas", **self.read_pd_kwargs | ||
| sql=self.query, parameters=self.parameters, df_type=self.df_type, **self.read_kwargs | ||
| ) | ||
| self.log.info("Data from SQL obtained") | ||
| if ("dtype_backend", "pyarrow") not in self.read_pd_kwargs.items(): | ||
| self._fix_dtypes(data_df, self.file_format) | ||
| # Only apply dtype fixes to pandas DataFrames since Polars doesn't have the same NaN/None inconsistencies as panda | ||
| if ("dtype_backend", "pyarrow") not in self.read_kwargs.items() and self.df_type == "pandas": | ||
| self._fix_dtypes(data_df, self.file_format) # type: ignore[arg-type] | ||
| file_options = FILE_OPTIONS_MAP[self.file_format] | ||
|
|
||
| for group_name, df in self._partition_dataframe(df=data_df): | ||
| buf = io.BytesIO() | ||
| self.log.info("Writing data to in-memory buffer") | ||
| object_key = f"{self.s3_key}_{group_name}" if group_name else self.s3_key | ||
|
|
||
| if self.pd_kwargs.get("compression") == "gzip": | ||
| pd_kwargs = {k: v for k, v in self.pd_kwargs.items() if k != "compression"} | ||
| if self.df_kwargs.get("compression") == "gzip": | ||
| df_kwargs = {k: v for k, v in self.df_kwargs.items() if k != "compression"} | ||
| with gzip.GzipFile(fileobj=buf, mode="wb", filename=object_key) as gz: | ||
| getattr(df, file_options.function)(gz, **pd_kwargs) | ||
| getattr(df, file_options.function)(gz, **df_kwargs) | ||
| else: | ||
| if self.file_format == FILE_FORMAT.PARQUET: | ||
| getattr(df, file_options.function)(buf, **self.pd_kwargs) | ||
| getattr(df, file_options.function)(buf, **self.df_kwargs) | ||
| else: | ||
| text_buf = io.TextIOWrapper(buf, encoding="utf-8", write_through=True) | ||
| getattr(df, file_options.function)(text_buf, **self.pd_kwargs) | ||
| getattr(df, file_options.function)(text_buf, **self.df_kwargs) | ||
| text_buf.flush() | ||
| buf.seek(0) | ||
|
|
||
|
|
@@ -220,17 +244,23 @@ def execute(self, context: Context) -> None: | |
| file_obj=buf, key=object_key, bucket_name=self.s3_bucket, replace=self.replace | ||
| ) | ||
|
|
||
| def _partition_dataframe(self, df: pd.DataFrame) -> Iterable[tuple[str, pd.DataFrame]]: | ||
| def _partition_dataframe(self, df: pd.DataFrame | pl.DataFrame) -> Iterable[tuple[str, pd.DataFrame]]: | ||
| """Partition dataframe using pandas groupby() method.""" | ||
| try: | ||
| import secrets | ||
| import string | ||
|
|
||
| import numpy as np | ||
| import polars as pl | ||
| except ImportError: | ||
| pass | ||
|
|
||
| if isinstance(df, pl.DataFrame): | ||
| df = df.to_pandas() | ||
|
Comment on lines
+258
to
+259
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @guan404ming Taking a closer look. Is this right?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right, |
||
|
|
||
| # if max_rows_per_file argument is specified, a temporary column with a random unusual name will be | ||
| # added to the dataframe. This column is used to dispatch the dataframe into smaller ones using groupby() | ||
|
|
||
| random_column_name = "" | ||
| if self.max_rows_per_file and not self.groupby_kwargs: | ||
| random_column_name = "".join(secrets.choice(string.ascii_letters) for _ in range(20)) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.