Skip to content

Commit 9cded51

Browse files
author
Ryan Li
committed
add s3path
1 parent 2d95165 commit 9cded51

File tree

4 files changed

+558
-0
lines changed

4 files changed

+558
-0
lines changed

s3torchconnector/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ classifiers = [
2424
dependencies = [
2525
"torch >= 2.0.1, != 2.5.0",
2626
"s3torchconnectorclient >= 1.3.0",
27+
"pathlib_abc >= 0.3.1"
2728
]
2829

2930
[project.optional-dependencies]

s3torchconnector/src/s3torchconnector/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .s3iterable_dataset import S3IterableDataset
1111
from .s3map_dataset import S3MapDataset
1212
from .s3checkpoint import S3Checkpoint
13+
from .s3path import S3Path
1314
from ._version import __version__
1415
from ._s3client import S3ClientConfig
1516

@@ -21,5 +22,6 @@
2122
"S3Writer",
2223
"S3Exception",
2324
"S3ClientConfig",
25+
"S3Path",
2426
"__version__",
2527
]
Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
3+
import errno
4+
import io
5+
import logging
6+
import os
7+
import posixpath
8+
import stat
9+
import time
10+
from types import SimpleNamespace
11+
from typing import Optional
12+
13+
from pathlib import PurePosixPath
14+
from pathlib_abc import ParserBase, PathBase, UnsupportedOperation
15+
from urllib.parse import urlparse
16+
17+
from s3torchconnectorclient._mountpoint_s3_client import S3Exception
18+
from ._s3client import S3Client, S3ClientConfig
19+
20+
logger = logging.getLogger(__name__)
21+
22+
ENV_S3_TORCH_CONNECTOR_REGION = "S3_TORCH_CONNECTOR_REGION"
23+
ENV_S3_TORCH_CONNECTOR_THROUGHPUT_TARGET_GPBS = (
24+
"S3_TORCH_CONNECTOR_THROUGHPUT_TARGET_GPBS"
25+
)
26+
ENV_S3_TORCH_CONNECTOR_PART_SIZE_MB = "S3_TORCH_CONNECTOR_PART_SIZE_MB"
27+
DRIVE = "s3://"
28+
29+
30+
def _get_default_bucket_region():
31+
for var in [
32+
ENV_S3_TORCH_CONNECTOR_REGION,
33+
"AWS_DEFAULT_REGION",
34+
"AWS_REGION",
35+
"REGION",
36+
]:
37+
if var in os.environ:
38+
return os.environ[var]
39+
40+
41+
def _get_default_throughput_target_gbps():
42+
if ENV_S3_TORCH_CONNECTOR_THROUGHPUT_TARGET_GPBS in os.environ:
43+
return float(os.environ[ENV_S3_TORCH_CONNECTOR_THROUGHPUT_TARGET_GPBS])
44+
45+
46+
def _get_default_part_size():
47+
if ENV_S3_TORCH_CONNECTOR_PART_SIZE_MB in os.environ:
48+
return int(os.environ[ENV_S3_TORCH_CONNECTOR_PART_SIZE_MB]) * 1024 * 1024
49+
50+
51+
class S3Parser(ParserBase):
52+
@classmethod
53+
def _unsupported_msg(cls, attribute):
54+
return f"{cls.__name__}.{attribute} is unsupported"
55+
56+
@property
57+
def sep(self):
58+
return "/"
59+
60+
def join(self, path, *paths):
61+
return posixpath.join(path, *paths)
62+
63+
def split(self, path):
64+
scheme, bucket, prefix, _, _, _ = urlparse(path)
65+
parent, _, name = prefix.lstrip("/").rpartition("/")
66+
if not bucket:
67+
return bucket, name
68+
return (scheme + "://" + bucket + "/" + parent, name)
69+
70+
def splitdrive(self, path):
71+
scheme, bucket, prefix, _, _, _ = urlparse(path)
72+
drive = f"{scheme}://{bucket}"
73+
return drive, prefix.lstrip("/")
74+
75+
def splitext(self, path):
76+
return posixpath.splitext(path)
77+
78+
def normcase(self, path):
79+
return posixpath.normcase(path)
80+
81+
def isabs(self, path):
82+
s = os.fspath(path)
83+
scheme_tail = s.split("://", 1)
84+
return len(scheme_tail) == 2
85+
86+
87+
class S3Path(PathBase):
88+
__slots__ = ("_region", "_s3_client_config", "_client", "_raw_path")
89+
parser = S3Parser()
90+
_stat_cache_ttl_seconds = 1
91+
_stat_cache_size = 1024
92+
_stat_cache = {}
93+
94+
def __init__(
95+
self,
96+
*pathsegments,
97+
client: Optional[S3Client] = None,
98+
region=None,
99+
s3_client_config=None,
100+
):
101+
super().__init__(*pathsegments)
102+
if not self.drive.startswith(DRIVE):
103+
raise ValueError("Should pass in S3 uri")
104+
self._region = region or _get_default_bucket_region()
105+
self._s3_client_config = s3_client_config or S3ClientConfig(
106+
throughput_target_gbps=_get_default_throughput_target_gbps(),
107+
part_size=_get_default_part_size(),
108+
)
109+
self._client = client or S3Client(
110+
region=self._region,
111+
s3client_config=self._s3_client_config,
112+
)
113+
114+
def __repr__(self):
115+
return f"{type(self).__name__}({str(self)!r})"
116+
117+
def __hash__(self):
118+
return hash(str(self))
119+
120+
def __eq__(self, other):
121+
if not isinstance(other, S3Path):
122+
return NotImplemented
123+
return str(self) == str(other)
124+
125+
def with_segments(self, *pathsegments):
126+
path = str("/".join(pathsegments)).lstrip("/")
127+
if not path.startswith(self.anchor):
128+
path = f"{self.anchor}{path}"
129+
return type(self)(
130+
path,
131+
client=self._client,
132+
region=self._region,
133+
s3_client_config=self._s3_client_config,
134+
)
135+
136+
@property
137+
def bucket(self):
138+
if self.is_absolute() and self.drive.startswith(DRIVE):
139+
return self.drive[5:]
140+
return ""
141+
142+
@property
143+
def key(self):
144+
if self.is_absolute() and len(self.parts) > 1:
145+
return self.parser.sep.join(self.parts[1:])
146+
return ""
147+
148+
def open(self, mode="r", buffering=-1, encoding=None, errors=None, newline=None):
149+
if buffering != -1:
150+
raise ValueError("Only default buffering (-1) is supported.")
151+
if not self.is_absolute():
152+
raise ValueError("S3Path must be absolute.")
153+
action = "".join(c for c in mode if c not in "btU")
154+
if action == "r":
155+
try:
156+
fileobj = self._client.get_object(self.bucket, self.key)
157+
except S3Exception:
158+
raise FileNotFoundError(errno.ENOENT, "Not found", str(self)) from None
159+
except:
160+
raise
161+
elif action == "w":
162+
try:
163+
fileobj = self._client.put_object(self.bucket, self.key)
164+
except S3Exception:
165+
raise
166+
except:
167+
raise
168+
else:
169+
raise UnsupportedOperation()
170+
if "b" not in mode:
171+
fileobj = io.TextIOWrapper(fileobj, encoding, errors, newline)
172+
return fileobj
173+
174+
def stat(self, *, follow_symlinks=True):
175+
cache_key = (self.bucket, self.key.rstrip("/"))
176+
cached_result = self._stat_cache.get(cache_key)
177+
if cached_result:
178+
result, timestamp = cached_result
179+
if time.time() - timestamp < self._stat_cache_ttl_seconds:
180+
return result
181+
del self._stat_cache[cache_key]
182+
try:
183+
info = self._client.head_object(self.bucket, self.key.rstrip("/"))
184+
mode = stat.S_IFREG
185+
except S3Exception as e:
186+
listobj = next(self._list_objects(max_keys=2))
187+
188+
if len(listobj.object_info) > 0 or len(listobj.common_prefixes) > 0:
189+
info = SimpleNamespace(size=0, last_modified=None)
190+
mode = stat.S_IFDIR
191+
else:
192+
error_msg = f"No stats available for {self}; it may not exist."
193+
raise FileNotFoundError(error_msg) from e
194+
195+
result = os.stat_result(
196+
(
197+
mode, # mode
198+
None, # ino
199+
DRIVE, # dev
200+
None, # nlink
201+
None, # uid
202+
None, # gid
203+
info.size, # size
204+
None, # atime
205+
info.last_modified or 0, # mtime
206+
None, # ctime
207+
)
208+
)
209+
if len(self._stat_cache) >= self._stat_cache_size:
210+
self._stat_cache.pop(next(iter(self._stat_cache)))
211+
212+
self._stat_cache[cache_key] = (result, time.time())
213+
return result
214+
215+
def iterdir(self):
216+
if not self.is_dir():
217+
raise NotADirectoryError("not a s3 folder")
218+
key = "" if not self.key else self.key.rstrip("/") + "/"
219+
for page in self._list_objects():
220+
for prefix in page.common_prefixes:
221+
# yield directories first
222+
yield self.with_segments(prefix.rstrip("/"))
223+
for info in page.object_info:
224+
if info.key != key:
225+
yield self.with_segments(info.key)
226+
227+
def mkdir(self, mode=0o777, parents=False, exist_ok=False):
228+
if self.is_dir():
229+
if exist_ok:
230+
return
231+
raise FileExistsError(f"S3 folder {self} already exists.")
232+
with self._client.put_object(self.bucket, self.key.rstrip("/") + "/"):
233+
pass
234+
235+
def unlink(self, missing_ok=False):
236+
if self.is_dir():
237+
if missing_ok:
238+
return
239+
raise IsADirectoryError(
240+
f"Path {self} is a directory; call rmdir instead of unlink."
241+
)
242+
self._client.delete_object(self.bucket, self.key)
243+
244+
def rmdir(self):
245+
if not self.is_dir():
246+
raise NotADirectoryError(f"{self} is not an s3 folder")
247+
listobj = next(self._list_objects(max_keys=2))
248+
if len(listobj.object_info) > 1:
249+
raise Exception(f"{self} is not empty")
250+
self._client.delete_object(self.bucket, self.key.rstrip("/") + "/")
251+
252+
def glob(self, pattern, *, case_sensitive=None, recurse_symlinks=True):
253+
if ".." in pattern:
254+
raise NotImplementedError(
255+
"Relative paths with '..' not supported in glob patterns"
256+
)
257+
if pattern.startswith(self.anchor) or pattern.startswith("/"):
258+
raise NotImplementedError("Non-relative patterns are unsupported")
259+
260+
parts = list(PurePosixPath(pattern).parts)
261+
select = self._glob_selector(parts, case_sensitive, recurse_symlinks)
262+
return select(self)
263+
264+
def with_name(self, name):
265+
"""Return a new path with the file name changed."""
266+
split = self.parser.split
267+
if split(name)[0]:
268+
# Ensure that the provided name does not contain any path separators
269+
raise ValueError(f"Invalid name {name!r}")
270+
return self.with_segments(str(self.parent), name)
271+
272+
def _list_objects(self, max_keys: int = 1000):
273+
try:
274+
key = "" if not self.key else self.key.rstrip("/") + "/"
275+
pages = iter(
276+
self._client.list_objects(
277+
self.bucket, key, delimiter="/", max_keys=max_keys
278+
)
279+
)
280+
for page in pages:
281+
yield page
282+
except S3Exception as e:
283+
raise RuntimeError(f"Failed to list contents of {self}") from e
284+
285+
def __getstate__(self):
286+
state = {
287+
slot: getattr(self, slot, None)
288+
for cls in self.__class__.__mro__
289+
for slot in getattr(cls, "__slots__", [])
290+
if slot
291+
not in [
292+
"_client",
293+
]
294+
}
295+
return (None, state)
296+
297+
def __setstate__(self, state):
298+
_, state_dict = state
299+
for slot, value in state_dict.items():
300+
if slot not in ["_client"]:
301+
setattr(self, slot, value)
302+
self._client = S3Client(
303+
region=self._region,
304+
s3client_config=self._s3_client_config,
305+
)

0 commit comments

Comments
 (0)