forked from Jack-Cherish/Deep-Learning
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
3 changed files
with
343 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import torch | ||
from PIL import Image | ||
import os | ||
import glob | ||
from torch.utils.data import Dataset | ||
import random | ||
import torchvision.transforms as transforms | ||
from PIL import ImageFile | ||
ImageFile.LOAD_TRUNCATED_IMAGES = True | ||
|
||
class Garbage_Loader(Dataset): | ||
def __init__(self, txt_path, train_flag=True): | ||
self.imgs_info = self.get_images(txt_path) | ||
self.train_flag = train_flag | ||
|
||
self.train_tf = transforms.Compose([ | ||
transforms.Resize(224), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.RandomVerticalFlip(), | ||
transforms.ToTensor(), | ||
|
||
]) | ||
self.val_tf = transforms.Compose([ | ||
transforms.Resize(224), | ||
transforms.ToTensor(), | ||
]) | ||
|
||
def get_images(self, txt_path): | ||
with open(txt_path, 'r', encoding='utf-8') as f: | ||
imgs_info = f.readlines() | ||
imgs_info = list(map(lambda x:x.strip().split('\t'), imgs_info)) | ||
return imgs_info | ||
|
||
def padding_black(self, img): | ||
|
||
w, h = img.size | ||
|
||
scale = 224. / max(w, h) | ||
img_fg = img.resize([int(x) for x in [w * scale, h * scale]]) | ||
|
||
size_fg = img_fg.size | ||
size_bg = 224 | ||
|
||
img_bg = Image.new("RGB", (size_bg, size_bg)) | ||
|
||
img_bg.paste(img_fg, ((size_bg - size_fg[0]) // 2, | ||
(size_bg - size_fg[1]) // 2)) | ||
|
||
img = img_bg | ||
return img | ||
|
||
def __getitem__(self, index): | ||
img_path, label = self.imgs_info[index] | ||
img = Image.open(img_path) | ||
img = img.convert('RGB') | ||
img = self.padding_black(img) | ||
if self.train_flag: | ||
img = self.train_tf(img) | ||
else: | ||
img = self.val_tf(img) | ||
label = int(label) | ||
|
||
return img, label | ||
|
||
def __len__(self): | ||
return len(self.imgs_info) | ||
|
||
|
||
if __name__ == "__main__": | ||
train_dataset = Garbage_Loader("train.txt", True) | ||
print("数据个数:", len(train_dataset)) | ||
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, | ||
batch_size=1, | ||
shuffle=True) | ||
for image, label in train_loader: | ||
print(image.shape) | ||
print(label) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from dataset import Garbage_Loader | ||
from torch.utils.data import DataLoader | ||
import torchvision.transforms as transforms | ||
from torchvision import models | ||
import torch.nn as nn | ||
import torch | ||
import os | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
#%matplotlib inline | ||
os.environ["CUDA_VISIBLE_DEVICES"] = "0" | ||
|
||
def softmax(x): | ||
exp_x = np.exp(x) | ||
softmax_x = exp_x / np.sum(exp_x, 0) | ||
return softmax_x | ||
|
||
with open('dir_label.txt', 'r', encoding='utf-8') as f: | ||
labels = f.readlines() | ||
labels = list(map(lambda x:x.strip().split('\t'), labels)) | ||
|
||
if __name__ == "__main__": | ||
test_list = 'test.txt' | ||
test_data = Garbage_Loader(test_list, train_flag=False) | ||
test_loader = DataLoader(dataset=test_data, num_workers=1, pin_memory=True, batch_size=1) | ||
model = models.resnet50(pretrained=False) | ||
fc_inputs = model.fc.in_features | ||
model.fc = nn.Linear(fc_inputs, 214) | ||
model = model.cuda() | ||
checkpoint = torch.load('model_best_checkpoint_resnet50.pth.tar') | ||
model.load_state_dict(checkpoint['state_dict']) | ||
model.eval() | ||
for i, (image, label) in enumerate(test_loader): | ||
src = image.numpy() | ||
src = src.reshape(3, 224, 224) | ||
src = np.transpose(src, (1, 2, 0)) | ||
image = image.cuda() | ||
label = label.cuda() | ||
pred = model(image) | ||
pred = pred.data.cpu().numpy()[0] | ||
score = softmax(pred) | ||
pred_id = np.argmax(score) | ||
plt.imshow(src) | ||
print('预测结果:', labels[pred_id][0]) | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,221 @@ | ||
from dataset import Garbage_Loader | ||
from torch.utils.data import DataLoader | ||
from torchvision import models | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
import torch | ||
import time | ||
import os | ||
import shutil | ||
os.environ["CUDA_VISIBLE_DEVICES"] = "0" | ||
|
||
""" | ||
Author : Jack Cui | ||
Wechat : https://mp.weixin.qq.com/s/OCWwRVDFNslIuKyiCVUoTA | ||
""" | ||
|
||
from tensorboardX import SummaryWriter | ||
|
||
def accuracy(output, target, topk=(1,)): | ||
""" | ||
计算topk的准确率 | ||
""" | ||
with torch.no_grad(): | ||
maxk = max(topk) | ||
batch_size = target.size(0) | ||
|
||
_, pred = output.topk(maxk, 1, True, True) | ||
pred = pred.t() | ||
correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||
|
||
class_to = pred[0].cpu().numpy() | ||
|
||
res = [] | ||
for k in topk: | ||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||
res.append(correct_k.mul_(100.0 / batch_size)) | ||
return res, class_to | ||
|
||
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): | ||
""" | ||
根据 is_best 存模型,一般保存 valid acc 最好的模型 | ||
""" | ||
torch.save(state, filename) | ||
if is_best: | ||
shutil.copyfile(filename, 'model_best_' + filename) | ||
|
||
def train(train_loader, model, criterion, optimizer, epoch, writer): | ||
""" | ||
训练代码 | ||
参数: | ||
train_loader - 训练集的 DataLoader | ||
model - 模型 | ||
criterion - 损失函数 | ||
optimizer - 优化器 | ||
epoch - 进行第几个 epoch | ||
writer - 用于写 tensorboardX | ||
""" | ||
batch_time = AverageMeter() | ||
data_time = AverageMeter() | ||
losses = AverageMeter() | ||
top1 = AverageMeter() | ||
top5 = AverageMeter() | ||
|
||
# switch to train mode | ||
model.train() | ||
|
||
end = time.time() | ||
for i, (input, target) in enumerate(train_loader): | ||
# measure data loading time | ||
data_time.update(time.time() - end) | ||
|
||
input = input.cuda() | ||
target = target.cuda() | ||
|
||
# compute output | ||
output = model(input) | ||
loss = criterion(output, target) | ||
|
||
# measure accuracy and record loss | ||
[prec1, prec5], class_to = accuracy(output, target, topk=(1, 5)) | ||
losses.update(loss.item(), input.size(0)) | ||
top1.update(prec1[0], input.size(0)) | ||
top5.update(prec5[0], input.size(0)) | ||
|
||
# compute gradient and do SGD step | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
# measure elapsed time | ||
batch_time.update(time.time() - end) | ||
end = time.time() | ||
|
||
if i % 10 == 0: | ||
print('Epoch: [{0}][{1}/{2}]\t' | ||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | ||
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' | ||
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' | ||
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' | ||
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( | ||
epoch, i, len(train_loader), batch_time=batch_time, | ||
data_time=data_time, loss=losses, top1=top1, top5=top5)) | ||
writer.add_scalar('loss/train_loss', losses.val, global_step=epoch) | ||
|
||
def validate(val_loader, model, criterion, epoch, writer, phase="VAL"): | ||
""" | ||
验证代码 | ||
参数: | ||
val_loader - 验证集的 DataLoader | ||
model - 模型 | ||
criterion - 损失函数 | ||
epoch - 进行第几个 epoch | ||
writer - 用于写 tensorboardX | ||
""" | ||
batch_time = AverageMeter() | ||
losses = AverageMeter() | ||
top1 = AverageMeter() | ||
top5 = AverageMeter() | ||
|
||
# switch to evaluate mode | ||
model.eval() | ||
|
||
with torch.no_grad(): | ||
end = time.time() | ||
for i, (input, target) in enumerate(val_loader): | ||
input = input.cuda() | ||
target = target.cuda() | ||
# compute output | ||
output = model(input) | ||
loss = criterion(output, target) | ||
|
||
# measure accuracy and record loss | ||
[prec1, prec5], class_to = accuracy(output, target, topk=(1, 5)) | ||
losses.update(loss.item(), input.size(0)) | ||
top1.update(prec1[0], input.size(0)) | ||
top5.update(prec5[0], input.size(0)) | ||
|
||
# measure elapsed time | ||
batch_time.update(time.time() - end) | ||
end = time.time() | ||
|
||
if i % 10 == 0: | ||
print('Test-{0}: [{1}/{2}]\t' | ||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | ||
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' | ||
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' | ||
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( | ||
phase, i, len(val_loader), | ||
batch_time=batch_time, | ||
loss=losses, | ||
top1=top1, top5=top5)) | ||
|
||
print(' * {} Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' | ||
.format(phase, top1=top1, top5=top5)) | ||
writer.add_scalar('loss/valid_loss', losses.val, global_step=epoch) | ||
return top1.avg, top5.avg | ||
|
||
class AverageMeter(object): | ||
"""Computes and stores the average and current value""" | ||
def __init__(self): | ||
self.reset() | ||
|
||
def reset(self): | ||
self.val = 0 | ||
self.avg = 0 | ||
self.sum = 0 | ||
self.count = 0 | ||
|
||
def update(self, val, n=1): | ||
self.val = val | ||
self.sum += val * n | ||
self.count += n | ||
self.avg = self.sum / self.count | ||
|
||
if __name__ == "__main__": | ||
# -------------------------------------------- step 1/4 : 加载数据 --------------------------- | ||
train_dir_list = 'train.txt' | ||
valid_dir_list = 'val.txt' | ||
batch_size = 64 | ||
epochs = 80 | ||
num_classes = 214 | ||
train_data = Garbage_Loader(train_dir_list, train_flag=True) | ||
valid_data = Garbage_Loader(valid_dir_list, train_flag=False) | ||
train_loader = DataLoader(dataset=train_data, num_workers=8, pin_memory=True, batch_size=batch_size, shuffle=True) | ||
valid_loader = DataLoader(dataset=valid_data, num_workers=8, pin_memory=True, batch_size=batch_size) | ||
train_data_size = len(train_data) | ||
print('训练集数量:%d' % train_data_size) | ||
valid_data_size = len(valid_data) | ||
print('验证集数量:%d' % valid_data_size) | ||
# ------------------------------------ step 2/4 : 定义网络 ------------------------------------ | ||
model = models.resnet50(pretrained=True) | ||
fc_inputs = model.fc.in_features | ||
model.fc = nn.Linear(fc_inputs, num_classes) | ||
model = model.cuda() | ||
# ------------------------------------ step 3/4 : 定义损失函数和优化器等 ------------------------- | ||
lr_init = 0.0001 | ||
lr_stepsize = 20 | ||
weight_decay = 0.001 | ||
criterion = nn.CrossEntropyLoss().cuda() | ||
optimizer = optim.Adam(model.parameters(), lr=lr_init, weight_decay=weight_decay) | ||
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_stepsize, gamma=0.1) | ||
|
||
writer = SummaryWriter('runs/resnet50') | ||
# ------------------------------------ step 4/4 : 训练 ----------------------------------------- | ||
best_prec1 = 0 | ||
for epoch in range(epochs): | ||
scheduler.step() | ||
train(train_loader, model, criterion, optimizer, epoch, writer) | ||
# 在验证集上测试效果 | ||
valid_prec1, valid_prec5 = validate(valid_loader, model, criterion, epoch, writer, phase="VAL") | ||
is_best = valid_prec1 > best_prec1 | ||
best_prec1 = max(valid_prec1, best_prec1) | ||
save_checkpoint({ | ||
'epoch': epoch + 1, | ||
'arch': 'resnet50', | ||
'state_dict': model.state_dict(), | ||
'best_prec1': best_prec1, | ||
'optimizer' : optimizer.state_dict(), | ||
}, is_best, | ||
filename='checkpoint_resnet50.pth.tar') | ||
writer.close() |