File tree Expand file tree Collapse file tree 7 files changed +72
-8
lines changed Expand file tree Collapse file tree 7 files changed +72
-8
lines changed Original file line number Diff line number Diff line change
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import asyncio
9
+ import logging
10
+ import os
11
+ import tempfile
12
+ import unittest
13
+
14
+ import torch
15
+ import torchsnapshot
16
+ from torchsnapshot .storage_plugins .fs import FSStoragePlugin
17
+
18
+ logger = logging .getLogger (__name__ )
19
+
20
+ _TENSOR_SZ = int (100_000_000 / 4 )
21
+
22
+
23
+ class FSStoragePluginTest (unittest .TestCase ):
24
+ def test_write_read_delete (self ) -> None :
25
+ with tempfile .TemporaryDirectory () as path :
26
+ logger .info (path )
27
+ plugin = FSStoragePlugin (root = path )
28
+
29
+ tensor = torch .rand ((_TENSOR_SZ ,))
30
+ tensor_path = os .path .join (path , "tensor" )
31
+ write_req = torchsnapshot .io_types .IOReq (path = tensor_path )
32
+ torch .save (tensor , write_req .buf )
33
+ asyncio .run (plugin .write (io_req = write_req ))
34
+ self .assertTrue (os .path .exists (tensor_path ))
35
+
36
+ read_req = torchsnapshot .io_types .IOReq (path = tensor_path )
37
+ asyncio .run (plugin .read (io_req = read_req ))
38
+ loaded = torch .load (read_req .buf )
39
+ self .assertTrue (torch .allclose (tensor , loaded ))
40
+
41
+ asyncio .run (plugin .delete (path = tensor_path ))
42
+ self .assertFalse (os .path .exists (tensor_path ))
43
+ plugin .close ()
Original file line number Diff line number Diff line change @@ -38,20 +38,21 @@ def test_read_write_via_snapshot(self) -> None:
38
38
self .assertTrue (torch .allclose (tensor , app_state ["state" ]["tensor" ]))
39
39
40
40
@unittest .skipIf (os .environ .get ("TORCHSNAPSHOT_ENABLE_GCP_TEST" ) is None , "" )
41
- def test_write_read (self ) -> None :
41
+ def test_write_read_delete (self ) -> None :
42
42
path = f"{ _TEST_BUCKET } /{ uuid .uuid4 ()} "
43
43
logger .info (path )
44
44
plugin = GCSStoragePlugin (root = path )
45
45
46
46
tensor = torch .rand ((_TENSOR_SZ ,))
47
- write_req = torchsnapshot .io_types .IOReq (path = os .path .join (path , "tensor" ))
47
+ path = os .path .join (path , "tensor" )
48
+ write_req = torchsnapshot .io_types .IOReq (path = path )
48
49
torch .save (tensor , write_req .buf )
49
50
asyncio .run (plugin .write (io_req = write_req ))
50
51
51
- read_req = torchsnapshot .io_types .IOReq (path = os . path . join ( path , "tensor" ) )
52
+ read_req = torchsnapshot .io_types .IOReq (path = path )
52
53
asyncio .run (plugin .read (io_req = read_req ))
53
54
loaded = torch .load (read_req .buf )
54
-
55
55
self .assertTrue (torch .allclose (tensor , loaded ))
56
56
57
+ asyncio .run (plugin .delete (path = path ))
57
58
plugin .close ()
Original file line number Diff line number Diff line change @@ -38,20 +38,21 @@ def test_read_write_via_snapshot(self) -> None:
38
38
self .assertTrue (torch .allclose (tensor , app_state ["state" ]["tensor" ]))
39
39
40
40
@unittest .skipIf (os .environ .get ("TORCHSNAPSHOT_ENABLE_AWS_TEST" ) is None , "" )
41
- def test_write_read (self ) -> None :
41
+ def test_write_read_delete (self ) -> None :
42
42
path = f"{ _TEST_BUCKET } /{ uuid .uuid4 ()} "
43
43
logger .info (path )
44
44
plugin = S3StoragePlugin (root = path )
45
45
46
46
tensor = torch .rand ((_TENSOR_SZ ,))
47
- write_req = torchsnapshot .io_types .IOReq (path = os .path .join (path , "tensor" ))
47
+ path = os .path .join (path , "tensor" )
48
+ write_req = torchsnapshot .io_types .IOReq (path = path )
48
49
torch .save (tensor , write_req .buf )
49
50
asyncio .run (plugin .write (io_req = write_req ))
50
51
51
- read_req = torchsnapshot .io_types .IOReq (path = os . path . join ( path , "tensor" ) )
52
+ read_req = torchsnapshot .io_types .IOReq (path = path )
52
53
asyncio .run (plugin .read (io_req = read_req ))
53
54
loaded = torch .load (read_req .buf )
54
-
55
55
self .assertTrue (torch .allclose (tensor , loaded ))
56
56
57
+ asyncio .run (plugin .delete (path = path ))
57
58
plugin .close ()
Original file line number Diff line number Diff line change @@ -41,6 +41,10 @@ async def write(self, io_req: IOReq) -> None:
41
41
async def read (self , io_req : IOReq ) -> None :
42
42
pass
43
43
44
+ @abc .abstractmethod
45
+ async def delete (self , path : str ) -> None :
46
+ pass
47
+
44
48
@abc .abstractmethod
45
49
def close (self ) -> None :
46
50
pass
Original file line number Diff line number Diff line change 11
11
from typing import Set
12
12
13
13
import aiofiles
14
+ import aiofiles .os
14
15
from torchsnapshot .io_types import IOReq , StoragePlugin
15
16
16
17
@@ -36,5 +37,8 @@ async def read(self, io_req: IOReq) -> None:
36
37
async with aiofiles .open (path , "rb" ) as f :
37
38
io_req .buf = io .BytesIO (await f .read ())
38
39
40
+ async def delete (self , path : str ) -> None :
41
+ await aiofiles .os .remove (path )
42
+
39
43
def close (self ) -> None :
40
44
pass
Original file line number Diff line number Diff line change @@ -57,5 +57,11 @@ async def read(self, io_req: IOReq) -> None:
57
57
)
58
58
io_req .buf .seek (0 )
59
59
60
+ async def delete (self , path : str ) -> None :
61
+ loop = asyncio .get_running_loop ()
62
+ key = os .path .join (self .root , path )
63
+ blob = self .bucket .blob (key )
64
+ await loop .run_in_executor (self .executor , blob .delete )
65
+
60
66
def close (self ) -> None :
61
67
self .executor .shutdown ()
Original file line number Diff line number Diff line change @@ -44,5 +44,10 @@ async def read(self, io_req: IOReq) -> None:
44
44
async with response ["Body" ] as stream :
45
45
io_req .buf = io .BytesIO (await stream .read ())
46
46
47
+ async def delete (self , path : str ) -> None :
48
+ async with self .session .create_client ("s3" ) as client :
49
+ key = os .path .join (self .root , path )
50
+ await client .delete_object (Bucket = self .bucket , Key = key )
51
+
47
52
def close (self ) -> None :
48
53
pass
You can’t perform that action at this time.
0 commit comments