forked from braindotai/Watermark-Removal-Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
helper.py
70 lines (55 loc) · 2.26 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
from torchvision.utils import make_grid
def pil_to_np_array(pil_image):
ar = np.array(pil_image)
if len(ar.shape) == 3:
ar = ar.transpose(2,0,1)
else:
ar = ar[None, ...]
return ar.astype(np.float32) / 255.
def np_to_torch_array(np_array):
return torch.from_numpy(np_array)[None, :]
def torch_to_np_array(torch_array):
return torch_array.detach().cpu().numpy()[0]
def read_image(path, image_size = -1):
pil_image = Image.open(path)
return pil_image
def crop_image(image, crop_factor = 64):
shape = (image.size[0] - image.size[0] % crop_factor, image.size[1] - image.size[1] % crop_factor)
bbox = [int((image.shape[0] - shape[0])/2), int((image.shape[1] - shape[1])/2), int((image.shape[0] + shape[0])/2), int((image.shape[1] + shape[1])/2)]
return image.crop(bbox)
def get_image_grid(images, nrow = 3):
torch_images = [torch.from_numpy(x) for x in images]
grid = make_grid(torch_images, nrow)
return grid.numpy()
def visualize_sample(*images_np, nrow = 3, size_factor = 10):
c = max(x.shape[0] for x in images_np)
images_np = [x if (x.shape[0] == c) else np.concatenate([x, x, x], axis = 0) for x in images_np]
grid = get_image_grid(images_np, nrow)
plt.figure(figsize = (len(images_np) + size_factor, 12 + size_factor))
plt.axis('off')
plt.imshow(grid.transpose(1, 2, 0))
plt.show()
def max_dimension_resize(image_pil, mask_pil, max_dim):
w, h = image_pil.size
aspect_ratio = w / h
if w > max_dim:
h = int((h / w) * max_dim)
w = max_dim
elif h > max_dim:
w = int((w / h) * max_dim)
h = max_dim
return image_pil.resize((w, h)), mask_pil.resize((w, h))
def preprocess_images(image_path, mask_path, max_dim):
image_pil = read_image(image_path).convert('RGB')
mask_pil = read_image(mask_path).convert('RGB')
max_dim = int(max_dim)
image_pil, mask_pil = max_dimension_resize(image_pil, mask_pil, max_dim)
image_np = pil_to_np_array(image_pil)
mask_np = pil_to_np_array(mask_pil)
print('Visualizing mask overlap...')
visualize_sample(image_np, mask_np, image_np * mask_np, nrow = 3, size_factor = 10)
return image_np, mask_np