Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,16 @@ def _fix_dtypes(df: pd.DataFrame, file_format: FILE_FORMAT) -> None:
df[col] = np.where(df[col].isnull(), None, df[col]) # type: ignore[call-overload]
df[col] = df[col].astype(pd.Float64Dtype())

@staticmethod
def _strip_suffixes(
path: str,
) -> str:
suffixes = [".json.gz", ".csv.gz", ".json", ".csv", ".parquet"]
for suffix in sorted(suffixes, key=len, reverse=True):
if path.endswith(suffix):
return path[: -len(suffix)]
return path

def execute(self, context: Context) -> None:
sql_hook = self._get_hook()
s3_conn = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
Expand All @@ -224,9 +234,15 @@ def execute(self, context: Context) -> None:
for group_name, df in self._partition_dataframe(df=data_df):
buf = io.BytesIO()
self.log.info("Writing data to in-memory buffer")
object_key = f"{self.s3_key}_{group_name}" if group_name else self.s3_key
clean_key = self._strip_suffixes(self.s3_key)
object_key = (
f"{clean_key}_{group_name}{file_options.suffix}"
if group_name
else f"{clean_key}{file_options.suffix}"
)

if self.df_kwargs.get("compression") == "gzip":
if self.file_format != FILE_FORMAT.PARQUET and self.df_kwargs.get("compression") == "gzip":
object_key += ".gz"
df_kwargs = {k: v for k, v in self.df_kwargs.items() if k != "compression"}
with gzip.GzipFile(fileobj=buf, mode="wb", filename=object_key) as gz:
getattr(df, file_options.function)(gz, **df_kwargs)
Expand All @@ -237,6 +253,7 @@ def execute(self, context: Context) -> None:
text_buf = io.TextIOWrapper(buf, encoding="utf-8", write_through=True)
getattr(df, file_options.function)(text_buf, **self.df_kwargs)
text_buf.flush()

buf.seek(0)

self.log.info("Uploading data to S3")
Expand Down
201 changes: 123 additions & 78 deletions providers/amazon/tests/unit/amazon/aws/transfers/test_sql_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,23 @@


class TestSqlToS3Operator:
@pytest.mark.parametrize("dtype_backend", ["numpy_nullable", "pyarrow"])
@pytest.mark.parametrize(
"file_format, dtype_backend, df_kwargs, expected_key_suffix",
[
("csv", "numpy_nullable", {"index": False, "header": False}, ".csv"),
("csv", "pyarrow", {"index": False, "header": False}, ".csv"),
("parquet", "numpy_nullable", {}, ".parquet"),
("parquet", "pyarrow", {}, ".parquet"),
(
"json",
None,
{"date_format": "iso", "lines": True, "orient": "records"},
".json",
),
],
)
@mock.patch("airflow.providers.amazon.aws.transfers.sql_to_s3.S3Hook")
def test_execute_csv(self, mock_s3_hook, dtype_backend):
def test_execute_formats(self, mock_s3_hook, file_format, dtype_backend, df_kwargs, expected_key_suffix):
query = "query"
s3_bucket = "bucket"
s3_key = "key"
Expand All @@ -42,43 +56,7 @@ def test_execute_csv(self, mock_s3_hook, dtype_backend):
get_df_mock = mock_dbapi_hook.return_value.get_df
get_df_mock.return_value = test_df

op = SqlToS3Operator(
query=query,
s3_bucket=s3_bucket,
s3_key=s3_key,
sql_conn_id="mysql_conn_id",
aws_conn_id="aws_conn_id",
task_id="task_id",
replace=True,
read_kwargs={"dtype_backend": dtype_backend},
df_kwargs={"index": False, "header": False},
dag=None,
)
op._get_hook = mock_dbapi_hook
op.execute(None)

