Skip to content

Commit

Permalink
Old changes i commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ariel415el committed Aug 8, 2023
1 parent 5047276 commit e3f7d5c
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 22 deletions.
12 changes: 10 additions & 2 deletions super_resolution/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,10 @@ def __init__(self, ref_image, scale_factor, p=20, s=1, n_proj=64, num_steps=500,
def loss(self, image):
from torchvision.transforms import Resize
self.criteria.init()
sr_filters = self.criteria.rand.to(image.device)
lr_filters = Resize(self.p//self.scale_factor, antialias=True)(sr_filters.clone())
# sr_filters = self.criteria.rand.to(image.device)
# lr_filters = Resize(self.p//self.scale_factor, antialias=True)(sr_filters.clone())
lr_filters = self.criteria.rand.to(image.device)
sr_filters = Resize(self.p*self.scale_factor, antialias=True)(lr_filters.clone())
projy = F.conv2d(image, sr_filters).transpose(1,0).reshape(self.n_proj, -1)
projx = F.conv2d(self.ref_image, lr_filters).transpose(1,0).reshape(self.n_proj, -1)

Expand All @@ -213,6 +215,12 @@ def loss(self, image):
loss = torch.abs(projx - projy).mean()
return loss

def run(self, init_image):
img, losses = match_patch_distributions(init_image, self.loss, self.num_steps, self.lr)
if self.gradient_projector is not None:
img = self.gradient_projector(img)
img = torch.clip(img.detach(), -1, 1)
return img, losses

class GD_gradient_projector:
def __init__(self, corrupt_image, operator, n_steps=100, lr=0.0001, reg_weight=0):
Expand Down
33 changes: 17 additions & 16 deletions super_resolution/scripts/test_projections_histogram.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import sys
import torch
from matplotlib import pyplot as plt

from super_resolution.debug_utils import plot_hists, plot_img
from super_resolution.predefined_filters import get_random_filters, get_gabor_filters, appply_filter, resize_filters
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from sr_utils.debug_utils import plot_hists, plot_img
from sr_utils.predefined_filters import get_random_filters, get_gabor_filters, appply_filter, resize_filters, normalize_filters


def split_list_randomly(values):
Expand All @@ -13,22 +15,21 @@ def split_list_randomly(values):
return l1,l2


def compare_resized_histograms(p, reverse=True):
factor = 4
def compare_resized_histograms(factor, p, reverse=True):
low_res = Resize(int(im_size // factor), antialias=True)(hig_res)

if reverse:
p = factor * p
factor = 1/factor

rand_projs = get_random_filters(p, n=50)
rand_projs = normalize_filters(get_random_filters(p, n=50))
# naive_projs = get_naive_kernels(p)
gabor_projs = get_gabor_filters(p)
gabor_projs = normalize_filters(get_gabor_filters(p))


resized_rand_projs = resize_filters(rand_projs, p, factor)
resized_rand_projs = resize_filters(rand_projs, p, factor, normalize=True)
# resized_naive_projs = resize_filters(naive_projs, p, factor, False)
resized_gabor_projs = resize_filters(gabor_projs, p, factor)
resized_gabor_projs = resize_filters(gabor_projs, p, factor, normalize=True)

filters = [
(rand_projs[0], resized_rand_projs[0], f"Random"),
Expand Down Expand Up @@ -60,9 +61,7 @@ def compare_resized_histograms(p, reverse=True):
plt.show()


def hist_sanity():
p=5
factor=4
def hist_sanity(factor, p):
low_res = Resize(int(im_size // factor), antialias=True)(hig_res)
low_res_big = Resize(im_size, antialias=True)(low_res)

Expand Down Expand Up @@ -108,9 +107,11 @@ def hist_sanity():
device = torch.device("cpu")
im_size = 1024
nbins = 100
hig_res = load_image('../data/images/SR/fox2.jpg').to(device)
hig_res = load_image('data/images/SR/fox2.jpg').to(device)
hig_res = Resize(im_size, antialias=True)(hig_res)

hist_sanity()
compare_resized_histograms(p=5, reverse=True)
compare_resized_histograms(p=5, reverse=False)
factor = 2
p=2
hist_sanity(factor, p)
compare_resized_histograms(factor, p, reverse=True)
compare_resized_histograms(factor, p, reverse=False)
2 changes: 1 addition & 1 deletion super_resolution/sr_utils/predefined_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_random_filters(p, n=1, c=3):
rand = torch.randn(n, c * p ** 2) # (slice_size**2*ch)
rand = rand / torch.norm(rand, dim=1, keepdim=True) # noramlize to unit directions
rand = rand.reshape(n, c, p, p)
rand /= torch.sum(rand, dim=(1,2,3), keepdim=True)
# rand /= torch.sum(rand, dim=(1,2,3), keepdim=True)
return rand


Expand Down
10 changes: 7 additions & 3 deletions super_resolution/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,12 @@ def dump_image(img, path, normalize=True):

# Define models
models = [
# Fixe vs resample:
# DirectSWD(refernce_image, p=args.p, s=args.s, mode="Resample", num_steps=args.num_steps, n_proj=args.n_proj),
# DirectSWD(refernce_image, p=args.p, s=args.s, mode="Fixed", num_steps=args.num_steps, n_proj=args.n_proj),

# # Patch sizes
DirectSWD(refernce_image, p=8, s=1, mode="Resample", num_steps=args.num_steps, n_proj=args.n_proj, lr=100),
# DirectSWD(refernce_image, p=8, s=1, mode="Resample", num_steps=args.num_steps, n_proj=args.n_proj, lr=100),
# DirectSWD(refernce_image, p=16, s=2, mode="Resample", num_steps=args.num_steps, n_proj=args.n_proj),
# DirectSWD(refernce_image, p=32, s=1, mode="Resample", num_steps=args.num_steps, n_proj=args.n_proj),
# MSSWD(refernce_image, ps=(3, 5, 7, 9, 11), s=1, mode="Resample", num_steps=args.num_steps, n_proj=args.n_proj, lr=50),
Expand All @@ -101,8 +105,8 @@ def dump_image(img, path, normalize=True):
# gradient_projector=back_projector(corrupt_image, operator, n_steps=100), name="BackProject"),

# # Self
# TwoScalesSWD(corrupt_image, scale_factor=4, p=args.p, s=args.s, mode="Resample", num_steps=args.num_steps, n_proj=args.n_proj, name="self-rescaled"),
# DirectSWD(corrupt_image, p=args.p, s=args.s, mode="Resample", num_steps=args.num_steps, n_proj=args.n_proj, name="self"),
TwoScalesSWD(corrupt_image, scale_factor=4, p=args.p, s=args.s, mode="Resample", num_steps=args.num_steps, n_proj=args.n_proj, name="self-rescaled"),
DirectSWD(corrupt_image, p=args.p, s=args.s, mode="Resample", num_steps=args.num_steps, n_proj=args.n_proj, name="self"),
# DirectSWD(corrupt_image, p=args.p, s=args.s, mode="Resample", num_steps=args.num_steps, n_proj=args.n_proj, name="self-PGD",
# gradient_projector=gradient_projector),
# DirectSWD(corrupt_image, p=args.p, s=args.s, mode="Resample", num_steps=args.num_steps, n_proj=args.n_proj, name="self-PGDreg",
Expand Down

0 comments on commit e3f7d5c

Please sign in to comment.