Skip to content

Commit

Permalink
modify predict.py
Browse files Browse the repository at this point in the history
modify README.md
  • Loading branch information
yahuiwei123 committed Dec 20, 2023
1 parent 2ba37cc commit 0b90273
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion .idea/segment_snn.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 27 additions & 2 deletions predict.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
import argparse

import cv2
import torch
from torch import nn
from torch.utils.data import DataLoader

from dataset import SegmentationDataset
from model import SegmentModel

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--img_path", type=int, default=8, help="input image path")
parser.add_argument("-s", "--step", type=int, default=8, help="time slice")
parser.add_argument("-os", "--output_size", type=tuple, default=(128, 128), help="size of images(H, W)")
parser.add_argument("-od", "--output_dir", type=str, default='./test/mask.png', help="output directory")
args = parser.parse_args()

setup_seed(42)
Expand All @@ -14,8 +23,24 @@
img_path = args.img_path
step = args.step
output_size = args.output_size
output_dir = args.output_dir
device = 'cuda:0'

# get image
cv2.imread()
test_data = SegmentationDataset(root=img_path)
test_iter = DataLoader(test_data, batch_size=1, shuffle=True, num_workers=0)


net = SegmentModel(output_size=output_size, out_cls=train_dataset.num_class, node=BiasLIFNode, step=step)
# load model
torch.load(net, device, './checkpoints/Segment_SNN.pth')
net = net.to(device)
with torch.no_grad():
for (idx, img) in enumerate(test_iter):
logits = net(img.to(device))
softmax = nn.Softmax(dim=1)
logits = softmax(logits)
mask = torch.argmax(logits, dim=1)
cv2.imwrite(output_dir, mask)


net = SegmentModel(output_size=output_size, out_cls=train_dataset.num_class, node=BiasLIFNode, step=step)
5 changes: 4 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import argparse
import sys

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader

sys.path.append('../../..')
Expand Down Expand Up @@ -57,7 +60,7 @@ def train(net, train_iter, test_iter, optimizer, scheduler, device, num_epochs,

if test_acc > best:
best = test_acc
torch.save(net.state_dict(), './checkpoints/CIFAR10_VGG16.pth')
torch.save(net.state_dict(), './checkpoints/Segment_SNN.pth')

def estimate_dice(gt_msk, prt_msk):
intersection = gt_msk * prt_msk
Expand Down

0 comments on commit 0b90273

Please sign in to comment.