diff --git a/.idea/misc.xml b/.idea/misc.xml index ab530bf..8161a60 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -1,4 +1,4 @@ - + \ No newline at end of file diff --git a/.idea/segment_snn.iml b/.idea/segment_snn.iml index d870a4a..b7077a5 100644 --- a/.idea/segment_snn.iml +++ b/.idea/segment_snn.iml @@ -2,7 +2,7 @@ - + \ No newline at end of file diff --git a/predict.py b/predict.py index 6b16096..f7d927d 100644 --- a/predict.py +++ b/predict.py @@ -1,10 +1,19 @@ +import argparse + import cv2 +import torch +from torch import nn +from torch.utils.data import DataLoader + +from dataset import SegmentationDataset +from model import SegmentModel if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("-i", "--img_path", type=int, default=8, help="input image path") parser.add_argument("-s", "--step", type=int, default=8, help="time slice") parser.add_argument("-os", "--output_size", type=tuple, default=(128, 128), help="size of images(H, W)") + parser.add_argument("-od", "--output_dir", type=str, default='./test/mask.png', help="output directory") args = parser.parse_args() setup_seed(42) @@ -14,8 +23,24 @@ img_path = args.img_path step = args.step output_size = args.output_size + output_dir = args.output_dir + device = 'cuda:0' # get image - cv2.imread() + test_data = SegmentationDataset(root=img_path) + test_iter = DataLoader(test_data, batch_size=1, shuffle=True, num_workers=0) + + + net = SegmentModel(output_size=output_size, out_cls=train_dataset.num_class, node=BiasLIFNode, step=step) + # load model + torch.load(net, device, './checkpoints/Segment_SNN.pth') + net = net.to(device) + with torch.no_grad(): + for (idx, img) in enumerate(test_iter): + logits = net(img.to(device)) + softmax = nn.Softmax(dim=1) + logits = softmax(logits) + mask = torch.argmax(logits, dim=1) + cv2.imwrite(output_dir, mask) + - net = SegmentModel(output_size=output_size, out_cls=train_dataset.num_class, node=BiasLIFNode, step=step) \ No newline at end of file diff --git a/train.py b/train.py index ff82bd1..df0fbab 100644 --- a/train.py +++ b/train.py @@ -1,6 +1,9 @@ +import argparse import sys +import numpy as np import torch +from torch import nn from torch.utils.data import DataLoader sys.path.append('../../..') @@ -57,7 +60,7 @@ def train(net, train_iter, test_iter, optimizer, scheduler, device, num_epochs, if test_acc > best: best = test_acc - torch.save(net.state_dict(), './checkpoints/CIFAR10_VGG16.pth') + torch.save(net.state_dict(), './checkpoints/Segment_SNN.pth') def estimate_dice(gt_msk, prt_msk): intersection = gt_msk * prt_msk