Skip to content

Commit

Permalink
add visual process in ddpm diffusion
Browse files Browse the repository at this point in the history
Janspiry committed Aug 6, 2021
1 parent b67b002 commit d068920
Showing 4 changed files with 25 additions and 13 deletions.
7 changes: 4 additions & 3 deletions config/basic_ddpm.json
Original file line number Diff line number Diff line change
@@ -42,15 +42,16 @@
"inner_channel": 64,
"channel_multiplier": [
1,
1,
2,
2,
4,
8,
8,
4
],
"attn_res": [
16
],
"res_blocks": 3,
"res_blocks": 2,
"dropout": 0.2
},
"beta_schedule": {
6 changes: 3 additions & 3 deletions data/LRHR_dataset.py
Original file line number Diff line number Diff line change
@@ -32,9 +32,9 @@ def __len__(self):

def AugmentWithTransform(self, img_list, hflip=True, rot=False):
# horizontal flip OR rotate
hflip = hflip and (self.split == 'train' and random.random() < 1)
vflip = rot and (self.split == 'train' and random.random() < 1)
rot90 = rot and (self.split == 'train' and random.random() < 1)
hflip = hflip and (self.split == 'train' and random.random() < 0.5)
vflip = rot and (self.split == 'train' and random.random() < 0.5)
rot90 = rot and (self.split == 'train' and random.random() < 0.5)

def _augment(img):
if hflip:
23 changes: 17 additions & 6 deletions model/ddpm_modules/diffusion.py
Original file line number Diff line number Diff line change
@@ -190,35 +190,46 @@ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False, condition_x=Non
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise

@torch.no_grad()
def p_sample_loop(self, x_in):
def p_sample_loop(self, x_in, continous=False):
device = self.betas.device
sample_inter = self.num_timesteps//10

if not self.conditional:
shape = x_in
b = shape[0]
img = torch.randn(shape, device=device)
ret_img = img
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
img = self.p_sample(img, torch.full(
(b,), i, device=device, dtype=torch.long))
if i % sample_inter == 0:
ret_img = torch.cat([ret_img, img], dim=0)
return img
else:
x = x_in
shape = x.shape
b = shape[0]
img = torch.randn(shape, device=device)
ret_img = x
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
img = self.p_sample(img, torch.full(
(b,), i, device=device, dtype=torch.long), condition_x=x)
return img
if i % sample_inter == 0:
ret_img = torch.cat([ret_img, img], dim=0)
if continous:
return ret_img
else:
return ret_img[-1]

@torch.no_grad()
def sample(self, batch_size=16):
def sample(self, batch_size=16, continous=False):
image_size = self.image_size
channels = self.channels
return self.p_sample_loop((batch_size, channels, image_size, image_size))
return self.p_sample_loop((batch_size, channels, image_size, image_size), continous)

@torch.no_grad()
def super_resolution(self, x_in):
return self.p_sample_loop(x_in)
def super_resolution(self, x_in, continous=False):
return self.p_sample_loop(x_in, continous)

@torch.no_grad()
def interpolate(self, x1, x2, t=None, lam=0.5):
2 changes: 1 addition & 1 deletion run.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default='config/basic_sr3.json',
parser.add_argument('-c', '--config', type=str, default='config/basic_ddpm.json',
help='JSON file for configuration')
parser.add_argument('-p', '--phase', type=str, choices=['train', 'val'],
help='Run either train(training) or val(generation)', default='val')

0 comments on commit d068920

Please sign in to comment.