-
Notifications
You must be signed in to change notification settings - Fork 60
/
Copy pathmark_relu.py
29 lines (25 loc) · 934 Bytes
/
mark_relu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from torchvision.models.resnet import Bottleneck, BasicBlock
from torch.nn.parallel.data_parallel import DataParallel
def mark_bottlenetck_before_relu(model):
for m in model.children():
if isinstance(m, Bottleneck):
m.conv1.before_relu = True
m.bn1.before_relu = True
m.conv2.before_relu = True
m.bn2.before_relu = True
else:
mark_bottlenetck_before_relu(m)
def mark_basicblock_before_relu(model):
for m in model.children():
if isinstance(m, BasicBlock):
m.conv1.before_relu = True
m.bn1.before_relu = True
else:
mark_basicblock_before_relu(m)
def resnet_mark_before_relu(model):
if isinstance(model, DataParallel):
model.module.conv1.before_relu = True
else:
model.conv1.before_relu = True
mark_bottlenetck_before_relu(model)
mark_basicblock_before_relu(model)