Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,15 +916,6 @@ def get_key(self, key: str, bucket_name: str | None = None) -> S3ResourceObject:
:param bucket_name: the name of the bucket
:return: the key object from the bucket
"""

def sanitize_extra_args() -> dict[str, str]:
"""Parse extra_args and return a dict with only the args listed in ALLOWED_DOWNLOAD_ARGS."""
return {
arg_name: arg_value
for (arg_name, arg_value) in self.extra_args.items()
if arg_name in S3Transfer.ALLOWED_DOWNLOAD_ARGS
}

s3_resource = self.get_session().resource(
"s3",
endpoint_url=self.conn_config.endpoint_url,
Expand All @@ -933,7 +924,7 @@ def sanitize_extra_args() -> dict[str, str]:
)
obj = s3_resource.Object(bucket_name, key)

obj.load(**sanitize_extra_args())
obj.load()
return obj

@unify_bucket_name_and_key
Expand Down Expand Up @@ -1367,6 +1358,15 @@ def download_file(
Default: True.
:return: the file name.
"""

def sanitize_extra_args() -> dict[str, str]:
"""Parse extra_args and return a dict with only the args listed in ALLOWED_DOWNLOAD_ARGS."""
return {
arg_name: arg_value
for (arg_name, arg_value) in self.extra_args.items()
if arg_name in S3Transfer.ALLOWED_DOWNLOAD_ARGS
}

self.log.info(
"This function shadows the 'download_file' method of S3 API, but it is not the same. If you "
"want to use the original method from S3 API, please call "
Expand Down Expand Up @@ -1404,7 +1404,7 @@ def download_file(
with file:
s3_obj.download_fileobj(
file,
ExtraArgs=self.extra_args,
ExtraArgs=sanitize_extra_args(),
Config=self.transfer_config,
)

Expand Down
13 changes: 9 additions & 4 deletions tests/providers/amazon/aws/hooks/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import re
import unittest
from unittest import mock, mock as async_mock
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import ANY, MagicMock, Mock, patch
from urllib.parse import parse_qs

import boto3
Expand Down Expand Up @@ -1073,9 +1073,14 @@ def test_download_file_with_extra_args_sanitizes_values(self, mock_session):
s3_hook.download_file(key=s3_key, bucket_name=bucket)

mock_obj.assert_called_once_with(bucket, s3_key)
mock_obj.return_value.load.assert_called_once_with(
SSECustomerKey=encryption_key,
SSECustomerAlgorithm=encryption_algorithm,
mock_obj.return_value.load.assert_called_once_with()
mock_obj().download_fileobj.assert_called_once_with(
ANY, # File-like object
ExtraArgs={
"SSECustomerKey": encryption_key,
"SSECustomerAlgorithm": encryption_algorithm,
},
Config=s3_hook.transfer_config,
)

def test_generate_presigned_url(self, s3_bucket):
Expand Down