Skip to content

Commit

Permalink
InpaWorking better with RGB data, also faster than before
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasBizzozzero committed May 23, 2018
1 parent 5937054 commit 6adb11b
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 32 deletions.
10 changes: 5 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
LENA_GRAY_512, LIVINGROOM, MANDRIL_COLOR, MANDRIL_GRAY, PEPPERS_COLOR, PEPPERS_GRAY, PIRATE, WALKBRIDGE, \
WOMAN_BLONDE, WOMAN_DARKHAIR, CASTLE, OUTDOOR
from src.picture_tools.codage import Codage
from src.picture_tools.picture import Picture, show_patch
from src.picture_tools.picture import Picture, show_patch, flatten, unflatten, VALUE_MISSING_PIXEL, VALUE_OUT_OF_BOUNDS, get_patch
from src.usps_tools import test_all_usps_1_vs_all, test_all_usps
from src.linear.linear_regression import LinearRegression, identite, mse_g, l1, l1_g, l2, l2_g, DescenteDeGradient
from src.inpainting import InPainting
Expand All @@ -12,10 +12,10 @@
import numpy as np


PATCH_SIZE = 9
STEP = PATCH_SIZE
PICTURE_PATH = OUTDOOR
CODAGE = Codage.HSV
PATCH_SIZE = 5
STEP = PATCH_SIZE // 2
CODAGE = Codage.RGB
PICTURE_PATH = LENA_COLOR_512


