Skip to content
37 changes: 30 additions & 7 deletions providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
import re
import shutil
import time
from collections.abc import AsyncIterator, Callable
import warnings
from collections.abc import AsyncIterator, Callable, Iterator
from contextlib import suppress
from copy import deepcopy
from datetime import datetime
Expand Down Expand Up @@ -57,7 +58,7 @@
from boto3.s3.transfer import S3Transfer, TransferConfig
from botocore.exceptions import ClientError

from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.exceptions import AirflowException, AirflowNotFoundException, AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.exceptions import S3HookUriParseFailure
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.utils.tags import format_tags
Expand Down Expand Up @@ -931,7 +932,31 @@ def get_file_metadata(
max_items: int | None = None,
) -> list:
"""
List metadata objects in a bucket under prefix.
.. deprecated:: <9.13.0> Use `iter_file_metadata` instead.

This method `get_file_metadata` is deprecated. Calling this method will result in all matching keys
being loaded into a single list, and can often result in out-of-memory exceptions.
"""
warnings.warn(
"This method `get_file_metadata` is deprecated. Calling this method will result in all matching "
"keys being loaded into a single list, and can often result in out-of-memory exceptions. "
"Instead, use `iter_file_metadata`.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)

return list(self.iter_file_metadata(prefix=prefix, page_size=page_size, max_items=max_items))

@provide_bucket_name
def iter_file_metadata(
self,
prefix: str,
bucket_name: str | None = None,
page_size: int | None = None,
max_items: int | None = None,
) -> Iterator:
"""
Yield metadata objects from a bucket under a prefix.

.. seealso::
- :external+boto3:py:class:`S3.Paginator.ListObjectsV2`
Expand All @@ -940,7 +965,7 @@ def get_file_metadata(
:param bucket_name: the name of the bucket
:param page_size: pagination size
:param max_items: maximum items to return
:return: a list of metadata of objects
:return: an Iterator of metadata of objects
"""
config = {
"PageSize": page_size,
Expand All @@ -957,11 +982,9 @@ def get_file_metadata(
params["RequestPayer"] = "requester"
response = paginator.paginate(**params)

files = []
for page in response:
if "Contents" in page:
files += page["Contents"]
return files
yield from page["Contents"]

@unify_bucket_name_and_key
@provide_bucket_name
Expand Down
35 changes: 24 additions & 11 deletions providers/amazon/src/airflow/providers/amazon/aws/sensors/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,19 @@ def _check_key(self, key, context: Context):
"""
if self.wildcard_match:
prefix = re.split(r"[\[*?]", key, 1)[0]
keys = self.hook.get_file_metadata(prefix, bucket_name)
key_matches = [k for k in keys if fnmatch.fnmatch(k["Key"], key)]

key_matches: list[str] = []

# Is check_fn is None, then we can return True without having to iterate through each value in
# yielded by iter_file_metadata. Otherwise, we'll check for a match, and add all matches to the
# key_matches list
for k in self.hook.iter_file_metadata(prefix, bucket_name):
if fnmatch.fnmatch(k["Key"], key):
if self.check_fn is None:
# This will only wait for a single match, and will immediately return
return True
key_matches.append(k)

if not key_matches:
return False

Expand All @@ -132,21 +143,23 @@ def _check_key(self, key, context: Context):
for f in key_matches:
metadata = {}
if "*" in self.metadata_keys:
metadata = self.hook.head_object(f["Key"], bucket_name)
metadata = self.hook.head_object(f["Key"], bucket_name) # type: ignore[index]
else:
for key in self.metadata_keys:
for mk in self.metadata_keys:
try:
metadata[key] = f[key]
metadata[mk] = f[mk] # type: ignore[index]
except KeyError:
# supplied key might be from head_object response
self.log.info("Key %s not found in response, performing head_object", key)
metadata[key] = self.hook.head_object(f["Key"], bucket_name).get(key, None)
self.log.info("Key %s not found in response, performing head_object", mk)
metadata[mk] = self.hook.head_object(f["Key"], bucket_name).get(mk, None) # type: ignore[index]
files.append(metadata)

elif self.use_regex:
keys = self.hook.get_file_metadata("", bucket_name)
key_matches = [k for k in keys if re.match(pattern=key, string=k["Key"])]
if not key_matches:
return False
for k in self.hook.iter_file_metadata("", bucket_name):
if re.match(pattern=key, string=k["Key"]):
return True
return False

else:
obj = self.hook.head_object(key, bucket_name)
if obj is None:
Expand Down
19 changes: 12 additions & 7 deletions providers/amazon/tests/unit/amazon/aws/hooks/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import inspect
import os
import re
from collections.abc import Iterator
from datetime import datetime as std_datetime, timezone
from pathlib import Path
from unittest import mock, mock as async_mock
Expand Down Expand Up @@ -389,22 +390,26 @@ def test_list_keys_paged(self, s3_bucket):

assert sorted(keys) == sorted(hook.list_keys(s3_bucket, delimiter="/", page_size=1))

def test_get_file_metadata(self, s3_bucket):
def test_iter_file_metadata(self, s3_bucket):
hook = S3Hook()
bucket = hook.get_bucket(s3_bucket)
bucket.put_object(Key="test", Body=b"a")

assert len(hook.get_file_metadata("t", s3_bucket)) == 1
assert hook.get_file_metadata("t", s3_bucket)[0]["Size"] is not None
assert len(hook.get_file_metadata("test", s3_bucket)) == 1
assert len(hook.get_file_metadata("a", s3_bucket)) == 0
assert isinstance(hook.iter_file_metadata("t", s3_bucket), Iterator)

def test_get_file_metadata_when_requester_pays(self, s3_bucket):
# Since iter_file_metadata now returns an Iterator, it will first be cast to a `list` before being
# able to determine its length
assert len(list(hook.iter_file_metadata("t", s3_bucket))) == 1
assert next(hook.iter_file_metadata("t", s3_bucket))["Size"] is not None
assert len(list(hook.iter_file_metadata("test", s3_bucket))) == 1
assert len(list(hook.iter_file_metadata("a", s3_bucket))) == 0

def test_iter_file_metadata_when_requester_pays(self, s3_bucket):
hook = S3Hook(requester_pays=True)
hook.get_conn = MagicMock()
hook.get_conn.return_value.get_paginator.return_value.paginate.return_value = []

assert hook.get_file_metadata("test", s3_bucket) == []
assert not any(hook.iter_file_metadata("test", s3_bucket)) # Empty Iterator

hook.get_conn.return_value.get_paginator.return_value.paginate.assert_called_with(
Bucket="airflow-test-s3-bucket",
Expand Down
39 changes: 21 additions & 18 deletions providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,39 +243,42 @@ def test_poke_multiple_files(self, mock_head_object):
mock_head_object.assert_any_call("file1", "test_bucket")
mock_head_object.assert_any_call("file2", "test_bucket")

@mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook.get_file_metadata")
def test_poke_wildcard(self, mock_get_file_metadata):
@mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook.iter_file_metadata")
def test_poke_wildcard(self, mock_iter_file_metadata):
op = S3KeySensor(task_id="s3_key_sensor", bucket_key="s3://test_bucket/file*", wildcard_match=True)

mock_get_file_metadata.return_value = []
mock_iter_file_metadata.return_value = []
assert op.poke(None) is False
mock_get_file_metadata.assert_called_once_with("file", "test_bucket")
mock_iter_file_metadata.assert_called_once_with("file", "test_bucket")

mock_get_file_metadata.return_value = [{"Key": "dummyFile", "Size": 0}]
mock_iter_file_metadata.return_value = [{"Key": "dummyFile", "Size": 0}]
assert op.poke(None) is False

mock_get_file_metadata.return_value = [{"Key": "file1", "Size": 0}]
mock_iter_file_metadata.return_value = [{"Key": "file1", "Size": 0}]
assert op.poke(None) is True

@mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook.get_file_metadata")
def test_poke_wildcard_multiple_files(self, mock_get_file_metadata):
@mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook.iter_file_metadata")
def test_poke_wildcard_multiple_files(self, mock_iter_file_metadata):
op = S3KeySensor(
task_id="s3_key_sensor",
bucket_key=["s3://test_bucket/file*", "s3://test_bucket/*.zip"],
wildcard_match=True,
)

mock_get_file_metadata.side_effect = [[{"Key": "file1", "Size": 0}], []]
mock_iter_file_metadata.side_effect = [[{"Key": "file1", "Size": 0}], []]
assert op.poke(None) is False

mock_get_file_metadata.side_effect = [[{"Key": "file1", "Size": 0}], [{"Key": "file2", "Size": 0}]]
mock_iter_file_metadata.side_effect = [[{"Key": "file1", "Size": 0}], [{"Key": "file2", "Size": 0}]]
assert op.poke(None) is False

mock_get_file_metadata.side_effect = [[{"Key": "file1", "Size": 0}], [{"Key": "test.zip", "Size": 0}]]
mock_iter_file_metadata.side_effect = [
[{"Key": "file1", "Size": 0}],
[{"Key": "test.zip", "Size": 0}],
]
assert op.poke(None) is True

mock_get_file_metadata.assert_any_call("file", "test_bucket")
mock_get_file_metadata.assert_any_call("", "test_bucket")
mock_iter_file_metadata.assert_any_call("file", "test_bucket")
mock_iter_file_metadata.assert_any_call("", "test_bucket")

@mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook.head_object")
def test_poke_with_check_function(self, mock_head_object):
Expand All @@ -298,15 +301,15 @@ def check_fn(files: list) -> bool:
("test/test.csv", r"test/[a-z]+\.csv", True),
],
)
@mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook.get_file_metadata")
def test_poke_with_use_regex(self, mock_get_file_metadata, key, pattern, expected):
@mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook.iter_file_metadata")
def test_poke_with_use_regex(self, mock_iter_file_metadata, key, pattern, expected):
op = S3KeySensor(
task_id="s3_key_sensor_async",
bucket_key=pattern,
bucket_name="test_bucket",
use_regex=True,
)
mock_get_file_metadata.return_value = [{"Key": key, "Size": 0}]
mock_iter_file_metadata.return_value = [{"Key": key, "Size": 0}]
assert op.poke(None) is expected

@mock.patch("airflow.providers.amazon.aws.sensors.s3.S3KeySensor.poke", return_value=False)
Expand Down Expand Up @@ -423,7 +426,7 @@ def check_fn(files: list) -> bool:
assert op.poke(None) is True

@mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook.head_object")
@mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook.get_file_metadata")
@mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook.iter_file_metadata")
def test_custom_metadata_wildcard(self, mock_file_metadata, mock_head_object):
def check_fn(files: list) -> bool:
for f in files:
Expand All @@ -445,7 +448,7 @@ def check_fn(files: list) -> bool:
mock_head_object.assert_called_once()

@mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook.head_object")
@mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook.get_file_metadata")
@mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook.iter_file_metadata")
def test_custom_metadata_wildcard_all_attributes(self, mock_file_metadata, mock_head_object):
def check_fn(files: list) -> bool:
for f in files:
Expand Down