Skip to content

Commit

Permalink
result cache.
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #173

result cache.

Reviewed By: wat3rBro

Differential Revision: D46511387

fbshipit-source-id: 7d9a85594a8679201a110a1eb21e3742720520f9
  • Loading branch information
Peizhao Zhang authored and facebook-github-bot committed Jul 25, 2023
1 parent 9de93ff commit 4673743
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 0 deletions.
41 changes: 41 additions & 0 deletions mobile_cv/torch/tests/utils_pytorch/test_result_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved.

import shutil
import tempfile
import unittest

from mobile_cv.common.utils_io import get_path_manager
from mobile_cv.torch.utils_pytorch import comm, distributed_helper as dh

from mobile_cv.torch.utils_pytorch.result_cache import ResultCache


class TestUtilsPytorchResultCache(unittest.TestCase):
@dh.launch_deco(num_processes=2)
def test_result_cache(self):
"""
buck2 run @mode/dev-nosan //mobile-vision/mobile_cv/mobile_cv/torch/tests:utils_pytorch_test_result_cache
"""
path_manager = get_path_manager()
if comm.is_main_process():
cache_dir = tempfile.mkdtemp()
else:
cache_dir = None
cache_dir = comm.all_gather(cache_dir)[0]

rc = ResultCache(
cache_dir, "test_cache", logger=None, path_manager=path_manager
)
self.assertFalse(rc.has_cache())
rc.save({"data": f"data_{comm.get_rank()}"})
comm.synchronize()
self.assertTrue(rc.has_cache())
out = rc.load(gather=True)
self.assertEqual(len(out), 2)
self.assertEqual(out[0]["data"], "data_0")
self.assertEqual(out[1]["data"], "data_1")

if comm.is_main_process():
shutil.rmtree(cache_dir)
comm.synchronize()
52 changes: 52 additions & 0 deletions mobile_cv/torch/utils_pytorch/result_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os
from typing import Any, Optional

import torch
from mobile_cv.torch.utils_pytorch import comm


class ResultCache(object):
def __init__(self, cache_dir: Optional[str], cache_name: str, logger, path_manager):
"""A utility class to handle save/load cache data across processes"""
self.cache_dir = cache_dir
self.cache_name = cache_name
self.logger = logger
self.path_manager = path_manager

@property
def cache_file(self):
if self.cache_dir is None:
return None
return os.path.join(
self.cache_dir, f"_result_cache_{self.cache_name}.{comm.get_rank()}.pkl"
)

def has_cache(self):
return self.path_manager.isfile(self.cache_file)

def load(self, gather=False):
"""
Load cache results.
gather (bool): gather cache results arcoss ranks to a list
"""
if self.cache_file is None or not self.path_manager.exists(self.cache_file):
ret = None
else:
with self.path_manager.open(self.cache_file, "rb") as fp:
ret = torch.load(fp)
if self.logger is not None:
self.logger.info(f"Loaded from checkpoint {self.cache_file}")

if gather:
ret = comm.all_gather(ret)
return ret

def save(self, data: Any):
if self.cache_file is None:
return

self.path_manager.mkdirs(os.path.dirname(self.cache_file))
with self.path_manager.open(self.cache_file, "wb") as fp:
torch.save(data, fp)
if self.logger is not None:
self.logger.info(f"Saved checkpoint to {self.cache_file}")

0 comments on commit 4673743

Please sign in to comment.