From 083198cb62634e7bad76cf749e8a10eb85f7e7e3 Mon Sep 17 00:00:00 2001 From: Jack Cui Date: Fri, 22 May 2020 00:09:23 +0800 Subject: [PATCH] garbage garbage --- Pytorch-Seg/dataset.py | 77 ++++++++++++++ Pytorch-Seg/infer.py | 45 +++++++++ Pytorch-Seg/train.py | 221 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 343 insertions(+) create mode 100644 Pytorch-Seg/dataset.py create mode 100644 Pytorch-Seg/infer.py create mode 100644 Pytorch-Seg/train.py diff --git a/Pytorch-Seg/dataset.py b/Pytorch-Seg/dataset.py new file mode 100644 index 0000000..58060eb --- /dev/null +++ b/Pytorch-Seg/dataset.py @@ -0,0 +1,77 @@ +import torch +from PIL import Image +import os +import glob +from torch.utils.data import Dataset +import random +import torchvision.transforms as transforms +from PIL import ImageFile +ImageFile.LOAD_TRUNCATED_IMAGES = True + +class Garbage_Loader(Dataset): + def __init__(self, txt_path, train_flag=True): + self.imgs_info = self.get_images(txt_path) + self.train_flag = train_flag + + self.train_tf = transforms.Compose([ + transforms.Resize(224), + transforms.RandomHorizontalFlip(), + transforms.RandomVerticalFlip(), + transforms.ToTensor(), + + ]) + self.val_tf = transforms.Compose([ + transforms.Resize(224), + transforms.ToTensor(), + ]) + + def get_images(self, txt_path): + with open(txt_path, 'r', encoding='utf-8') as f: + imgs_info = f.readlines() + imgs_info = list(map(lambda x:x.strip().split('\t'), imgs_info)) + return imgs_info + + def padding_black(self, img): + + w, h = img.size + + scale = 224. / max(w, h) + img_fg = img.resize([int(x) for x in [w * scale, h * scale]]) + + size_fg = img_fg.size + size_bg = 224 + + img_bg = Image.new("RGB", (size_bg, size_bg)) + + img_bg.paste(img_fg, ((size_bg - size_fg[0]) // 2, + (size_bg - size_fg[1]) // 2)) + + img = img_bg + return img + + def __getitem__(self, index): + img_path, label = self.imgs_info[index] + img = Image.open(img_path) + img = img.convert('RGB') + img = self.padding_black(img) + if self.train_flag: + img = self.train_tf(img) + else: + img = self.val_tf(img) + label = int(label) + + return img, label + + def __len__(self): + return len(self.imgs_info) + + +if __name__ == "__main__": + train_dataset = Garbage_Loader("train.txt", True) + print("数据个数:", len(train_dataset)) + train_loader = torch.utils.data.DataLoader(dataset=train_dataset, + batch_size=1, + shuffle=True) + for image, label in train_loader: + print(image.shape) + print(label) \ No newline at end of file diff --git a/Pytorch-Seg/infer.py b/Pytorch-Seg/infer.py new file mode 100644 index 0000000..ec8a60f --- /dev/null +++ b/Pytorch-Seg/infer.py @@ -0,0 +1,45 @@ +from dataset import Garbage_Loader +from torch.utils.data import DataLoader +import torchvision.transforms as transforms +from torchvision import models +import torch.nn as nn +import torch +import os +import numpy as np +import matplotlib.pyplot as plt +#%matplotlib inline +os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +def softmax(x): + exp_x = np.exp(x) + softmax_x = exp_x / np.sum(exp_x, 0) + return softmax_x + +with open('dir_label.txt', 'r', encoding='utf-8') as f: + labels = f.readlines() + labels = list(map(lambda x:x.strip().split('\t'), labels)) + +if __name__ == "__main__": + test_list = 'test.txt' + test_data = Garbage_Loader(test_list, train_flag=False) + test_loader = DataLoader(dataset=test_data, num_workers=1, pin_memory=True, batch_size=1) + model = models.resnet50(pretrained=False) + fc_inputs = model.fc.in_features + model.fc = nn.Linear(fc_inputs, 214) + model = model.cuda() + checkpoint = torch.load('model_best_checkpoint_resnet50.pth.tar') + model.load_state_dict(checkpoint['state_dict']) + model.eval() + for i, (image, label) in enumerate(test_loader): + src = image.numpy() + src = src.reshape(3, 224, 224) + src = np.transpose(src, (1, 2, 0)) + image = image.cuda() + label = label.cuda() + pred = model(image) + pred = pred.data.cpu().numpy()[0] + score = softmax(pred) + pred_id = np.argmax(score) + plt.imshow(src) + print('预测结果:', labels[pred_id][0]) + plt.show() \ No newline at end of file diff --git a/Pytorch-Seg/train.py b/Pytorch-Seg/train.py new file mode 100644 index 0000000..8d0b5a2 --- /dev/null +++ b/Pytorch-Seg/train.py @@ -0,0 +1,221 @@ +from dataset import Garbage_Loader +from torch.utils.data import DataLoader +from torchvision import models +import torch.nn as nn +import torch.optim as optim +import torch +import time +import os +import shutil +os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +""" + Author : Jack Cui + Wechat : https://mp.weixin.qq.com/s/OCWwRVDFNslIuKyiCVUoTA +""" + +from tensorboardX import SummaryWriter + +def accuracy(output, target, topk=(1,)): + """ + 计算topk的准确率 + """ + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + class_to = pred[0].cpu().numpy() + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res, class_to + +def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): + """ + 根据 is_best 存模型,一般保存 valid acc 最好的模型 + """ + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, 'model_best_' + filename) + +def train(train_loader, model, criterion, optimizer, epoch, writer): + """ + 训练代码 + 参数: + train_loader - 训练集的 DataLoader + model - 模型 + criterion - 损失函数 + optimizer - 优化器 + epoch - 进行第几个 epoch + writer - 用于写 tensorboardX + """ + batch_time = AverageMeter() + data_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + + # switch to train mode + model.train() + + end = time.time() + for i, (input, target) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + input = input.cuda() + target = target.cuda() + + # compute output + output = model(input) + loss = criterion(output, target) + + # measure accuracy and record loss + [prec1, prec5], class_to = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), input.size(0)) + top1.update(prec1[0], input.size(0)) + top5.update(prec5[0], input.size(0)) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % 10 == 0: + print('Epoch: [{0}][{1}/{2}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + epoch, i, len(train_loader), batch_time=batch_time, + data_time=data_time, loss=losses, top1=top1, top5=top5)) + writer.add_scalar('loss/train_loss', losses.val, global_step=epoch) + +def validate(val_loader, model, criterion, epoch, writer, phase="VAL"): + """ + 验证代码 + 参数: + val_loader - 验证集的 DataLoader + model - 模型 + criterion - 损失函数 + epoch - 进行第几个 epoch + writer - 用于写 tensorboardX + """ + batch_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + + # switch to evaluate mode + model.eval() + + with torch.no_grad(): + end = time.time() + for i, (input, target) in enumerate(val_loader): + input = input.cuda() + target = target.cuda() + # compute output + output = model(input) + loss = criterion(output, target) + + # measure accuracy and record loss + [prec1, prec5], class_to = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), input.size(0)) + top1.update(prec1[0], input.size(0)) + top5.update(prec5[0], input.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % 10 == 0: + print('Test-{0}: [{1}/{2}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + phase, i, len(val_loader), + batch_time=batch_time, + loss=losses, + top1=top1, top5=top5)) + + print(' * {} Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' + .format(phase, top1=top1, top5=top5)) + writer.add_scalar('loss/valid_loss', losses.val, global_step=epoch) + return top1.avg, top5.avg + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + +if __name__ == "__main__": + # -------------------------------------------- step 1/4 : 加载数据 --------------------------- + train_dir_list = 'train.txt' + valid_dir_list = 'val.txt' + batch_size = 64 + epochs = 80 + num_classes = 214 + train_data = Garbage_Loader(train_dir_list, train_flag=True) + valid_data = Garbage_Loader(valid_dir_list, train_flag=False) + train_loader = DataLoader(dataset=train_data, num_workers=8, pin_memory=True, batch_size=batch_size, shuffle=True) + valid_loader = DataLoader(dataset=valid_data, num_workers=8, pin_memory=True, batch_size=batch_size) + train_data_size = len(train_data) + print('训练集数量:%d' % train_data_size) + valid_data_size = len(valid_data) + print('验证集数量:%d' % valid_data_size) + # ------------------------------------ step 2/4 : 定义网络 ------------------------------------ + model = models.resnet50(pretrained=True) + fc_inputs = model.fc.in_features + model.fc = nn.Linear(fc_inputs, num_classes) + model = model.cuda() + # ------------------------------------ step 3/4 : 定义损失函数和优化器等 ------------------------- + lr_init = 0.0001 + lr_stepsize = 20 + weight_decay = 0.001 + criterion = nn.CrossEntropyLoss().cuda() + optimizer = optim.Adam(model.parameters(), lr=lr_init, weight_decay=weight_decay) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_stepsize, gamma=0.1) + + writer = SummaryWriter('runs/resnet50') + # ------------------------------------ step 4/4 : 训练 ----------------------------------------- + best_prec1 = 0 + for epoch in range(epochs): + scheduler.step() + train(train_loader, model, criterion, optimizer, epoch, writer) + # 在验证集上测试效果 + valid_prec1, valid_prec5 = validate(valid_loader, model, criterion, epoch, writer, phase="VAL") + is_best = valid_prec1 > best_prec1 + best_prec1 = max(valid_prec1, best_prec1) + save_checkpoint({ + 'epoch': epoch + 1, + 'arch': 'resnet50', + 'state_dict': model.state_dict(), + 'best_prec1': best_prec1, + 'optimizer' : optimizer.state_dict(), + }, is_best, + filename='checkpoint_resnet50.pth.tar') + writer.close()