Skip to content

Commit 2757045

Browse files
author
Ryan Li
committed
add s3path
1 parent 2d95165 commit 2757045

File tree

4 files changed

+450
-0
lines changed

4 files changed

+450
-0
lines changed

s3torchconnector/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ classifiers = [
2424
dependencies = [
2525
"torch >= 2.0.1, != 2.5.0",
2626
"s3torchconnectorclient >= 1.3.0",
27+
"pathlib_abc >= 0.3.1"
2728
]
2829

2930
[project.optional-dependencies]

s3torchconnector/src/s3torchconnector/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .s3iterable_dataset import S3IterableDataset
1111
from .s3map_dataset import S3MapDataset
1212
from .s3checkpoint import S3Checkpoint
13+
from .s3path import S3Path
1314
from ._version import __version__
1415
from ._s3client import S3ClientConfig
1516

@@ -21,5 +22,6 @@
2122
"S3Writer",
2223
"S3Exception",
2324
"S3ClientConfig",
25+
"S3Path",
2426
"__version__",
2527
]
Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
3+
import errno
4+
import io
5+
import logging
6+
import os
7+
import posixpath
8+
import stat
9+
from pathlib import PurePosixPath
10+
from types import SimpleNamespace
11+
from typing import Optional
12+
from urllib.parse import urlparse
13+
from pathlib_abc import ParserBase, PathBase, UnsupportedOperation
14+
15+
16+
from s3torchconnectorclient._mountpoint_s3_client import S3Exception
17+
from ._s3client import S3Client, S3ClientConfig
18+
19+
logger = logging.getLogger(__name__)
20+
21+
ENV_S3_TORCH_CONNECTOR_REGION = "S3_TORCH_CONNECTOR_REGION"
22+
ENV_S3_TORCH_CONNECTOR_THROUGHPUT_TARGET_GPBS = (
23+
"S3_TORCH_CONNECTOR_THROUGHPUT_TARGET_GPBS"
24+
)
25+
ENV_S3_TORCH_CONNECTOR_PART_SIZE_MB = "S3_TORCH_CONNECTOR_PART_SIZE_MB"
26+
DRIVE = "s3://"
27+
28+
29+
def _get_default_bucket_region():
30+
for var in [
31+
ENV_S3_TORCH_CONNECTOR_REGION,
32+
"AWS_DEFAULT_REGION",
33+
"AWS_REGION",
34+
"REGION",
35+
]:
36+
if var in os.environ:
37+
return os.environ[var]
38+
39+
40+
def _get_default_throughput_target_gbps():
41+
if ENV_S3_TORCH_CONNECTOR_THROUGHPUT_TARGET_GPBS in os.environ:
42+
return float(os.environ[ENV_S3_TORCH_CONNECTOR_THROUGHPUT_TARGET_GPBS])
43+
44+
45+
def _get_default_part_size():
46+
if ENV_S3_TORCH_CONNECTOR_PART_SIZE_MB in os.environ:
47+
return int(os.environ[ENV_S3_TORCH_CONNECTOR_PART_SIZE_MB]) * 1024 * 1024
48+
49+
50+
class S3Parser(ParserBase):
51+
@classmethod
52+
def _unsupported_msg(cls, attribute):
53+
return f"{cls.__name__}.{attribute} is unsupported"
54+
55+
@property
56+
def sep(self):
57+
return "/"
58+
59+
def join(self, path, *paths):
60+
return posixpath.join(path, *paths)
61+
62+
def split(self, path):
63+
scheme, bucket, prefix, _, _, _ = urlparse(path)
64+
parent, _, name = prefix.lstrip("/").rpartition("/")
65+
if not bucket:
66+
return bucket, name
67+
return (scheme + "://" + bucket + "/" + parent, name)
68+
69+
def splitdrive(self, path):
70+
scheme, bucket, prefix, _, _, _ = urlparse(path)
71+
drive = f"{scheme}://{bucket}"
72+
return drive, prefix.lstrip("/")
73+
74+
def splitext(self, path):
75+
return posixpath.splitext(path)
76+
77+
def normcase(self, path):
78+
return posixpath.normcase(path)
79+
80+
def isabs(self, path):
81+
s = os.fspath(path)
82+
scheme_tail = s.split("://", 1)
83+
return len(scheme_tail) == 2
84+
85+
86+
class S3Path(PathBase):
87+
__slots__ = ("_region", "_s3_client_config", "_client", "_raw_path")
88+
parser = S3Parser()
89+
90+
def __init__(
91+
self,
92+
*pathsegments,
93+
client: Optional[S3Client] = None,
94+
region=None,
95+
s3_client_config=None,
96+
):
97+
super().__init__(*pathsegments)
98+
if not self.drive.startswith(DRIVE):
99+
raise ValueError("Should pass in S3 uri")
100+
self._region = region or _get_default_bucket_region()
101+
self._s3_client_config = s3_client_config or S3ClientConfig(
102+
throughput_target_gbps=_get_default_throughput_target_gbps(),
103+
part_size=_get_default_part_size(),
104+
)
105+
self._client = client or S3Client(
106+
region=self._region,
107+
s3client_config=self._s3_client_config,
108+
)
109+
110+
def __repr__(self):
111+
return f"{type(self).__name__}({str(self)!r})"
112+
113+
def __hash__(self):
114+
return hash(str(self))
115+
116+
def __eq__(self, other):
117+
if not isinstance(other, S3Path):
118+
return NotImplemented
119+
return str(self) == str(other)
120+
121+
def with_segments(self, *pathsegments):
122+
path = str("/".join(pathsegments))
123+
if path.startswith("/"):
124+
path = path[1:]
125+
if not path.startswith(self.anchor):
126+
path = f"{self.anchor}{path}"
127+
return type(self)(
128+
path,
129+
client=self._client,
130+
region=self._region,
131+
s3_client_config=self._s3_client_config,
132+
)
133+
134+
@property
135+
def bucket(self):
136+
if self.is_absolute() and self.drive.startswith("s3://"):
137+
return self.drive[5:]
138+
return ""
139+
140+
@property
141+
def key(self):
142+
if self.is_absolute() and len(self.parts) > 1:
143+
return self.parser.sep.join(self.parts[1:])
144+
return ""
145+
146+
def open(self, mode="r", buffering=-1, encoding=None, errors=None, newline=None):
147+
if buffering != -1 or not self.is_absolute():
148+
raise UnsupportedOperation()
149+
action = "".join(c for c in mode if c not in "btU")
150+
if action == "r":
151+
try:
152+
fileobj = self._client.get_object(self.bucket, self.key)
153+
except S3Exception:
154+
raise FileNotFoundError(errno.ENOENT, "Not found", str(self)) from None
155+
except:
156+
raise
157+
elif action == "w":
158+
try:
159+
fileobj = self._client.put_object(self.bucket, self.key)
160+
except S3Exception:
161+
raise
162+
except:
163+
raise
164+
else:
165+
raise UnsupportedOperation()
166+
if "b" not in mode:
167+
fileobj = io.TextIOWrapper(fileobj, encoding, errors, newline)
168+
return fileobj
169+
170+
def stat(self, *, follow_symlinks=True):
171+
try:
172+
info = self._client.head_object(self.bucket, self.key)
173+
mode = stat.S_IFDIR if info.key.endswith("/") else stat.S_IFREG
174+
except S3Exception as e:
175+
error_msg = f"No stats available for {self}; it may not exist."
176+
try:
177+
info = self._client.head_object(self.bucket, self.key + "/")
178+
mode = stat.S_IFDIR
179+
except S3Exception:
180+
try:
181+
listobjs = list(self._client.list_objects(self.bucket, self.key))
182+
if len(listobjs) > 0 and len(listobjs[0].object_info) > 0:
183+
info = SimpleNamespace(
184+
size=0,
185+
last_modified=None,
186+
)
187+
mode = stat.S_IFDIR
188+
else:
189+
raise FileNotFoundError(error_msg) from e
190+
except S3Exception:
191+
raise FileNotFoundError(error_msg) from e
192+
return os.stat_result(
193+
(
194+
mode, # mode
195+
None, # ino
196+
"s3://", # dev,
197+
None, # nlink,
198+
None, # uid,
199+
None, # gid,
200+
info.size, # size,
201+
None, # atime,
202+
info.last_modified, # mtime,
203+
None, # ctime,
204+
)
205+
)
206+
207+
def iterdir(self):
208+
if not self.is_dir():
209+
raise NotADirectoryError("not a s3 folder")
210+
try:
211+
key = self.key if (not self.key or self.key.endswith("/")) else self.key + "/"
212+
pages = iter(self._client.list_objects(self.bucket, key, delimiter="/"))
213+
for page in pages:
214+
for prefix in page.common_prefixes:
215+
# yield directories first
216+
yield self.with_segments(prefix.rstrip("/"))
217+
for info in page.object_info:
218+
if info.key != key:
219+
yield self.with_segments(info.key)
220+
except S3Exception:
221+
raise
222+
223+
def mkdir(self, mode=0o777, parents=False, exist_ok=False):
224+
if self.is_dir():
225+
if exist_ok:
226+
return
227+
raise FileExistsError("S3 folder already exists.")
228+
writer = None
229+
try:
230+
writer = self._client.put_object(
231+
self.bucket, self.key if self.key.endswith("/") else self.key + "/"
232+
)
233+
finally:
234+
if writer is not None:
235+
writer.close()
236+
237+
def unlink(self, missing_ok=False):
238+
if self.is_dir():
239+
if missing_ok:
240+
return
241+
raise Exception(
242+
f"Path {self} is a directory; call rmdir instead of unlink."
243+
)
244+
self._client.delete_object(self.bucket, self.key)
245+
246+
def rmdir(self):
247+
try:
248+
next(self.iterdir())
249+
raise Exception(f"Path {self} is not empty")
250+
except NotADirectoryError:
251+
self._client.delete_object(self.bucket, self.key)
252+
253+
def glob(self, pattern, *, case_sensitive=None, recurse_symlinks=True):
254+
if ".." in pattern:
255+
raise NotImplementedError(
256+
"Relative paths with '..' not supported in glob patterns"
257+
)
258+
if pattern.startswith(self.anchor) or pattern.startswith("/"):
259+
raise NotImplementedError("Non-relative patterns are unsupported")
260+
261+
parts = list(PurePosixPath(pattern).parts)
262+
select = self._glob_selector(parts, case_sensitive, recurse_symlinks)
263+
return select(self)
264+
265+
def with_name(self, name):
266+
"""Return a new path with the file name changed."""
267+
split = self.parser.split
268+
if split(name)[0]:
269+
raise ValueError(f"Invalid name {name!r}")
270+
return self.with_segments(split(self._raw_path)[0], name)
271+
272+
def __getstate__(self):
273+
state = {
274+
slot: getattr(self, slot, None)
275+
for cls in self.__class__.__mro__
276+
for slot in getattr(cls, "__slots__", [])
277+
if slot
278+
not in [
279+
"_client",
280+
]
281+
}
282+
return (None, state)
283+
284+
def __setstate__(self, state):
285+
_, state_dict = state
286+
for slot, value in state_dict.items():
287+
if slot not in ["_client"]:
288+
setattr(self, slot, value)
289+
self._client = S3Client(
290+
region=self._region,
291+
s3client_config=self._s3_client_config,
292+
)

0 commit comments

Comments
 (0)