Skip to content

Commit

Permalink
modify train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yahuiwei123 committed Dec 20, 2023
1 parent 14625d9 commit 76f3be6
Showing 1 changed file with 41 additions and 25 deletions.
66 changes: 41 additions & 25 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import argparse
import sys

import numpy as np
import torch.nn.functional as F
import torch
from torch import nn
from torch.utils.data import DataLoader

sys.path.append('../../..')
Expand All @@ -12,18 +10,25 @@
from torchvision.transforms import *
import time
from braincog.utils import setup_seed
from dataset import SegmentationDataset
from dataset import *
from braincog.base.node.node import *
from model import SegmentModel
from tqdm import tqdm
from torchmetrics import Dice
import os


device = torch.device('cuda:5' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
DATA_DIR = '/data/datasets'


def train(net, train_iter, test_iter, optimizer, scheduler, device, num_epochs, losstype='mse'):
def train(net, train_iter, test_iter, optimizer, scheduler, device, num_epochs, save_path='./checkpoints', losstype='mse'):
best = 0
net = net.to(device)
print("training on ", device)
if not os.path.exists(save_path):
os.mkdir(save_path)

if losstype == 'mse':
loss = torch.nn.MSELoss()
else:
Expand All @@ -35,32 +40,38 @@ def train(net, train_iter, test_iter, optimizer, scheduler, device, num_epochs,
learning_rate = param_group['lr']

losss = []
train_acc = []
dice = Dice(average='micro').to(device)
train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()
for X, y in train_iter:
tbar = tqdm(train_iter)
for X, y in tbar:
optimizer.zero_grad()
# X = X.to(device)
# y = y.to(device)
X = torch.ones(6, 8, 2, 128, 128).to(device)
y = torch.ones(8, 13, 128, 128).to(device)
X = X.to(device)
y = y.to(device)
# X = torch.ones(6, 8, 2, 128, 128).to(device)
# y = torch.ones(8, 13, 128, 128).to(device)
y_hat = net(X)
label = y
l = loss(y_hat, label)
l = loss(y,label)
losss.append(l.cpu().item())
l.backward()
optimizer.step()
train_l_sum += l.cpu().item()
# train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
with torch.no_grad():
acc = dice(y_hat.detach(), label.detach())
train_acc.append(acc.cpu().item())
n += y.shape[0]
batch_count += 1
scheduler.step()
tbar.set_description(f'Epoch {epoch}: Loss = {l.cpu().item(): .4f}, dice = {acc.item():.4f}')
scheduler.step()
test_acc = evaluate_accuracy(test_iter, net)
losses.append(np.mean(losss))
print('epoch %d, lr %.6f, loss %.6f, train acc %.6f, test acc %.6f, time %.1f sec'
% (epoch + 1, learning_rate, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))
% (epoch + 1, learning_rate, train_l_sum / batch_count, sum(train_acc) / len(train_acc), test_acc, time.time() - start))

if test_acc > best:
best = test_acc
torch.save(net.state_dict(), './checkpoints/Segment_SNN.pth')
torch.save(net.state_dict(), os.path.join(save_path, '/SegmentModel.pth'))

def estimate_dice(gt_msk, prt_msk):
intersection = gt_msk * prt_msk
Expand All @@ -71,19 +82,24 @@ def evaluate_accuracy(data_iter, net, device=None, only_onebatch=False):
if device is None and isinstance(net, torch.nn.Module):
device = list(net.parameters())[0].device
acc_sum, n = 0.0, 0
dice = Dice(average='micro').to(device)
net.eval()
with torch.no_grad():
for X, y in data_iter:
net.eval()
tbar = tqdm(data_iter)
for X, y in tbar:
logits = net(X.to(device))
softmax = nn.Softmax(dim=1)
logits = softmax(logits)
mask = torch.argmax(logits, dim=1)
acc_sum += estimate_dice(mask, y)
net.train()
y = y.to(device)
# softmax = nn.Softmax(dim=1)
# logits = softmax(logits)
acc = dice(logits, y.detach())
acc_sum += acc.cpu().item()
n += y.shape[0]
tbar.set_description(f'Validation acc = {acc: .4f}')

if only_onebatch: break
return acc_sum / n
if only_onebatch:
break
net.train()
return acc_sum / len(acc)


if __name__ == '__main__':
Expand Down

0 comments on commit 76f3be6

Please sign in to comment.