Skip to content

Commit 6a4df80

Browse files
readme update
1 parent a3af8c5 commit 6a4df80

27 files changed

+141
-4
lines changed

README.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
# Advanced Deep Learning Practical Course : Realistic Composite Image Creation Using GANs
1+
# Deep Image Compositing : Realistic Composite Image Creation Using GANs
2+
3+
This work has been accepted at 3rd BAI Workshop, NeurIPS Conference 2019. Find the attached poster (here)[https://www.academia.edu/40947406/Deep_Image_Compositing].
24

35
## 1. How To train the model
46

5-
```bash
7+
```
68
usage: main.py [-h] [-d DATASET] [--data-dirpath DATA_DIRPATH]
79
[--n-workers N_WORKERS] [--gpu GPU] [-rs RANDOM_SEED]
810
[-dr DISCRIMINATOR] [-gr GENERATOR] [-d_lr D_LR] [-g_lr G_LR]
@@ -51,13 +53,13 @@ optional arguments:
5153
```
5254

5355
### Sample Command for training
54-
```bash
56+
```
5557
python main.py -b 5 --gpu 1 -d_lr 1e-7 -g_lr 1e-5 -m pix2pix_patch_hue_total -e 1000 -tf tf_logs/pix2pix_patch_hue_total -rl l1 -dr patch -gr skip2
5658
```
5759

5860
## 2. How To generate results on the model
5961

60-
```bash
62+
```
6163
usage: evaluate_models.py [-h] [-d DATASET] [--data-dirpath DATA_DIRPATH]
6264
[--n-workers N_WORKERS] [--gpu GPU] [-rs RANDOM_SEED]
6365
[-dr DISCRIMINATOR] [-gr GENERATOR] [-d_lr D_LR] [-g_lr G_LR]

calc_metrics.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
"""
4+
Main file to train and evaluate the models
5+
"""
6+
7+
from deep_adversarial_network.metrics.metric_eval import *
8+
from PIL import Image
9+
import numpy as np
10+
import os
11+
from skimage.measure import compare_ssim as ssim
12+
13+
14+
def main():
15+
"""
16+
main function that parses the arguments and trains
17+
:param args: arguments related
18+
:return: None
19+
"""
20+
mse_avg_total = 0.
21+
psnr_avg_total = 0.
22+
ssim_avg_total = 0.
23+
vif_avg_total = 0.
24+
25+
image_list = ['4','36','39','70','121','149']
26+
for img in image_list:
27+
28+
gt_image = Image.open(os.getcwd()+'/metrics/gt/gt_'+img+'.png')
29+
gt_image = np.array(gt_image)
30+
test_image = Image.open(os.getcwd() + '/metrics/ours/ht_' + img + '.png')
31+
test_image = np.array(test_image)
32+
33+
mse_avg_iter, psnr_avg_iter = calc_mse_psnr_img(test_image, gt_image)
34+
#tv_avg_iter = get_total_variation(tf.convert_to_tensor(test_image))
35+
ssim_avg_iter = ssim(test_image,gt_image,multichannel=True)
36+
vif_avg_iter = calc_vif_img(test_image, gt_image)
37+
38+
mse_avg_total += mse_avg_iter
39+
psnr_avg_total += psnr_avg_iter
40+
#tv_avg_total += tv_avg_iter
41+
ssim_avg_total += ssim_avg_iter
42+
vif_avg_total += vif_avg_iter
43+
44+
mse_avg_total /= len(image_list)
45+
psnr_avg_total /= len(image_list)
46+
ssim_avg_total /= len(image_list)
47+
vif_avg_total /= len(image_list)
48+
49+
print("MSE : %.3f, PSNR : %.3f, SSIM : %.3f VIF: %3F" % (
50+
mse_avg_total, psnr_avg_total, ssim_avg_total, vif_avg_total))
51+
52+
53+
if __name__ == '__main__':
54+
main()

deep_adversarial_network/metrics/metric_eval.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,67 @@ def calc_vif(img_list1, img_list2):
8888
return vifp/num_imgs
8989

9090

91+
def calc_vif_img(img1, img2):
92+
"""
93+
Calculate VIF for a set of Images
94+
:param img_list1: Test Image
95+
:param img_list2: Ground Truth Image
96+
:return: VIF
97+
"""
98+
99+
sigma_nsq = 2
100+
eps = 1e-10
101+
102+
num = 0.0
103+
den = 0.0
104+
vifp = 0.
105+
ref = img1
106+
dist = img2
107+
for scale in range(1, 5):
108+
109+
N = 2 ** (4 - scale + 1) + 1
110+
sd = N / 5.0
111+
112+
if scale > 1:
113+
ref = scipy.ndimage.gaussian_filter(ref, sd)
114+
dist = scipy.ndimage.gaussian_filter(dist, sd)
115+
ref = ref[::2, ::2]
116+
dist = dist[::2, ::2]
117+
118+
mu1 = scipy.ndimage.gaussian_filter(ref, sd)
119+
mu2 = scipy.ndimage.gaussian_filter(dist, sd)
120+
mu1_sq = mu1 * mu1
121+
mu2_sq = mu2 * mu2
122+
mu1_mu2 = mu1 * mu2
123+
sigma1_sq = scipy.ndimage.gaussian_filter(ref * ref, sd) - mu1_sq
124+
sigma2_sq = scipy.ndimage.gaussian_filter(dist * dist, sd) - mu2_sq
125+
sigma12 = scipy.ndimage.gaussian_filter(ref * dist, sd) - mu1_mu2
126+
127+
sigma1_sq[sigma1_sq < 0] = 0
128+
sigma2_sq[sigma2_sq < 0] = 0
129+
130+
g = sigma12 / (sigma1_sq + eps)
131+
sv_sq = sigma2_sq - g * sigma12
132+
133+
g[sigma1_sq < eps] = 0
134+
sv_sq[sigma1_sq < eps] = sigma2_sq[sigma1_sq < eps]
135+
sigma1_sq[sigma1_sq < eps] = 0
136+
137+
g[sigma2_sq < eps] = 0
138+
sv_sq[sigma2_sq < eps] = 0
139+
140+
sv_sq[g < 0] = sigma2_sq[g < 0]
141+
g[g < 0] = 0
142+
sv_sq[sv_sq <= eps] = eps
143+
144+
num += np.sum(np.log10(1 + g * g * sigma1_sq / (sv_sq + sigma_nsq)))
145+
den += np.sum(np.log10(1 + sigma1_sq / sigma_nsq))
146+
147+
vifp += num / den
148+
149+
return vifp
150+
151+
91152
def calc_mse_psnr(img_list1, img_list2):
92153
"""
93154
Calculate MSE and PSNR for a set of Images
@@ -119,6 +180,26 @@ def calc_mse_psnr(img_list1, img_list2):
119180
return total_mse/num_imgs, total_psnr/num_imgs
120181

121182

183+
def calc_mse_psnr_img(img1, img2):
184+
"""
185+
Calculate MSE and PSNR for a set of Images
186+
:param img_list1: Image List1
187+
:param img_list2: Image List2
188+
:return: MSE, PSNR
189+
"""
190+
mse_val = calc_mse(img1,img2)
191+
if mse_val == 0.:
192+
psnr = 100
193+
else:
194+
# im1 = tf.image.convert_image_dtype(img_list1[i], tf.float32)
195+
# im2 = tf.image.convert_image_dtype(img_list2[i], tf.float32)
196+
# psnr = tf.image.psnr(im1, im2, max_val=255)
197+
psnr = 20 * math.log10(PIXEL_MAX / math.sqrt(mse_val))
198+
199+
200+
return mse_val, psnr
201+
202+
122203
def d_accuracy(real_prob, fake_prob):
123204
"""
124205
Calculate Discriminator Accuracy

metrics/gt/gt_121.png

130 KB
Loading

metrics/gt/gt_149.png

87 KB
Loading

metrics/gt/gt_36.png

110 KB
Loading

metrics/gt/gt_39.png

74.7 KB
Loading

metrics/gt/gt_4.png

68.1 KB
Loading

metrics/gt/gt_70.png

130 KB
Loading

metrics/ht/ht_121.png

125 KB
Loading

0 commit comments

Comments
 (0)