Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 101 additions & 34 deletions providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def __init__(
) -> None:
kwargs["client_type"] = "s3"
kwargs["aws_conn_id"] = aws_conn_id
self._requester_pays = kwargs.pop("requester_pays", False)

if transfer_config_args and not isinstance(transfer_config_args, dict):
raise TypeError(f"transfer_config_args expected dict, got {type(transfer_config_args).__name__}.")
Expand Down Expand Up @@ -409,12 +410,15 @@ def list_prefixes(
}

paginator = self.get_conn().get_paginator("list_objects_v2")
response = paginator.paginate(
Bucket=bucket_name,
Prefix=prefix,
Delimiter=delimiter,
PaginationConfig=config,
)
params = {
"Bucket": bucket_name,
"Prefix": prefix,
"Delimiter": delimiter,
"PaginationConfig": config,
}
if self._requester_pays:
params["RequestPayer"] = "requester"
response = paginator.paginate(**params)

prefixes: list[str] = []
for page in response:
Expand All @@ -437,7 +441,13 @@ async def get_head_object_async(
"""
head_object_val: dict[str, Any] | None = None
try:
head_object_val = await client.head_object(Bucket=bucket_name, Key=key)
params = {
"Bucket": bucket_name,
"Key": key,
}
if self._requester_pays:
params["RequestPayer"] = "requester"
head_object_val = await client.head_object(**params)
return head_object_val
except ClientError as e:
if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
Expand Down Expand Up @@ -472,12 +482,15 @@ async def list_prefixes_async(
}

paginator = client.get_paginator("list_objects_v2")
response = paginator.paginate(
Bucket=bucket_name,
Prefix=prefix,
Delimiter=delimiter,
PaginationConfig=config,
)
params = {
"Bucket": bucket_name,
"Prefix": prefix,
"Delimiter": delimiter,
"PaginationConfig": config,
}
if self._requester_pays:
params["RequestPayer"] = "requester"
response = paginator.paginate(**params)

prefixes = []
async for page in response:
Expand All @@ -501,7 +514,14 @@ async def get_file_metadata_async(
prefix = re.split(r"[\[\*\?]", key, 1)[0] if key else ""
delimiter = ""
paginator = client.get_paginator("list_objects_v2")
response = paginator.paginate(Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter)
params = {
"Bucket": bucket_name,
"Prefix": prefix,
"Delimiter": delimiter,
}
if self._requester_pays:
params["RequestPayer"] = "requester"
response = paginator.paginate(**params)
async for page in response:
if "Contents" in page:
for row in page["Contents"]:
Expand Down Expand Up @@ -622,14 +642,21 @@ async def get_files_async(
prefix = re.split(r"[\[*?]", key, 1)[0]

paginator = client.get_paginator("list_objects_v2")
response = paginator.paginate(Bucket=bucket, Prefix=prefix, Delimiter=delimiter)
params = {
"Bucket": bucket,
"Prefix": prefix,
"Delimiter": delimiter,
}
if self._requester_pays:
params["RequestPayer"] = "requester"
response = paginator.paginate(**params)
async for page in response:
if "Contents" in page:
keys.extend(k for k in page["Contents"] if isinstance(k.get("Size"), (int, float)))
return keys

@staticmethod
async def _list_keys_async(
self,
client: AioBaseClient,
bucket_name: str | None = None,
prefix: str | None = None,
Expand All @@ -655,12 +682,15 @@ async def _list_keys_async(
}

paginator = client.get_paginator("list_objects_v2")
response = paginator.paginate(
Bucket=bucket_name,
Prefix=prefix,
Delimiter=delimiter,
PaginationConfig=config,
)
params = {
"Bucket": bucket_name,
"Prefix": prefix,
"Delimiter": delimiter,
"PaginationConfig": config,
}
if self._requester_pays:
params["RequestPayer"] = "requester"
response = paginator.paginate(**params)

keys = []
async for page in response:
Expand Down Expand Up @@ -863,13 +893,16 @@ def _is_in_period(input_date: datetime) -> bool:
}

paginator = self.get_conn().get_paginator("list_objects_v2")
response = paginator.paginate(
Bucket=bucket_name,
Prefix=_prefix,
Delimiter=delimiter,
PaginationConfig=config,
StartAfter=start_after_key,
)
params = {
"Bucket": bucket_name,
"Prefix": _prefix,
"Delimiter": delimiter,
"PaginationConfig": config,
"StartAfter": start_after_key,
}
if self._requester_pays:
params["RequestPayer"] = "requester"
response = paginator.paginate(**params)

keys: list[str] = []
for page in response:
Expand Down Expand Up @@ -909,7 +942,14 @@ def get_file_metadata(
}

paginator = self.get_conn().get_paginator("list_objects_v2")
response = paginator.paginate(Bucket=bucket_name, Prefix=prefix, PaginationConfig=config)
params = {
"Bucket": bucket_name,
"Prefix": prefix,
"PaginationConfig": config,
}
if self._requester_pays:
params["RequestPayer"] = "requester"
response = paginator.paginate(**params)

files = []
for page in response:
Expand All @@ -931,7 +971,13 @@ def head_object(self, key: str, bucket_name: str | None = None) -> dict | None:
:return: metadata of an object
"""
try:
return self.get_conn().head_object(Bucket=bucket_name, Key=key)
params = {
"Bucket": bucket_name,
"Key": key,
}
if self._requester_pays:
params["RequestPayer"] = "requester"
return self.get_conn().head_object(**params)
except ClientError as e:
if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
return None
Expand Down Expand Up @@ -975,8 +1021,11 @@ def sanitize_extra_args() -> dict[str, str]:
if arg_name in S3Transfer.ALLOWED_DOWNLOAD_ARGS
}

