Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions airflow/providers/amazon/aws/transfers/sql_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class SqlToS3Operator(BaseOperator):
CA cert bundle than the one used by botocore.
:param file_format: the destination file format, only string 'csv', 'json' or 'parquet' is accepted.
:param pd_kwargs: arguments to include in DataFrame ``.to_parquet()``, ``.to_json()`` or ``.to_csv()``.
:param groupby_kwargs: argument to include in DataFrame ``groupby()``.
"""

template_fields: Sequence[str] = (
Expand Down Expand Up @@ -107,6 +108,7 @@ def __init__(
verify: bool | str | None = None,
file_format: Literal["csv", "json", "parquet"] = "csv",
pd_kwargs: dict | None = None,
groupby_kwargs: dict | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -119,6 +121,7 @@ def __init__(
self.replace = replace
self.pd_kwargs = pd_kwargs or {}
self.parameters = parameters
self.groupby_kwargs = groupby_kwargs or {}

if "path_or_buf" in self.pd_kwargs:
raise AirflowException("The argument path_or_buf is not allowed, please remove it")
Expand Down Expand Up @@ -170,15 +173,26 @@ def execute(self, context: Context) -> None:
self._fix_dtypes(data_df, self.file_format)
file_options = FILE_OPTIONS_MAP[self.file_format]

with NamedTemporaryFile(mode=file_options.mode, suffix=file_options.suffix) as tmp_file:

self.log.info("Writing data to temp file")
getattr(data_df, file_options.function)(tmp_file.name, **self.pd_kwargs)

self.log.info("Uploading data to S3")
s3_conn.load_file(
filename=tmp_file.name, key=self.s3_key, bucket_name=self.s3_bucket, replace=self.replace
)
for group_name, df in self._partition_dataframe(df=data_df):
with NamedTemporaryFile(mode=file_options.mode, suffix=file_options.suffix) as tmp_file:

self.log.info("Writing data to temp file")
getattr(df, file_options.function)(tmp_file.name, **self.pd_kwargs)

self.log.info("Uploading data to S3")
object_key = f"{self.s3_key}_{group_name}" if group_name else self.s3_key
s3_conn.load_file(
filename=tmp_file.name, key=object_key, bucket_name=self.s3_bucket, replace=self.replace
)

def _partition_dataframe(self, df: DataFrame) -> Iterable[tuple[str, DataFrame]]:
"""Partition dataframe using pandas groupby() method"""
if not self.groupby_kwargs:
yield "", df
else:
grouped_df = df.groupby(**self.groupby_kwargs)
for group_label in grouped_df.groups.keys():
yield group_label, grouped_df.get_group(group_label).reset_index(drop=True)

def _get_hook(self) -> DbApiHook:
self.log.debug("Get connection for %s", self.sql_conn_id)
Expand Down
13 changes: 13 additions & 0 deletions docs/apache-airflow-providers-amazon/transfer/sql_to_s3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,19 @@ Example usage:
:start-after: [START howto_transfer_sql_to_s3]
:end-before: [END howto_transfer_sql_to_s3]

Grouping
--------

We can group the data in the table by passing the ``groupby_kwargs`` param. This param accepts a ``dict`` which will be passed to pandas `groupby() <https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.groupby.html#pandas.DataFrame.groupby>`_ as kwargs.

Example usage:

.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_sql_to_s3.py
:language: python
:dedent: 4
:start-after: [START howto_transfer_sql_to_s3_with_groupby_param]
:end-before: [END howto_transfer_sql_to_s3_with_groupby_param]

Reference
---------

Expand Down
94 changes: 94 additions & 0 deletions tests/providers/amazon/aws/transfers/test_sql_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,97 @@ def test_invalid_file_format(self):
file_format="invalid_format",
dag=None,
)

def test_with_groupby_kwarg(self):
"""
Test operator when the groupby_kwargs is specified
"""
query = "query"
s3_bucket = "bucket"
s3_key = "key"

op = SqlToS3Operator(
query=query,
s3_bucket=s3_bucket,
s3_key=s3_key,
sql_conn_id="mysql_conn_id",
aws_conn_id="aws_conn_id",
task_id="task_id",
replace=True,
pd_kwargs={"index": False, "header": False},
groupby_kwargs={"by": "Team"},
dag=None,
)
example = {
"Team": ["Australia", "Australia", "India", "India"],
"Player": ["Ricky", "David Warner", "Virat Kohli", "Rohit Sharma"],
"Runs": [345, 490, 672, 560],
}

df = pd.DataFrame(example)
data = []
for group_name, df in op._partition_dataframe(df):
data.append((group_name, df))
data.sort(key=lambda d: d[0])
team, df = data[0]
assert df.equals(
pd.DataFrame(
{
"Team": ["Australia", "Australia"],
"Player": ["Ricky", "David Warner"],
"Runs": [345, 490],
}
)
)
team, df = data[1]
assert df.equals(
pd.DataFrame(
{
"Team": ["India", "India"],
"Player": ["Virat Kohli", "Rohit Sharma"],
"Runs": [672, 560],
}
)
)

def test_without_groupby_kwarg(self):
"""
Test operator when the groupby_kwargs is not specified
"""
query = "query"
s3_bucket = "bucket"
s3_key = "key"

op = SqlToS3Operator(
query=query,
s3_bucket=s3_bucket,
s3_key=s3_key,
sql_conn_id="mysql_conn_id",
aws_conn_id="aws_conn_id",
task_id="task_id",
replace=True,
pd_kwargs={"index": False, "header": False},
dag=None,
)
example = {
"Team": ["Australia", "Australia", "India", "India"],
"Player": ["Ricky", "David Warner", "Virat Kohli", "Rohit Sharma"],
"Runs": [345, 490, 672, 560],
}

df = pd.DataFrame(example)
data = []
for group_name, df in op._partition_dataframe(df):
data.append((group_name, df))

assert len(data) == 1
team, df = data[0]
assert df.equals(
pd.DataFrame(
{
"Team": ["Australia", "Australia", "India", "India"],
"Player": ["Ricky", "David Warner", "Virat Kohli", "Rohit Sharma"],
"Runs": [345, 490, 672, 560],
}
)
)
13 changes: 13 additions & 0 deletions tests/system/providers/amazon/aws/example_sql_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,18 @@ def delete_security_group(sec_group_id: str, sec_group_name: str):
)
# [END howto_transfer_sql_to_s3]

# [START howto_transfer_sql_to_s3_with_groupby_param]
sql_to_s3_task_with_groupby = SqlToS3Operator(
task_id="sql_to_s3_with_groupby_task",
sql_conn_id=conn_id_name,
query=SQL_QUERY,
s3_bucket=bucket_name,
s3_key=key,
replace=True,
groupby_kwargs={"by": "color"},
)
# [END howto_transfer_sql_to_s3_with_groupby_param]

delete_bucket = S3DeleteBucketOperator(
task_id="delete_bucket",
bucket_name=bucket_name,
Expand Down Expand Up @@ -202,6 +214,7 @@ def delete_security_group(sec_group_id: str, sec_group_name: str):
insert_data,
# TEST BODY
sql_to_s3_task,
sql_to_s3_task_with_groupby,
# TEST TEARDOWN
delete_bucket,
delete_cluster,
Expand Down