forked from noahcao/Pixel2Mesh
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gpu_mem_track.py
113 lines (98 loc) · 4.27 KB
/
gpu_mem_track.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gc
import datetime
import inspect
import torch
import numpy as np
dtype_memory_size_dict = {
torch.float64: 64/8,
torch.double: 64/8,
torch.float32: 32/8,
torch.float: 32/8,
torch.float16: 16/8,
torch.half: 16/8,
torch.int64: 64/8,
torch.long: 64/8,
torch.int32: 32/8,
torch.int: 32/8,
torch.int16: 16/8,
torch.short: 16/6,
torch.uint8: 8/8,
torch.int8: 8/8,
}
# compatibility of torch1.0
if getattr(torch, "bfloat16", None) is not None:
dtype_memory_size_dict[torch.bfloat16] = 16/8
if getattr(torch, "bool", None) is not None:
dtype_memory_size_dict[torch.bool] = 8/8 # pytorch use 1 byte for a bool, see https://github.com/pytorch/pytorch/issues/41571
def get_mem_space(x):
try:
ret = dtype_memory_size_dict[x]
except KeyError:
print(f"dtype {x} is not supported!")
return ret
class MemTracker(object):
"""
Class used to track pytorch memory usage
Arguments:
detail(bool, default True): whether the function shows the detail gpu memory usage
path(str): where to save log file
verbose(bool, default False): whether show the trivial exception
device(int): GPU number, default is 0
"""
def __init__(self, detail=True, path='', verbose=False, device=0):
self.print_detail = detail
self.last_tensor_sizes = set()
self.gpu_profile_fn = path + f'{datetime.datetime.now():%d-%b-%y-%H:%M:%S}-gpu_mem_track.txt'
self.verbose = verbose
self.begin = True
self.device = device
def get_tensors(self):
for obj in gc.get_objects():
try:
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
tensor = obj
else:
continue
if tensor.is_cuda:
yield tensor
except Exception as e:
if self.verbose:
print('A trivial exception occured: {}'.format(e))
def get_tensor_usage(self):
sizes = [np.prod(np.array(tensor.size())) * get_mem_space(tensor.dtype) for tensor in self.get_tensors()]
return np.sum(sizes) / 1024**2
def get_allocate_usage(self):
return torch.cuda.memory_allocated() / 1024**2
def clear_cache(self):
gc.collect()
torch.cuda.empty_cache()
def print_all_gpu_tensor(self, file=None):
for x in self.get_tensors():
print(x.size(), x.dtype, np.prod(np.array(x.size()))*get_mem_space(x.dtype)/1024**2, file=file)
def track(self):
"""
Track the GPU memory usage
"""
frameinfo = inspect.stack()[1]
where_str = frameinfo.filename + ' line ' + str(frameinfo.lineno) + ': ' + frameinfo.function
with open(self.gpu_profile_fn, 'a+') as f:
if self.begin:
f.write(f"GPU Memory Track | {datetime.datetime.now():%d-%b-%y-%H:%M:%S} |"
f" Total Tensor Used Memory:{self.get_tensor_usage():<7.1f}Mb"
f" Total Allocated Memory:{self.get_allocate_usage():<7.1f}Mb\n\n")
self.begin = False
if self.print_detail is True:
ts_list = [(tensor.size(), tensor.dtype) for tensor in self.get_tensors()]
new_tensor_sizes = {(type(x),
tuple(x.size()),
ts_list.count((x.size(), x.dtype)),
np.prod(np.array(x.size()))*get_mem_space(x.dtype)/1024**2,
x.dtype) for x in self.get_tensors()}
for t, s, n, m, data_type in new_tensor_sizes - self.last_tensor_sizes:
f.write(f'+ | {str(n)} * Size:{str(s):<20} | Memory: {str(m*n)[:6]} M | {str(t):<20} | {data_type}\n')
for t, s, n, m, data_type in self.last_tensor_sizes - new_tensor_sizes:
f.write(f'- | {str(n)} * Size:{str(s):<20} | Memory: {str(m*n)[:6]} M | {str(t):<20} | {data_type}\n')
self.last_tensor_sizes = new_tensor_sizes
f.write(f"\nAt {where_str:<50}"
f" Total Tensor Used Memory:{self.get_tensor_usage():<7.1f}Mb"
f" Total Allocated Memory:{self.get_allocate_usage():<7.1f}Mb\n\n")