Skip to content

Commit

Permalink
refactored vqgan
Browse files Browse the repository at this point in the history
  • Loading branch information
FirasGit committed Sep 19, 2022
1 parent 329ba75 commit bebabb2
Show file tree
Hide file tree
Showing 12 changed files with 128 additions and 165 deletions.
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
"workbench.colorTheme": "Darcula"
"workbench.colorTheme": "Default Dark+"
}
3 changes: 1 addition & 2 deletions config/dataset/brats.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
name: BRATS
root_dir: ???
root_dir: /data/BraTS/BraTS 2020
image_channels: 1
train: True
img_type: flair
1 change: 0 additions & 1 deletion config/dataset/mrnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@ root_dir: /data/home/firas/Desktop/work/MR_Knie/Data/MRNet/MRNet-v1.0/
image_channels: 1
task: acl
plane: sagittal
split: train


35 changes: 35 additions & 0 deletions config/model/vq_gan_3d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
seed: 1234
batch_size: 30
num_workers: 30

gpus: 1
accumulate_grad_batches: 1
default_root_dir: /data/home/firas/Desktop/work/other_groups/medicaldiffusion/checkpoints/vq_gan/
default_root_dir_postfix:
resume_from_checkpoint:
max_steps: -1
precision: 16
gradient_clip_val: 1.0


embedding_dim: 256
n_codes: 2048
n_hiddens: 240
lr: 3e-4
downsample: [4, 4, 4]
disc_channels: 64
disc_layers: 3
discriminator_iter_start: 50000
disc_loss_type: hinge
image_gan_weight: 1.0
video_gan_weight: 1.0
l1_weight: 4.0
gan_feat_weight: 0.0
perceptual_weight: 0.0
i3d_feat: False
restart_thres: 1.0
no_random_restart: False
norm_type: group
padding_type: replicate


12 changes: 0 additions & 12 deletions dataset/brats.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,3 @@ def __getitem__(self, index):
1, sp_size, sp_size, sp_size)

return {'data': imageout}

@staticmethod
def add_data_specific_args(parent_parser):
parser = argparse.ArgumentParser(
parents=[parent_parser], add_help=False)
parser.add_argument('--root_dir', type=str,
default='/data/BraTS/BraTS 2020')
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--num_workers', type=int, default=15)
parser.add_argument('--image_channels', type=int, default=1)
parser.add_argument('--imgtype', type=str, default='flair')
return parser
15 changes: 2 additions & 13 deletions dataset/mrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import math
import argparse

# TODO: Normalize to -1 and 1


def reformat_label(label):
if label == 1:
Expand Down Expand Up @@ -133,16 +135,3 @@ def __getitem__(self, index):
data = array[self.plane]
data_org = array_org[self.plane]
return {'data': data, 'mean_org': data_org.mean(), 'std_org': data_org.std()}

@staticmethod
def add_data_specific_args(parent_parser):
parser = argparse.ArgumentParser(
parents=[parent_parser], add_help=False)
parser.add_argument('--root_dir', type=str,
default='/data/home/firas/Desktop/work/MR_Knie/Data/MRNet/MRNet-v1.0/')
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--image_channels', type=int, default=1)
parser.add_argument('--task', type=str, default='acl')
parser.add_argument('--plane', type=str, default='sagittal')
return parser
18 changes: 15 additions & 3 deletions train/dataset.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
from dataset import *
from torch.utils.data import WeightedRandomSampler


def get_dataset(cfg):
if cfg.dataset.name == 'MRNet':
return MRNetDataset(root_dir=cfg.dataset.root_dir, task=cfg.dataset.task, plane=cfg.dataset.plane, split=cfg.dataset.split)
train_dataset = MRNetDataset(
root_dir=cfg.dataset.root_dir, task=cfg.dataset.task, plane=cfg.dataset.plane, split='train')
val_dataset = MRNetDataset(root_dir=cfg.dataset.root_dir,
task=cfg.dataset.task, plane=cfg.dataset.plane, split='valid')
sampler = WeightedRandomSampler(
weights=train_dataset.sample_weight, num_samples=len(train_dataset.sample_weight))
return train_dataset, val_dataset, sampler
if cfg.dataset.name == 'BRATS':
return BRATSDataset(root_dir=cfg.dataset.root_dir, train=cfg.dataset.train, img_type=cfg.dataset.img_type)
train_dataset = BRATSDataset(
root_dir=cfg.dataset.root_dir, img_type=cfg.dataset.img_type, train=True)
val_dataset = BRATSDataset(
root_dir=cfg.dataset.root_dir, img_type=cfg.dataset.img_type, train=False)
sampler = None
return train_dataset, val_dataset, sampler
if cfg.dataset.name == 'ADNI':
return ADNIDataset()
raise NotImplementedError
raise ValueError(f'{cfg.dataset.name} Dataset is not available')
Empty file.
7 changes: 7 additions & 0 deletions train/scripts/train_vqgan.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# TO TRAIN:
# export PYTHONPATH=$PWD in previous folder
# NCCL_DEBUG=WARN PL_TORCH_DISTRIBUTED_BACKEND=gloo python train/train_vqgan.py --gpus 1 --default_root_dir /data/home/firas/Desktop/work/other_groups/vq_gan_3d/checkpoints/knee_mri --precision 16 --embedding_dim 256 --n_hiddens 16 --downsample 16 16 16 --num_workers 32 --gradient_clip_val 1.0 --lr 3e-4 --discriminator_iter_start 10000 --perceptual_weight 4 --image_gan_weight 1 --video_gan_weight 1 --gan_feat_weight 4 --batch_size 2 --n_codes 1024 --accumulate_grad_batches 1
# PL_TORCH_DISTRIBUTED_BACKEND=gloo CUDA_VISIBLE_DEVICES=1 python train/train_vqgan.py --gpus 1 --default_root_dir /data/home/firas/Desktop/work/other_groups/vq_gan_3d/checkpoints_generation/knee_mri_gen --precision 16 --embedding_dim 8 --n_hiddens 16 --downsample 8 8 8 --num_workers 32 --gradient_clip_val 1.0 --lr 3e-4 --discriminator_iter_start 10000 --perceptual_weight 4 --image_gan_weight 1 --video_gan_weight 1 --gan_feat_weight 4 --batch_size 2 --n_codes 16384 --accumulate_grad_batches 1
# https://github.com/Lightning-AI/lightning/issues/9641

