Skip to content

Commit

Permalink
Added scripts for cond synthesis and Misc. Changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Kushagra Pandey authored and kpandey008 committed Oct 3, 2023
1 parent aa457e6 commit 4125ce2
Show file tree
Hide file tree
Showing 39 changed files with 1,365 additions and 2,916 deletions.
99 changes: 64 additions & 35 deletions main/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,27 +63,31 @@ def update_weights(
targ.mul_(self.tau).add_(src, alpha=1 - self.tau)


class ImageWriter(BasePredictionWriter):
# TODO: Add Support for saving momentum states
class SimpleImageWriter(BasePredictionWriter):
def __init__(
self,
output_dir,
write_interval,
n_steps=None,
eval_mode="sample",
conditional=True,
sample_prefix="",
path_prefix="",
save_mode="image",
is_norm=True,
is_augmented=True,
save_batch=False,
conditional=False
):
super().__init__(write_interval)
assert eval_mode in ["sample", "recons"]
self.output_dir = output_dir
self.n_steps = 1000 if n_steps is None else n_steps
self.eval_mode = eval_mode
self.conditional = conditional
self.sample_prefix = sample_prefix
self.path_prefix = path_prefix
self.is_norm = is_norm
self.is_augmented = is_augmented
self.save_fn = save_as_images if save_mode == "image" else save_as_np
self.save_batch = save_batch

def write_on_batch_end(
self,
Expand All @@ -96,68 +100,70 @@ def write_on_batch_end(
dataloader_idx,
):
rank = pl_module.global_rank
if self.conditional:
ddpm_samples_dict, vae_samples = prediction

if self.save_vae:
vae_samples = vae_samples.cpu()
vae_save_path = os.path.join(self.output_dir, "vae")
os.makedirs(vae_save_path, exist_ok=True)
self.save_fn(
vae_samples,
file_name=os.path.join(
vae_save_path,
f"output_vae_{self.sample_prefix}_{rank}_{batch_idx}",
),
denorm=self.is_norm,
)
else:
ddpm_samples_dict = prediction

# Write output images
# NOTE: We need to use gpu rank during saving to prevent
# processes from overwriting images
for k, ddpm_samples in ddpm_samples_dict.items():
ddpm_samples = ddpm_samples.cpu()
samples = prediction.cpu()

if self.is_augmented:
samples, _ = torch.chunk(samples, 2, dim=1)

# Setup dirs
base_save_path = os.path.join(self.output_dir, k)
img_save_path = os.path.join(base_save_path, "images")
os.makedirs(img_save_path, exist_ok=True)
# Setup dirs
if self.path_prefix != "":
base_save_path = os.path.join(self.output_dir, str(self.path_prefix))
else:
base_save_path = self.output_dir
img_save_path = os.path.join(base_save_path, "images")
os.makedirs(img_save_path, exist_ok=True)

# Save
# Save images
self.save_fn(
samples,
file_name=os.path.join(
img_save_path, f"output_{self.sample_prefix }_{rank}_{batch_idx}"
),
denorm=self.is_norm,
)

# Save batch
if self.save_batch:
batch_save_path = os.path.join(base_save_path, "batch")
os.makedirs(batch_save_path, exist_ok=True)
img = batch
self.save_fn(
ddpm_samples,
img,
file_name=os.path.join(
img_save_path, f"output_{self.sample_prefix }_{rank}_{batch_idx}"
batch_save_path, f"output_{self.sample_prefix }_{rank}_{batch_idx}"
),
denorm=self.is_norm,
)


class SimpleImageWriter(BasePredictionWriter):
class InpaintingImageWriter(BasePredictionWriter):
def __init__(
self,
output_dir,
write_interval,
eval_mode="sample",
conditional=True,
sample_prefix="",
path_prefix="",
save_mode="image",
is_norm=True,
is_augmented=True,
save_batch=False,
conditional=False
):
super().__init__(write_interval)
assert eval_mode in ["sample", "recons"]
self.output_dir = output_dir
self.eval_mode = eval_mode
self.conditional = conditional
self.sample_prefix = sample_prefix
self.path_prefix = path_prefix
self.is_norm = is_norm
self.is_augmented = is_augmented
self.save_fn = save_as_images if save_mode == "image" else save_as_np
self.save_batch = save_batch

def write_on_batch_end(
self,
Expand Down Expand Up @@ -187,11 +193,34 @@ def write_on_batch_end(
img_save_path = os.path.join(base_save_path, "images")
os.makedirs(img_save_path, exist_ok=True)

# Save
# Save images
self.save_fn(
samples,
file_name=os.path.join(
img_save_path, f"output_{self.sample_prefix }_{rank}_{batch_idx}"
),
denorm=self.is_norm,
)

# Save batch (For inpainting)
if self.save_batch:
batch_save_path = os.path.join(base_save_path, "batch")
corr_save_path = os.path.join(base_save_path, "corrupt")
os.makedirs(batch_save_path, exist_ok=True)
os.makedirs(corr_save_path, exist_ok=True)
img, mask = batch
img = img * 0.5 + 0.5
self.save_fn(
img * mask,
file_name=os.path.join(
corr_save_path, f"output_{self.sample_prefix }_{rank}_{batch_idx}"
),
denorm=False,
)
self.save_fn(
img,
file_name=os.path.join(
batch_save_path, f"output_{self.sample_prefix }_{rank}_{batch_idx}"
),
denorm=False,
)
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
diffusion:
data:
root: ???
name: "celebahq256"
image_size: 256
name: "afhqv2"
image_size: 128
hflip: True
num_channels: 3
norm: True
apply_cond: False
cond_size: 32

return_target: False
model:
pl_module: 'sde_wrapper'
score_fn:
Expand All @@ -17,7 +17,7 @@ diffusion:
out_ch: 6
nonlinearity: "swish"
nf : 128
ch_mult: [1,1,2,2,2,2,2]
ch_mult: [1,2,2,2,3]
num_res_blocks: 2
attn_resolutions: [16,]
dropout: 0.1
Expand All @@ -43,11 +43,10 @@ diffusion:
gamma: 0.01
kappa: 0.04
decomp_mode: "lower"
numerical_eps: 1e-6
numerical_eps: 1e-9
n_timesteps: 1000
is_augmented: True
use_ms: False

training:
seed: 0
continuous: True
Expand Down Expand Up @@ -81,7 +80,6 @@ diffusion:
results_dir: ???
workers: 1
chkpt_prefix: ""

evaluation:
# Sampler specific config goes here
sampler:
Expand All @@ -102,34 +100,85 @@ diffusion:
batch_size: 64
save_mode: image
sample_prefix: "gpu"

# VAE config used for VAE training
vae:
path_prefix: ""
clf:
data:
root: ???
name: "cifar10"
image_size: 32
n_channels: 3
hflip: False

name: afhqv2
image_size: 128
hflip: true
num_channels: 3
norm: true
apply_cond: false
cond_size: 16
return_target: true
model:
enc_block_config : "32x7,32d2,32t16,16x4,16d2,16t8,8x4,8d2,8t4,4x3,4d4,4t1,1x3"
enc_channel_config: "32:64,16:128,8:256,4:256,1:512"
dec_block_config: "1x1,1u4,1t4,4x2,4u2,4t8,8x3,8u2,8t16,16x7,16u2,16t32,32x15"
dec_channel_config: "32:64,16:128,8:256,4:256,1:512"

pl_module: tclf_wrapper
clf_fn:
name: ncsnpp_clf
in_ch: 6
nonlinearity: swish
nf: 128
ch_mult:
- 1
- 2
- 2
- 2
num_res_blocks: 2
attn_resolutions:
- 16
dropout: 0.1
resamp_with_conv: true
noise_cond: true
fir: false
fir_kernel:
- 1
- 3
- 3
- 1
skip_rescale: true
resblock_type: biggan
progressive: none
progressive_input: none
progressive_combine: sum
embedding_type: positional
init_scale: 0
fourier_scale: 16
n_cls: ???
training:
seed: 0
fp16: False
batch_size: 128
epochs: 1000
continuous: true
loss:
name: tce_loss
l_type: l2
reduce_mean: true
optimizer:
name: Adam
lr: 0.0002
beta_1: 0.9
beta_2: 0.999
weight_decay: 0
eps: 1e-8
warmup: 5000
fp16: false
batch_size: 32
epochs: 500
log_step: 1
device: "gpu:0"
accelerator: gpu
devices:
- 0
chkpt_interval: 1
optimizer: "Adam"
lr: 1e-4
restore_path: ""
results_dir: ???
workers: 2
workers: 1
chkpt_prefix: ""
alpha: 1.0
evaluation:
seed: 0
chkpt_path: ???
accelerator: gpu
devices:
- 0
workers: 1
batch_size: 64
clf_temp: 1.0
label_to_sample: 0
Loading

0 comments on commit 4125ce2

Please sign in to comment.