def main():
Expand Down
24 changes: 12 additions & 12 deletions src/inpainting.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def inpaint(self, picture: Picture) -> Picture:
next_pixel_value = self.predict(x - next_pixel[0] + (self.patch_size // 2),
y - next_pixel[1] + (self.patch_size // 2),
dictionary)
picture.pixels[y, x] = next_pixel_value
picture.pixels[x, y] = next_pixel_value
progress_bar.update(progress_bar.value + 1)

progress_bar.finish()
Expand All @@ -73,9 +73,9 @@ def fit(self, dictionary, patch):
self._classifier_value.fit(datax_value, datay_value)

def predict(self, x, y, dictionary):
hue = self._classifier_hue.predict(dictionary[:, y, x, 0].reshape(1, -1))
saturation = self._classifier_saturation.predict(dictionary[:, y, x, 1].reshape(1, -1))
value = self._classifier_value.predict(dictionary[:, y, x, 2].reshape(1, -1))
hue = self._classifier_hue.predict(dictionary[:, x, y, 0].reshape(1, -1))
saturation = self._classifier_saturation.predict(dictionary[:, x, y, 1].reshape(1, -1))
value = self._classifier_value.predict(dictionary[:, x, y, 2].reshape(1, -1))
return np.hstack((hue, saturation, value))

def _get_next_patch(self, picture: Picture, size: int, value_out_of_bounds: np.ndarray = VALUE_OUT_OF_BOUNDS,
Expand All @@ -87,7 +87,7 @@ def _get_next_patch(self, picture: Picture, size: int, value_out_of_bounds: np.n
# for (x, y) in picture.get_patches()}
# return max(patches_priorities.keys(), key=lambda k: patches_priorities[k])

missing_pixels_y, missing_pixels_x, *_ = np.where(picture.pixels == self.value_missing_pixel)
missing_pixels_x, missing_pixels_y, *_ = np.where(picture.pixels == self.value_missing_pixel)
return zip(missing_pixels_x, missing_pixels_y).__next__()

def _preprocess_training_data(self, patch, dictionary):
Expand All @@ -97,13 +97,13 @@ def _preprocess_training_data(self, patch, dictionary):
for x in range(self.patch_size):
for y in range(self.patch_size):
# Si on tombe sur une valeur manquante, on ne l'ajoute évidemment pas (impossible à apprendre)
if np.all(patch[y, x] != self.value_missing_pixel) and np.all(patch[y, x] != self.value_out_of_bounds):
datax_hue.append(dictionary[:, y, x, 0])
datax_saturation.append(dictionary[:, y, x, 1])
datax_value.append(dictionary[:, y, x, 2])
datay_hue.append(patch[y, x, 0])
datay_saturation.append(patch[y, x, 1])
datay_value.append(patch[y, x, 2])
if np.all(patch[x, y] != self.value_missing_pixel) and np.all(patch[x, y] != self.value_out_of_bounds):
datax_hue.append(dictionary[:, x, y, 0])
datax_saturation.append(dictionary[:, x, y, 1])
datax_value.append(dictionary[:, x, y, 2])
datay_hue.append(patch[x, y, 0])
datay_saturation.append(patch[x, y, 1])
datay_value.append(patch[x, y, 2])

return np.array(datax_hue), np.array(datax_saturation), \
np.array(datax_value), np.array(datay_hue), \
Expand Down
30 changes: 15 additions & 15 deletions src/picture_tools/picture.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def add_noise(self, threshold: float = 0.05):
"""
for x in range(self.largeur):
for y in range(self.hauteur):
self.pixels[y, x] = VALUE_MISSING_PIXEL if random.random(
) < threshold else self.pixels[y, x]
self.pixels[x, y] = VALUE_MISSING_PIXEL if random.random(
) < threshold else self.pixels[x, y]

def add_rectangle(self, x: int, y: int, hauteur: int, largeur: int) -> None:
""" Ajoute aléatoirement un rectangle de bruit dans l'image.
Expand All @@ -90,15 +90,15 @@ def add_rectangle(self, x: int, y: int, hauteur: int, largeur: int) -> None:
:param: hauteur, La hauteur du rectangle
:param: largeur, La largeur du rectangle
"""
self.pixels[y:y + hauteur, x:x + largeur] = VALUE_MISSING_PIXEL
self.pixels[x:x + largeur, y:y + hauteur] = VALUE_MISSING_PIXEL

def get_pixel(self, x: int, y: int) -> np.ndarray:
""" Retourne le pixel de l'image aux indexes (x, y).
:param: x, l'index de la colonne.
:param: y, l'index de la ligne.
:return: Le contenu du pixel demandé.
"""
return self.pixels[y, x]
return self.pixels[x, y]

def get_patch(self, x: int, y: int, size: int) -> np.ndarray:
""" Retourne le patch de l'image centré aux indexes (x, y).
Expand All @@ -108,12 +108,12 @@ def get_patch(self, x: int, y: int, size: int) -> np.ndarray:
:return: Le contenu du patch demandé.
"""
if not self.out_of_bounds_patch(x, y, size):
return self.pixels[y - (size // 2):y + (size // 2) + 1, x - (size // 2): x + (size // 2) + 1]
return self.pixels[x - (size // 2):x + (size // 2) + 1, y - (size // 2): y + (size // 2) + 1]
else:
patch = []
for index_y in range(y - (size // 2), y + (size // 2) + 1):
for index_x in range(x - (size // 2), x + (size // 2) + 1):
new_line = []
for index_x in range(x - (size // 2), x + (size // 2) + 1):
for index_y in range(y - (size // 2), y + (size // 2) + 1):
if not self.out_of_bounds(index_x, index_y):
new_line.append(self.get_pixel(index_x, index_y))
else:
Expand Down Expand Up @@ -228,14 +228,14 @@ def unflatten(vector: np.ndarray, size_patch: int) -> np.ndarray:

def out_of_bounds(pixels: np.ndarray, x: int, y: int) -> bool:
""" Check if the pixel located at (x, y) is out of the bounds of the picture. """
return not (0 <= x < pixels.shape[1] and 0 <= y < pixels.shape[0])
return not (0 <= x < pixels.shape[0] and 0 <= y < pixels.shape[1])


def out_of_bounds_patch(pixels: np.ndarray, x: int, y: int, size: int) -> bool:
return (x - (size // 2) <= 0) or \
(x + (size // 2) + 1 < pixels.shape[1]) or \
(x + (size // 2) + 1 < pixels.shape[0]) or \
(y - (size // 2) <= 0) or \
(y + (size // 2) + 1 < pixels.shape[0])
(y + (size // 2) + 1 < pixels.shape[1])


def get_patch(pixels: np.ndarray, x: int, y: int, size: int,
Expand All @@ -247,14 +247,14 @@ def get_patch(pixels: np.ndarray, x: int, y: int, size: int,
:return: Le contenu du patch demandé.
"""
if not out_of_bounds_patch(pixels, x, y, size):
return pixels[y - (size // 2):y + (size // 2) + 1, x - (size // 2): x + (size // 2) + 1]
return pixels[x - (size // 2):x + (size // 2) + 1, y - (size // 2): y + (size // 2) + 1]
else:
patch = []
for index_y in range(y - (size // 2), y + (size // 2) + 1):
for index_x in range(x - (size // 2), x + (size // 2) + 1):
new_line = []
for index_x in range(x - (size // 2), x + (size // 2) + 1):
for index_y in range(y - (size // 2), y + (size // 2) + 1):
if not out_of_bounds(pixels, index_x, index_y):
new_line.append(pixels[index_y, index_x])
new_line.append(pixels[index_x, index_y])
else:
new_line.append(value_out_of_bounds)
patch.append(np.array(new_line))
Expand All @@ -270,7 +270,7 @@ def iter_patch(x: int, y: int, size: int):
def iter_patch_empty(pixels: np.ndarray, x: int, y: int, size: int):
for index_x, index_y in iter_patch(x, y, size):
if not out_of_bounds(pixels, index_x, index_y):
if all(pixels[index_y, index_x] == VALUE_MISSING_PIXEL):
if all(pixels[index_x, index_y] == VALUE_MISSING_PIXEL):
yield index_x, index_y


Expand Down

0 comments on commit 6adb11b

Please sign in to comment.