-
Notifications
You must be signed in to change notification settings - Fork 1
/
tmp_thop.py
81 lines (60 loc) · 2.8 KB
/
tmp_thop.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch.nn as nn
import torch
from thop.profile import register_hooks,prRed
from thop.vision.basic_hooks import count_parameters
def myprofile(model: nn.Module, inputs, custom_ops=None, verbose=True):
handler_collection = {}
types_collection = set()
if custom_ops is None:
custom_ops = {}
def add_hooks(m: nn.Module):
m.register_buffer('total_ops', torch.zeros(1, dtype=torch.float64))
m.register_buffer('total_params', torch.zeros(1, dtype=torch.float64))
# for p in m.parameters():
# m.total_params += torch.DoubleTensor([p.numel()])
m_type = type(m)
fn = None
if m_type in custom_ops: # if defined both op maps, use custom_ops to overwrite.
fn = custom_ops[m_type]
if m_type not in types_collection and verbose:
print("[INFO] Customize rule %s() %s." % (fn.__qualname__, m_type))
elif m_type in register_hooks:
fn = register_hooks[m_type]
if m_type not in types_collection and verbose:
print("[INFO] Register %s() for %s." % (fn.__qualname__, m_type))
else:
if m_type not in types_collection and verbose:
prRed("[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params." % m_type)
if fn is not None:
handler_collection[m] = (m.register_forward_hook(fn), m.register_forward_hook(count_parameters))
types_collection.add(m_type)
prev_training_status = model.training
model.eval()
model.apply(add_hooks)
with torch.no_grad():
model(**inputs)
def dfs_count(module: nn.Module, prefix="\t") -> (int, int):
total_ops, total_params = 0, 0
for m in module.children():
# if not hasattr(m, "total_ops") and not hasattr(m, "total_params"): # and len(list(m.children())) > 0:
# m_ops, m_params = dfs_count(m, prefix=prefix + "\t")
# else:
# m_ops, m_params = m.total_ops, m.total_params
if m in handler_collection and not isinstance(m, (nn.Sequential, nn.ModuleList)):
m_ops, m_params = m.total_ops.item(), m.total_params.item()
else:
m_ops, m_params = dfs_count(m, prefix=prefix + "\t")
total_ops += m_ops
total_params += m_params
# print(prefix, module._get_name(), (total_ops.item(), total_params.item()))
return total_ops, total_params
total_ops, total_params = dfs_count(model)
# reset model to original status
model.train(prev_training_status)
for m, (op_handler, params_handler) in handler_collection.items():
op_handler.remove()
params_handler.remove()
m._buffers.pop("total_ops")
m._buffers.pop("total_params")
return total_ops, total_params
#