Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from aiobotocore.client import AioBaseClient

from asgiref.sync import sync_to_async
from boto3.s3.transfer import TransferConfig
from boto3.s3.transfer import S3Transfer, TransferConfig
from botocore.exceptions import ClientError

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
Expand Down Expand Up @@ -912,14 +912,24 @@ def get_key(self, key: str, bucket_name: str | None = None) -> S3ResourceObject:
:param bucket_name: the name of the bucket
:return: the key object from the bucket
"""

def sanitize_extra_args() -> dict[str, str]:
"""Parse extra_args and return a dict with only the args listed in ALLOWED_DOWNLOAD_ARGS."""
return {
arg_name: arg_value
for (arg_name, arg_value) in self.extra_args.items()
if arg_name in S3Transfer.ALLOWED_DOWNLOAD_ARGS
}

s3_resource = self.get_session().resource(
"s3",
endpoint_url=self.conn_config.endpoint_url,
config=self.config,
verify=self.verify,
)
obj = s3_resource.Object(bucket_name, key)
obj.load()

obj.load(**sanitize_extra_args())
return obj

@unify_bucket_name_and_key
Expand Down
28 changes: 28 additions & 0 deletions tests/providers/amazon/aws/hooks/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,34 @@ def test_download_file_with_preserve_name_file_already_exists(self, tmp_path):
use_autogenerated_subdir=False,
)

@mock.patch.object(S3Hook, "get_session")
def test_download_file_with_extra_args_sanitizes_values(self, mock_session):
bucket = "test_bucket"
s3_key = "test_key"
encryption_key = "abcd123"
encryption_algorithm = "AES256" # This is the only algorithm currently supported.

s3_hook = S3Hook(
extra_args={
"SSECustomerKey": encryption_key,
"SSECustomerAlgorithm": encryption_algorithm,
"invalid_arg": "should be dropped",
}
)

mock_obj = Mock(name="MockedS3Object")
mock_resource = Mock(name="MockedBoto3Resource")
mock_resource.return_value.Object = mock_obj
mock_session.return_value.resource = mock_resource

s3_hook.download_file(key=s3_key, bucket_name=bucket)

mock_obj.assert_called_once_with(bucket, s3_key)
mock_obj.return_value.load.assert_called_once_with(
SSECustomerKey=encryption_key,
SSECustomerAlgorithm=encryption_algorithm,
)

def test_generate_presigned_url(self, s3_bucket):
hook = S3Hook()
presigned_url = hook.generate_presigned_url(
Expand Down