Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 43 additions & 5 deletions airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@
from inspect import signature
from io import BytesIO
from pathlib import Path
from tempfile import NamedTemporaryFile
from tempfile import NamedTemporaryFile, gettempdir
from typing import Any, Callable, List, TypeVar, cast
from urllib.parse import urlparse
from uuid import uuid4

from boto3.s3.transfer import S3Transfer, TransferConfig
from botocore.exceptions import ClientError
Expand Down Expand Up @@ -879,17 +880,38 @@ def delete_objects(self, bucket: str, keys: str | list) -> None:

@provide_bucket_name
@unify_bucket_name_and_key
def download_file(self, key: str, bucket_name: str | None = None, local_path: str | None = None) -> str:
def download_file(
self,
key: str,
bucket_name: str | None = None,
local_path: str | None = None,
preserve_file_name: bool = False,
use_autogenerated_subdir: bool = True,
) -> str:
"""
Downloads a file from the S3 location to the local file system.

:param key: The key path in S3.
:param bucket_name: The specific bucket to use.
:param local_path: The local path to the downloaded file. If no path is provided it will use the
system's temporary directory.
:param preserve_file_name: If you want the downloaded file name to be the same name as it is in S3,
set this parameter to True. When set to False, a random filename will be generated.
Default: False.
:param use_autogenerated_subdir: Pairs with 'preserve_file_name = True' to download the file into a
random generated folder inside the 'local_path', useful to avoid collisions between various tasks
that might download the same file name. Set it to 'False' if you don't want it, and you want a
predictable path.
Default: True.
:return: the file name.
:rtype: str
"""
self.log.info(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Taragolis I've added a log message here to show that this function shadows boto's method, hope that's fine :)

"This function shadows the 'download_file' method of S3 API, but it is not the same. If you "
"want to use the original method from S3 API, please call "
"'S3Hook.get_conn().download_file()'"
)

self.log.info("Downloading source S3 file from Bucket %s with path %s", bucket_name, key)

try:
Expand All @@ -902,14 +924,30 @@ def download_file(self, key: str, bucket_name: str | None = None, local_path: st
else:
raise e

with NamedTemporaryFile(dir=local_path, prefix="airflow_tmp_", delete=False) as local_tmp_file:
if preserve_file_name:
local_dir = local_path if local_path else gettempdir()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m a bit worried that using the temp dir directly with a predictive file name may cause a vulnarability. I don’t have concrete examples, but the combination is sort of a red flag.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYM?
I've tested it locally in Airflow docker and in Breeze, and the gettempdir() method retunes /tmp on all Linux envs..
When the dir parameter is not provided to the NamedTemporaryFile, it also called directly:
https://github.com/python/cpython/blob/0d68879104dfb392d31e52e25dcb0661801a0249/Lib/tempfile.py#L126

I do not quite understand why it may cause a vulnerability.

Do you think it's better to stay with the older implementation of renaming the file after it's already been created? I think that this way is a bit cleaner, but I'm also ok with also "reverting" to the old flow..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, in all flows, the user can still have a file kept in S3 with a name that can cause vulnerability and later saved in the temp directory that will be generated using the same function (if we don't provide a local_path)..

This is not different, it means that we can't implement this feature at all?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To follow up on what @uranusjr 's thought:

Say we are using LocalExecutor or CeleryExecutor, so two users' jobs can be executed on the same host.

Here you are having filename_in_s3 = s3_obj.key.rsplit('/', 1)[-1]. So if user A is having file .../A/data.json and user B is having .../B/data.json, there may be conflict, right?

But just a vague thinking and very likely I missed something. Please feel free to point out.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The risk can be greatly reduced if the file is put in a subdirectory instead of directly inside the temp directory root (so the full path the file is downloaded to remains unpredictable), but that may lead to additional cleanup issues since directories are more finicky than files. I’d be happy tif it works though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do also like option 1, it solves many of the complexities you pointed out (or at least bubbles them up to the user) and also allows the user to create a path that is predictable, so this is probably my preference. But they should be able to provide a full sub path within tmp so that they can organize files with similar names to their preference.

Although, option 2 would be perfectly serviceable as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually, I did 2 things:

  • Added a check if the file exists before re-writing it, failing the task if it already exists to bubble the issue to the user.
  • Added another parameter, use_autogenerated_subdir, that is True by default, which creates a new sub-directory. The user can disable it to control the target file location, but it's on by default.

@o-nikolas @uranusjr @XD-DENG Will appreciate your review of the latest additions to this flow 🙏

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @alexkruc, thanks for sticking with this! The method stub is a little complicated now, but I think it's a decent middle ground given all the constraints that came up in the discussions here 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@o-nikolas Thanks!
It seems like this PR is beginning to be a bit stale.. Do you think we should do anything else? Or is this ok to approve and merge this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @alexkruc,

I think this is good to merge, but unfortunately I'm not a committer. CCing @eladkal

subdir = f"airflow_tmp_dir_{uuid4().hex[0:8]}" if use_autogenerated_subdir else ""
filename_in_s3 = s3_obj.key.rsplit("/", 1)[-1]
file_path = Path(local_dir, subdir, filename_in_s3)