# PL_TORCH_DISTRIBUTED_BACKEND=gloo CUDA_VISIBLE_DEVICES=1 python train/train_vqgan.py --gpus 1 --default_root_dir /data/home/firas/Desktop/work/other_groups/vq_gan_3d/checkpoints_brats/flair --precision 16 --embedding_dim 8 --n_hiddens 16 --downsample 2 2 2 --num_workers 32 --gradient_clip_val 1.0 --lr 3e-4 --discriminator_iter_start 10000 --perceptual_weight 4 --image_gan_weight 1 --video_gan_weight 1 --gan_feat_weight 4 --batch_size 2 --n_codes 16384 --accumulate_grad_batches 1 --dataset BRATS
2 changes: 1 addition & 1 deletion train/train_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def run(cfg: DictConfig):
# objective=cfg.objective
).cuda()

train_dataset = get_dataset(cfg)
train_dataset, *_ = get_dataset(cfg)

trainer = Trainer(
diffusion,
Expand Down
104 changes: 40 additions & 64 deletions train/train_vqgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,69 +4,36 @@
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader, WeightedRandomSampler
import torch
from ddpm.diffusion import default
from vq_gan_3d.model import VQGAN
from dataset import MRNetDataset, BRATSDataset, ADNIDataset
from train.callbacks import ImageLogger, VideoLogger
from train.dataset import get_dataset
import hydra
from omegaconf import DictConfig, open_dict

# TO TRAIN:
# export PYTHONPATH=$PWD in previous folder
# NCCL_DEBUG=WARN PL_TORCH_DISTRIBUTED_BACKEND=gloo python train/train_vqgan.py --gpus 1 --default_root_dir /data/home/firas/Desktop/work/other_groups/vq_gan_3d/checkpoints/knee_mri --precision 16 --embedding_dim 256 --n_hiddens 16 --downsample 16 16 16 --num_workers 32 --gradient_clip_val 1.0 --lr 3e-4 --discriminator_iter_start 10000 --perceptual_weight 4 --image_gan_weight 1 --video_gan_weight 1 --gan_feat_weight 4 --batch_size 2 --n_codes 1024 --accumulate_grad_batches 1
# PL_TORCH_DISTRIBUTED_BACKEND=gloo CUDA_VISIBLE_DEVICES=1 python train/train_vqgan.py --gpus 1 --default_root_dir /data/home/firas/Desktop/work/other_groups/vq_gan_3d/checkpoints_generation/knee_mri_gen --precision 16 --embedding_dim 8 --n_hiddens 16 --downsample 8 8 8 --num_workers 32 --gradient_clip_val 1.0 --lr 3e-4 --discriminator_iter_start 10000 --perceptual_weight 4 --image_gan_weight 1 --video_gan_weight 1 --gan_feat_weight 4 --batch_size 2 --n_codes 16384 --accumulate_grad_batches 1
# https://github.com/Lightning-AI/lightning/issues/9641

# PL_TORCH_DISTRIBUTED_BACKEND=gloo CUDA_VISIBLE_DEVICES=1 python train/train_vqgan.py --gpus 1 --default_root_dir /data/home/firas/Desktop/work/other_groups/vq_gan_3d/checkpoints_brats/flair --precision 16 --embedding_dim 8 --n_hiddens 16 --downsample 2 2 2 --num_workers 32 --gradient_clip_val 1.0 --lr 3e-4 --discriminator_iter_start 10000 --perceptual_weight 4 --image_gan_weight 1 --video_gan_weight 1 --gan_feat_weight 4 --batch_size 2 --n_codes 16384 --accumulate_grad_batches 1 --dataset BRATS
@hydra.main(config_path='../config', config_name='base_cfg')
def run(cfg: DictConfig):
pl.seed_everything(cfg.model.seed)

train_dataset, val_dataset, sampler = get_dataset(cfg)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=cfg.model.batch_size,
num_workers=cfg.model.num_workers, sampler=sampler)
val_dataloader = DataLoader(val_dataset, batch_size=cfg.model.batch_size,
shuffle=False, num_workers=cfg.model.num_workers)

