Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeqiang-Lai committed May 21, 2023
1 parent c5fde4c commit 9971956
Show file tree
Hide file tree
Showing 12 changed files with 2,088 additions and 0 deletions.
211 changes: 211 additions & 0 deletions drag_gan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as FF
from torchvision import utils
import torch.optim
from stylegan2.model import Generator
import copy

class CustomGenerator(Generator):
def prepare(
self,
styles,
inject_index=None,
truncation=1,
truncation_latent=None,
input_is_latent=False,
noise=None,
randomize_noise=True,
):
if not input_is_latent:
styles = [self.style(s) for s in styles]

if noise is None:
if randomize_noise:
noise = [None] * self.num_layers
else:
noise = [
getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
]

if truncation < 1:
style_t = []

for style in styles:
style_t.append(
truncation_latent + truncation * (style - truncation_latent)
)

styles = style_t

if len(styles) < 2:
inject_index = self.n_latent

if styles[0].ndim < 3:
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)

else:
latent = styles[0]

else:
if inject_index is None:
inject_index = random.randint(1, self.n_latent - 1)

latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)

latent = torch.cat([latent, latent2], 1)

return latent, noise

def generate(
self,
latent,
noise,
):
out = self.input(latent)
out = self.conv1(out, latent[:, 0], noise=noise[0])

skip = self.to_rgb1(out, latent[:, 1])
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
):
out = conv1(out, latent[:, i], noise=noise1)
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
if out.shape[-1] == 256: F = out
i += 2

image = skip
F = FF.interpolate(F, image.shape[-2:], mode='bilinear')
return image, F


def stylegan2(
size=1024,
channel_multiplier=2,
latent=512,
n_mlp=8,
ckpt='stylegan2-ffhq-config-f.pt'
):
g_ema = CustomGenerator(size, latent, n_mlp, channel_multiplier=channel_multiplier)
checkpoint = torch.load(ckpt)
g_ema.load_state_dict(checkpoint["g_ema"], strict=False)
g_ema.requires_grad_(False)
g_ema.eval()
return g_ema


def bilinear_interpolate_torch(im, y, x):
"""
im : B,C,H,W
y : 1,numPoints -- pixel location y float
x : 1,numPOints -- pixel location y float
"""

x0 = torch.floor(x).long()
x1 = x0 + 1

y0 = torch.floor(y).long()
y1 = y0 + 1

wa = (x1.float() - x) * (y1.float() - y)
wb = (x1.float() - x) * (y - y0.float())
wc = (x - x0.float()) * (y1.float() - y)
wd = (x - x0.float()) * (y - y0.float())
# Instead of clamp
x1 = x1 - torch.floor(x1 / im.shape[3]).int()
y1 = y1 - torch.floor(y1 / im.shape[2]).int()
Ia = im[:, :, y0, x0]
Ib = im[:, :, y1, x0]
Ic = im[:, :, y0, x1]
Id = im[:, :, y1, x1]

return Ia * wa + Ib * wb + Ic * wc + Id * wd


def drag_gan(g_ema, latent: torch.Tensor, noise, F, handle_points, target_points, mask, max_iters=1000):
handle_points0 = copy.deepcopy(handle_points)
n = len(handle_points)
r1, r2, lam, d = 3, 12, 20, 1

def neighbor(x, y, d):
points = []
for i in range(x - d, x + d):
for j in range(y - d, y + d):
points.append(torch.tensor([i, j]).float().cuda())
return points

F0 = F.detach().clone()
# latent = latent.detach().clone().requires_grad_(True)
latent_trainable = latent[:, :6, :].detach().clone().requires_grad_(True)
latent_untrainable = latent[:, 6:, :].detach().clone().requires_grad_(False)
optimizer = torch.optim.Adam([latent_trainable], lr=2e-3)
for iter in range(max_iters):
for s in range(5):
optimizer.zero_grad()
latent = torch.cat([latent_trainable, latent_untrainable], dim=1)
sample2, F2 = g_ema.generate(latent, noise)

# motion supervision
loss = 0
for i in range(n):
pi, ti = handle_points[i], target_points[i]
di = (ti - pi) / torch.sum((ti - pi)**2)

for qi in neighbor(int(pi[0]), int(pi[1]), r1):
# f1 = F[..., int(qi[0]), int(qi[1])]
# f2 = F2[..., int(qi[0] + di[0]), int(qi[1] + di[1])]
f1 = bilinear_interpolate_torch(F2, qi[0], qi[1]).detach()
f2 = bilinear_interpolate_torch(F2, qi[0] + di[0], qi[1] + di[1])
loss += FF.l1_loss(f2, f1)

