From afc06d74d032f84c953f21b8d09a3e8ea0681d67 Mon Sep 17 00:00:00 2001 From: wurenkai Date: Tue, 31 Jan 2023 17:42:11 +0800 Subject: [PATCH] remove --- engine.py | 156 ---------------------------- loader.py | 83 --------------- test.py | 97 ----------------- utils.py | 304 ------------------------------------------------------ 4 files changed, 640 deletions(-) delete mode 100644 engine.py delete mode 100644 loader.py delete mode 100644 test.py delete mode 100644 utils.py diff --git a/engine.py b/engine.py deleted file mode 100644 index b5ce8a5..0000000 --- a/engine.py +++ /dev/null @@ -1,156 +0,0 @@ -import numpy as np -from tqdm import tqdm -import torch -from torch.cuda.amp import autocast as autocast -from sklearn.metrics import confusion_matrix -from utils import save_imgs - - -def train_one_epoch(train_loader, - model, - criterion, - optimizer, - scheduler, - epoch, - logger, - config, - scaler=None): - ''' - train model for one epoch - ''' - # switch to train mode - model.train() - - loss_list = [] - - for iter, data in enumerate(train_loader): - optimizer.zero_grad() - images, targets = data - images, targets = images.cuda(non_blocking=True).float(), targets.cuda(non_blocking=True).float() - if config.amp: - with autocast(): - out = model(images) - loss = criterion(out, targets) - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - else: - out = model(images) - loss = criterion(out, targets) - loss.backward() - optimizer.step() - - loss_list.append(loss.item()) - - now_lr = optimizer.state_dict()['param_groups'][0]['lr'] - if iter % config.print_interval == 0: - log_info = f'train: epoch {epoch}, iter:{iter}, loss: {np.mean(loss_list):.4f}, lr: {now_lr}' - print(log_info) - logger.info(log_info) - scheduler.step() - - -def val_one_epoch(test_loader, - model, - criterion, - epoch, - logger, - config): - # switch to evaluate mode - model.eval() - preds = [] - gts = [] - loss_list = [] - with torch.no_grad(): - for data in tqdm(test_loader): - img, msk = data - img, msk = img.cuda(non_blocking=True).float(), msk.cuda(non_blocking=True).float() - out = model(img) - loss = criterion(out, msk) - loss_list.append(loss.item()) - gts.append(msk.squeeze(1).cpu().detach().numpy()) - if type(out) is tuple: - out = out[0] - out = out.squeeze(1).cpu().detach().numpy() - preds.append(out) - - if epoch % config.val_interval == 0: - preds = np.array(preds).reshape(-1) - gts = np.array(gts).reshape(-1) - - y_pre = np.where(preds>=config.threshold, 1, 0) - y_true = np.where(gts>=0.5, 1, 0) - - confusion = confusion_matrix(y_true, y_pre) - TN, FP, FN, TP = confusion[0,0], confusion[0,1], confusion[1,0], confusion[1,1] - - accuracy = float(TN + TP) / float(np.sum(confusion)) if float(np.sum(confusion)) != 0 else 0 - sensitivity = float(TP) / float(TP + FN) if float(TP + FN) != 0 else 0 - specificity = float(TN) / float(TN + FP) if float(TN + FP) != 0 else 0 - f1_or_dsc = float(2 * TP) / float(2 * TP + FP + FN) if float(2 * TP + FP + FN) != 0 else 0 - miou = float(TP) / float(TP + FP + FN) if float(TP + FP + FN) != 0 else 0 - - log_info = f'val epoch: {epoch}, loss: {np.mean(loss_list):.4f}, miou: {miou}, f1_or_dsc: {f1_or_dsc}, accuracy: {accuracy}, \ - specificity: {specificity}, sensitivity: {sensitivity}, confusion_matrix: {confusion}' - print(log_info) - logger.info(log_info) - - else: - log_info = f'val epoch: {epoch}, loss: {np.mean(loss_list):.4f}' - print(log_info) - logger.info(log_info) - - return np.mean(loss_list) - - -def test_one_epoch(test_loader, - model, - criterion, - logger, - config, - test_data_name=None): - # switch to evaluate mode - model.eval() - preds = [] - gts = [] - loss_list = [] - with torch.no_grad(): - for i, data in enumerate(tqdm(test_loader)): - img, msk = data - img, msk = img.cuda(non_blocking=True).float(), msk.cuda(non_blocking=True).float() - out = model(img) - loss = criterion(out, msk) - loss_list.append(loss.item()) - msk = msk.squeeze(1).cpu().detach().numpy() - gts.append(msk) - if type(out) is tuple: - out = out[0] - out = out.squeeze(1).cpu().detach().numpy() - preds.append(out) - save_imgs(img, msk, out, i, config.work_dir + 'outputs/', config.datasets, config.threshold, test_data_name=test_data_name) - - preds = np.array(preds).reshape(-1) - gts = np.array(gts).reshape(-1) - - y_pre = np.where(preds>=config.threshold, 1, 0) - y_true = np.where(gts>=0.5, 1, 0) - - confusion = confusion_matrix(y_true, y_pre) - TN, FP, FN, TP = confusion[0,0], confusion[0,1], confusion[1,0], confusion[1,1] - - accuracy = float(TN + TP) / float(np.sum(confusion)) if float(np.sum(confusion)) != 0 else 0 - sensitivity = float(TP) / float(TP + FN) if float(TP + FN) != 0 else 0 - specificity = float(TN) / float(TN + FP) if float(TN + FP) != 0 else 0 - f1_or_dsc = float(2 * TP) / float(2 * TP + FP + FN) if float(2 * TP + FP + FN) != 0 else 0 - miou = float(TP) / float(TP + FP + FN) if float(TP + FP + FN) != 0 else 0 - - if test_data_name is not None: - log_info = f'test_datasets_name: {test_data_name}' - print(log_info) - logger.info(log_info) - log_info = f'test of best model, loss: {np.mean(loss_list):.4f},miou: {miou}, f1_or_dsc: {f1_or_dsc}, accuracy: {accuracy}, \ - specificity: {specificity}, sensitivity: {sensitivity}, confusion_matrix: {confusion}' - print(log_info) - logger.info(log_info) - - return np.mean(loss_list) diff --git a/loader.py b/loader.py deleted file mode 100644 index 19b3943..0000000 --- a/loader.py +++ /dev/null @@ -1,83 +0,0 @@ -from torch.utils.data import Dataset, DataLoader -import torch -import numpy as np -import random -import os -from PIL import Image -from einops.layers.torch import Rearrange -from scipy.ndimage.morphology import binary_dilation -from torch.utils.data import Dataset -from torchvision import transforms -from scipy import ndimage -from utils import * - - -# ===== normalize over the dataset -def dataset_normalized(imgs): - imgs_normalized = np.empty(imgs.shape) - imgs_std = np.std(imgs) - imgs_mean = np.mean(imgs) - imgs_normalized = (imgs-imgs_mean)/imgs_std - for i in range(imgs.shape[0]): - imgs_normalized[i] = ((imgs_normalized[i] - np.min(imgs_normalized[i])) / (np.max(imgs_normalized[i])-np.min(imgs_normalized[i])))*255 - return imgs_normalized - - -## Temporary -class isic_loader(Dataset): - """ dataset class for Brats datasets - """ - def __init__(self, path_Data, train = True, Test = False): - super(isic_loader, self) - self.train = train - if train: - self.data = np.load(path_Data+'data_train.npy') - self.mask = np.load(path_Data+'mask_train.npy') - else: - if Test: - self.data = np.load(path_Data+'data_test.npy') - self.mask = np.load(path_Data+'mask_test.npy') - else: - self.data = np.load(path_Data+'data_val.npy') - self.mask = np.load(path_Data+'mask_val.npy') - - self.data = dataset_normalized(self.data) - self.mask = np.expand_dims(self.mask, axis=3) - self.mask = self.mask/255. - - def __getitem__(self, indx): - img = self.data[indx] - seg = self.mask[indx] - if self.train: - if random.random() > 0.5: - img, seg = self.random_rot_flip(img, seg) - if random.random() > 0.5: - img, seg = self.random_rotate(img, seg) - - seg = torch.tensor(seg.copy()) - img = torch.tensor(img.copy()) - img = img.permute( 2, 0, 1) - seg = seg.permute( 2, 0, 1) - - return img, seg - - def random_rot_flip(self,image, label): - k = np.random.randint(0, 4) - image = np.rot90(image, k) - label = np.rot90(label, k) - axis = np.random.randint(0, 2) - image = np.flip(image, axis=axis).copy() - label = np.flip(label, axis=axis).copy() - return image, label - - def random_rotate(self,image, label): - angle = np.random.randint(-360, 360) - image = ndimage.rotate(image, angle, order=0, reshape=False) - label = ndimage.rotate(label, angle, order=0, reshape=False) - return image, label - - - - def __len__(self): - return len(self.data) - \ No newline at end of file diff --git a/test.py b/test.py deleted file mode 100644 index 1670cf5..0000000 --- a/test.py +++ /dev/null @@ -1,97 +0,0 @@ -import torch -from torch import nn -from torch.cuda.amp import autocast, GradScaler -from torch.utils.data import DataLoader -from loader import * - -from models.MHorUNet import MHorunet -from engine import * -import os -import sys -os.environ["CUDA_VISIBLE_DEVICES"] = "0" # "0, 1, 2, 3" - -from utils import * -from configs.config_setting import setting_config - -import warnings -warnings.filterwarnings("ignore") - - - -def main(config): - - print('#----------Creating logger----------#') - sys.path.append(config.work_dir + '/') - log_dir = os.path.join(config.work_dir, 'log') - checkpoint_dir = os.path.join(config.work_dir, 'checkpoints') - resume_model = os.path.join('/root/MHorUNet/', 'best.pth') - outputs = os.path.join(config.work_dir, 'outputs') - if not os.path.exists(checkpoint_dir): - os.makedirs(checkpoint_dir) - if not os.path.exists(outputs): - os.makedirs(outputs) - - global logger - logger = get_logger('test', log_dir) - - log_config_info(config, logger) - - - - - - print('#----------GPU init----------#') - set_seed(config.seed) - gpu_ids = [0]# [0, 1, 2, 3] - torch.cuda.empty_cache() - - - - - - - print('#----------Preparing dataset----------#') - test_dataset = isic_loader(path_Data = config.data_path, train = False, Test = True) - test_loader = DataLoader(test_dataset, - batch_size=1, - shuffle=False, - pin_memory=True, - num_workers=config.num_workers, - drop_last=True) - - print('#----------Prepareing Models----------#') - model_cfg = config.model_config - model = MHorunet(num_classes=model_cfg['num_classes'], - input_channels=model_cfg['input_channels'], - c_list=model_cfg['c_list'], - split_att=model_cfg['split_att'], - bridge=model_cfg['bridge'], - drop_path_rate=model_cfg['drop_path_rate']) - - model = torch.nn.DataParallel(model.cuda(), device_ids=gpu_ids, output_device=gpu_ids[0]) - - - - print('#----------Prepareing loss, opt, sch and amp----------#') - criterion = config.criterion - optimizer = get_optimizer(config, model) - scheduler = get_scheduler(config, optimizer) - scaler = GradScaler() - - - print('#----------Testing----------#') - best_weight = torch.load(resume_model, map_location=torch.device('cpu')) - model.module.load_state_dict(best_weight) - loss = test_one_epoch( - test_loader, - model, - criterion, - logger, - config, - ) - - - -if __name__ == '__main__': - config = setting_config - main(config) \ No newline at end of file diff --git a/utils.py b/utils.py deleted file mode 100644 index a66da70..0000000 --- a/utils.py +++ /dev/null @@ -1,304 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.backends.cudnn as cudnn -import torchvision.transforms.functional as TF -import numpy as np -import os -import math -import random -import logging -import logging.handlers -from matplotlib import pyplot as plt - - -def set_seed(seed): - # for hash - os.environ['PYTHONHASHSEED'] = str(seed) - # for python and numpy - random.seed(seed) - np.random.seed(seed) - # for cpu gpu - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - # for cudnn - cudnn.benchmark = False - cudnn.deterministic = True - - -def get_logger(name, log_dir): - ''' - Args: - name(str): name of logger - log_dir(str): path of log - ''' - - if not os.path.exists(log_dir): - os.makedirs(log_dir) - - logger = logging.getLogger(name) - logger.setLevel(logging.INFO) - - info_name = os.path.join(log_dir, '{}.info.log'.format(name)) - info_handler = logging.handlers.TimedRotatingFileHandler(info_name, - when='D', - encoding='utf-8') - info_handler.setLevel(logging.INFO) - - formatter = logging.Formatter('%(asctime)s - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S') - - info_handler.setFormatter(formatter) - - logger.addHandler(info_handler) - - return logger - - -def log_config_info(config, logger): - config_dict = config.__dict__ - log_info = f'#----------Config info----------#' - logger.info(log_info) - for k, v in config_dict.items(): - if k[0] == '_': - continue - else: - log_info = f'{k}: {v},' - logger.info(log_info) - - - -def get_optimizer(config, model): - assert config.opt in ['Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'ASGD', 'RMSprop', 'Rprop', 'SGD'], 'Unsupported optimizer!' - - if config.opt == 'Adadelta': - return torch.optim.Adadelta( - model.parameters(), - lr = config.lr, - rho = config.rho, - eps = config.eps, - weight_decay = config.weight_decay - ) - elif config.opt == 'Adagrad': - return torch.optim.Adagrad( - model.parameters(), - lr = config.lr, - lr_decay = config.lr_decay, - eps = config.eps, - weight_decay = config.weight_decay - ) - elif config.opt == 'Adam': - return torch.optim.Adam( - model.parameters(), - lr = config.lr, - betas = config.betas, - eps = config.eps, - weight_decay = config.weight_decay, - amsgrad = config.amsgrad - ) - elif config.opt == 'AdamW': - return torch.optim.AdamW( - model.parameters(), - lr = config.lr, - betas = config.betas, - eps = config.eps, - weight_decay = config.weight_decay, - amsgrad = config.amsgrad - ) - elif config.opt == 'Adamax': - return torch.optim.Adamax( - model.parameters(), - lr = config.lr, - betas = config.betas, - eps = config.eps, - weight_decay = config.weight_decay - ) - elif config.opt == 'ASGD': - return torch.optim.ASGD( - model.parameters(), - lr = config.lr, - lambd = config.lambd, - alpha = config.alpha, - t0 = config.t0, - weight_decay = config.weight_decay - ) - elif config.opt == 'RMSprop': - return torch.optim.RMSprop( - model.parameters(), - lr = config.lr, - momentum = config.momentum, - alpha = config.alpha, - eps = config.eps, - centered = config.centered, - weight_decay = config.weight_decay - ) - elif config.opt == 'Rprop': - return torch.optim.Rprop( - model.parameters(), - lr = config.lr, - etas = config.etas, - step_sizes = config.step_sizes, - ) - elif config.opt == 'SGD': - return torch.optim.SGD( - model.parameters(), - lr = config.lr, - momentum = config.momentum, - weight_decay = config.weight_decay, - dampening = config.dampening, - nesterov = config.nesterov - ) - else: # default opt is SGD - return torch.optim.SGD( - model.parameters(), - lr = 0.01, - momentum = 0.9, - weight_decay = 0.05, - ) - - - -def get_scheduler(config, optimizer): - assert config.sch in ['StepLR', 'MultiStepLR', 'ExponentialLR', 'CosineAnnealingLR', 'ReduceLROnPlateau', - 'CosineAnnealingWarmRestarts', 'WP_MultiStepLR', 'WP_CosineLR'], 'Unsupported scheduler!' - if config.sch == 'StepLR': - scheduler = torch.optim.lr_scheduler.StepLR( - optimizer, - step_size = config.step_size, - gamma = config.gamma, - last_epoch = config.last_epoch - ) - elif config.sch == 'MultiStepLR': - scheduler = torch.optim.lr_scheduler.MultiStepLR( - optimizer, - milestones = config.milestones, - gamma = config.gamma, - last_epoch = config.last_epoch - ) - elif config.sch == 'ExponentialLR': - scheduler = torch.optim.lr_scheduler.ExponentialLR( - optimizer, - gamma = config.gamma, - last_epoch = config.last_epoch - ) - elif config.sch == 'CosineAnnealingLR': - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer, - T_max = config.T_max, - eta_min = config.eta_min, - last_epoch = config.last_epoch - ) - elif config.sch == 'ReduceLROnPlateau': - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, - mode = config.mode, - factor = config.factor, - patience = config.patience, - threshold = config.threshold, - threshold_mode = config.threshold_mode, - cooldown = config.cooldown, - min_lr = config.min_lr, - eps = config.eps - ) - elif config.sch == 'CosineAnnealingWarmRestarts': - scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( - optimizer, - T_0 = config.T_0, - T_mult = config.T_mult, - eta_min = config.eta_min, - last_epoch = config.last_epoch - ) - elif config.sch == 'WP_MultiStepLR': - lr_func = lambda epoch: epoch / config.warm_up_epochs if epoch <= config.warm_up_epochs else config.gamma**len( - [m for m in config.milestones if m <= epoch]) - scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_func) - elif config.sch == 'WP_CosineLR': - lr_func = lambda epoch: epoch / config.warm_up_epochs if epoch <= config.warm_up_epochs else 0.5 * ( - math.cos((epoch - config.warm_up_epochs) / (config.epochs - config.warm_up_epochs) * math.pi) + 1) - scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_func) - - return scheduler - - - -def save_imgs(img, msk, msk_pred, i, save_path, datasets, threshold=0.5, test_data_name=None): - img = img.squeeze(0).permute(1,2,0).detach().cpu().numpy() - img = img / 255. if img.max() > 1.1 else img - if datasets == 'retinal': - msk = np.squeeze(msk, axis=0) - msk_pred = np.squeeze(msk_pred, axis=0) - else: - msk = np.where(np.squeeze(msk, axis=0) > 0.5, 1, 0) - msk_pred = np.where(np.squeeze(msk_pred, axis=0) > threshold, 1, 0) - - plt.figure(figsize=(7,15)) - - plt.subplot(3,1,1) - plt.imshow(img) - plt.axis('off') - - plt.subplot(3,1,2) - plt.imshow(msk, cmap= 'gray') - plt.axis('off') - - plt.subplot(3,1,3) - plt.imshow(msk_pred, cmap = 'gray') - plt.axis('off') - - if test_data_name is not None: - save_path = save_path + test_data_name + '_' - plt.savefig(save_path + str(i) +'.png') - plt.close() - - - -class BCELoss(nn.Module): - def __init__(self): - super(BCELoss, self).__init__() - self.bceloss = nn.BCELoss() - - def forward(self, pred, target): - size = pred.size(0) - pred_ = pred.view(size, -1) - target_ = target.view(size, -1) - - return self.bceloss(pred_, target_) - - -class DiceLoss(nn.Module): - def __init__(self): - super(DiceLoss, self).__init__() - - def forward(self, pred, target): - smooth = 1 - size = pred.size(0) - - pred_ = pred.view(size, -1) - target_ = target.view(size, -1) - intersection = pred_ * target_ - dice_score = (2 * intersection.sum(1) + smooth)/(pred_.sum(1) + target_.sum(1) + smooth) - dice_loss = 1 - dice_score.sum()/size - - return dice_loss - - -class BceDiceLoss(nn.Module): - def __init__(self, wb=1, wd=1): - super(BceDiceLoss, self).__init__() - self.bce = BCELoss() - self.dice = DiceLoss() - self.wb = wb - self.wd = wd - - def forward(self, pred, target): - bceloss = self.bce(pred, target) - diceloss = self.dice(pred, target) - - loss = self.wd * diceloss + self.wb * bceloss - return loss - - - - - \ No newline at end of file