-
Notifications
You must be signed in to change notification settings - Fork 160
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: #173 result cache. Reviewed By: wat3rBro Differential Revision: D46511387 fbshipit-source-id: 7d9a85594a8679201a110a1eb21e3742720520f9
- Loading branch information
1 parent
9de93ff
commit 4673743
Showing
2 changed files
with
93 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. | ||
|
||
import shutil | ||
import tempfile | ||
import unittest | ||
|
||
from mobile_cv.common.utils_io import get_path_manager | ||
from mobile_cv.torch.utils_pytorch import comm, distributed_helper as dh | ||
|
||
from mobile_cv.torch.utils_pytorch.result_cache import ResultCache | ||
|
||
|
||
class TestUtilsPytorchResultCache(unittest.TestCase): | ||
@dh.launch_deco(num_processes=2) | ||
def test_result_cache(self): | ||
""" | ||
buck2 run @mode/dev-nosan //mobile-vision/mobile_cv/mobile_cv/torch/tests:utils_pytorch_test_result_cache | ||
""" | ||
path_manager = get_path_manager() | ||
if comm.is_main_process(): | ||
cache_dir = tempfile.mkdtemp() | ||
else: | ||
cache_dir = None | ||
cache_dir = comm.all_gather(cache_dir)[0] | ||
|
||
rc = ResultCache( | ||
cache_dir, "test_cache", logger=None, path_manager=path_manager | ||
) | ||
self.assertFalse(rc.has_cache()) | ||
rc.save({"data": f"data_{comm.get_rank()}"}) | ||
comm.synchronize() | ||
self.assertTrue(rc.has_cache()) | ||
out = rc.load(gather=True) | ||
self.assertEqual(len(out), 2) | ||
self.assertEqual(out[0]["data"], "data_0") | ||
self.assertEqual(out[1]["data"], "data_1") | ||
|
||
if comm.is_main_process(): | ||
shutil.rmtree(cache_dir) | ||
comm.synchronize() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import os | ||
from typing import Any, Optional | ||
|
||
import torch | ||
from mobile_cv.torch.utils_pytorch import comm | ||
|
||
|
||
class ResultCache(object): | ||
def __init__(self, cache_dir: Optional[str], cache_name: str, logger, path_manager): | ||
"""A utility class to handle save/load cache data across processes""" | ||
self.cache_dir = cache_dir | ||
self.cache_name = cache_name | ||
self.logger = logger | ||
self.path_manager = path_manager | ||
|
||
@property | ||
def cache_file(self): | ||
if self.cache_dir is None: | ||
return None | ||
return os.path.join( | ||
self.cache_dir, f"_result_cache_{self.cache_name}.{comm.get_rank()}.pkl" | ||
) | ||
|
||
def has_cache(self): | ||
return self.path_manager.isfile(self.cache_file) | ||
|
||
def load(self, gather=False): | ||
""" | ||
Load cache results. | ||
gather (bool): gather cache results arcoss ranks to a list | ||
""" | ||
if self.cache_file is None or not self.path_manager.exists(self.cache_file): | ||
ret = None | ||
else: | ||
with self.path_manager.open(self.cache_file, "rb") as fp: | ||
ret = torch.load(fp) | ||
if self.logger is not None: | ||
self.logger.info(f"Loaded from checkpoint {self.cache_file}") | ||
|
||
if gather: | ||
ret = comm.all_gather(ret) | ||
return ret | ||
|
||
def save(self, data: Any): | ||
if self.cache_file is None: | ||
return | ||
|
||
self.path_manager.mkdirs(os.path.dirname(self.cache_file)) | ||
with self.path_manager.open(self.cache_file, "wb") as fp: | ||
torch.save(data, fp) | ||
if self.logger is not None: | ||
self.logger.info(f"Saved checkpoint to {self.cache_file}") |