Skip to content

Commit

Permalink
fix bug of wgen and add no_use_fbn
Browse files Browse the repository at this point in the history
  • Loading branch information
hao-pt committed Nov 5, 2022
1 parent 4a73169 commit 6dc17cc
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 11 deletions.
6 changes: 3 additions & 3 deletions score_sde/models/discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,6 @@ def forward(self, input, t_emb):

out = self.act(out)

skip = self.skip(input)

if self.downsample:
# outLL, outH = self.dwt(out)
Expand All @@ -337,13 +336,14 @@ def forward(self, input, t_emb):

# inputLL, inputH = self.dwt(input)
# inputLH, inputHL, inputHH = torch.unbind(inputH[0], dim=2)
skipLL, skipLH, skipHL, skipHH = self.dwt(skip)
inputLL, inputLH, inputHL, inputHH = self.dwt(input)

# input = (inputLL + inputLH + inputHL + inputHH) / (2. * 4.)
# skip = torch.cat((skipLL, skipLH, skipHL, skipHH), dim=1) / 2.
skip = skipLL / 2.
input = inputLL / 2.

out = self.conv2(out)
skip = self.skip(input) # new
out = (out + skip) / np.sqrt(2)


Expand Down
24 changes: 16 additions & 8 deletions score_sde/models/ncsnpp_generator_adagn.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,6 @@ def __init__(self, config):
mapping_layers.append(self.act)
self.z_transform = nn.Sequential(*mapping_layers)


def forward(self, x, time_cond, z): # return_mid=False
# patchify
x = rearrange(x, "n c (h p1) (w p2) -> n (p1 p2 c) h w", p1=self.patch_size, p2=self.patch_size)
Expand Down Expand Up @@ -695,6 +694,8 @@ def __init__(self, config):
self.dwt = DWT_2D("haar")
self.iwt = IDWT_2D("haar")

self.no_use_fbn = getattr(self.config, "no_use_fbn", False)

def forward(self, x, time_cond, z): # return_mid=False
# patchify
x = rearrange(x, "n c (h p1) (w p2) -> n (p1 p2 c) h w", p1=self.patch_size, p2=self.patch_size)
Expand Down Expand Up @@ -786,23 +787,30 @@ def forward(self, x, time_cond, z): # return_mid=False
# h = modules[m_idx](h/2., temb, zemb)
# h = self.iwt((h*2., hH))

h, hlh, hhl, hhh = self.dwt(h)
h = modules[m_idx](h/2., temb, zemb)
h = self.iwt(h*2., hlh, hhl, hhh)
if self.no_use_fbn:
h = modules[m_idx](h, temb, zemb)
else:
h, hlh, hhl, hhh = self.dwt(h)
h = modules[m_idx](h/2., temb, zemb)
h = self.iwt(h*2., hlh, hhl, hhh)
m_idx += 1

# attn block
h = modules[m_idx](h)
m_idx += 1

h = modules[m_idx](h, temb, zemb)
# h = modules[m_idx](h, temb, zemb)

# h, hH = self.dwt(h)
# h = modules[m_idx](h/2., temb, zemb)
# h = self.iwt((h*2., hH))

h, hlh, hhl, hhh = self.dwt(h)
h = modules[m_idx](h/2., temb, zemb)
h = self.iwt(h*2., hlh, hhl, hhh)
if self.no_use_fbn:
h = modules[m_idx](h, temb, zemb)
else:
h, hlh, hhl, hhh = self.dwt(h)
h = modules[m_idx](h/2., temb, zemb)
h = self.iwt(h*2., hlh, hhl, hhh)
m_idx += 1

mid_out = h
Expand Down
1 change: 1 addition & 0 deletions test_wddgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def sample_and_test(args):
if args.magnify_data:
fake_sample = demagnified_function(fake_sample, train_mode=args.infer_mode)

fake_sample *= 2.
fake_sample = iwt((fake_sample[:, :3], [torch.stack((fake_sample[:, 3:6], fake_sample[:, 6:9], fake_sample[:, 9:12]), dim=2)]))
fake_sample = torch.clamp(fake_sample, -1, 1)

Expand Down
1 change: 1 addition & 0 deletions train_wddgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,7 @@ def train(rank, gpu, args):
parser.add_argument("--disc_net_type", default="normal")
parser.add_argument("--num_disc_layers", default=6, type=int)
parser.add_argument("--magnify_data", action="store_true")
parser.add_argument("--no_use_fbn", action="store_true")


parser.add_argument('--save_content', action='store_true',default=False)
Expand Down

0 comments on commit 6dc17cc

Please sign in to comment.