Skip to content

Commit 531fd8d

Browse files
committed
modify guide training
1 parent 10558cf commit 531fd8d

File tree

8 files changed

+521
-11
lines changed

8 files changed

+521
-11
lines changed

graffiti/float32touint8.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import torch
2+
from net import net_quantize_guide
3+
from torchvision import models
4+
5+
# coding=utf-8
6+
model = net_quantize_guide.resnet18()
7+
print(model.state_dict().keys())
8+
model = models.resnet18(pretrained=True)
9+
state_dict = model.state_dict()
10+
state_dict = {k: v.to(torch.uint8) for k, v in state_dict.items()}
11+
torch.save(state_dict, "nowgood.pth")

graffiti/merge_conv_bn.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# coding=utf-8
2+
import torch
3+
from torchvision import models
4+
import numpy as np
5+
import os
6+
from net import net_bn_conv_merge, net_bn_conv_merge_quantize
7+
from utils.data_loader import load_val_data
8+
from utils.train_val import validate
9+
10+
epsilon = 1e-5
11+
data = "/media/wangbin/8057840b-9a1e-48c9-aa84-d353a6ba1090/ImageNet_ILSVRC2012/ILSVRC2012"
12+
13+
model = models.resnet18(pretrained=True)
14+
# merge_model = net_bn_conv_merge.resnet18()
15+
merge_model = net_bn_conv_merge_quantize.resnet18()
16+
state_dict = model.state_dict()
17+
merge_state_dict = merge_model.state_dict()
18+
19+
# for name in state_dict:
20+
# print(name)
21+
22+
merge_state_dict.update({"fc.weight": state_dict["fc.weight"],
23+
"fc.bias": state_dict["fc.bias"]})
24+
del state_dict["fc.weight"]
25+
del state_dict["fc.bias"]
26+
params = np.array(list(state_dict.keys()))
27+
28+
params = params.reshape((-1, 5))
29+
for index in range(params.shape[0]):
30+
weight = state_dict[params[index][0]]
31+
gamma = state_dict[params[index][1]]
32+
beta = state_dict[params[index][2]]
33+
running_mean = state_dict[params[index][3]]
34+
running_var = state_dict[params[index][4]]
35+
delta = gamma/(torch.sqrt(running_var+epsilon))
36+
weight = weight * delta.view(-1, 1, 1, 1)
37+
bias = (0-running_mean) * delta + beta
38+
merge_state_dict.update({params[index][0]: weight,
39+
params[index][0][:-6] + "bias": bias})
40+
merge_model.load_state_dict(merge_state_dict)
41+
merge_model_name = "resnet18_merge_bn_conv.pth.tar"
42+
torch.save(merge_model.state_dict(), merge_model_name)
43+
44+
"""
45+
conv1.weight
46+
bn1.weight
47+
bn1.bias
48+
bn1.running_mean
49+
bn1.running_var
50+
layer1.0.conv1.weight
51+
layer1.0.bn1.weight
52+
layer1.0.bn1.bias
53+
layer1.0.bn1.running_mean
54+
layer1.0.bn1.running_var
55+
"""
56+
57+
# print("bn1.weight: \n", len(state_dict["bn1.weight"]), state_dict["bn1.weight"])
58+
# print("bn1.bias: \n", len(state_dict["bn1.bias"]), state_dict["bn1.bias"])
59+
# print("bn1.running_mean: \n", state_dict["bn1.running_mean"])
60+
# print("bn1.running_val: \n", state_dict["bn1.running_var"])
61+
62+
val_loader = load_val_data(data)
63+
evaluate = merge_model_name
64+
if os.path.isfile(evaluate):
65+
print("Loading evaluate model '{}'".format(evaluate))
66+
checkpoint = torch.load(evaluate)
67+
merge_model.load_state_dict(checkpoint)
68+
print("Loaded evaluate model '{}'".format(evaluate))
69+
else:
70+
print("No evaluate mode found at '{}'".format(evaluate))
71+
72+
merge_model.cuda()
73+
merge_model.eval()
74+
criterion = torch.nn.CrossEntropyLoss().cuda()
75+
validate(merge_model, val_loader, criterion)

graffiti/nowgood.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
from net import net_quantize_guide
33
from torchvision import models
44

5-
model = net_quantize_guide.resnet18()
6-
print(model.state_dict().keys())
7-
model = models.resnet18(pretrained=True)
8-
state_dict = model.state_dict()
9-
state_dict = {k: v.to(torch.uint8) for k, v in state_dict.items()}
10-
torch.save(state_dict, "nowgood.pth")
5+
6+
x = torch.ones(5, 3)
7+
bias = torch.ones(5, 1)
8+
bias[0][0] = 4
9+
bias[3][0] = 3
10+
y = x * bias
11+
print(y)

