Skip to content

Commit

Permalink
'修改了dataset.py和train.py'
Browse files Browse the repository at this point in the history
  • Loading branch information
AsCome11 committed Dec 21, 2023
1 parent 5b26b28 commit e898c53
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 59 deletions.
119 changes: 75 additions & 44 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,37 +96,37 @@ def pred_offset(self):
return 0

class COCOSegmentation(SegmentationDataset):
"""COCO Semantic Segmentation Dataset for VOC Pre-training.
"""
COCO数据集, 参考https://github.com/Tramac/awesome-semantic-segmentation-pytorch
Parameters
----------
root : string
Path to ADE20K folder. Default is './datasets/coco'
split: string
'train', 'val' or 'test'
transform : callable, optional
A function that transforms the image
Examples
--------
>>> from torchvision import transforms
>>> import torch.utils.data as data
>>> # Transforms for Normalization
>>> input_transform = transforms.Compose([
>>> transforms.ToTensor(),
>>> transforms.Normalize((.485, .456, .406), (.229, .224, .225)),
>>> ])
>>> # Create Dataset
>>> trainset = COCOSegmentation(split='train', transform=input_transform)
>>> # Create Training Loader
>>> train_data = data.DataLoader(
>>> trainset, 4, shuffle=True,
>>> num_workers=4)
"""
CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4,
1, 64, 20, 63, 7, 72]
NUM_CLASS = 21

def __init__(self, root='../datasets/coco', annotation_root='', split='train', mode=None, transform=None, stride=1, **kwargs):
"""
Parameters:
----------
root:
图像文件夹路径, COCO数据集的路径应为root/val2017
annotation_root:
COCO数据集相应的annotation路径, json文件在annotation_root/annotations下
split:
'train' / 'val', 对应不同的数据增广方式, 本项目使用val2017的数据 (5000张) 进行训练, 因此将split设为'val'
mode:
'train' / 'val', 若不指定mode则self.mode = split
transform:
基本的图像增广 (利用ImageNet的mean和std归一化)
stride:
快速眼动法的步长
"""
super(COCOSegmentation, self).__init__(root, split, mode, transform, **kwargs)
# lazy import pycocotools
from pycocotools.coco import COCO
Expand Down Expand Up @@ -200,21 +200,6 @@ def _gen_seg_mask(self, target, h, w):
mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype(np.uint8)
return mask

# def generate_dynamic_translation(self, image):
# tracex = self.stride * 2 * np.array([0, 2, 1, 0, 2, 1, 1, 2, 1])
# tracey = self.stride * 2 * np.array([0, 1, 2, 1, 0, 2, 1, 1, 2])
#
# num_frames = len(tracex)
# channel = image.shape[0]
# height = image.shape[1]
# width = image.shape[2]
#
# frames = torch.zeros((num_frames, channel, height, width))
# for i in range(num_frames):
# anchor_x = tracex[i]
# anchor_y = tracey[i]
# frames[i, :, anchor_y // 2: height - anchor_y // 2, anchor_x // 2: width - anchor_x // 2] = image[:, anchor_y:, anchor_x:]
# return frames
def generate_dynamic_translation(self, image):
tracex = self.stride * 2 * np.array([0, 2, 1, 0, 2, 1, 1, 2, 1])
tracey = self.stride * 2 * np.array([0, 1, 2, 1, 0, 2, 1, 1, 2])
Expand All @@ -223,17 +208,18 @@ def generate_dynamic_translation(self, image):
channel = image.shape[0]
height = image.shape[1]
width = image.shape[2]

frames = torch.zeros((num_frames, channel, height, width))
image_copy = image.copy()
image_tensor = torch.from_numpy(image_copy)
for i in range(num_frames):
anchor_x = tracex[i]
anchor_y = tracey[i]
frames[i, :, anchor_y // 2: height - anchor_y // 2, anchor_x // 2: width - anchor_x // 2] = image_tensor[:,anchor_y:, anchor_x:]
frames[i, :, anchor_y // 2: height - anchor_y // 2, anchor_x // 2: width - anchor_x // 2] = image[:, anchor_y:, anchor_x:]
return frames

def _preprocess(self, ids, ids_file):
"""
预处理, 保留mask像素大于1000的图片
"""
print("Preprocessing mask, this will take a while." + \
"But don't worry, it only run once for each split.")
tbar = trange(len(ids))
Expand Down Expand Up @@ -264,6 +250,51 @@ def classes(self):
def __len__(self):
return len(self.ids)



class test_dataset(torch.utils.data.Dataset):

def __init__(self, path='../datasets/coco/val2017', transform=None, output_size=(480,480), stride=1, **kwargs):
super(test_dataset, self).__init__(path, transform, **kwargs)
self.path = path
self.transform = transform
self.stride = stride
self.img_names = os.listdir(path)
self.output_size = output_size

def __getitem__(self, index):
img_id = self.img_names[index]
img = Image.open(os.path.join(self.path, img_id)).convert('RGB')
# general resize, normalize and toTensor
if self.transform is not None:
img = self.transform(img)
frames = torch.diff(self.generate_dynamic_translation(img), dim=0)
p_img = torch.zeros_like(frames)
n_img = torch.zeros_like(frames)
p_img[frames > 0] = frames[frames > 0]
n_img[frames < 0] = frames[frames < 0]
output = torch.concat([p_img, n_img], dim=1)
return output

def generate_dynamic_translation(self, image):
tracex = self.stride * 2 * np.array([0, 2, 1, 0, 2, 1, 1, 2, 1])
tracey = self.stride * 2 * np.array([0, 1, 2, 1, 0, 2, 1, 1, 2])

num_frames = len(tracex)
channel = image.shape[0]
height = image.shape[1]
width = image.shape[2]

frames = torch.zeros((num_frames, channel, height, width))
for i in range(num_frames):
anchor_x = tracex[i]
anchor_y = tracey[i]
frames[i, :, anchor_y // 2: height - anchor_y // 2, anchor_x // 2: width - anchor_x // 2] = image[:, anchor_y:, anchor_x:]
return frames

def __len__(self):
return len(self.img_names)

if __name__ == '__main__':
input_transform = transforms.Compose([
transforms.ToTensor(),
Expand Down
15 changes: 11 additions & 4 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
from torch import nn
from torch.utils.data import DataLoader

from dataset import SegmentationDataset
from dataset import test_dataset
from model import SegmentModel
from braincog.base.node.node import *
from torchvision import transforms
from braincog.utils import setup_seed


if __name__ == '__main__':
parser = argparse.ArgumentParser()
Expand All @@ -25,13 +29,16 @@
output_size = args.output_size
output_dir = args.output_dir
device = 'cuda:0'

input_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((.485, .456, .406), (.229, .224, .225)),
])
# get image
test_data = SegmentationDataset(root=img_path)
test_data = test_dataset(root=img_path)
test_iter = DataLoader(test_data, batch_size=1, shuffle=True, num_workers=0)


net = SegmentModel(output_size=output_size, out_cls=train_dataset.num_class, node=BiasLIFNode, step=step)
net = SegmentModel(output_size=output_size, out_cls=21, node=BiasLIFNode, step=step)
# load model
torch.load(net, device, './checkpoints/Segment_SNN.pth')
net = net.to(device)
Expand Down
26 changes: 15 additions & 11 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
DATA_DIR = '/data/datasets'


device = torch.device('cuda:5' if torch.cuda.is_available() else 'cpu')
DATA_DIR = '/data/datasets'


def train(net, train_iter, test_iter, optimizer, scheduler, device, num_epochs, save_path='./checkpoints', losstype='mse'):
best = 0
net = net.to(device)
Expand Down Expand Up @@ -52,7 +56,11 @@ def train(net, train_iter, test_iter, optimizer, scheduler, device, num_epochs,
# y = torch.ones(8, 13, 128, 128).to(device)
y_hat = net(X)
label = y
l = loss(y,label)
if losstype == 'mse':
one_hot_label = F.one_hot(label, 21)
l = loss(y_hat.permute(0,2,3,1), one_hot_label.float())
else:
l = loss(y_hat,label)
losss.append(l.cpu().item())
l.backward()
optimizer.step()
Expand All @@ -71,7 +79,8 @@ def train(net, train_iter, test_iter, optimizer, scheduler, device, num_epochs,

if test_acc > best:
best = test_acc
torch.save(net.state_dict(), os.path.join(save_path, '/SegmentModel.pth'))
torch.save(net.state_dict(), os.path.join(save_path, 'SegmentModel.pth'))
print('Best model saved! ')

def estimate_dice(gt_msk, prt_msk):
intersection = gt_msk * prt_msk
Expand All @@ -81,25 +90,20 @@ def estimate_dice(gt_msk, prt_msk):
def evaluate_accuracy(data_iter, net, device=None, only_onebatch=False):
if device is None and isinstance(net, torch.nn.Module):
device = list(net.parameters())[0].device
acc_sum, n = 0.0, 0
dice = Dice(average='micro').to(device)
acc = []
net.eval()
with torch.no_grad():
tbar = tqdm(data_iter)
for X, y in tbar:
logits = net(X.to(device))
y = y.to(device)
# softmax = nn.Softmax(dim=1)
# logits = softmax(logits)
acc = dice(logits, y.detach())
acc_sum += acc.cpu().item()
n += y.shape[0]
tbar.set_description(f'Validation acc = {acc: .4f}')

acc.append(dice(logits, y.detach()).item())
tbar.set_description(f'Validation acc = {acc[-1]: .4f}')
if only_onebatch:
break
net.train()
return acc_sum / len(acc)
return sum(acc) / len(acc)


if __name__ == '__main__':
Expand Down

0 comments on commit e898c53

Please sign in to comment.