Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,10 @@ async def get_files_async(
delimiter: str | None = "/",
) -> list[Any]:
"""Get a list of files in the bucket."""
# Validate that bucket_keys is in fact a list, otherwise, the characters will be split
if isinstance(bucket_keys, str):
bucket_keys = [bucket_keys]

keys: list[Any] = []
for key in bucket_keys:
prefix = key
Expand All @@ -652,7 +656,9 @@ async def get_files_async(
response = paginator.paginate(**params)
async for page in response:
if "Contents" in page:
keys.extend(k for k in page["Contents"] if isinstance(k.get("Size"), (int, float)))
keys.extend(
k.get("Key") for k in page["Contents"] if isinstance(k.get("Size"), (int, float))
)
return keys

async def _list_keys_async(
Expand Down
49 changes: 47 additions & 2 deletions providers/amazon/tests/unit/amazon/aws/hooks/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ async def test_s3_key_hook_get_files_without_wildcard_async(self):
mock_paginator.paginate.assert_called_with(
Bucket="test_bucket",
Delimiter="/",
Prefix="t",
Prefix="test.txt",
RequestPayer="requester",
)

Expand Down Expand Up @@ -746,9 +746,54 @@ async def test_s3_key_hook_get_files_with_wildcard_async(self):
mock_paginator.paginate.assert_called_with(
Bucket="test_bucket",
Delimiter="/",
Prefix="t",
Prefix="test.txt",
)

@pytest.mark.asyncio
@pytest.mark.parametrize(
"mock_bucket_keys, mock_response_bucket_keys",
[
(["test.txt"], ["test.txt"]),
(["test_key"], ["test_key", "test_key2"]),
],
)
async def test_s3_key_hook_get_files_bucket_keys_list(self, mock_bucket_keys, mock_response_bucket_keys):
test_resp_iter = [
{
"Contents": [
{
"Key": mock_response_bucket_key,
"Size": 0,
"ETag": "etag1",
"LastModified": datetime(2020, 8, 14, 17, 19, 34),
}
for mock_response_bucket_key in mock_response_bucket_keys
]
}
]

mock_paginator = mock.Mock()
mock_paginate = mock.MagicMock()
mock_paginate.__aiter__.return_value = test_resp_iter
mock_paginator.paginate.return_value = mock_paginate

s3_hook_async = S3Hook(client_type="S3", resource_type="S3", requester_pays=True)
mock_client = AsyncMock()
mock_client.get_paginator = mock.Mock(return_value=mock_paginator)
response = await s3_hook_async.get_files_async(
client=mock_client, bucket="test_bucket", bucket_keys=mock_bucket_keys, wildcard_match=False
)

assert response == mock_response_bucket_keys

for test_bucket_key in mock_bucket_keys:
mock_paginator.paginate.assert_called_with(
Bucket="test_bucket",
Delimiter="/",
Prefix=test_bucket_key,
RequestPayer="requester",
)

@pytest.mark.asyncio
async def test_s3_key_hook_list_keys_async(self):
"""
Expand Down