net/net_bn_conv_merge.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# coding=utf-8
2+
import torch.nn as nn
3+
import math
4+
import torch.utils.model_zoo as model_zoo
5+
6+
"""
7+
网络修改步骤;
8+
1. 将卷积层的 bias 设置为 True
9+
2. 将 bn 层删掉
10+
"""
11+
12+
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
13+
'resnet152']
14+
15+
16+
model_urls = {
17+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
18+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
19+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
20+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
21+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
22+
}
23+
24+
25+
def conv3x3(in_planes, out_planes, stride=1):
26+
"""3x3 convolution with padding"""
27+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
28+
padding=1, bias=True)
29+
30+
31+
class BasicBlock(nn.Module):
32+
expansion = 1
33+
34+
def __init__(self, inplanes, planes, stride=1, downsample=None):
35+
super(BasicBlock, self).__init__()
36+
self.conv1 = conv3x3(inplanes, planes, stride)
37+
self.relu = nn.ReLU(inplace=True)
38+
self.conv2 = conv3x3(planes, planes)
39+
self.downsample = downsample
40+
self.stride = stride
41+
42+
def forward(self, x):
43+
residual = x
44+
45+
out = self.conv1(x)
46+
out = self.relu(out)
47+
48+
out = self.conv2(out)
49+
50+
if self.downsample is not None:
51+
residual = self.downsample(x)
52+
53+
out += residual
54+
out = self.relu(out)
55+
56+
return out
57+
58+
59+
class Bottleneck(nn.Module):
60+
expansion = 4
61+
62+
def __init__(self, inplanes, planes, stride=1, downsample=None):
63+
super(Bottleneck, self).__init__()
64+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=True)
65+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
66+
padding=1, bias=True)
67+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=True)
68+
self.relu = nn.ReLU(inplace=True)
69+
self.downsample = downsample
70+
self.stride = stride
71+
72+
def forward(self, x):
73+
residual = x
74+
75+
out = self.conv1(x)
76+
out = self.relu(out)
77+
78+
out = self.conv2(out)
79+
out = self.relu(out)
80+
81+
out = self.conv3(out)
82+
83+
if self.downsample is not None:
84+
residual = self.downsample(x)
85+
86+
out += residual
87+
out = self.relu(out)
88+
89+
return out
90+
91+
92+
class ResNet(nn.Module):
93+
94+
def __init__(self, block, layers, num_classes=1000):
95+
self.inplanes = 64
96+
super(ResNet, self).__init__()
97+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
98+
bias=True)
99+
self.relu = nn.ReLU(inplace=True)
100+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
101+
self.layer1 = self._make_layer(block, 64, layers[0])
102+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
103+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
104+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
105+
self.avgpool = nn.AvgPool2d(7, stride=1)
106+
self.fc = nn.Linear(512 * block.expansion, num_classes)
107+
108+
for m in self.modules():
109+
if isinstance(m, nn.Conv2d):
110+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
111+
m.weight.data.normal_(0, math.sqrt(2. / n))
112+
elif isinstance(m, nn.BatchNorm2d):
113+
m.weight.data.fill_(1)
114+
m.bias.data.zero_()
115+
116+
def _make_layer(self, block, planes, blocks, stride=1):
117+
downsample = None
118+
if stride != 1 or self.inplanes != planes * block.expansion:
119+
downsample = nn.Sequential(
120+
nn.Conv2d(self.inplanes, planes * block.expansion,
121+
kernel_size=1, stride=stride, bias=True),
122+
)
123+
124+
layers = []
125+
layers.append(block(self.inplanes, planes, stride, downsample))
126+
self.inplanes = planes * block.expansion
127+
for i in range(1, blocks):
128+
layers.append(block(self.inplanes, planes))
129+
130+
return nn.Sequential(*layers)
131+
132+
def forward(self, x):
133+
x = self.conv1(x)
134+
x = self.relu(x)
135+
x = self.maxpool(x)
136+
137+
x = self.layer1(x)
138+
x = self.layer2(x)
139+
x = self.layer3(x)
140+
x = self.layer4(x)
141+
142+
x = self.avgpool(x)
143+
x = x.view(x.size(0), -1)
144+
x = self.fc(x)
145+
146+
return x
147+
148+
149+
def resnet18(pretrained=False, **kwargs):
150+
"""Constructs a ResNet-18 model.
151+
152+
Args:
153+
pretrained (bool): If True, returns a model pre-trained on ImageNet
154+
"""
155+
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
156+
if pretrained:
157+
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
158+
return model
159+
160+
161+
def resnet34(pretrained=False, **kwargs):
162+
"""Constructs a ResNet-34 model.
163+
164+
Args:
165+
pretrained (bool): If True, returns a model pre-trained on ImageNet
166+
"""
167+
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
168+
if pretrained:
169+
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
170+
return model
171+
172+
173+
def resnet50(pretrained=False, **kwargs):
174+
"""Constructs a ResNet-50 model.
175+
176+
Args:
177+
pretrained (bool): If True, returns a model pre-trained on ImageNet
178+
"""
179+
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
180+
if pretrained:
181+
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
182+
return model
183+
184+
185+
def resnet101(pretrained=False, **kwargs):
186+
"""Constructs a ResNet-101 model.
187+
188+
Args:
189+
pretrained (bool): If True, returns a model pre-trained on ImageNet
190+
"""
191+
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
192+
if pretrained:
193+
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
194+
return model
195+
196+
197+
def resnet152(pretrained=False, **kwargs):
198+
"""Constructs a ResNet-152 model.
199+
200+
Args:
201+
pretrained (bool): If True, returns a model pre-trained on ImageNet
202+
"""
203+
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
204+
if pretrained:
205+
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
206+
return model

0 commit comments

Comments
 (0)