-
Notifications
You must be signed in to change notification settings - Fork 60
/
Copy pathdump_manager.py
35 lines (29 loc) · 887 Bytes
/
dump_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from abc import ABC, abstractmethod
import numpy as np
from utils.misc import Singleton
import os
import shutil
import uuid
class DumpManager(metaclass=Singleton):
def __init__(self, dump_dir=None):
if dump_dir is None:
raise Exception('dump_dir must be provided')
self.dump_dir = dump_dir
if os.path.exists(dump_dir):
shutil.rmtree(dump_dir)
os.makedirs(dump_dir)
self.enabled = False
self.tag = ''
def __enter__(self):
self.enabled = True
return self
def __exit__(self, *args):
self.enabled = False
def set_tag(self, tag):
self.tag = tag
def dump(self, tensor, name):
if self.enabled:
f = os.path.join(self.dump_dir, name + '_' + self.tag)
print("dumping: %s" % f)
np.save(f, tensor.cpu().numpy())