Skip to content

Commit

Permalink
Merge pull request #2136 from PrefectHQ/issue-2108-new-boto-session-f…
Browse files Browse the repository at this point in the history
…or-each-thread

Closes #2108. New boto session for each thread
  • Loading branch information
cicdw authored Mar 11, 2020
2 parents 974625c + b16042c commit 7f12d0a
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 16 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ These changes are available in the [master branch](https://github.com/PrefectHQ/

### Enhancements

- - Add examples to Interactive API Docs [#2122](https://github.com/PrefectHQ/prefect/pull/2122)
- Add examples to Interactive API Docs [#2122](https://github.com/PrefectHQ/prefect/pull/2122)
- Use a new boto3 session per thread when using S3ResultHandlers [#2108](https://github.com/PrefectHQ/prefect/issues/2108)

### Task Library

Expand Down
13 changes: 11 additions & 2 deletions src/prefect/engine/result_handlers/s3_result_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import cloudpickle
import pendulum

import prefect
from prefect.client import Secret
from prefect.engine.result_handlers import ResultHandler

Expand Down Expand Up @@ -54,7 +55,10 @@ def initialize_client(self) -> None:
aws_access_key = aws_credentials["ACCESS_KEY"]
aws_secret_access_key = aws_credentials["SECRET_ACCESS_KEY"]

s3_client = boto3.client(
# use a new boto session when initializing in case we are in a new thread
# see https://boto3.amazonaws.com/v1/documentation/api/latest/guide/resources.html?#multithreading-multiprocessing
session = boto3.session.Session()
s3_client = session.client(
"s3",
aws_access_key_id=aws_access_key,
aws_secret_access_key=aws_secret_access_key,
Expand All @@ -63,8 +67,13 @@ def initialize_client(self) -> None:

@property
def client(self) -> "boto3.client":
if not hasattr(self, "_client"):
"""
Initializes a client if we believe we are in a new thread.
We consider ourselves in a new thread if we haven't stored a client yet in the current context.
"""
if not prefect.context.get("boto3client"):
self.initialize_client()
prefect.context["boto3client"] = self._client
return self._client

@client.setter
Expand Down
31 changes: 18 additions & 13 deletions tests/engine/result_handlers/test_result_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,32 +166,31 @@ def __getstate__(self):
@pytest.mark.xfail(raises=ImportError, reason="aws extras not installed.")
class TestS3ResultHandler:
@pytest.fixture
def s3_client(self, monkeypatch):
def session(self, monkeypatch):
import boto3

client = MagicMock()
with patch.dict("sys.modules", {"boto3": MagicMock(client=client)}):
yield client
session = MagicMock()
with patch.dict("sys.modules", {"boto3": MagicMock(session=session)}):
yield session

def test_s3_client_init_uses_secrets(self, s3_client):
def test_s3_client_init_uses_secrets(self, session):
handler = S3ResultHandler(
bucket="bob", aws_credentials_secret="AWS_CREDENTIALS"
)
assert handler.bucket == "bob"
assert s3_client.called is False
assert session.Session().client.called is False

with prefect.context(
secrets=dict(AWS_CREDENTIALS=dict(ACCESS_KEY=1, SECRET_ACCESS_KEY=42))
):
with set_temporary_config({"cloud.use_local_secrets": True}):
handler.initialize_client()

assert s3_client.call_args[1] == {
assert session.Session().client.call_args[1] == {
"aws_access_key_id": 1,
"aws_secret_access_key": 42,
}

def test_s3_client_init_uses_custom_secrets(self, s3_client):
def test_s3_client_init_uses_custom_secrets(self, session):
handler = S3ResultHandler(bucket="bob", aws_credentials_secret="MY_FOO")

with prefect.context(
Expand All @@ -201,12 +200,12 @@ def test_s3_client_init_uses_custom_secrets(self, s3_client):
handler.initialize_client()

assert handler.bucket == "bob"
assert s3_client.call_args[1] == {
assert session.Session().client.call_args[1] == {
"aws_access_key_id": 1,
"aws_secret_access_key": 999,
}

def test_s3_writes_to_blob_prefixed_by_date_suffixed_by_prefect(self, s3_client):
def test_s3_writes_to_blob_prefixed_by_date_suffixed_by_prefect(self, session):
handler = S3ResultHandler(bucket="foo")

with prefect.context(
Expand All @@ -215,7 +214,9 @@ def test_s3_writes_to_blob_prefixed_by_date_suffixed_by_prefect(self, s3_client)
with set_temporary_config({"cloud.use_local_secrets": True}):
uri = handler.write("so-much-data")

used_uri = s3_client.return_value.upload_fileobj.call_args[1]["Key"]
used_uri = session.Session().client.return_value.upload_fileobj.call_args[1][
"Key"
]

assert used_uri == uri
assert used_uri.startswith(pendulum.now("utc").format("Y/M/D"))
Expand All @@ -229,7 +230,11 @@ def __init__(self, *args, **kwargs):
def __getstate__(self):
raise ValueError("I cannot be pickled.")

with patch.dict("sys.modules", {"boto3": MagicMock(client=client)}):
import boto3

with patch.dict("sys.modules", {"boto3": MagicMock()}):
boto3.session.Session().client = client

with prefect.context(
secrets=dict(AWS_CREDENTIALS=dict(ACCESS_KEY=1, SECRET_ACCESS_KEY=42))
):
Expand Down

0 comments on commit 7f12d0a

Please sign in to comment.