Skip to content

Commit

Permalink
revise the generator same as official
Browse files Browse the repository at this point in the history
  • Loading branch information
MingtaoGuo authored Oct 26, 2022
1 parent f08ec32 commit 1ab6179
Showing 1 changed file with 49 additions and 29 deletions.
78 changes: 49 additions & 29 deletions face_model/gpen_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch.nn import functional as F
from torch.autograd import Function


from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d

class PixelNorm(nn.Module):
Expand Down Expand Up @@ -282,7 +283,7 @@ def forward(self, input, style):


class NoiseInjection(nn.Module):
def __init__(self, isconcat=False):
def __init__(self, isconcat=True):
super().__init__()

self.isconcat = isconcat
Expand Down Expand Up @@ -382,17 +383,15 @@ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1
self.upsample = Upsample(blur_kernel, device=device)

self.conv = ModulatedConv2d(in_channel, 1, 1, style_dim, demodulate=False, device=device)
self.bias = nn.Parameter(torch.zeros(1, 1, 1, 1))

def forward(self, input, style, skip=None):
out = self.conv(input, style)
out = out + self.bias

if skip is not None:
skip = self.upsample(skip)

out = out + skip

out = F.sigmoid(out)
return out

class Generator(nn.Module):
Expand All @@ -404,7 +403,7 @@ def __init__(
channel_multiplier=2,
blur_kernel=[1, 3, 3, 1],
lr_mlp=0.01,
isconcat=False,
isconcat=True,
narrow=1,
device='cpu'
):
Expand All @@ -413,7 +412,7 @@ def __init__(
self.size = size
self.n_mlp = n_mlp
self.style_dim = style_dim
self.feat_multiplier = 1# if isconcat else 1
self.feat_multiplier = 2 if isconcat else 1

layers = [PixelNorm()]

Expand Down Expand Up @@ -445,7 +444,6 @@ def __init__(
)
self.to_rgb1 = ToRGB(self.channels[4]*self.feat_multiplier, style_dim, upsample=False, device=device)
self.to_mask1 = ToMask(self.channels[4]*self.feat_multiplier, style_dim, upsample=False, device=device)

self.log_size = int(math.log(size, 2))

self.convs = nn.ModuleList()
Expand All @@ -460,7 +458,7 @@ def __init__(

self.convs.append(
StyledConv(
in_channel * 2,
in_channel*self.feat_multiplier,
out_channel,
3,
style_dim,
Expand All @@ -482,7 +480,7 @@ def __init__(

in_channel = out_channel

self.n_latent = self.log_size * 3 - 3
self.n_latent = self.log_size * 2 - 2

def make_noise(self):
device = self.input.input.device
Expand All @@ -506,6 +504,16 @@ def mean_latent(self, n_latent):
def get_latent(self, input):
return self.style(input)

def conver_noise_withmask(self, noise, mask):
att_msk = mask
masked_img = noise * att_msk
color_sum = torch.sum(masked_img, (2, 3))
mask_sum = torch.sum(att_msk, (2, 3))
ratio = color_sum / mask_sum
ratio = ratio.unsqueeze(2).unsqueeze(3)
ratio_mask = ratio * att_msk + (1 - att_msk) * noise
return ratio_mask

def forward(
self,
styles,
Expand All @@ -515,6 +523,7 @@ def forward(
truncation_latent=None,
input_is_latent=False,
noise=None,
noise_blank_level=9,
):
if not input_is_latent:
styles = [self.style(s) for s in styles]
Expand Down Expand Up @@ -556,26 +565,28 @@ def forward(
out = self.input(latent)
out = self.conv1(out, latent[:, 0], noise=noise[0])

masks = []
skip = self.to_rgb1(out, latent[:, 1])
skip_mask = self.to_mask1(out, latent[:, 2])

mask = self.to_mask1(out, latent[:, 1])
masks.append(mask)
i = 1
for conv1, conv2, noise, to_rgb, to_mask in zip(
self.convs[::2], self.convs[1::2], noise, self.to_rgbs, self.to_masks):
out = torch.cat([out, noise * (1 - F.sigmoid(skip_mask))], dim=1)
out = conv1(out, latent[:, i])
out = conv2(out, latent[:, i + 1])
for conv1, conv2, noise1, noise2, to_rgb, to_mask in zip(
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs, self.to_masks
):
if i > noise_blank_level:
blankmask = F.interpolate(masks[-1], noise1.shape[-2:])
noise1 = self.conver_noise_withmask(noise1, blankmask)
noise2 = self.conver_noise_withmask(noise2, blankmask)
out = conv1(out, latent[:, i], noise=noise1)
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
skip_mask = to_mask(out, latent[:, i + 3], skip_mask)
i += 3
mask = to_mask(out, latent[:, i + 2], mask)
masks.append(mask)
i += 2

image = skip

if return_latents:
return image, latent, F.sigmoid(skip_mask)

else:
return image, None, F.sigmoid(skip_mask)
return image, latent, mask

class ConvLayer(nn.Sequential):
def __init__(
Expand Down Expand Up @@ -656,7 +667,7 @@ def __init__(
channel_multiplier=2,
blur_kernel=[1, 3, 3, 1],
lr_mlp=0.01,
isconcat=False,
isconcat=True,
narrow=1,
device='cpu'
):
Expand Down Expand Up @@ -688,27 +699,31 @@ def __init__(
conv = [ConvLayer(in_channel, out_channel, 3, downsample=True, device=device)]
setattr(self, self.names[self.log_size-i+1], nn.Sequential(*conv))
in_channel = out_channel
self.final_linear = nn.Sequential(EqualLinear(style_dim, style_dim, activation='fused_lrelu', device=device))

def forward(self,
inputs,
z_id,
id_emb,
return_latents=False,
inject_index=None,
truncation=1,
truncation_latent=None,
input_is_latent=False,
):
noise = []
target_img = inputs.clone()
for i in range(self.log_size-1):
ecd = getattr(self, self.names[i])
inputs = ecd(inputs)
noise.append(inputs)
#print(inputs.shape)
# inputs = inputs.view(inputs.shape[0], -1)
outs = self.final_linear(id_emb)
#print(outs.shape)
noise = noise[::-1]
outs = self.generator([z_id], return_latents, inject_index, truncation, truncation_latent, input_is_latent, noise=noise)
I_st, latents, M = outs
return I_st, latents, M
noise = list(itertools.chain.from_iterable(itertools.repeat(x, 2) for x in noise))[::-1]
out_img, latent, mask = self.generator([outs], return_latents, inject_index, truncation, truncation_latent, input_is_latent, noise=noise[1:])
blend_res = out_img * mask + (1 - mask) * target_img
return out_img, latent, mask

class Discriminator(nn.Module):
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], narrow=1, device='cpu'):
Expand Down Expand Up @@ -800,6 +815,11 @@ def __init__(self, requires_grad=False):
param.requires_grad = False

def forward(self, X):
mu = torch.tensor([0.485, 0.456, 0.406]).cuda().view(1, 3, 1, 1).float()
std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(1, 3, 1, 1).float()
X = (X + 1) / 2
X = F.interpolate(X, [224, 224], mode="bilinear")
X = (X - mu) / std
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
Expand Down

0 comments on commit 1ab6179

Please sign in to comment.