def main():
DATASET = BRATSDataset
# automatically adjust learning rate
bs, base_lr, ngpu, accumulate = cfg.model.batch_size, cfg.model.lr, cfg.model.gpus, cfg.model.accumulate_grad_batches

pl.seed_everything(1234)

parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = VQGAN.add_model_specific_args(parser)
parser = DATASET.add_data_specific_args(parser)
args = parser.parse_args()

if args.dataset == 'MRNet':
train_dataset = MRNetDataset(
root_dir=args.root_dir, task=args.task, plane=args.plane, split='train')
sampler = WeightedRandomSampler(
weights=train_dataset.sample_weight, num_samples=len(train_dataset.sample_weight))
train_dataloader = DataLoader(
dataset=train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, sampler=sampler)
val_dataset = MRNetDataset(
root_dir=args.root_dir, task=args.task, plane=args.plane, split='valid')
val_dataloader = DataLoader(
dataset=val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
elif args.dataset == 'BRATS':
train_dataset = BRATSDataset(
root_dir=args.root_dir, imgtype=args.imgtype, train=True)
train_dataloader = DataLoader(
dataset=train_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
val_dataset = BRATSDataset(
root_dir=args.root_dir, imgtype=args.imgtype, train=False)
val_dataloader = DataLoader(
dataset=val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
elif args.dataset == 'ADNI':
train_dataset = MRNetDataset(
root_dir=args.root_dir, task=args.task, plane=args.plane, split='train')
sampler = WeightedRandomSampler(
weights=train_dataset.sample_weight, num_samples=len(train_dataset.sample_weight))
train_dataloader = DataLoader(
dataset=train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, sampler=sampler)
val_dataset = MRNetDataset(
root_dir=args.root_dir, task=args.task, plane=args.plane, split='valid')
val_dataloader = DataLoader(
dataset=val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

# automatically adjust learning rate
bs, base_lr, ngpu, accumulate = args.batch_size, args.lr, args.gpus, args.accumulate_grad_batches
args.lr = accumulate * (ngpu/8.) * (bs/4.) * base_lr
with open_dict(cfg):
cfg.model.lr = accumulate * (ngpu/8.) * (bs/4.) * base_lr
cfg.model.default_root_dir = os.path.join(
cfg.model.default_root_dir, cfg.dataset.name, cfg.model.default_root_dir_postfix)
print("Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus/8) * {} (batchsize/4) * {:.2e} (base_lr)".format(
args.lr, accumulate, ngpu/8, bs/4, base_lr))
cfg.model.lr, accumulate, ngpu/8, bs/4, base_lr))

model = VQGAN(args)
model = VQGAN(cfg)

callbacks = []
callbacks.append(ModelCheckpoint(monitor='val/recon_loss',
Expand All @@ -80,12 +47,8 @@ def main():
callbacks.append(VideoLogger(
batch_frequency=1500, max_videos=4, clamp=True))

kwargs = dict()
if args.gpus > 1:
kwargs = dict(accelerator='ddp', gpus=args.gpus)

# load the most recent checkpoint file
base_dir = os.path.join(args.default_root_dir, 'lightning_logs')
base_dir = os.path.join(cfg.model.default_root_dir, 'lightning_logs')
if os.path.exists(base_dir):
log_folder = ckpt_file = ''
version_id_used = step_used = 0
Expand All @@ -102,16 +65,29 @@ def main():
os.rename(os.path.join(ckpt_folder, fn),
os.path.join(ckpt_folder, ckpt_file))
if len(ckpt_file) > 0:
args.resume_from_checkpoint = os.path.join(
cfg.model.resume_from_checkpoint = os.path.join(
ckpt_folder, ckpt_file)
print('will start from the recent ckpt %s' %
args.resume_from_checkpoint)

trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks,
max_steps=args.max_steps, **kwargs)
cfg.model.resume_from_checkpoint)

accelerator = None
if cfg.model.gpus > 1:
accelerator = 'ddp'

trainer = pl.Trainer(
gpus=cfg.model.gpus,
accumulate_grad_batches=cfg.model.accumulate_grad_batches,
default_root_dir=cfg.model.default_root_dir,
resume_from_checkpoint=cfg.model.resume_from_checkpoint,
callbacks=callbacks,
max_steps=cfg.model.max_steps,
precision=cfg.model.precision,
gradient_clip_val=cfg.model.gradient_clip_val,
accelerator=accelerator,
)

trainer.fit(model, train_dataloader, val_dataloader)


if __name__ == '__main__':
main()
run()
Loading

0 comments on commit bebabb2

Please sign in to comment.