diff --git a/.idea/misc.xml b/.idea/misc.xml
index ab530bf..8161a60 100644
--- a/.idea/misc.xml
+++ b/.idea/misc.xml
@@ -1,4 +1,4 @@
-
+
\ No newline at end of file
diff --git a/.idea/segment_snn.iml b/.idea/segment_snn.iml
index d870a4a..b7077a5 100644
--- a/.idea/segment_snn.iml
+++ b/.idea/segment_snn.iml
@@ -2,7 +2,7 @@
-
+
\ No newline at end of file
diff --git a/predict.py b/predict.py
index 6b16096..f7d927d 100644
--- a/predict.py
+++ b/predict.py
@@ -1,10 +1,19 @@
+import argparse
+
import cv2
+import torch
+from torch import nn
+from torch.utils.data import DataLoader
+
+from dataset import SegmentationDataset
+from model import SegmentModel
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--img_path", type=int, default=8, help="input image path")
parser.add_argument("-s", "--step", type=int, default=8, help="time slice")
parser.add_argument("-os", "--output_size", type=tuple, default=(128, 128), help="size of images(H, W)")
+ parser.add_argument("-od", "--output_dir", type=str, default='./test/mask.png', help="output directory")
args = parser.parse_args()
setup_seed(42)
@@ -14,8 +23,24 @@
img_path = args.img_path
step = args.step
output_size = args.output_size
+ output_dir = args.output_dir
+ device = 'cuda:0'
# get image
- cv2.imread()
+ test_data = SegmentationDataset(root=img_path)
+ test_iter = DataLoader(test_data, batch_size=1, shuffle=True, num_workers=0)
+
+
+ net = SegmentModel(output_size=output_size, out_cls=train_dataset.num_class, node=BiasLIFNode, step=step)
+ # load model
+ torch.load(net, device, './checkpoints/Segment_SNN.pth')
+ net = net.to(device)
+ with torch.no_grad():
+ for (idx, img) in enumerate(test_iter):
+ logits = net(img.to(device))
+ softmax = nn.Softmax(dim=1)
+ logits = softmax(logits)
+ mask = torch.argmax(logits, dim=1)
+ cv2.imwrite(output_dir, mask)
+
- net = SegmentModel(output_size=output_size, out_cls=train_dataset.num_class, node=BiasLIFNode, step=step)
\ No newline at end of file
diff --git a/train.py b/train.py
index ff82bd1..df0fbab 100644
--- a/train.py
+++ b/train.py
@@ -1,6 +1,9 @@
+import argparse
import sys
+import numpy as np
import torch
+from torch import nn
from torch.utils.data import DataLoader
sys.path.append('../../..')
@@ -57,7 +60,7 @@ def train(net, train_iter, test_iter, optimizer, scheduler, device, num_epochs,
if test_acc > best:
best = test_acc
- torch.save(net.state_dict(), './checkpoints/CIFAR10_VGG16.pth')
+ torch.save(net.state_dict(), './checkpoints/Segment_SNN.pth')
def estimate_dice(gt_msk, prt_msk):
intersection = gt_msk * prt_msk