Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closes #2108. New boto session for each thread #2136

Merged
merged 8 commits into from
Mar 11, 2020
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ These changes are available in the [master branch](https://github.com/PrefectHQ/

### Enhancements

- - Add examples to Interactive API Docs [#2122](https://github.com/PrefectHQ/prefect/pull/2122)
- Add examples to Interactive API Docs [#2122](https://github.com/PrefectHQ/prefect/pull/2122)
- Use a new boto3 session per thread when using S3ResultHandlers [#2108](https://github.com/PrefectHQ/prefect/issues/2108)

### Task Library

Expand Down
13 changes: 11 additions & 2 deletions src/prefect/engine/result_handlers/s3_result_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import cloudpickle
import pendulum

import prefect
from prefect.client import Secret
from prefect.engine.result_handlers import ResultHandler

Expand Down Expand Up @@ -54,7 +55,10 @@ def initialize_client(self) -> None:
aws_access_key = aws_credentials["ACCESS_KEY"]
aws_secret_access_key = aws_credentials["SECRET_ACCESS_KEY"]

s3_client = boto3.client(
# use a new boto session when initializing in case we are in a new thread
# see https://boto3.amazonaws.com/v1/documentation/api/latest/guide/resources.html?#multithreading-multiprocessing
session = boto3.session.Session()
s3_client = session.client(
"s3",
aws_access_key_id=aws_access_key,
aws_secret_access_key=aws_secret_access_key,
Expand All @@ -63,8 +67,13 @@ def initialize_client(self) -> None:

@property
def client(self) -> "boto3.client":
if not hasattr(self, "_client"):
"""
Initializes a client if we believe we are in a new thread.
We consider ourselves in a new thread if we haven't stored a client yet in the current context.
"""
if not prefect.context.get("boto3client"):
self.initialize_client()
prefect.context["boto3client"] = self._client
return self._client

@client.setter
Expand Down
31 changes: 18 additions & 13 deletions tests/engine/result_handlers/test_result_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,32 +166,31 @@ def __getstate__(self):
@pytest.mark.xfail(raises=ImportError, reason="aws extras not installed.")
class TestS3ResultHandler:
@pytest.fixture
def s3_client(self, monkeypatch):
def session(self, monkeypatch):
import boto3

client = MagicMock()
with patch.dict("sys.modules", {"boto3": MagicMock(client=client)}):
yield client
session = MagicMock()
with patch.dict("sys.modules", {"boto3": MagicMock(session=session)}):
yield session

def test_s3_client_init_uses_secrets(self, s3_client):
def test_s3_client_init_uses_secrets(self, session):
handler = S3ResultHandler(
bucket="bob", aws_credentials_secret="AWS_CREDENTIALS"
)
assert handler.bucket == "bob"
assert s3_client.called is False
assert session.Session().client.called is False

with prefect.context(
secrets=dict(AWS_CREDENTIALS=dict(ACCESS_KEY=1, SECRET_ACCESS_KEY=42))
):
with set_temporary_config({"cloud.use_local_secrets": True}):
handler.initialize_client()

assert s3_client.call_args[1] == {
assert session.Session().client.call_args[1] == {
"aws_access_key_id": 1,
"aws_secret_access_key": 42,
}

def test_s3_client_init_uses_custom_secrets(self, s3_client):
def test_s3_client_init_uses_custom_secrets(self, session):
handler = S3ResultHandler(bucket="bob", aws_credentials_secret="MY_FOO")

with prefect.context(
Expand All @@ -201,12 +200,12 @@ def test_s3_client_init_uses_custom_secrets(self, s3_client):
handler.initialize_client()

assert handler.bucket == "bob"
assert s3_client.call_args[1] == {
assert session.Session().client.call_args[1] == {
"aws_access_key_id": 1,
"aws_secret_access_key": 999,
}

def test_s3_writes_to_blob_prefixed_by_date_suffixed_by_prefect(self, s3_client):
def test_s3_writes_to_blob_prefixed_by_date_suffixed_by_prefect(self, session):
handler = S3ResultHandler(bucket="foo")

with prefect.context(
Expand All @@ -215,7 +214,9 @@ def test_s3_writes_to_blob_prefixed_by_date_suffixed_by_prefect(self, s3_client)
with set_temporary_config({"cloud.use_local_secrets": True}):
uri = handler.write("so-much-data")

used_uri = s3_client.return_value.upload_fileobj.call_args[1]["Key"]
used_uri = session.Session().client.return_value.upload_fileobj.call_args[1][
"Key"
]

assert used_uri == uri
assert used_uri.startswith(pendulum.now("utc").format("Y/M/D"))
Expand All @@ -229,7 +230,11 @@ def __init__(self, *args, **kwargs):
def __getstate__(self):
raise ValueError("I cannot be pickled.")

with patch.dict("sys.modules", {"boto3": MagicMock(client=client)}):
import boto3

with patch.dict("sys.modules", {"boto3": MagicMock()}):
boto3.session.Session().client = client

with prefect.context(
secrets=dict(AWS_CREDENTIALS=dict(ACCESS_KEY=1, SECRET_ACCESS_KEY=42))
):
Expand Down