Skip to content

Commit

Permalink
code improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
chompaa committed Aug 1, 2023
1 parent 4257aaf commit 902d879
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 144 deletions.
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"python.analysis.typeCheckingMode": "off"
}
123 changes: 72 additions & 51 deletions prepare.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,64 @@
import argparse
import glob

import h5py
import numpy as np
import PIL.Image as pil_image
from torchvision.utils import save_image
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from pathlib import Path
from PIL import Image
from pathlib import Path
from torchvision import transforms
from torchvision.utils import save_image

from utils import convert_rgb_to_y


def get_image_paths(path):
return sorted(glob.glob(f"{path}/*"))


def make_hr_lr_images(image_paths, scale):
hr_images = []
lr_images = []

# floating point scale factors are bad..
scale = int(scale)

for image_path in get_image_paths(image_paths):
with Image.open(image_path).convert("RGB") as hr:
# want hr image to be divisible by scale
hr_width = (hr.width // scale) * scale
hr_height = (hr.height // scale) * scale
hr = hr.resize((hr_width, hr_height), resample=Image.BICUBIC)

lr = hr.resize(
(hr_width // scale, hr_height // scale),
resample=Image.BICUBIC,
)
lr = lr.resize(
(lr.width * scale, lr.height * scale), resample=Image.BICUBIC
)

hr = np.array(hr).astype(np.float32)
lr = np.array(lr).astype(np.float32)

hr = convert_rgb_to_y(hr)
lr = convert_rgb_to_y(lr)

hr_images.append(hr)
lr_images.append(lr)

return hr_images, lr_images


def train(args):
h5_file = h5py.File(args.output_path, "w")

lr_patches = []
hr_patches = []

for image_path in sorted(glob.glob(f"{args.images_dir}/*")):
hr = pil_image.open(image_path).convert("RGB")
hr_width = (hr.width // args.scale) * args.scale
hr_height = (hr.height // args.scale) * args.scale
hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)
lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)
hr = np.array(hr).astype(np.float32)
lr = np.array(lr).astype(np.float32)
hr = convert_rgb_to_y(hr)
lr = convert_rgb_to_y(lr)

for hr, lr in zip(*make_hr_lr_images(args.images_dir, args.scale)):
for i in range(0, lr.shape[0] - args.patch_size + 1, args.stride):
for j in range(0, lr.shape[1] - args.patch_size + 1, args.stride):
lr_patches.append(lr[i:i + args.patch_size, j:j + args.patch_size])
hr_patches.append(hr[i:i + args.patch_size, j:j + args.patch_size])
lr_patches.append(lr[i : i + args.patch_size, j : j + args.patch_size])
hr_patches.append(hr[i : i + args.patch_size, j : j + args.patch_size])

lr_patches = np.array(lr_patches)
hr_patches = np.array(hr_patches)
Expand All @@ -51,38 +75,35 @@ def eval(args):
lr_group = h5_file.create_group("lr")
hr_group = h5_file.create_group("hr")

for i, image_path in enumerate(sorted(glob.glob(f"{args.images_dir}/*"))):
hr = pil_image.open(image_path).convert("RGB")
hr_width = (hr.width // args.scale) * args.scale
hr_height = (hr.height // args.scale) * args.scale
hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)
lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)
hr = np.array(hr).astype(np.float32)
lr = np.array(lr).astype(np.float32)
hr = convert_rgb_to_y(hr)
lr = convert_rgb_to_y(lr)

lr_group.create_dataset(str(i), data=lr)
hr_group.create_dataset(str(i), data=hr)
for index, (hr, lr) in enumerate(
zip(*make_hr_lr_images(args.images_dir, args.scale))
):
lr_group.create_dataset(str(index), data=lr)
hr_group.create_dataset(str(index), data=hr)

h5_file.close()


def resize(args):
image_path_list = list(glob.glob(f"{args.images_dir}/*.png"))

for image_path in image_path_list:
with Image.open(image_path) as image:
transform = transforms.Compose([
transforms.Resize(size=(int(image.height * args.resize_scale), int(image.width * args.resize_scale)), interpolation=Image.BICUBIC),
transforms.ToTensor(),
])

image = transform(image)
name = image_path.split('/')[-1].split('.')[0].split("\\")[-1]

save_image(image, f"{args.output_path}/{name}.png")
for image_path in get_image_paths(args.images_dir):
with Image.open(image_path) as image:
transform = transforms.Compose(
[
transforms.Resize(
size=(
int(image.height * args.scale),
int(image.width * args.scale),
),
interpolation=Image.BICUBIC,
),
transforms.ToTensor(),
]
)

image = transform(image)
name = image_path.split("/")[-1].split(".")[0].split("\\")[-1]

save_image(image, f"{args.output_path}/{name}.png")


if __name__ == "__main__":
Expand All @@ -91,15 +112,15 @@ def resize(args):
parser.add_argument("--output-path", type=str, required=True)
parser.add_argument("--patch-size", type=int, default=33)
parser.add_argument("--stride", type=int, default=14)
parser.add_argument("--scale", type=int, default=2)
parser.add_argument("--resize-scale", type=float, default=0.1)
parser.add_argument("--scale", type=float, default=2)
parser.add_argument("--resize", action="store_true")
parser.add_argument("--eval", action="store_true")

args = parser.parse_args()

if not args.eval and not args.resize:
train(args)
elif args.resize:
resize(args)
else:
eval(args)
eval(args)
71 changes: 71 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import argparse

import torch
import torch.backends.cudnn as cudnn
import numpy as np
from PIL import Image

from srcnn import SRCNN
from utils import convert_rgb_to_ycbcr, convert_ycbcr_to_rgb, calculate_psnr


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--weights-file", type=str, required=True)
parser.add_argument("--image-file", type=str, required=True)
parser.add_argument("--scale", type=int, default=3)
args = parser.parse_args()

cudnn.benchmark = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = SRCNN().to(device)

state_dict = model.state_dict()
for n, p in torch.load(
args.weights_file, map_location=lambda storage, _: storage
).items():
if n in state_dict.keys():
state_dict[n].copy_(p)
else:
raise KeyError(n)

model.eval()

ycbr = None

with Image.open(args.image_file).convert("RGB") as image:
image_width = (image.width // args.scale) * args.scale
image_height = (image.height // args.scale) * args.scale
image = image.resize((image_width, image_height), resample=Image.BICUBIC)
image = image.resize(
(image.width // args.scale, image.height // args.scale),
resample=Image.BICUBIC,
)
image = image.resize(
(image.width * args.scale, image.height * args.scale),
resample=Image.BICUBIC,
)
image.save(args.image_file.replace(".", f"_bicubic_x{args.scale}."))

image = np.array(image).astype(np.float32)

ycbcr = convert_rgb_to_ycbcr(image)

y = ycbcr[..., 0]
y /= 255.0
y = torch.from_numpy(y).to(device)
y = y.unsqueeze(0).unsqueeze(0)

with torch.no_grad():
preds = model(y).clamp(0.0, 1.0)

psnr = calculate_psnr(y, preds)
print("PSNR: {:.2f}".format(psnr))

preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)

output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
output = Image.fromarray(output)
output.save(args.image_file.replace(".", f"_srcnn_x{args.scale}."))
Binary file added test/butterfly_GT.bmp
Binary file not shown.
Binary file added test/butterfly_GT_bicubic_x2.bmp
Binary file not shown.
Binary file added test/butterfly_GT_srcnn_x2.bmp
Binary file not shown.
Loading

0 comments on commit 902d879

Please sign in to comment.