params = sanitize_extra_args()
if self._requester_pays:
params["RequestPayer"] = "requester"
obj = self.resource.Object(bucket_name, key)
obj.load(**sanitize_extra_args())
obj.load(**params)
return obj

@unify_bucket_name_and_key
Expand Down Expand Up @@ -1022,11 +1071,14 @@ def select_key(
"""
expression = expression or "SELECT * FROM S3Object"
expression_type = expression_type or "SQL"
extra_args = {}

if input_serialization is None:
input_serialization = {"CSV": {}}
if output_serialization is None:
output_serialization = {"CSV": {}}
if self._requester_pays:
extra_args["RequestPayer"] = "requester"

response = self.get_conn().select_object_content(
Bucket=bucket_name,
Expand All @@ -1035,6 +1087,7 @@ def select_key(
ExpressionType=expression_type,
InputSerialization=input_serialization,
OutputSerialization=output_serialization,
ExtraArgs=extra_args,
)

return b"".join(
Expand Down Expand Up @@ -1124,6 +1177,8 @@ def load_file(
filename = filename_gz
if acl_policy:
extra_args["ACL"] = acl_policy
if self._requester_pays:
extra_args["RequestPayer"] = "requester"

client = self.get_conn()
client.upload_file(
Expand Down Expand Up @@ -1270,6 +1325,8 @@ def _upload_file_obj(
extra_args["ServerSideEncryption"] = "AES256"
if acl_policy:
extra_args["ACL"] = acl_policy
if self._requester_pays:
extra_args["RequestPayer"] = "requester"

client = self.get_conn()
client.upload_fileobj(
Expand Down Expand Up @@ -1330,6 +1387,8 @@ def copy_object(
kwargs["ACL"] = acl_policy
if meta_data_directive:
kwargs["MetadataDirective"] = meta_data_directive
if self._requester_pays:
kwargs["RequestPayer"] = "requester"

dest_bucket_name, dest_bucket_key = self.get_s3_bucket_key(
dest_bucket_name, dest_bucket_key, "dest_bucket_name", "dest_bucket_key"
Expand Down Expand Up @@ -1412,12 +1471,17 @@ def delete_objects(self, bucket: str, keys: str | list) -> None:
keys = [keys]

s3 = self.get_conn()
extra_kwargs = {}
if self._requester_pays:
extra_kwargs["RequestPayer"] = "requester"

# We can only send a maximum of 1000 keys per request.
# For details see:
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.delete_objects
for chunk in chunks(keys, chunk_size=1000):
response = s3.delete_objects(Bucket=bucket, Delete={"Objects": [{"Key": k} for k in chunk]})
response = s3.delete_objects(
Bucket=bucket, Delete={"Objects": [{"Key": k} for k in chunk]}, **extra_kwargs
)
deleted_keys = [x["Key"] for x in response.get("Deleted", [])]
self.log.info("Deleted: %s", deleted_keys)
if "Errors" in response:
Expand Down Expand Up @@ -1496,9 +1560,12 @@ def download_file(
file = NamedTemporaryFile(dir=local_path, prefix="airflow_tmp_", delete=False) # type: ignore

with file:
extra_args = {**self.extra_args}
if self._requester_pays:
extra_args["RequestPayer"] = "requester"
s3_obj.download_fileobj(
file,
ExtraArgs=self.extra_args,
ExtraArgs=extra_args,
Config=self.transfer_config,
)
get_hook_lineage_collector().add_input_asset(
Expand Down
Loading