Skip to content

Commit d36637b

Browse files
tangbinhfacebook-github-bot
authored andcommitted
Add a delete function to FSStoragePlugin, S3StoragePlugin and GCSStoragePlugin
Summary: We add a `remove` function to the `FSStoragePlugin` in order to implement a delete API for TorchSnapshot. Similar functions for `GCSStoragePlugin` and `S3StoragePlugin` and their integration with `Snapshot` will be added in subsequent diffs. Reviewed By: yifuwang Differential Revision: D36948787 fbshipit-source-id: fc37336e5738a9122bb72f68afe192491faceb13
1 parent 7eba5e1 commit d36637b

File tree

7 files changed

+72
-8
lines changed

7 files changed

+72
-8
lines changed

tests/test_fs_storage_plugin.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import asyncio
9+
import logging
10+
import os
11+
import tempfile
12+
import unittest
13+
14+
import torch
15+
import torchsnapshot
16+
from torchsnapshot.storage_plugins.fs import FSStoragePlugin
17+
18+
logger = logging.getLogger(__name__)
19+
20+
_TENSOR_SZ = int(100_000_000 / 4)
21+
22+
23+
class FSStoragePluginTest(unittest.TestCase):
24+
def test_write_read_delete(self) -> None:
25+
with tempfile.TemporaryDirectory() as path:
26+
logger.info(path)
27+
plugin = FSStoragePlugin(root=path)
28+
29+
tensor = torch.rand((_TENSOR_SZ,))
30+
tensor_path = os.path.join(path, "tensor")
31+
write_req = torchsnapshot.io_types.IOReq(path=tensor_path)
32+
torch.save(tensor, write_req.buf)
33+
asyncio.run(plugin.write(io_req=write_req))
34+
self.assertTrue(os.path.exists(tensor_path))
35+
36+
read_req = torchsnapshot.io_types.IOReq(path=tensor_path)
37+
asyncio.run(plugin.read(io_req=read_req))
38+
loaded = torch.load(read_req.buf)
39+
self.assertTrue(torch.allclose(tensor, loaded))
40+
41+
asyncio.run(plugin.delete(path=tensor_path))
42+
self.assertFalse(os.path.exists(tensor_path))
43+
plugin.close()

tests/test_gcs_storage_plugin.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,21 @@ def test_read_write_via_snapshot(self) -> None:
3838
self.assertTrue(torch.allclose(tensor, app_state["state"]["tensor"]))
3939

4040
@unittest.skipIf(os.environ.get("TORCHSNAPSHOT_ENABLE_GCP_TEST") is None, "")
41-
def test_write_read(self) -> None:
41+
def test_write_read_delete(self) -> None:
4242
path = f"{_TEST_BUCKET}/{uuid.uuid4()}"
4343
logger.info(path)
4444
plugin = GCSStoragePlugin(root=path)
4545

4646
tensor = torch.rand((_TENSOR_SZ,))
47-
write_req = torchsnapshot.io_types.IOReq(path=os.path.join(path, "tensor"))
47+
path = os.path.join(path, "tensor")
48+
write_req = torchsnapshot.io_types.IOReq(path=path)
4849
torch.save(tensor, write_req.buf)
4950
asyncio.run(plugin.write(io_req=write_req))
5051

51-
read_req = torchsnapshot.io_types.IOReq(path=os.path.join(path, "tensor"))
52+
read_req = torchsnapshot.io_types.IOReq(path=path)
5253
asyncio.run(plugin.read(io_req=read_req))
5354
loaded = torch.load(read_req.buf)
54-
5555
self.assertTrue(torch.allclose(tensor, loaded))
5656

57+
asyncio.run(plugin.delete(path=path))
5758
plugin.close()

tests/test_s3_storage_plugin.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,21 @@ def test_read_write_via_snapshot(self) -> None:
3838
self.assertTrue(torch.allclose(tensor, app_state["state"]["tensor"]))
3939

4040
@unittest.skipIf(os.environ.get("TORCHSNAPSHOT_ENABLE_AWS_TEST") is None, "")
41-
def test_write_read(self) -> None:
41+
def test_write_read_delete(self) -> None:
4242
path = f"{_TEST_BUCKET}/{uuid.uuid4()}"
4343
logger.info(path)
4444
plugin = S3StoragePlugin(root=path)
4545

4646
tensor = torch.rand((_TENSOR_SZ,))
47-
write_req = torchsnapshot.io_types.IOReq(path=os.path.join(path, "tensor"))
47+
path = os.path.join(path, "tensor")
48+
write_req = torchsnapshot.io_types.IOReq(path=path)
4849
torch.save(tensor, write_req.buf)
4950
asyncio.run(plugin.write(io_req=write_req))
5051

51-
read_req = torchsnapshot.io_types.IOReq(path=os.path.join(path, "tensor"))
52+
read_req = torchsnapshot.io_types.IOReq(path=path)
5253
asyncio.run(plugin.read(io_req=read_req))
5354
loaded = torch.load(read_req.buf)
54-
5555
self.assertTrue(torch.allclose(tensor, loaded))
5656

57+
asyncio.run(plugin.delete(path=path))
5758
plugin.close()

torchsnapshot/io_types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ async def write(self, io_req: IOReq) -> None:
4141
async def read(self, io_req: IOReq) -> None:
4242
pass
4343

44+
@abc.abstractmethod
45+
async def delete(self, path: str) -> None:
46+
pass
47+
4448
@abc.abstractmethod
4549
def close(self) -> None:
4650
pass

torchsnapshot/storage_plugins/fs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import Set
1212

1313
import aiofiles
14+
import aiofiles.os
1415
from torchsnapshot.io_types import IOReq, StoragePlugin
1516

1617

@@ -36,5 +37,8 @@ async def read(self, io_req: IOReq) -> None:
3637
async with aiofiles.open(path, "rb") as f:
3738
io_req.buf = io.BytesIO(await f.read())
3839

40+
async def delete(self, path: str) -> None:
41+
await aiofiles.os.remove(path)
42+
3943
def close(self) -> None:
4044
pass

torchsnapshot/storage_plugins/gcs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,5 +57,11 @@ async def read(self, io_req: IOReq) -> None:
5757
)
5858
io_req.buf.seek(0)
5959

60+
async def delete(self, path: str) -> None:
61+
loop = asyncio.get_running_loop()
62+
key = os.path.join(self.root, path)
63+
blob = self.bucket.blob(key)
64+
await loop.run_in_executor(self.executor, blob.delete)
65+
6066
def close(self) -> None:
6167
self.executor.shutdown()

torchsnapshot/storage_plugins/s3.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,10 @@ async def read(self, io_req: IOReq) -> None:
4444
async with response["Body"] as stream:
4545
io_req.buf = io.BytesIO(await stream.read())
4646

47+
async def delete(self, path: str) -> None:
48+
async with self.session.create_client("s3") as client:
49+
key = os.path.join(self.root, path)
50+
await client.delete_object(Bucket=self.bucket, Key=key)
51+
4752
def close(self) -> None:
4853
pass

0 commit comments

Comments
 (0)