Skip to content

Commit

Permalink
Add fixed support to muli-image inputs and outputs
Browse files Browse the repository at this point in the history
Rate limit · GitHub

Whoa there!

You have triggered an abuse detection mechanism.

Please wait a few minutes before you try again;
in some cases this may take up to an hour.

ariel415el committed Mar 9, 2023
1 parent 8edac2d commit 68dd629
Showing 4 changed files with 64 additions and 29 deletions.
36 changes: 27 additions & 9 deletions GPDM.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from math import sqrt

from torchvision.utils import save_image
from tqdm import tqdm
@@ -9,13 +10,14 @@

def generate(reference_images,
criteria,
init_from: str = 'zeros',
init_from = 'zeros',
pyramid_scales=(32, 64, 128, 256),
lr: float = 0.01,
num_steps: int = 300,
aspect_ratio=(1, 1),
num_outputs=1,
additive_noise_sigma=0.0,
device: str = 'cuda:0',
device: str = 'cuda:1',
debug_dir=None):
"""
Run the GPDM model to generate an image/s with a similar patch distribution to reference_images/s with a given criteria.
@@ -24,13 +26,19 @@ def generate(reference_images,
if debug_dir:
os.makedirs(debug_dir, exist_ok=True)

pbar = GPDMLogger(num_steps, len(pyramid_scales))

criteria = criteria.to(device)

reference_images = reference_images.to(device)
synthesized_images = get_fist_initial_guess(reference_images, init_from, additive_noise_sigma).to(device)
synthesized_images = ensure_size(synthesized_images, num_outputs)
original_image_shape = synthesized_images.shape[-2:]

print(f"Matching the patches of {len(synthesized_images)} generated images to the patches of {len(reference_images)} reference images")
pbar = GPDMLogger(num_steps, len(pyramid_scales))

if debug_dir:
nrow = int(sqrt(len(synthesized_images)))
save_image(synthesized_images, os.path.join(debug_dir, f'init.png'), normalize=True, nrow=nrow)
all_losses = []
for scale in pyramid_scales:
pbar.new_lvl()
@@ -42,12 +50,12 @@ def generate(reference_images,
all_losses += losses

if debug_dir:
save_image(lvl_references, os.path.join(debug_dir, f'references-lvl-{pbar.lvl}.png'), normalize=True)
save_image(synthesized_images, os.path.join(debug_dir, f'outputs-lvl-{pbar.lvl}.png'), normalize=True)
save_image(lvl_references, os.path.join(debug_dir, f'references-lvl-{pbar.lvl}.png'), normalize=True, nrow=nrow)
save_image(synthesized_images, os.path.join(debug_dir, f'outputs-lvl-{pbar.lvl}.png'), normalize=True, nrow=nrow)
plot_loss(all_losses, os.path.join(debug_dir, f'train_losses.png'))

pbar.pbar.close()
return synthesized_images
return synthesized_images, lvl_references


def _match_patch_distributions(synthesized_images, reference_images, criteria, num_steps, lr, pbar):
@@ -61,6 +69,7 @@ def _match_patch_distributions(synthesized_images, reference_images, criteria, n
optim = torch.optim.Adam([synthesized_images], lr=lr)
losses = []
for i in range(num_steps):
criteria.init()
# Optimize image
optim.zero_grad()
loss = criteria(synthesized_images, reference_images)
@@ -101,14 +110,17 @@ def print(self):

def get_fist_initial_guess(reference_images, init_from, additive_noise_sigma):
if init_from == "zeros":
synthesized_images = torch.zeros_like(reference_images)
synthesized_images = torch.zeros(1, *reference_images.shape[1:])
elif init_from == "mean":
synthesized_images = torch.mean(reference_images, dim=0, keepdim=True)
elif init_from == "target":
synthesized_images = reference_images.clone()
import torchvision
synthesized_images = torchvision.transforms.GaussianBlur(7, sigma=7)(synthesized_images)
# elif type(init_from) == torch.Tensor:
# synthesized_images = init_from
elif os.path.exists(init_from):
synthesized_images = load_image(init_from)
synthesized_images = synthesized_images.repeat(reference_images.shape[0], 1, 1, 1)
else:
raise ValueError("Bad init mode", init_from)
if additive_noise_sigma:
@@ -117,6 +129,12 @@ def get_fist_initial_guess(reference_images, init_from, additive_noise_sigma):
return synthesized_images


def ensure_size(batch, num_outputs):
if num_outputs > 1 and batch.shape[0] == 1:
batch = batch.repeat(num_outputs, 1, 1, 1)
return batch


def get_output_shape(initial_image_shape, size, aspect_ratio):
"""Get the size of the output pyramid level"""
h, w = initial_image_shape
18 changes: 11 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
@@ -3,16 +3,17 @@
sys.path.append(dirname(dirname(abspath(__file__))))
import GPDM
from patch_swd import PatchSWDLoss
from utils import load_image, dump_images, get_pyramid_scales
from utils import read_data, dump_images, get_pyramid_scales, show_nns
import argparse


def parse_args():
# IO
parser = argparse.ArgumentParser(description='Run GPDM')
parser.add_argument('target_image', help="This image has the reference patch distribution to be matched")
parser.add_argument('--output_dir', default="Outputs", help="Where to put the results")
parser.add_argument('--output_dir', default="outputs", help="Where to put the results")
parser.add_argument('--debug', action='store_true', default=False, help="Dumpp debug images")
parser.add_argument('--device', default="cuda:0")

# SWD parameters
parser.add_argument('--patch_size', type=int, default=7)
@@ -39,7 +40,7 @@ def parse_args():
parser.add_argument('--lr', type=float, default=0.01, help="Adam learning rate for the optimization")
parser.add_argument('--num_steps', type=int, default=300, help="Number of Adam steps")
parser.add_argument('--noise_sigma', type=float, default=1.5, help="Std of noise added to the first initial image")
parser.add_argument('--num_images', type=int, default=1,
parser.add_argument('--num_outputs', type=int, default=1,
help="If > 1, batched inference is used (see paper) and multiple images are generated")

return parser.parse_args()
@@ -48,20 +49,23 @@ def parse_args():
if __name__ == '__main__':
args = parse_args()

refernce_images = load_image(args.target_image).repeat(args.num_images, 1, 1, 1)
refernce_images = read_data(args.target_image)

criteria = PatchSWDLoss(patch_size=args.patch_size, stride=args.stride, num_proj=args.num_proj)
criteria = PatchSWDLoss(patch_size=args.patch_size, stride=args.stride, num_proj=args.num_proj, c=refernce_images.shape[1])

fine_dim = args.fine_dim if args.fine_dim is not None else refernce_images.shape[-2]

outputs_dir = join(args.output_dir, basename(args.target_image))
new_images = GPDM.generate(refernce_images, criteria,
new_images, last_lvl_references = GPDM.generate(refernce_images, criteria,
pyramid_scales=get_pyramid_scales(fine_dim, args.coarse_dim, args.pyr_factor),
aspect_ratio=(args.height_factor, args.width_factor),
init_from=args.init_from,
lr=args.lr,
num_steps=args.num_steps,
additive_noise_sigma=args.noise_sigma,
debug_dir=f"{outputs_dir}/debug" if args.debug else None
num_outputs=args.num_outputs,
debug_dir=f"{outputs_dir}/debug" if args.debug else None,
device=args.device
)
dump_images(new_images, outputs_dir)
show_nns(new_images, last_lvl_references, outputs_dir)
24 changes: 15 additions & 9 deletions patch_swd.py
Original file line number Diff line number Diff line change
@@ -3,23 +3,26 @@


class PatchSWDLoss(torch.nn.Module):
def __init__(self, patch_size=7, stride=1, num_proj=256):
def __init__(self, patch_size=7, stride=1, num_proj=256, c=3, l2=False):
super(PatchSWDLoss, self).__init__()
self.patch_size = patch_size
self.stride = stride
self.num_proj = num_proj
self.l2 = l2
self.c = c
self.init()

def forward(self, x, y):
b, c, h, w = x.shape

def init(self):
# Sample random normalized projections
rand = torch.randn(self.num_proj, c*self.patch_size**2).to(x.device) # (slice_size**2*ch)
rand = torch.randn(self.num_proj, self.c*self.patch_size**2) # (slice_size**2*ch)
rand = rand / torch.norm(rand, dim=1, keepdim=True) # noramlize to unit directions
rand = rand.reshape(self.num_proj, c, self.patch_size, self.patch_size)
self.rand = rand.reshape(self.num_proj, self.c, self.patch_size, self.patch_size)

def forward(self, x, y):
self.rand = self.rand.to(x.device)
# Project patches
projx = F.conv2d(x, rand).transpose(1,0).reshape(self.num_proj, -1)
projy = F.conv2d(y, rand).transpose(1,0).reshape(self.num_proj, -1)
projx = F.conv2d(x, self.rand).transpose(1,0).reshape(self.num_proj, -1)
projy = F.conv2d(y, self.rand).transpose(1,0).reshape(self.num_proj, -1)

# Duplicate patches if number does not equal
projx, projy = duplicate_to_match_lengths(projx, projy)
@@ -28,7 +31,10 @@ def forward(self, x, y):
projx, _ = torch.sort(projx, dim=1)
projy, _ = torch.sort(projy, dim=1)

loss = torch.abs(projx - projy).mean()
if self.l2:
loss = ((projx - projy)**2).mean()
else:
loss = torch.abs(projx - projy).mean()

return loss

15 changes: 11 additions & 4 deletions utils.py
Original file line number Diff line number Diff line change
@@ -19,6 +19,14 @@ def load_image(path):
return img


def read_data(path):
if os.path.isdir(path):
refernce_images = torch.cat([load_image(f'{path}/{x}') for x in os.listdir(path)], dim=0)
else:
refernce_images = load_image(path)
return refernce_images


def dump_images(images, out_dir):
if os.path.exists(out_dir):
i = len(os.listdir(out_dir))
@@ -33,11 +41,10 @@ def show_nns(images, ref_images, out_dir):
nn_indices = []
for i in range(len(images)):
dists = torch.mean((ref_images - images[i].unsqueeze(0))**2, dim=(1,2,3))
j = dists.argmin()
j = dists.argmin().item()
nn_indices.append(j)

debug_image = torch.cat([images, ref_images[nn_indices]], dim=0)
save_image(debug_image, os.path.join(out_dir, f"NNs.png"), normalize=True, nrow=len(ref_images))
debug_image = torch.cat([images, ref_images[nn_indices] ], dim=0)
save_image(debug_image, os.path.join(out_dir, f"NNs.png"), normalize=True, nrow=len(images))


def get_pyramid_scales(max_height, min_height, step):

0 comments on commit 68dd629

Please sign in to comment.