Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,18 @@
import enum
import gzip
import io
import warnings
from collections import namedtuple
from collections.abc import Iterable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, cast

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.version_compat import BaseHook, BaseOperator

if TYPE_CHECKING:
import pandas as pd
import polars as pl

from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.utils.context import Context
Expand Down Expand Up @@ -69,7 +71,8 @@ class SqlToS3Operator(BaseOperator):
:param sql_hook_params: Extra config params to be passed to the underlying hook.
Should match the desired hook constructor params.
:param parameters: (optional) the parameters to render the SQL query with.
:param read_pd_kwargs: arguments to include in DataFrame when ``pd.read_sql()`` is called.
:param read_kwargs: arguments to include in DataFrame when reading from SQL (supports both pandas and polars).
:param df_type: the type of DataFrame to use ('pandas' or 'polars'). Defaults to 'pandas'.
:param aws_conn_id: reference to a specific S3 connection
:param verify: Whether or not to verify SSL certificates for S3 connection.
By default SSL certificates are verified.
Expand All @@ -84,7 +87,7 @@ class SqlToS3Operator(BaseOperator):
:param max_rows_per_file: (optional) argument to set destination file number of rows limit, if source data
is larger than that, it will be dispatched into multiple files.
Will be ignored if ``groupby_kwargs`` argument is specified.
:param pd_kwargs: arguments to include in DataFrame ``.to_parquet()``, ``.to_json()`` or ``.to_csv()``.
:param df_kwargs: arguments to include in DataFrame ``.to_parquet()``, ``.to_json()`` or ``.to_csv()``.
:param groupby_kwargs: argument to include in DataFrame ``groupby()``.
"""

Expand All @@ -97,8 +100,9 @@ class SqlToS3Operator(BaseOperator):
template_ext: Sequence[str] = (".sql",)
template_fields_renderers = {
"query": "sql",
"df_kwargs": "json",
"pd_kwargs": "json",
"read_pd_kwargs": "json",
"read_kwargs": "json",
}

def __init__(
Expand All @@ -110,12 +114,15 @@ def __init__(
sql_conn_id: str,
sql_hook_params: dict | None = None,
parameters: None | Mapping[str, Any] | list | tuple = None,
read_kwargs: dict | None = None,
read_pd_kwargs: dict | None = None,
df_type: Literal["pandas", "polars"] = "pandas",
replace: bool = False,
aws_conn_id: str | None = "aws_default",
verify: bool | str | None = None,
file_format: Literal["csv", "json", "parquet"] = "csv",
max_rows_per_file: int = 0,
df_kwargs: dict | None = None,
pd_kwargs: dict | None = None,
groupby_kwargs: dict | None = None,
**kwargs,
Expand All @@ -128,14 +135,30 @@ def __init__(
self.aws_conn_id = aws_conn_id
self.verify = verify
self.replace = replace
self.pd_kwargs = pd_kwargs or {}
self.parameters = parameters
self.read_pd_kwargs = read_pd_kwargs or {}
self.max_rows_per_file = max_rows_per_file
self.groupby_kwargs = groupby_kwargs or {}
self.sql_hook_params = sql_hook_params
self.df_type = df_type

if "path_or_buf" in self.pd_kwargs:
if read_pd_kwargs is not None:
warnings.warn(
"The 'read_pd_kwargs' parameter is deprecated. Use 'read_kwargs' instead.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
self.read_kwargs = read_kwargs if read_kwargs is not None else read_pd_kwargs or {}

if pd_kwargs is not None:
warnings.warn(
"The 'pd_kwargs' parameter is deprecated. Use 'df_kwargs' instead.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)

self.df_kwargs = df_kwargs if df_kwargs is not None else pd_kwargs or {}

if "path_or_buf" in self.df_kwargs:
raise AirflowException("The argument path_or_buf is not allowed, please remove it")

if self.max_rows_per_file and self.groupby_kwargs:
Expand Down Expand Up @@ -190,28 +213,29 @@ def execute(self, context: Context) -> None:
sql_hook = self._get_hook()
s3_conn = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
data_df = sql_hook.get_df(
sql=self.query, parameters=self.parameters, df_type="pandas", **self.read_pd_kwargs
sql=self.query, parameters=self.parameters, df_type=self.df_type, **self.read_kwargs
)
self.log.info("Data from SQL obtained")
if ("dtype_backend", "pyarrow") not in self.read_pd_kwargs.items():
self._fix_dtypes(data_df, self.file_format)
# Only apply dtype fixes to pandas DataFrames since Polars doesn't have the same NaN/None inconsistencies as panda
if ("dtype_backend", "pyarrow") not in self.read_kwargs.items() and self.df_type == "pandas":
self._fix_dtypes(data_df, self.file_format) # type: ignore[arg-type]
file_options = FILE_OPTIONS_MAP[self.file_format]

for group_name, df in self._partition_dataframe(df=data_df):
buf = io.BytesIO()
self.log.info("Writing data to in-memory buffer")
object_key = f"{self.s3_key}_{group_name}" if group_name else self.s3_key

if self.pd_kwargs.get("compression") == "gzip":
pd_kwargs = {k: v for k, v in self.pd_kwargs.items() if k != "compression"}
if self.df_kwargs.get("compression") == "gzip":
df_kwargs = {k: v for k, v in self.df_kwargs.items() if k != "compression"}
with gzip.GzipFile(fileobj=buf, mode="wb", filename=object_key) as gz:
getattr(df, file_options.function)(gz, **pd_kwargs)
getattr(df, file_options.function)(gz, **df_kwargs)
else:
if self.file_format == FILE_FORMAT.PARQUET:
getattr(df, file_options.function)(buf, **self.pd_kwargs)
getattr(df, file_options.function)(buf, **self.df_kwargs)
else:
text_buf = io.TextIOWrapper(buf, encoding="utf-8", write_through=True)
getattr(df, file_options.function)(text_buf, **self.pd_kwargs)
getattr(df, file_options.function)(text_buf, **self.df_kwargs)
text_buf.flush()
buf.seek(0)

Expand All @@ -220,17 +244,23 @@ def execute(self, context: Context) -> None:
file_obj=buf, key=object_key, bucket_name=self.s3_bucket, replace=self.replace
)

def _partition_dataframe(self, df: pd.DataFrame) -> Iterable[tuple[str, pd.DataFrame]]:
def _partition_dataframe(self, df: pd.DataFrame | pl.DataFrame) -> Iterable[tuple[str, pd.DataFrame]]:
"""Partition dataframe using pandas groupby() method."""
try:
import secrets
import string

import numpy as np
import polars as pl
except ImportError:
pass

if isinstance(df, pl.DataFrame):
df = df.to_pandas()
Comment on lines +258 to +259
Copy link
Contributor

@eladkal eladkal Aug 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@guan404ming Taking a closer look. Is this right?
This means that user who uses polars must also install pandas.

Copy link
Member Author

@guan404ming guan404ming Aug 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, to_pandas() still need pandas installed. That means it need re-implementation for this function for supporting two libs differently and I would open a PR for it later.


# if max_rows_per_file argument is specified, a temporary column with a random unusual name will be
# added to the dataframe. This column is used to dispatch the dataframe into smaller ones using groupby()

random_column_name = ""
if self.max_rows_per_file and not self.groupby_kwargs:
random_column_name = "".join(secrets.choice(string.ascii_letters) for _ in range(20))
Expand Down
Loading
Loading