Skip to content

Commit 452cb40

Browse files
authored
Update the S3 client initialization to explicitly use boto3.Session() (#461)
* Refactor S3 client initialization to use boto3 session for improved configuration handling * update client * update tests to use boto3 session
1 parent a147df6 commit 452cb40

File tree

2 files changed

+22
-14
lines changed

2 files changed

+22
-14
lines changed

src/litdata/streaming/client.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,19 @@ def _create_client(self) -> None:
3838
)
3939

4040
if has_shared_credentials_file or not _IS_IN_STUDIO or self._storage_options:
41-
self._client = boto3.client(
41+
session = boto3.Session()
42+
self._client = session.client(
4243
"s3",
4344
**{
4445
"config": botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}),
45-
**self._storage_options,
46+
**self._storage_options, # If additional options are provided
4647
},
4748
)
4849
else:
4950
provider = InstanceMetadataProvider(iam_role_fetcher=InstanceMetadataFetcher(timeout=3600, num_attempts=5))
5051
credentials = provider.load()
51-
self._client = boto3.client(
52+
session = boto3.Session()
53+
self._client = session.client(
5254
"s3",
5355
aws_access_key_id=credentials.access_key,
5456
aws_secret_access_key=credentials.secret_key,

tests/streaming/test_client.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88

99

1010
def test_s3_client_with_storage_options(monkeypatch):
11-
boto3 = mock.MagicMock()
11+
boto3_session = mock.MagicMock()
12+
boto3 = mock.MagicMock(Session=boto3_session)
1213
monkeypatch.setattr(client, "boto3", boto3)
1314

1415
botocore = mock.MagicMock()
1516
monkeypatch.setattr(client, "botocore", botocore)
1617

18+
# Create S3Client with storage options
1719
storage_options = {
1820
"region_name": "us-west-2",
1921
"endpoint_url": "https://custom.endpoint",
@@ -23,24 +25,27 @@ def test_s3_client_with_storage_options(monkeypatch):
2325

2426
assert s3_client.client
2527

26-
boto3.client.assert_called_with(
28+
boto3_session().client.assert_called_with(
2729
"s3",
2830
region_name="us-west-2",
2931
endpoint_url="https://custom.endpoint",
3032
config=botocore.config.Config(retries={"max_attempts": 100}),
3133
)
3234

35+
# Create S3Client without storage options
3336
s3_client = client.S3Client()
34-
3537
assert s3_client.client
3638

37-
boto3.client.assert_called_with(
38-
"s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"})
39+
# Verify that boto3.Session().client was called with the default parameters
40+
boto3_session().client.assert_called_with(
41+
"s3",
42+
config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}),
3943
)
4044

4145

4246
def test_s3_client_without_cloud_space_id(monkeypatch):
43-
boto3 = mock.MagicMock()
47+
boto3_session = mock.MagicMock()
48+
boto3 = mock.MagicMock(Session=boto3_session)
4449
monkeypatch.setattr(client, "boto3", boto3)
4550

4651
botocore = mock.MagicMock()
@@ -59,13 +64,14 @@ def test_s3_client_without_cloud_space_id(monkeypatch):
5964
assert s3.client
6065
assert s3.client
6166

62-
boto3.client.assert_called_once()
67+
boto3_session().client.assert_called_once()
6368

6469

6570
@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows")
6671
@pytest.mark.parametrize("use_shared_credentials", [False, True, None])
6772
def test_s3_client_with_cloud_space_id(use_shared_credentials, monkeypatch):
68-
boto3 = mock.MagicMock()
73+
boto3_session = mock.MagicMock()
74+
boto3 = mock.MagicMock(Session=boto3_session)
6975
monkeypatch.setattr(client, "boto3", boto3)
7076

7177
botocore = mock.MagicMock()
@@ -85,14 +91,14 @@ def test_s3_client_with_cloud_space_id(use_shared_credentials, monkeypatch):
8591
s3 = client.S3Client(1)
8692
assert s3.client
8793
assert s3.client
88-
boto3.client.assert_called_once()
94+
boto3_session().client.assert_called_once()
8995
sleep(1 - (time() - s3._last_time))
9096
assert s3.client
9197
assert s3.client
92-
assert len(boto3.client._mock_mock_calls) == 6
98+
assert len(boto3_session().client._mock_mock_calls) == 6
9399
sleep(1 - (time() - s3._last_time))
94100
assert s3.client
95101
assert s3.client
96-
assert len(boto3.client._mock_mock_calls) == 9
102+
assert len(boto3_session().client._mock_mock_calls) == 9
97103

98104
assert instance_metadata_provider._mock_call_count == 0 if use_shared_credentials else 3

0 commit comments

Comments
 (0)