diff --git a/GPDM.py b/GPDM.py index 13a8010..bdefd3c 100644 --- a/GPDM.py +++ b/GPDM.py @@ -1,4 +1,5 @@ import os +from math import sqrt from torchvision.utils import save_image from tqdm import tqdm @@ -9,13 +10,14 @@ def generate(reference_images, criteria, - init_from: str = 'zeros', + init_from = 'zeros', pyramid_scales=(32, 64, 128, 256), lr: float = 0.01, num_steps: int = 300, aspect_ratio=(1, 1), + num_outputs=1, additive_noise_sigma=0.0, - device: str = 'cuda:0', + device: str = 'cuda:1', debug_dir=None): """ Run the GPDM model to generate an image/s with a similar patch distribution to reference_images/s with a given criteria. @@ -24,13 +26,19 @@ def generate(reference_images, if debug_dir: os.makedirs(debug_dir, exist_ok=True) - pbar = GPDMLogger(num_steps, len(pyramid_scales)) - criteria = criteria.to(device) reference_images = reference_images.to(device) synthesized_images = get_fist_initial_guess(reference_images, init_from, additive_noise_sigma).to(device) + synthesized_images = ensure_size(synthesized_images, num_outputs) original_image_shape = synthesized_images.shape[-2:] + + print(f"Matching the patches of {len(synthesized_images)} generated images to the patches of {len(reference_images)} reference images") + pbar = GPDMLogger(num_steps, len(pyramid_scales)) + + if debug_dir: + nrow = int(sqrt(len(synthesized_images))) + save_image(synthesized_images, os.path.join(debug_dir, f'init.png'), normalize=True, nrow=nrow) all_losses = [] for scale in pyramid_scales: pbar.new_lvl() @@ -42,12 +50,12 @@ def generate(reference_images, all_losses += losses if debug_dir: - save_image(lvl_references, os.path.join(debug_dir, f'references-lvl-{pbar.lvl}.png'), normalize=True) - save_image(synthesized_images, os.path.join(debug_dir, f'outputs-lvl-{pbar.lvl}.png'), normalize=True) + save_image(lvl_references, os.path.join(debug_dir, f'references-lvl-{pbar.lvl}.png'), normalize=True, nrow=nrow) + save_image(synthesized_images, os.path.join(debug_dir, f'outputs-lvl-{pbar.lvl}.png'), normalize=True, nrow=nrow) plot_loss(all_losses, os.path.join(debug_dir, f'train_losses.png')) pbar.pbar.close() - return synthesized_images + return synthesized_images, lvl_references def _match_patch_distributions(synthesized_images, reference_images, criteria, num_steps, lr, pbar): @@ -61,6 +69,7 @@ def _match_patch_distributions(synthesized_images, reference_images, criteria, n optim = torch.optim.Adam([synthesized_images], lr=lr) losses = [] for i in range(num_steps): + criteria.init() # Optimize image optim.zero_grad() loss = criteria(synthesized_images, reference_images) @@ -101,14 +110,17 @@ def print(self): def get_fist_initial_guess(reference_images, init_from, additive_noise_sigma): if init_from == "zeros": - synthesized_images = torch.zeros_like(reference_images) + synthesized_images = torch.zeros(1, *reference_images.shape[1:]) + elif init_from == "mean": + synthesized_images = torch.mean(reference_images, dim=0, keepdim=True) elif init_from == "target": synthesized_images = reference_images.clone() import torchvision synthesized_images = torchvision.transforms.GaussianBlur(7, sigma=7)(synthesized_images) + # elif type(init_from) == torch.Tensor: + # synthesized_images = init_from elif os.path.exists(init_from): synthesized_images = load_image(init_from) - synthesized_images = synthesized_images.repeat(reference_images.shape[0], 1, 1, 1) else: raise ValueError("Bad init mode", init_from) if additive_noise_sigma: @@ -117,6 +129,12 @@ def get_fist_initial_guess(reference_images, init_from, additive_noise_sigma): return synthesized_images +def ensure_size(batch, num_outputs): + if num_outputs > 1 and batch.shape[0] == 1: + batch = batch.repeat(num_outputs, 1, 1, 1) + return batch + + def get_output_shape(initial_image_shape, size, aspect_ratio): """Get the size of the output pyramid level""" h, w = initial_image_shape diff --git a/main.py b/main.py index 75487ac..decbdb3 100644 --- a/main.py +++ b/main.py @@ -3,7 +3,7 @@ sys.path.append(dirname(dirname(abspath(__file__)))) import GPDM from patch_swd import PatchSWDLoss -from utils import load_image, dump_images, get_pyramid_scales +from utils import read_data, dump_images, get_pyramid_scales, show_nns import argparse @@ -11,8 +11,9 @@ def parse_args(): # IO parser = argparse.ArgumentParser(description='Run GPDM') parser.add_argument('target_image', help="This image has the reference patch distribution to be matched") - parser.add_argument('--output_dir', default="Outputs", help="Where to put the results") + parser.add_argument('--output_dir', default="outputs", help="Where to put the results") parser.add_argument('--debug', action='store_true', default=False, help="Dumpp debug images") + parser.add_argument('--device', default="cuda:0") # SWD parameters parser.add_argument('--patch_size', type=int, default=7) @@ -39,7 +40,7 @@ def parse_args(): parser.add_argument('--lr', type=float, default=0.01, help="Adam learning rate for the optimization") parser.add_argument('--num_steps', type=int, default=300, help="Number of Adam steps") parser.add_argument('--noise_sigma', type=float, default=1.5, help="Std of noise added to the first initial image") - parser.add_argument('--num_images', type=int, default=1, + parser.add_argument('--num_outputs', type=int, default=1, help="If > 1, batched inference is used (see paper) and multiple images are generated") return parser.parse_args() @@ -48,20 +49,23 @@ def parse_args(): if __name__ == '__main__': args = parse_args() - refernce_images = load_image(args.target_image).repeat(args.num_images, 1, 1, 1) + refernce_images = read_data(args.target_image) - criteria = PatchSWDLoss(patch_size=args.patch_size, stride=args.stride, num_proj=args.num_proj) + criteria = PatchSWDLoss(patch_size=args.patch_size, stride=args.stride, num_proj=args.num_proj, c=refernce_images.shape[1]) fine_dim = args.fine_dim if args.fine_dim is not None else refernce_images.shape[-2] outputs_dir = join(args.output_dir, basename(args.target_image)) - new_images = GPDM.generate(refernce_images, criteria, + new_images, last_lvl_references = GPDM.generate(refernce_images, criteria, pyramid_scales=get_pyramid_scales(fine_dim, args.coarse_dim, args.pyr_factor), aspect_ratio=(args.height_factor, args.width_factor), init_from=args.init_from, lr=args.lr, num_steps=args.num_steps, additive_noise_sigma=args.noise_sigma, - debug_dir=f"{outputs_dir}/debug" if args.debug else None + num_outputs=args.num_outputs, + debug_dir=f"{outputs_dir}/debug" if args.debug else None, + device=args.device ) dump_images(new_images, outputs_dir) + show_nns(new_images, last_lvl_references, outputs_dir) \ No newline at end of file diff --git a/patch_swd.py b/patch_swd.py index 093d795..7fe6dd8 100644 --- a/patch_swd.py +++ b/patch_swd.py @@ -3,23 +3,26 @@ class PatchSWDLoss(torch.nn.Module): - def __init__(self, patch_size=7, stride=1, num_proj=256): + def __init__(self, patch_size=7, stride=1, num_proj=256, c=3, l2=False): super(PatchSWDLoss, self).__init__() self.patch_size = patch_size self.stride = stride self.num_proj = num_proj + self.l2 = l2 + self.c = c + self.init() - def forward(self, x, y): - b, c, h, w = x.shape - + def init(self): # Sample random normalized projections - rand = torch.randn(self.num_proj, c*self.patch_size**2).to(x.device) # (slice_size**2*ch) + rand = torch.randn(self.num_proj, self.c*self.patch_size**2) # (slice_size**2*ch) rand = rand / torch.norm(rand, dim=1, keepdim=True) # noramlize to unit directions - rand = rand.reshape(self.num_proj, c, self.patch_size, self.patch_size) + self.rand = rand.reshape(self.num_proj, self.c, self.patch_size, self.patch_size) + def forward(self, x, y): + self.rand = self.rand.to(x.device) # Project patches - projx = F.conv2d(x, rand).transpose(1,0).reshape(self.num_proj, -1) - projy = F.conv2d(y, rand).transpose(1,0).reshape(self.num_proj, -1) + projx = F.conv2d(x, self.rand).transpose(1,0).reshape(self.num_proj, -1) + projy = F.conv2d(y, self.rand).transpose(1,0).reshape(self.num_proj, -1) # Duplicate patches if number does not equal projx, projy = duplicate_to_match_lengths(projx, projy) @@ -28,7 +31,10 @@ def forward(self, x, y): projx, _ = torch.sort(projx, dim=1) projy, _ = torch.sort(projy, dim=1) - loss = torch.abs(projx - projy).mean() + if self.l2: + loss = ((projx - projy)**2).mean() + else: + loss = torch.abs(projx - projy).mean() return loss diff --git a/utils.py b/utils.py index 9831198..437ebbc 100644 --- a/utils.py +++ b/utils.py @@ -19,6 +19,14 @@ def load_image(path): return img +def read_data(path): + if os.path.isdir(path): + refernce_images = torch.cat([load_image(f'{path}/{x}') for x in os.listdir(path)], dim=0) + else: + refernce_images = load_image(path) + return refernce_images + + def dump_images(images, out_dir): if os.path.exists(out_dir): i = len(os.listdir(out_dir)) @@ -33,11 +41,10 @@ def show_nns(images, ref_images, out_dir): nn_indices = [] for i in range(len(images)): dists = torch.mean((ref_images - images[i].unsqueeze(0))**2, dim=(1,2,3)) - j = dists.argmin() + j = dists.argmin().item() nn_indices.append(j) - - debug_image = torch.cat([images, ref_images[nn_indices]], dim=0) - save_image(debug_image, os.path.join(out_dir, f"NNs.png"), normalize=True, nrow=len(ref_images)) + debug_image = torch.cat([images, ref_images[nn_indices] ], dim=0) + save_image(debug_image, os.path.join(out_dir, f"NNs.png"), normalize=True, nrow=len(images)) def get_pyramid_scales(max_height, min_height, step):