Skip to content

Commit 37598b1

Browse files
committed
init
0 parents  commit 37598b1

File tree

14 files changed

+1408
-0
lines changed

14 files changed

+1408
-0
lines changed

PyTorch实战指南.md

Lines changed: 831 additions & 0 deletions
Large diffs are not rendered by default.

README.md

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# PyTorch 实践指南
2+
3+
本文是文章[PyTorch实践指南](https://zhuanlan.zhihu.com/p/29024978)配套代码,请参照[知乎专栏原文](https://zhuanlan.zhihu.com/p/29024978)或者[对应的markdown文件](PyTorch实战指南.md)更好的了解而文件组织和代码细节。
4+
5+
6+
## 数据下载
7+
-[kaggle比赛官网](https://www.kaggle.com/c/dogs-vs-cats/data) 下载所需的数据
8+
- 解压并把训练集和测试集分别放在一个文件夹中
9+
10+
11+
## 安装
12+
- PyTorch : 可按照[PyTorch官网](http://pytorch.org)的指南,根据自己的平台安装指定的版本
13+
- 安装指定依赖:
14+
15+
```
16+
pip install -r requirements.txt
17+
```
18+
19+
## 训练
20+
必须首先启动visdom:
21+
22+
```
23+
python -m visdom.server
24+
```
25+
26+
然后使用如下命令启动训练:
27+
28+
```
29+
# 在gpu0上训练,并把可视化结果保存在visdom 的classifier env上
30+
python main.py train --data-root=./data/train --use-gpu=True --env=classifier
31+
```
32+
33+
34+
详细的使用命令 可使用
35+
```
36+
python main.py help
37+
```
38+
39+
## 测试
40+
41+
```
42+
python main.py --data-root=./data/test --use-gpu=False --batch-size=256
43+
```

config.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#coding:utf8
2+
import warnings
3+
class DefaultConfig(object):
4+
env = 'default' # visdom 环境
5+
model = 'AlexNet' # 使用的模型,名字必须与models/__init__.py中的名字一致
6+
7+
train_data_root = './data/train/' # 训练集存放路径
8+
test_data_root = './data/test1' # 测试集存放路径
9+
load_model_path = 'checkpoints/model.pth' # 加载预训练的模型的路径,为None代表不加载
10+
11+
batch_size = 128 # batch size
12+
use_gpu = True # user GPU or not
13+
num_workers = 4 # how many workers for loading data
14+
print_freq = 20 # print info every N batch
15+
16+
debug_file = '/tmp/debug' # if os.path.exists(debug_file): enter ipdb
17+
result_file = 'result.csv'
18+
19+
max_epoch = 10
20+
lr = 0.1 # initial learning rate
21+
lr_decay = 0.95 # when val_loss increase, lr = lr*lr_decay
22+
weight_decay = 1e-4 # 损失函数
23+
24+
25+
26+
def parse(self,kwargs):
27+
'''
28+
根据字典kwargs 更新 config参数
29+
'''
30+
for k,v in kwargs.iteritems():
31+
if not hasattr(self,k):
32+
warnings.warn("Warning: opt has not attribut %s" %k)
33+
setattr(self,k,v)
34+
35+
print('user config:')
36+
for k,v in self.__class__.__dict__.iteritems():
37+
if not k.startswith('__'):
38+
print(k,getattr(self,k))
39+
40+
41+
DefaultConfig.parse = parse
42+
opt =DefaultConfig()
43+
# opt.parse = parse

data/__init__.py

Whitespace-only changes.

data/dataset.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#coding:utf8
2+
import os
3+
from PIL import Image
4+
from torch.utils import data
5+
import numpy as np
6+
from torchvision import transforms as T
7+
8+
9+
class DogCat(data.Dataset):
10+
11+
def __init__(self,root,transforms=None,train=True,test=False):
12+
'''
13+
主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据
14+
'''
15+
self.test = test
16+
imgs = [os.path.join(root,img) for img in os.listdir(root)]
17+
18+
# test1: data/test1/8973.jpg
19+
# train: data/train/cat.10004.jpg
20+
if self.test:
21+
imgs = sorted(imgs,key=lambda x:int(x.split('.')[-2].split('/')[-1]))
22+
else:
23+
imgs = sorted(imgs,key=lambda x:int(x.split('.')[-2]))
24+
25+
imgs_num = len(imgs)
26+
27+
if self.test:
28+
self.imgs = imgs
29+
elif train:
30+
self.imgs = imgs[:int(0.7*imgs_num)]
31+
else :
32+
self.imgs = imgs[int(0.7*imgs_num):]
33+
34+
35+
if transforms is None:
36+
normalize = T.Normalize(mean = [0.485, 0.456, 0.406],
37+
std = [0.229, 0.224, 0.225])
38+
39+
if self.test or not train:
40+
self.transforms = T.Compose([
41+
T.Scale(224),
42+
T.CenterCrop(224),
43+
T.ToTensor(),
44+
normalize
45+
])
46+
else :
47+
self.transforms = T.Compose([
48+
T.Scale(256),
49+
T.RandomSizedCrop(224),
50+
T.RandomHorizontalFlip(),
51+
T.ToTensor(),
52+
normalize
53+
])
54+
55+
56+
def __getitem__(self,index):
57+
'''
58+
一次返回一张图片的数据
59+
'''
60+
img_path = self.imgs[index]
61+
if self.test: label = int(self.imgs[index].split('.')[-2].split('/')[-1])
62+
else: label = 1 if 'dog' in img_path.split('/')[-1] else 0
63+
data = Image.open(img_path)
64+
data = self.transforms(data)
65+
return data, label
66+
67+
def __len__(self):
68+
return len(self.imgs)

data/get_data.sh

Whitespace-only changes.

main.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
#coding:utf8
2+
from config import opt
3+
import os
4+
import torch as t
5+
import models
6+
from data.dataset import DogCat
7+
from torch.utils.data import DataLoader
8+
from torch.autograd import Variable
9+
from torchnet import meter
10+
from utils.visualize import Visualizer
11+
12+
def test(**kwargs):
13+
opt.parse(kwargs)
14+
import ipdb;
15+
ipdb.set_trace()
16+
# configure model
17+
model = getattr(models, opt.model)().eval()
18+
if opt.load_model_path:
19+
model.load(opt.load_model_path)
20+
if opt.use_gpu: model.cuda()
21+
22+
# data
23+
train_data = DogCat(opt.test_data_root,test=True)
24+
test_dataloader = DataLoader(train_data,batch_size=opt.batch_size,shuffle=False,num_workers=opt.num_workers)
25+
results = []
26+
for ii,(data,path) in enumerate(test_dataloader):
27+
input = t.autograd.Variable(data,volatile = True)
28+
if opt.use_gpu: input = input.cuda()
29+
score = model(input)
30+
probability = t.nn.functional.softmax(score)[:,0].data.tolist()
31+
# label = score.max(dim = 1)[1].data.tolist()
32+
33+
batch_results = [(path_,probability_) for path_,probability_ in zip(path,probability) ]
34+
35+
results += batch_results
36+
write_csv(results,opt.result_file)
37+
38+
return results
39+
40+
def write_csv(results,file_name):
41+
import csv
42+
with open(file_name,'w') as f:
43+
writer = csv.writer(f)
44+
writer.writerow(['id','label'])
45+
writer.writerows(results)
46+
47+
def train(**kwargs):
48+
opt.parse(kwargs)
49+
vis = Visualizer(opt.env)
50+
51+
# step1: configure model
52+
model = getattr(models, opt.model)()
53+
if opt.load_model_path:
54+
model.load(opt.load_model_path)
55+
if opt.use_gpu: model.cuda()
56+
57+
# step2: data
58+
train_data = DogCat(opt.train_data_root,train=True)
59+
val_data = DogCat(opt.train_data_root,train=False)
60+
train_dataloader = DataLoader(train_data,opt.batch_size,
61+
shuffle=True,num_workers=opt.num_workers)
62+
val_dataloader = DataLoader(val_data,opt.batch_size,
63+
shuffle=False,num_workers=opt.num_workers)
64+
65+
# step3: criterion and optimizer
66+
criterion = t.nn.CrossEntropyLoss()
67+
lr = opt.lr
68+
optimizer = t.optim.Adam(model.parameters(),lr = lr,weight_decay = opt.weight_decay)
69+
70+
# step4: meters
71+
loss_meter = meter.AverageValueMeter()
72+
confusion_matrix = meter.ConfusionMeter(2)
73+
previous_loss = 1e100
74+
75+
# train
76+
for epoch in range(opt.max_epoch):
77+
78+
loss_meter.reset()
79+
confusion_matrix.reset()
80+
81+
for ii,(data,label) in enumerate(train_dataloader):
82+
83+
# train model
84+
input = Variable(data)
85+
target = Variable(label)
86+
if opt.use_gpu:
87+
input = input.cuda()
88+
target = target.cuda()
89+
90+
optimizer.zero_grad()
91+
score = model(input)
92+
loss = criterion(score,target)
93+
loss.backward()
94+
optimizer.step()
95+
96+
97+
# meters update and visualize
98+
loss_meter.add(loss.data[0])
99+
confusion_matrix.add(score.data, target.data)
100+
101+
if ii%opt.print_freq==opt.print_freq-1:
102+
vis.plot('loss', loss_meter.value()[0])
103+
104+
# 进入debug模式
105+
if os.path.exists(opt.debug_file):
106+
import ipdb;
107+
ipdb.set_trace()
108+
109+
110+
model.save()
111+
112+
# validate and visualize
113+
val_cm,val_accuracy = val(model,val_dataloader)
114+
115+
vis.plot('val_accuracy',val_accuracy)
116+
vis.log("epoch:{epoch},lr:{lr},loss:{loss},train_cm:{train_cm},val_cm:{val_cm}".format(
117+
epoch = epoch,loss = loss_meter.value()[0],val_cm = str(val_cm.value()),train_cm=str(confusion_matrix.value()),lr=lr))
118+
119+
# update learning rate
120+
if loss_meter.value()[0] > previous_loss:
121+
lr = lr * opt.lr_decay
122+
# 第二种降低学习率的方法:不会有moment等信息的丢失
123+
for param_group in optimizer.param_groups:
124+
param_group['lr'] = lr
125+
126+
127+
previous_loss = loss_meter.value()[0]
128+
129+
def val(model,dataloader):
130+
'''
131+
计算模型在验证集上的准确率等信息
132+
'''
133+
model.eval()
134+
confusion_matrix = meter.ConfusionMeter(2)
135+
for ii, data in enumerate(dataloader):
136+
input, label = data
137+
val_input = Variable(input, volatile=True)
138+
val_label = Variable(label.type(t.LongTensor), volatile=True)
139+
if opt.use_gpu:
140+
val_input = val_input.cuda()
141+
val_label = val_label.cuda()
142+
score = model(val_input)
143+
confusion_matrix.add(score.data.squeeze(), label.type(t.LongTensor))
144+
145+
model.train()
146+
cm_value = confusion_matrix.value()
147+
accuracy = 100. * (cm_value[0][0] + cm_value[1][1]) / (cm_value.sum())
148+
return confusion_matrix, accuracy
149+
150+
def help():
151+
'''
152+
打印帮助的信息: python file.py help
153+
'''
154+
155+
print('''
156+
usage : python file.py <function> [--args=value]
157+
<function> := train | test | help
158+
example:
159+
python {0} train --env='env0701' --lr=0.01
160+
python {0} test --dataset='path/to/dataset/root/'
161+
python {0} help
162+
avaiable args:'''.format(__file__))
163+
164+
from inspect import getsource
165+
source = (getsource(opt.__class__))
166+
print(source)
167+
168+
if __name__=='__main__':
169+
import fire
170+
fire.Fire()

models/AlexNet.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#coding:utf8
2+
from torch import nn
3+
from .BasicModule import BasicModule
4+
5+
class AlexNet(BasicModule):
6+
'''
7+
code from torchvision/models/alexnet.py
8+
结构参考 <https://arxiv.org/abs/1404.5997>
9+
'''
10+
def __init__(self, num_classes=2):
11+
12+
super(AlexNet, self).__init__()
13+
14+
self.model_name = 'alexnet'
15+
16+
self.features = nn.Sequential(
17+
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
18+
nn.ReLU(inplace=True),
19+
nn.MaxPool2d(kernel_size=3, stride=2),
20+
nn.Conv2d(64, 192, kernel_size=5, padding=2),
21+
nn.ReLU(inplace=True),
22+
nn.MaxPool2d(kernel_size=3, stride=2),
23+
nn.Conv2d(192, 384, kernel_size=3, padding=1),
24+
nn.ReLU(inplace=True),
25+
nn.Conv2d(384, 256, kernel_size=3, padding=1),
26+
nn.ReLU(inplace=True),
27+
nn.Conv2d(256, 256, kernel_size=3, padding=1),
28+
nn.ReLU(inplace=True),
29+
nn.MaxPool2d(kernel_size=3, stride=2),
30+
)
31+
self.classifier = nn.Sequential(
32+
nn.Dropout(),
33+
nn.Linear(256 * 6 * 6, 4096),
34+
nn.ReLU(inplace=True),
35+
nn.Dropout(),
36+
nn.Linear(4096, 4096),
37+
nn.ReLU(inplace=True),
38+
nn.Linear(4096, num_classes),
39+
)
40+
41+
def forward(self, x):
42+
x = self.features(x)
43+
x = x.view(x.size(0), 256 * 6 * 6)
44+
x = self.classifier(x)
45+
return x

0 commit comments

Comments
 (0)