forked from FirasGit/medicaldiffusion
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
128 additions
and
165 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
{ | ||
"workbench.colorTheme": "Darcula" | ||
"workbench.colorTheme": "Default Dark+" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,4 @@ | ||
name: BRATS | ||
root_dir: ??? | ||
root_dir: /data/BraTS/BraTS 2020 | ||
image_channels: 1 | ||
train: True | ||
img_type: flair |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
seed: 1234 | ||
batch_size: 30 | ||
num_workers: 30 | ||
|
||
gpus: 1 | ||
accumulate_grad_batches: 1 | ||
default_root_dir: /data/home/firas/Desktop/work/other_groups/medicaldiffusion/checkpoints/vq_gan/ | ||
default_root_dir_postfix: | ||
resume_from_checkpoint: | ||
max_steps: -1 | ||
precision: 16 | ||
gradient_clip_val: 1.0 | ||
|
||
|
||
embedding_dim: 256 | ||
n_codes: 2048 | ||
n_hiddens: 240 | ||
lr: 3e-4 | ||
downsample: [4, 4, 4] | ||
disc_channels: 64 | ||
disc_layers: 3 | ||
discriminator_iter_start: 50000 | ||
disc_loss_type: hinge | ||
image_gan_weight: 1.0 | ||
video_gan_weight: 1.0 | ||
l1_weight: 4.0 | ||
gan_feat_weight: 0.0 | ||
perceptual_weight: 0.0 | ||
i3d_feat: False | ||
restart_thres: 1.0 | ||
no_random_restart: False | ||
norm_type: group | ||
padding_type: replicate | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,23 @@ | ||
from dataset import * | ||
from torch.utils.data import WeightedRandomSampler | ||
|
||
|
||
def get_dataset(cfg): | ||
if cfg.dataset.name == 'MRNet': | ||
return MRNetDataset(root_dir=cfg.dataset.root_dir, task=cfg.dataset.task, plane=cfg.dataset.plane, split=cfg.dataset.split) | ||
train_dataset = MRNetDataset( | ||
root_dir=cfg.dataset.root_dir, task=cfg.dataset.task, plane=cfg.dataset.plane, split='train') | ||
val_dataset = MRNetDataset(root_dir=cfg.dataset.root_dir, | ||
task=cfg.dataset.task, plane=cfg.dataset.plane, split='valid') | ||
sampler = WeightedRandomSampler( | ||
weights=train_dataset.sample_weight, num_samples=len(train_dataset.sample_weight)) | ||
return train_dataset, val_dataset, sampler | ||
if cfg.dataset.name == 'BRATS': | ||
return BRATSDataset(root_dir=cfg.dataset.root_dir, train=cfg.dataset.train, img_type=cfg.dataset.img_type) | ||
train_dataset = BRATSDataset( | ||
root_dir=cfg.dataset.root_dir, img_type=cfg.dataset.img_type, train=True) | ||
val_dataset = BRATSDataset( | ||
root_dir=cfg.dataset.root_dir, img_type=cfg.dataset.img_type, train=False) | ||
sampler = None | ||
return train_dataset, val_dataset, sampler | ||
if cfg.dataset.name == 'ADNI': | ||
return ADNIDataset() | ||
raise NotImplementedError | ||
raise ValueError(f'{cfg.dataset.name} Dataset is not available') |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# TO TRAIN: | ||
# export PYTHONPATH=$PWD in previous folder | ||
# NCCL_DEBUG=WARN PL_TORCH_DISTRIBUTED_BACKEND=gloo python train/train_vqgan.py --gpus 1 --default_root_dir /data/home/firas/Desktop/work/other_groups/vq_gan_3d/checkpoints/knee_mri --precision 16 --embedding_dim 256 --n_hiddens 16 --downsample 16 16 16 --num_workers 32 --gradient_clip_val 1.0 --lr 3e-4 --discriminator_iter_start 10000 --perceptual_weight 4 --image_gan_weight 1 --video_gan_weight 1 --gan_feat_weight 4 --batch_size 2 --n_codes 1024 --accumulate_grad_batches 1 | ||
# PL_TORCH_DISTRIBUTED_BACKEND=gloo CUDA_VISIBLE_DEVICES=1 python train/train_vqgan.py --gpus 1 --default_root_dir /data/home/firas/Desktop/work/other_groups/vq_gan_3d/checkpoints_generation/knee_mri_gen --precision 16 --embedding_dim 8 --n_hiddens 16 --downsample 8 8 8 --num_workers 32 --gradient_clip_val 1.0 --lr 3e-4 --discriminator_iter_start 10000 --perceptual_weight 4 --image_gan_weight 1 --video_gan_weight 1 --gan_feat_weight 4 --batch_size 2 --n_codes 16384 --accumulate_grad_batches 1 | ||
# https://github.com/Lightning-AI/lightning/issues/9641 | ||
|
||
# PL_TORCH_DISTRIBUTED_BACKEND=gloo CUDA_VISIBLE_DEVICES=1 python train/train_vqgan.py --gpus 1 --default_root_dir /data/home/firas/Desktop/work/other_groups/vq_gan_3d/checkpoints_brats/flair --precision 16 --embedding_dim 8 --n_hiddens 16 --downsample 2 2 2 --num_workers 32 --gradient_clip_val 1.0 --lr 3e-4 --discriminator_iter_start 10000 --perceptual_weight 4 --image_gan_weight 1 --video_gan_weight 1 --gan_feat_weight 4 --batch_size 2 --n_codes 16384 --accumulate_grad_batches 1 --dataset BRATS |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.