if file_path.is_file():
self.log.error("file '%s' already exists. Failing the task and not overwriting it", file_path)
raise FileExistsError

file_path.parent.mkdir(exist_ok=True, parents=True)

file = open(file_path, "wb")
else:
file = NamedTemporaryFile(dir=local_path, prefix="airflow_tmp_", delete=False) # type: ignore

with file:
s3_obj.download_fileobj(
local_tmp_file,
file,
ExtraArgs=self.extra_args,
Config=self.transfer_config,
)

return local_tmp_file.name
return file.name

def generate_presigned_url(
self,
Expand Down
89 changes: 80 additions & 9 deletions tests/providers/amazon/aws/hooks/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import gzip as gz
import os
import tempfile
from pathlib import Path
from unittest import mock
from unittest.mock import Mock

Expand Down Expand Up @@ -532,24 +533,94 @@ def test_function_with_test_key(self, test_key, bucket_name=None):

@mock.patch("airflow.providers.amazon.aws.hooks.s3.NamedTemporaryFile")
def test_download_file(self, mock_temp_file):
mock_temp_file.return_value.__enter__ = Mock(return_value=mock_temp_file)
with tempfile.NamedTemporaryFile(dir="/tmp", prefix="airflow_tmp_test_s3_hook") as temp_file:
mock_temp_file.return_value = temp_file
s3_hook = S3Hook(aws_conn_id="s3_test")
s3_hook.check_for_key = Mock(return_value=True)
s3_obj = Mock()
s3_obj.download_fileobj = Mock(return_value=None)
s3_hook.get_key = Mock(return_value=s3_obj)
key = "test_key"
bucket = "test_bucket"

output_file = s3_hook.download_file(key=key, bucket_name=bucket)

s3_hook.get_key.assert_called_once_with(key, bucket)
s3_obj.download_fileobj.assert_called_once_with(
temp_file,
Config=s3_hook.transfer_config,
ExtraArgs=s3_hook.extra_args,
)

assert temp_file.name == output_file

@mock.patch("airflow.providers.amazon.aws.hooks.s3.open")
def test_download_file_with_preserve_name(self, mock_open):
file_name = "test.log"
bucket = "test_bucket"
key = f"test_key/{file_name}"
local_folder = "/tmp"

s3_hook = S3Hook(aws_conn_id="s3_test")
s3_hook.check_for_key = Mock(return_value=True)
s3_obj = Mock()
s3_obj.key = f"s3://{bucket}/{key}"
s3_obj.download_fileobj = Mock(return_value=None)
s3_hook.get_key = Mock(return_value=s3_obj)
key = "test_key"
bucket = "test_bucket"
s3_hook.download_file(
key=key,
bucket_name=bucket,
local_path=local_folder,
preserve_file_name=True,
use_autogenerated_subdir=False,
)

s3_hook.download_file(key=key, bucket_name=bucket)
mock_open.assert_called_once_with(Path(local_folder, file_name), "wb")

s3_hook.get_key.assert_called_once_with(key, bucket)
s3_obj.download_fileobj.assert_called_once_with(
mock_temp_file,
Config=s3_hook.transfer_config,
ExtraArgs=s3_hook.extra_args,
@mock.patch("airflow.providers.amazon.aws.hooks.s3.open")
def test_download_file_with_preserve_name_with_autogenerated_subdir(self, mock_open):
file_name = "test.log"
bucket = "test_bucket"
key = f"test_key/{file_name}"
local_folder = "/tmp"

s3_hook = S3Hook(aws_conn_id="s3_test")
s3_hook.check_for_key = Mock(return_value=True)
s3_obj = Mock()
s3_obj.key = f"s3://{bucket}/{key}"
s3_obj.download_fileobj = Mock(return_value=None)
s3_hook.get_key = Mock(return_value=s3_obj)
result_file = s3_hook.download_file(
key=key,
bucket_name=bucket,
local_path=local_folder,
preserve_file_name=True,
use_autogenerated_subdir=True,
)

assert result_file.rsplit("/", 1)[-2].startswith("airflow_tmp_dir_")

def test_download_file_with_preserve_name_file_already_exists(self):
with tempfile.NamedTemporaryFile(dir="/tmp", prefix="airflow_tmp_test_s3_hook") as file:
file_name = file.name.rsplit("/", 1)[-1]
bucket = "test_bucket"
key = f"test_key/{file_name}"
local_folder = "/tmp"
s3_hook = S3Hook(aws_conn_id="s3_test")
s3_hook.check_for_key = Mock(return_value=True)
s3_obj = Mock()
s3_obj.key = f"s3://{bucket}/{key}"
s3_obj.download_fileobj = Mock(return_value=None)
s3_hook.get_key = Mock(return_value=s3_obj)
with pytest.raises(FileExistsError):
s3_hook.download_file(
key=key,
bucket_name=bucket,
local_path=local_folder,
preserve_file_name=True,
use_autogenerated_subdir=False,
)

def test_generate_presigned_url(self, s3_bucket):
hook = S3Hook()
presigned_url = hook.generate_presigned_url(
Expand Down