Skip to content

Commit

Permalink
Fixed diffusion part
Browse files Browse the repository at this point in the history
  • Loading branch information
FirasGit committed Sep 19, 2022
1 parent 9712589 commit 329ba75
Show file tree
Hide file tree
Showing 18 changed files with 136 additions and 178 deletions.
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"workbench.colorTheme": "Darcula"
}
3 changes: 3 additions & 0 deletions config/base_cfg.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
defaults:
- dataset: ???
- model: ???
Empty file added config/dataset/adni.yaml
Empty file.
5 changes: 5 additions & 0 deletions config/dataset/brats.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
name: BRATS
root_dir: ???
image_channels: 1
train: True
img_type: flair
8 changes: 8 additions & 0 deletions config/dataset/mrnet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name: MRNet
root_dir: /data/home/firas/Desktop/work/MR_Knie/Data/MRNet/MRNet-v1.0/
image_channels: 1
task: acl
plane: sagittal
split: train


27 changes: 27 additions & 0 deletions config/model/ddpm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
vqgan_ckpt: /data/home/firas/Desktop/work/other_groups/vq_gan_3d/checkpoints_generation/knee_mri_gen/lightning_logs/version_0/checkpoints/epoch=245-step=222000-train/recon_loss=0.81.ckpt
batch_size: 40
num_workers: 30
load_milestone: False
logger: wandb
objective: pred_x0
diffusion_img_size: 32
diffusion_depth_size: 4
diffusion_num_channels: 8
save_and_sample_every: 400
train_lr: 1e-4
timesteps: 300 # number of steps
sampling_timesteps: 250 # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
loss_type: l1 # L1 or L2
train_num_steps: 700000 # total training steps
gradient_accumulate_every: 2 # gradient accumulation steps
ema_decay: 0.995 # exponential moving average decay
amp: False # turn on mixed precision
num_sample_rows: 1
dim_muls: [1, 2, 4, 8]


# Has to be derived from VQ-GAN Latent Space
diffusion_img_size: 32
diffusion_depth_size: 4
diffusion_num_channels: 8

Empty file added config/model/vq_gan_3d.yaml
Empty file.
9 changes: 4 additions & 5 deletions dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import imp
from vq_gan_3d.dataset.breast_uka import BreastUKA
from vq_gan_3d.dataset.mrnet import MRNetDataset
from vq_gan_3d.dataset.brats import BRATSDataset
from vq_gan_3d.dataset.adni import ADNIDataset
from medicaldiffusion.dataset.breast_uka import BreastUKA
from medicaldiffusion.dataset.mrnet import MRNetDataset
from medicaldiffusion.dataset.brats import BRATSDataset
from medicaldiffusion.dataset.adni import ADNIDataset
2 changes: 2 additions & 0 deletions dataset/adni.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
""" Taken and adapted from https://github.com/cyclomon/3dbraingen """

import csv
import numpy as np
import torch
Expand Down
2 changes: 2 additions & 0 deletions dataset/brats.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
""" Taken and adapted from https://github.com/cyclomon/3dbraingen """

import csv
import numpy as np
import torch
Expand Down
1 change: 1 addition & 0 deletions ddpm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from diffusion import Unet3D, GaussianDiffusion, Trainer
6 changes: 4 additions & 2 deletions diffusion/video_diffusion_pytorch.py → ddpm/diffusion.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"Largely taken and adapted from https://github.com/lucidrains/video-diffusion-pytorch"

import math
import copy
import torch
Expand Down Expand Up @@ -957,7 +959,7 @@ class Trainer(object):
def __init__(
self,
diffusion_model,
args,
cfg,
folder=None,
dataset=None,
*,
Expand Down Expand Up @@ -993,7 +995,7 @@ def __init__(
channels = diffusion_model.channels
num_frames = diffusion_model.num_frames

self.args = args
self.cfg = cfg
if dataset:
self.ds = dataset
else:
Expand Down
1 change: 0 additions & 1 deletion diffusion/__init__.py

This file was deleted.

82 changes: 0 additions & 82 deletions diffusion/text.py

This file was deleted.

11 changes: 11 additions & 0 deletions train/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from dataset import *


def get_dataset(cfg):
if cfg.dataset.name == 'MRNet':
return MRNetDataset(root_dir=cfg.dataset.root_dir, task=cfg.dataset.task, plane=cfg.dataset.plane, split=cfg.dataset.split)
if cfg.dataset.name == 'BRATS':
return BRATSDataset(root_dir=cfg.dataset.root_dir, train=cfg.dataset.train, img_type=cfg.dataset.img_type)
if cfg.dataset.name == 'ADNI':
return ADNIDataset()
raise ValueError(f'{cfg.dataset.name} Dataset is not available')
64 changes: 64 additions & 0 deletions train/train_ddpm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from re import I
from ddpm import Unet3D, GaussianDiffusion, Trainer
from dataset import MRNetDataset, BRATSDataset
import argparse
import wandb
import hydra
from omegaconf import DictConfig, OmegaConf
from train.dataset import get_dataset


# NCCL_P2P_DISABLE=1 accelerate launch train/train_ddpm.py

@hydra.main(config_path='../config', config_name='base_cfg')
def run(cfg: DictConfig):
model = Unet3D(
dim=cfg.model.unet.diffusion_img_size,
dim_mults=cfg.model.unet.dim_mults,
channels=cfg.model.unet.diffusion_num_channels,
).cuda()

diffusion = GaussianDiffusion(
model,
vqgan_ckpt=cfg.model.vqgan_ckpt,
image_size=cfg.model.diffusion_img_size,
num_frames=cfg.model.diffusion_depth_size,
channels=cfg.model.diffusion_num_channels,
timesteps=cfg.model.timesteps,
# sampling_timesteps=cfg.model.sampling_timesteps,
loss_type=cfg.model.loss_type,
# objective=cfg.objective
).cuda()

train_dataset = get_dataset(cfg)

trainer = Trainer(
diffusion,
cfg=cfg,
dataset=train_dataset,
train_batch_size=cfg.model.batch_size,
save_and_sample_every=cfg.model.save_and_sample_every,
train_lr=cfg.model.train_lr,
train_num_steps=cfg.model.train_num_steps,
gradient_accumulate_every=cfg.model.gradient_accumulate_every,
ema_decay=cfg.model.ema_decay,
amp=cfg.model.amp,
num_sample_rows=cfg.model.num_sample_rows,
# logger=cfg.model.logger
)

if cfg.model.load_milestone:
trainer.load(cfg.model.load_milestone)

trainer.train()


if __name__ == '__main__':
run()

# wandb.finish()

# Incorporate GAN loss in DDPM training?
# Incorporate GAN loss in UNET segmentation?
# Maybe better if I don't use ema updates?
# Use with other vqgan latent space (the one with more channels?)
86 changes: 0 additions & 86 deletions train/train_video_diffusion.py

This file was deleted.

4 changes: 2 additions & 2 deletions train/train_vqgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from torch.utils.data import DataLoader, WeightedRandomSampler
import torch
from vq_gan_3d.model import VQGAN
from vq_gan_3d.dataset import MRNetDataset, BRATSDataset, ADNIDataset
from vq_gan_3d.train.callbacks import ImageLogger, VideoLogger
from dataset import MRNetDataset, BRATSDataset, ADNIDataset
from train.callbacks import ImageLogger, VideoLogger

# TO TRAIN:
# export PYTHONPATH=$PWD in previous folder
Expand Down

0 comments on commit 329ba75

Please sign in to comment.