mock_s3_hook.assert_called_once_with(aws_conn_id="aws_conn_id", verify=None)
get_df_mock.assert_called_once_with(
sql=query, parameters=None, df_type="pandas", dtype_backend=dtype_backend
)
file_obj = mock_s3_hook.return_value.load_file_obj.call_args[1]["file_obj"]
assert isinstance(file_obj, io.BytesIO)
mock_s3_hook.return_value.load_file_obj.assert_called_once_with(
file_obj=file_obj, key=s3_key, bucket_name=s3_bucket, replace=True
)

@pytest.mark.parametrize("dtype_backend", ["numpy_nullable", "pyarrow"])
@mock.patch("airflow.providers.amazon.aws.transfers.sql_to_s3.S3Hook")
def test_execute_parquet(self, mock_s3_hook, dtype_backend):
query = "query"
s3_bucket = "bucket"
s3_key = "key"

mock_dbapi_hook = mock.Mock()
test_df = pd.DataFrame({"a": "1", "b": "2"}, index=[0, 1])

get_df_mock = mock_dbapi_hook.return_value.get_df
get_df_mock.return_value = test_df
read_df_kwargs = {"dtype_backend": dtype_backend} if dtype_backend else {}

op = SqlToS3Operator(
query=query,
Expand All @@ -87,57 +65,35 @@ def test_execute_parquet(self, mock_s3_hook, dtype_backend):
sql_conn_id="mysql_conn_id",
aws_conn_id="aws_conn_id",
task_id="task_id",
read_kwargs={"dtype_backend": dtype_backend},
file_format="parquet",
file_format=file_format,
replace=True,
read_kwargs=read_df_kwargs,
df_kwargs=df_kwargs,
dag=None,
)
op._get_hook = mock_dbapi_hook
op.execute(None)

mock_s3_hook.assert_called_once_with(aws_conn_id="aws_conn_id", verify=None)
get_df_mock.assert_called_once_with(
sql=query, parameters=None, df_type="pandas", dtype_backend=dtype_backend
)

file_obj = mock_s3_hook.return_value.load_file_obj.call_args[1]["file_obj"]
assert isinstance(file_obj, io.BytesIO)
mock_s3_hook.return_value.load_file_obj.assert_called_once_with(
file_obj=file_obj, key=s3_key, bucket_name=s3_bucket, replace=True
)

@mock.patch("airflow.providers.amazon.aws.transfers.sql_to_s3.S3Hook")
def test_execute_json(self, mock_s3_hook):
query = "query"
s3_bucket = "bucket"
s3_key = "key"
expected_df_kwargs = {
"sql": query,
"parameters": None,
"df_type": "pandas",
}
if dtype_backend:
expected_df_kwargs["dtype_backend"] = dtype_backend

mock_dbapi_hook = mock.Mock()
test_df = pd.DataFrame({"a": "1", "b": "2"}, index=[0, 1])
get_df_mock = mock_dbapi_hook.return_value.get_df
get_df_mock.return_value = test_df
get_df_mock.assert_called_once_with(**expected_df_kwargs)

op = SqlToS3Operator(
query=query,
s3_bucket=s3_bucket,
s3_key=s3_key,
sql_conn_id="mysql_conn_id",
aws_conn_id="aws_conn_id",
task_id="task_id",
file_format="json",
replace=True,
df_kwargs={"date_format": "iso", "lines": True, "orient": "records"},
dag=None,
)
op._get_hook = mock_dbapi_hook
op.execute(None)

mock_s3_hook.assert_called_once_with(aws_conn_id="aws_conn_id", verify=None)
get_df_mock.assert_called_once_with(sql=query, parameters=None, df_type="pandas")
file_obj = mock_s3_hook.return_value.load_file_obj.call_args[1]["file_obj"]
assert isinstance(file_obj, io.BytesIO)

mock_s3_hook.return_value.load_file_obj.assert_called_once_with(
file_obj=file_obj, key=s3_key, bucket_name=s3_bucket, replace=True
file_obj=file_obj,
key=f"{s3_key}{expected_key_suffix}",
bucket_name=s3_bucket,
replace=True,
)

