forked from huggingface/pytorch-image-models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_utils.py
57 lines (49 loc) · 1.93 KB
/
test_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from torch.nn.modules.batchnorm import BatchNorm2d
from torchvision.ops.misc import FrozenBatchNorm2d
import timm
from timm.utils.model import freeze, unfreeze
def test_freeze_unfreeze():
model = timm.create_model('resnet18')
# Freeze all
freeze(model)
# Check top level module
assert model.fc.weight.requires_grad == False
# Check submodule
assert model.layer1[0].conv1.weight.requires_grad == False
# Check BN
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
# Unfreeze all
unfreeze(model)
# Check top level module
assert model.fc.weight.requires_grad == True
# Check submodule
assert model.layer1[0].conv1.weight.requires_grad == True
# Check BN
assert isinstance(model.layer1[0].bn1, BatchNorm2d)
# Freeze some
freeze(model, ['layer1', 'layer2.0'])
# Check frozen
assert model.layer1[0].conv1.weight.requires_grad == False
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
assert model.layer2[0].conv1.weight.requires_grad == False
# Check not frozen
assert model.layer3[0].conv1.weight.requires_grad == True
assert isinstance(model.layer3[0].bn1, BatchNorm2d)
assert model.layer2[1].conv1.weight.requires_grad == True
# Unfreeze some
unfreeze(model, ['layer1', 'layer2.0'])
# Check not frozen
assert model.layer1[0].conv1.weight.requires_grad == True
assert isinstance(model.layer1[0].bn1, BatchNorm2d)
assert model.layer2[0].conv1.weight.requires_grad == True
# Freeze/unfreeze BN
# From root
freeze(model, ['layer1.0.bn1'])
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
unfreeze(model, ['layer1.0.bn1'])
assert isinstance(model.layer1[0].bn1, BatchNorm2d)
# From direct parent
freeze(model.layer1[0], ['bn1'])
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
unfreeze(model.layer1[0], ['bn1'])
assert isinstance(model.layer1[0].bn1, BatchNorm2d)