# loss += ((F-F0) * (1-mask)).abs().mean() * lam

loss.backward()
optimizer.step()

print(latent_trainable[0,0,:10])
# if s % 10 ==0:
# utils.save_image(sample2, "test2.png", normalize=True, range=(-1, 1))

# point tracking
with torch.no_grad():
sample2, F2 = g_ema.generate(latent, noise)
for i in range(n):
pi = handle_points0[i]
# f = F0[..., int(pi[0]), int(pi[1])]
f0 = bilinear_interpolate_torch(F0, pi[0], pi[1])
minv = 1e9
minx = 1e9
miny = 1e9
for qi in neighbor(int(handle_points[i][0]), int(handle_points[i][1]), r2):
# f2 = F2[..., int(qi[0]), int(qi[1])]
try:
f2 = bilinear_interpolate_torch(F2, qi[0], qi[1])
except:
import ipdb; ipdb.set_trace()
v = torch.norm(f2 - f0, p=1)
if v < minv:
minv = v
minx = int(qi[0])
miny = int(qi[1])
handle_points[i][0] = minx
handle_points[i][1] = miny

F = F2.detach().clone()
if iter % 1 == 0:
print(iter, loss.item(), handle_points, target_points)
# p = handle_points[0].int()
# sample2[0, :, p[0] - 5:p[0] + 5, p[1] - 5:p[1] + 5] = sample2[0, :, p[0] - 5:p[0] + 5, p[1] - 5:p[1] + 5] * 0
# t = target_points[0].int()
# sample2[0, :, t[0] - 5:t[0] + 5, t[1] - 5:t[1] + 5] = sample2[0, :, t[0] - 5:t[0] + 5, t[1] - 5:t[1] + 5] * 255

# sample2[0, :, 210, 134] = sample2[0, :, 210, 134] * 0
utils.save_image(sample2, "test2.png", normalize=True, range=(-1, 1))

yield sample2, latent, F2
76 changes: 76 additions & 0 deletions gradio_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import gradio as gr
import torch
from drag_gan import stylegan2, drag_gan
from PIL import Image

device = 'cuda'
torch.cuda.manual_seed(25)
g_ema = stylegan2().to(device)


def to_image(tensor):
tensor = tensor.squeeze(0).permute(1, 2, 0)
arr = tensor.detach().cpu().numpy()
arr = (arr - arr.min()) / (arr.max() - arr.min())
arr = arr * 255
return arr.astype('uint8')


def on_click(image, target_point, points, evt: gr.SelectData):
x = evt.index[1]
y = evt.index[0]
if target_point:
image[x:x + 5, y:y + 5, :] = 255
points['target'].append([evt.index[1], evt.index[0]])
return image, str(evt.index)
points['handle'].append([evt.index[1], evt.index[0]])
image[x:x + 5, y:y + 5, :] = 0
return image, str(evt.index)


def on_drag(points, max_iters, state):
max_iters = int(max_iters)
latent = state['latent']
noise = state['noise']
F = state['F']

handle_points = [torch.tensor(p).float() for p in points['handle']]
target_points = [torch.tensor(p).float() for p in points['target']]
mask = torch.zeros((1, 1, 1024, 1024)).to(device)
mask[..., 720:820, 390:600] = 1
for sample2, latent, F in drag_gan(g_ema, latent, noise, F,
handle_points, target_points, mask,
max_iters=max_iters):
points = {'target': [], 'handle': []}
image = to_image(sample2)

state['F'] = F
state['latent'] = latent
yield points, image, state


def main():
sample_z = torch.randn([1, 512], device=device)
latent, noise = g_ema.prepare([sample_z])
sample, F = g_ema.generate(latent, noise)

with gr.Blocks() as demo:
state = gr.State({
'latent': latent,
'noise': noise,
'F': F,
})
max_iters = gr.Slider(1, 20, 5, label='Max Iterations')
image = gr.Image(to_image(sample)).style(height=512, width=512)
text = gr.Textbox()
btn = gr.Button('Drag it')
points = gr.State({'target': [], 'handle': []})
target_point = gr.Checkbox(label='Target Point')
image.select(on_click, [image, target_point, points], [image, text])
btn.click(on_drag, inputs=[points, max_iters, state], outputs=[points, image, state])

demo.queue(concurrency_count=5, max_size=20).launch()


if __name__ == '__main__':
main()
Empty file added stylegan2/_init__.py
Empty file.
Loading

0 comments on commit 9971956

Please sign in to comment.