@mock.patch("airflow.providers.amazon.aws.transfers.sql_to_s3.S3Hook")
Expand Down Expand Up @@ -420,7 +376,7 @@ def test_hook_params(self, mock_get_conn):
def test_execute_with_df_type(self, mock_s3_hook, df_type_param, expected_df_type):
query = "query"
s3_bucket = "bucket"
s3_key = "key"
s3_key = "key.csv"

mock_dbapi_hook = mock.Mock()
test_df = pd.DataFrame({"a": "1", "b": "2"}, index=[0, 1])
Expand Down Expand Up @@ -653,3 +609,92 @@ def test_deprecated_kwargs_priority_behavior(

assert op.read_kwargs == expected_read_kwargs
assert op.df_kwargs == expected_df_kwargs

@pytest.mark.parametrize(
"fmt, df_kwargs, expected_key",
[
("csv", {"compression": "gzip", "index": False}, "data.csv.gz"),
("csv", {"index": False}, "data.csv"),
("json", {"compression": "gzip"}, "data.json.gz"),
("json", {}, "data.json"),
("parquet", {"compression": "gzip"}, "data.parquet"),
("parquet", {}, "data.parquet"),
],
)
@mock.patch("airflow.providers.amazon.aws.transfers.sql_to_s3.S3Hook")
@mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook")
def test_file_format_handling(self, mock_dbapi_hook, mock_s3_hook, fmt, df_kwargs, expected_key):
s3_bucket = "bucket"
s3_key = "data." + fmt
test_df = pd.DataFrame({"x": [1, 2]})
mock_dbapi_hook.return_value.get_df.return_value = test_df

op = SqlToS3Operator(
query="SELECT * FROM test",
s3_bucket=s3_bucket,
s3_key=s3_key,
sql_conn_id="sqlite_conn",
aws_conn_id="aws_default",
task_id="task_id",
file_format=fmt,
df_kwargs=df_kwargs,
replace=True,
dag=None,
)
op._get_hook = lambda: mock_dbapi_hook.return_value
op.execute(context=None)

uploaded_key = mock_s3_hook.return_value.load_file_obj.call_args[1]["key"]
assert uploaded_key == expected_key

@pytest.mark.parametrize(
"file_format, df_kwargs, expected_suffix",
[
("csv", {"compression": "gzip", "index": False}, ".csv.gz"),
("csv", {"index": False}, ".csv"),
("json", {"compression": "gzip"}, ".json.gz"),
("json", {}, ".json"),
("parquet", {"compression": "gzip"}, ".parquet"),
("parquet", {}, ".parquet"),
],
)
@mock.patch("airflow.providers.amazon.aws.transfers.sql_to_s3.S3Hook")
@mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook")
def test_file_format_handling_with_groupby(
self, mock_dbapi_hook, mock_s3_hook, file_format, df_kwargs, expected_suffix
):
s3_bucket = "bucket"
s3_key = "data"

# Input DataFrame with groups
test_data = pd.DataFrame(
{"x": [1, 2, 3, 4, 5, 6], "group": ["group1", "group1", "group2", "group2", "group3", "group4"]}
)

mock_dbapi_hook.return_value.get_df.return_value = test_data

op = SqlToS3Operator(
query="SELECT * FROM test",
s3_bucket=s3_bucket,
s3_key=s3_key,
sql_conn_id="sqlite_conn",
aws_conn_id="aws_default",
task_id="task_id",
file_format=file_format,
df_kwargs=df_kwargs,
groupby_kwargs={"by": "group"},
replace=True,
dag=None,
)

op._get_hook = lambda: mock_dbapi_hook.return_value
op.execute(context=None)

expected_groups = test_data["group"].unique()
assert mock_s3_hook.return_value.load_file_obj.call_count == len(expected_groups)

called_keys = [call.kwargs["key"] for call in mock_s3_hook.return_value.load_file_obj.call_args_list]

for group in expected_groups:
expected_key = f"{s3_key}_{group}{expected_suffix}"
assert expected_key in called_keys, f"Missing expected key: {expected_key}"
Loading