forked from WhuEven/CNN_model_ColorConstancy
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
updates to generate patches on the fly
- Loading branch information
Hien Pham
committed
May 18, 2018
0 parents
commit 0544dbc
Showing
6 changed files
with
355 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Created on Wed May 2 11:44:06 2018 | ||
@author: phamh | ||
""" | ||
|
||
import numpy as np; | ||
import cv2; | ||
|
||
from error_measurement import angular_error; | ||
from glob import glob; | ||
from white_balancing import white_balancing; | ||
from progress_timer import progress_timer; | ||
|
||
casted_list = glob('Color-casted\\*.png'); | ||
corrected_list = glob('Corrected\\*.png'); | ||
gt_list = glob('Ground-truth\\*.png'); | ||
ang_error = []; | ||
patch_size = (64, 64); | ||
|
||
pt = progress_timer(n_iter = len(casted_list), description = 'Processing :'); | ||
|
||
for i in range (0, len(casted_list)): | ||
|
||
img = cv2.imread(casted_list[i]); | ||
image_name = casted_list[i].replace('Color-casted\\', ''); | ||
image_name = image_name.replace('.png', ''); | ||
|
||
img_gt = cv2.imread(gt_list[i]); | ||
|
||
img_cor = white_balancing(img, image_name, patch_size); | ||
cv2.imwrite('Corrected\\' + image_name + '_cor.png', img_cor); | ||
|
||
error = angular_error(img_gt, img_cor, 'mean'); | ||
ang_error.append(error); | ||
|
||
pt.update(); | ||
|
||
avg_ang_error = np.mean(ang_error); | ||
pt.finish(); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Created on Mon Apr 9 14:30:04 2018 | ||
@author: phamh | ||
""" | ||
import numpy as np; | ||
|
||
def rgb2xyz(B_channel, G_channel, R_channel, color_space): | ||
B_channel = B_channel/255; G_channel = G_channel/255; R_channel = R_channel/255; | ||
if color_space == 'AdobeRGB': | ||
X = 0.5767*R_channel + 0.1856*G_channel + 0.1882*B_channel; | ||
Y = 0.2974*R_channel + 0.6273*G_channel + 0.0753*B_channel; | ||
Z = 0.0270*R_channel + 0.0707*G_channel + 0.9911*B_channel; | ||
return (X, Y, Z); | ||
elif color_space == 'sRGB': | ||
X = 0.4124*R_channel + 0.3576*G_channel + 0.1805*B_channel; | ||
Y = 0.2126*R_channel + 0.7152*G_channel + 0.0722*B_channel; | ||
Z = 0.0193*R_channel + 0.1192*G_channel + 0.9505*B_channel; | ||
return (X, Y, Z); | ||
|
||
def xyz2rgb(X, Y, Z, color_space): | ||
if color_space == 'AdobeRGB': | ||
B_channel = 0.0134*X + -0.1184*Y + 1.0154*Z; | ||
G_channel = -0.9693*X + 1.8760*Y + 0.0416*Z; | ||
R_channel = 2.0414*X + -0.5649*Y + -0.3447*Z; | ||
B_channel = B_channel*255; G_channel = G_channel*255; R_channel = R_channel*255; | ||
return (B_channel, G_channel, R_channel); | ||
elif color_space == 'sRGB': | ||
B_channel = 0.0556*X + -0.2040*Y + 1.0572*Z; | ||
G_channel = -0.9693*X + 1.8760*Y + 0.0416*Z; | ||
R_channel = 3.2405*X + -1.5371*Y + -0.4985*Z; | ||
B_channel = B_channel*255; G_channel = G_channel*255; R_channel = R_channel*255; | ||
return (B_channel, G_channel, R_channel); | ||
|
||
def f_function(i, i_ref): | ||
r = i/i_ref; | ||
m, n = r.shape; | ||
f = np.zeros((m, n)); | ||
x, y = np.where(r > 0.008856); | ||
f[x, y] = r[x, y]**1/3; | ||
x, y = np.where(r <= 0.008856); | ||
f[x, y] = (7.787*r[x, y] + 16/116); | ||
return f; | ||
|
||
def xyz2lab(X, Y, Z, illuminant): | ||
if illuminant == 'D65': | ||
Xn = 95.047; | ||
Yn = 100; | ||
Zn = 108.883; | ||
elif illuminant == 'D50': | ||
Xn = 96.6797; | ||
Yn = 100; | ||
Zn = 82.5188; | ||
|
||
L = 116*f_function(Y, Yn) - 16; | ||
a = 500*f_function(X, Xn) + -500*f_function(Y, Yn); | ||
b = 200*f_function(Y, Yn) + -200*f_function(Z, Zn); | ||
|
||
return (L, a, b); | ||
|
||
def lab2xyz(L, a, b, illuminant): | ||
if illuminant == 'D65': | ||
Xn = 95.047; | ||
Yn = 100; | ||
Zn = 108.883; | ||
elif illuminant == 'D50': | ||
Xn = 96.6797; | ||
Yn = 100; | ||
Zn = 82.5188; | ||
|
||
if L > 7.9996: | ||
X = Xn*((L/116 + a/500 + 16/116)**3); | ||
Y = Yn*((L/116 + 16/116)**3); | ||
Z = Zn*((L/116 - b/200 + 16/116)**3); | ||
elif L <= 7.9996: | ||
X = Xn*(1/7.787)*(L/116 + a/500); | ||
Y = Yn*(1/7.787)*(L/116); | ||
Z = Zn*(1/7.787)*(L/116 - b/200); | ||
|
||
return (X, Y, Z); | ||
|
||
def rgb2lab(B_channel, G_channel, R_channel, color_space, illuminant): | ||
X, Y, Z = rgb2xyz(B_channel, G_channel, R_channel, color_space); | ||
L, a, b = xyz2lab(X, Y, Z, illuminant); | ||
return (L, a, b); | ||
|
||
def lab2rgb(L, a, b, color_space, illuminant): | ||
X, Y, Z = lab2xyz(L, a, b, illuminant); | ||
B_channel, G_channel, R_channel = xyz2rgb(X, Y, Z, color_space); | ||
return (B_channel, G_channel, R_channel); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Created on Thu Apr 26 14:52:45 2018 | ||
@author: phamh | ||
""" | ||
|
||
import numpy as np; | ||
import h5py; | ||
import cv2; | ||
|
||
from glob import glob; | ||
from tensorflow.python.keras.models import model_from_json; | ||
|
||
def create_local_illum_map(image, patch_size): | ||
|
||
#load cc_model | ||
json_file = open('cc_model.json', 'r'); | ||
loaded_model_json = json_file.read(); | ||
json_file.close(); | ||
cc_model = model_from_json(loaded_model_json); | ||
|
||
#load weights into cc_model | ||
cc_model.load_weights("cc_model.h5"); | ||
cc_model.compile(optimizer = 'Adam', loss = 'cosine_proximity', metrics = ['acc']); | ||
|
||
n_r, n_c, _ = image.shape; | ||
patch_r, patch_c = patch_size; | ||
|
||
total_patch = int(((n_r - n_r%patch_r)/patch_r)*((n_c - n_c%patch_c)/patch_c)); | ||
|
||
img_resize = cv2.resize(image, ((n_r - n_r%patch_r), (n_c - n_c%patch_c))); | ||
img_reshape = np.reshape(img_resize, (int(patch_r), -1, 3)); | ||
|
||
illum_estimate = []; | ||
|
||
for i in range(total_patch): | ||
|
||
img_patch = img_reshape[0:patch_r, i*patch_c:(i+1)*patch_c]; | ||
|
||
patch_cvt = cv2.cvtColor(img_patch, cv2.COLOR_BGR2RGB); | ||
patch_cvt = np.expand_dims(patch_cvt, axis=0); | ||
illum_estimate.append(cc_model.predict(patch_cvt)); | ||
|
||
illum_estimate = np.asarray(illum_estimate); | ||
illum_map = np.reshape(illum_estimate, (int((n_r - n_r%patch_r)/patch_r), int((n_c - n_c%patch_c)/patch_c), 3)); | ||
|
||
return illum_map; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Created on Mon Apr 9 10:23:49 2018 | ||
@author: phamh | ||
""" | ||
import numpy as np; | ||
import cv2; | ||
from color_space_conversion import rgb2lab; | ||
|
||
def angular_error(ground_truth_image, corrected_image, measurement_type): | ||
B_gt, G_gt, R_gt = cv2.split(ground_truth_image); | ||
B_cor, G_cor, R_cor = cv2.split(corrected_image); | ||
|
||
if measurement_type == 'mean': | ||
e_gt = np.array([np.mean(B_gt), np.mean(G_gt), np.mean(R_gt)]); | ||
e_est = np.array([np.mean(B_cor), np.mean(G_cor), np.mean(R_cor)]); | ||
|
||
elif measurement_type == 'median': | ||
e_gt = np.array([np.median(B_gt), np.median(G_gt), np.median(R_gt)]); | ||
e_est = np.array([np.median(B_cor), np.median(G_cor), np.median(R_cor)]); | ||
|
||
error_cos = np.dot(e_gt, e_est)/(np.linalg.norm(e_gt)*np.linalg.norm(e_est)); | ||
e_angular = np.degrees(np.arccos(error_cos)); | ||
return e_angular; | ||
|
||
def Euclidean_distance(ground_truth_image, corrected_image): | ||
B_gt, G_gt, R_gt = cv2.split(ground_truth_image); | ||
B_cor, G_cor, R_cor = cv2.split(corrected_image); | ||
|
||
L_gt, a_gt, b_gt = rgb2lab(B_gt, G_gt, R_gt, 'AdobeRGB', 'D65'); | ||
L_cor, a_cor, b_cor= rgb2lab(B_cor, G_cor, R_cor, 'AdobeRGB', 'D65'); | ||
|
||
delta = np.sqrt(np.square(L_gt - L_cor) + np.square(a_gt - a_cor) + np.square(b_gt - b_cor)); | ||
average_Euclidean_dist = np.mean(delta); | ||
|
||
return average_Euclidean_dist; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Created on Thu Apr 26 11:34:51 2018 | ||
@author: phamh | ||
""" | ||
|
||
import numpy as np; | ||
import cv2; | ||
import os; | ||
|
||
from create_local_illum_map import create_local_illum_map; | ||
#from sklearn.svm import SVR | ||
|
||
|
||
def pool_function(illum_map, pool_size, stride, mode): | ||
|
||
n_h, n_w, n_c = illum_map.shape; | ||
pool_h, pool_w = pool_size; | ||
stride_h, stride_w = stride; | ||
|
||
# Define the dimensions of the output | ||
h_out = int(1 + (n_h - pool_h) / stride_h); | ||
w_out = int(1 + (n_w - pool_w) / stride_w); | ||
c_out = n_c; | ||
|
||
pool_out = np.zeros((h_out, w_out, c_out)); | ||
|
||
for h in range(h_out): | ||
for w in range(w_out): | ||
for c in range (c_out): | ||
|
||
vert_start = h*stride_h; | ||
vert_end = h*stride_h + pool_h; | ||
horiz_start = w*stride_w; | ||
horiz_end = w*stride_w + pool_w; | ||
|
||
slice_window = illum_map[vert_start:vert_end, horiz_start:horiz_end, c]; | ||
|
||
if mode == "std": | ||
pool_out[h, w, c] = np.std(slice_window); | ||
elif mode == "average": | ||
pool_out[h, w, c] = np.mean(slice_window); | ||
|
||
return pool_out; | ||
|
||
def local_2_global(image_name, image, patch_size): | ||
|
||
n_r, n_c, _ = image.shape; | ||
|
||
illum_map = create_local_illum_map(image, patch_size); #patch_size = (32, 32) | ||
|
||
#Gaussian Smooting with 5x5 kernel | ||
illum_map_smoothed = cv2.GaussianBlur(illum_map, (5,5), cv2.BORDER_DEFAULT); | ||
|
||
#Perform median pooling | ||
median_pooling = np.zeros((1, 3)); | ||
median_pooling[0, 0] = np.median(illum_map_smoothed[:, :, 0]); | ||
median_pooling[0, 1] = np.median(illum_map_smoothed[:, :, 1]); | ||
median_pooling[0, 2] = np.median(illum_map_smoothed[:, :, 2]); | ||
|
||
#Perform avarage pooling | ||
n_r, n_c, _ = illum_map.shape; | ||
pool_size = (int(n_r/3), int(n_c/3)); | ||
stride = (int(n_r/3), int(n_c/3)); | ||
average_pooling = pool_function(illum_map, pool_size, stride, mode = 'average'); | ||
|
||
average_pooling[0, 0] = np.mean(average_pooling[:, :, 0]); | ||
average_pooling[0, 1] = np.mean(average_pooling[:, :, 1]); | ||
average_pooling[0, 2] = np.mean(average_pooling[:, :, 2]); | ||
|
||
#reshape_average = np.reshape(average_pooling, (-1, )); | ||
#svr_rbf = SVR(kernel='rbf', C=1e3, gamma=0.1) | ||
#illum_global = svr_rbf.fit(np.reshape(median_pooling, (3, 1)), reshape_average).predict(median_pooling); | ||
|
||
#Switch to average pooling if you prefer: | ||
#illum_global = average_pooling; | ||
illum_global = median_pooling; | ||
|
||
return illum_global; | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Created on Thu Apr 26 18:28:36 2018 | ||
@author: phamh | ||
""" | ||
|
||
import numpy as np; | ||
import cv2; | ||
from gamma_linearization import gamma_decode, gamma_encode; | ||
from local_2_global import local_2_global; | ||
|
||
#image name correspond to patches's sub-folder | ||
#Exp: image 0015 corresponds to patches/0015 folder | ||
|
||
def white_balancing(img, image_name, patch_size): | ||
|
||
m, n, _ = img.shape; | ||
B_channel, G_channel, R_channel = cv2.split(img); | ||
Color_space = 'AdobeRGB'; | ||
|
||
#Undo Gamma Correction | ||
B_channel, G_channel, R_channel = gamma_decode(B_channel, G_channel, R_channel, Color_space); | ||
|
||
#Compute local illuminant map then aggregate to global illuminant | ||
illum_global = local_2_global(image_name, img, patch_size); | ||
|
||
#Gain of red and blue channels | ||
alpha = np.max(illum_global)/illum_global[0, 2]; | ||
beta = np.max(illum_global)/illum_global[0, 1]; | ||
ceta = np.max(illum_global)/illum_global[0, 0]; | ||
|
||
#Corrected Image | ||
B_cor = alpha*B_channel; | ||
G_cor = beta*G_channel; | ||
R_cor = ceta*R_channel; | ||
|
||
#Gamma correction to display | ||
B_cor, G_cor, R_cor = gamma_encode(B_cor, G_cor, R_cor, Color_space); | ||
B_cor[B_cor > 255] = 255; | ||
G_cor[G_cor > 255] = 255; | ||
R_cor[R_cor > 255] = 255; | ||
|
||
#Convert to uint8 to display | ||
B_cor = B_cor.astype(np.uint8); | ||
G_cor = G_cor.astype(np.uint8); | ||
R_cor = R_cor.astype(np.uint8); | ||
img_cor = cv2.merge((B_cor, G_cor, R_cor)); | ||
|
||
return img_cor; | ||
|
||
|
||
#cv2.imwrite(image_name + '_cor.png', img_cor); |