Skip to content

Commit c59eaf7

Browse files
author
JeGa
committed
Added alpha expansion.
1 parent 4f398eb commit c59eaf7

File tree

1 file changed

+188
-14
lines changed

1 file changed

+188
-14
lines changed

binseg.py

Lines changed: 188 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import scipy.stats
44
import logging
55
import matplotlib.pyplot as plt
6-
from prettytable import PrettyTable
76
import progressbar
87
import maxflow
98
import os
10-
9+
import click
10+
import random
1111

1212
class GMM:
1313
def __init__(self, prob):
@@ -40,14 +40,13 @@ def unaryenergy(fg, bg, img):
4040
img: Image array (RGB)
4141
4242
Returns:
43-
Array A (ysize, xsize, 3) with
43+
Array A (ysize, xsize, 2) with
4444
A[y, x, 0] = energy background
4545
A[y, x, 1] = energy foreground
46-
A[y, x, 2] = label
4746
"""
4847
logging.info("Calculate unary energy functions.")
4948
ysize, xsize, _ = img.shape
50-
unary = np.empty((ysize, xsize, 3))
49+
unary = np.empty((ysize, xsize, 2))
5150

5251
with progressbar.ProgressBar(max_value=xsize, redirect_stdout=True) as progress:
5352
for x in range(xsize):
@@ -56,12 +55,6 @@ def unaryenergy(fg, bg, img):
5655
unary[y, x, 0] = -np.log(bg.prob(img[y, x]))
5756
# Foreground
5857
unary[y, x, 1] = -np.log(fg.prob(img[y, x]))
59-
60-
# Assign labels
61-
if unary[y, x, 0] < unary[y, x, 1]:
62-
unary[y, x, 2] = 0 # Background
63-
else:
64-
unary[y, x, 2] = 1 # Foreground
6558
progress.update(x)
6659

6760
return unary
@@ -155,6 +148,7 @@ def loop(self, edgecallback, nodecallback):
155148
nodecallback(self.getNode(self.ysize - 1, self.xsize - 1), self.g)
156149

157150
def loopnodes(self, callback):
151+
logging.info("Iterate through nodes.")
158152
for y in range(self.ysize):
159153
for x in range(self.xsize):
160154
callback(self.getNode(y, x), self.g)
@@ -193,7 +187,7 @@ def edge(self, node_i, node_j, graph):
193187
C = self.pairwiseenergy(1, 0, xi, xj)
194188
D = self.pairwiseenergy(1, 1, xi, xj)
195189

196-
#energy = self.pairwiseenergy(self.unaries[i[0], i[1], 2],
190+
# energy = self.pairwiseenergy(self.unaries[i[0], i[1], 2],
197191
# self.unaries[j[0], j[1], 2],
198192
# xi, xj)
199193

@@ -250,9 +244,156 @@ def getimg(self):
250244
return self.img
251245

252246

253-
def main():
254-
logging.basicConfig(level=logging.INFO)
247+
class BinsegAlphaexp:
248+
def __init__(self, img, unaries, numlabel):
249+
self.img = img
250+
self.unaries = unaries
251+
252+
# Available labels.
253+
self.label = range(numlabel)
254+
255+
# Initial labeling. All = 0
256+
self.y = np.empty((img.shape[0], img.shape[1]))
257+
258+
self.l = 0.5
259+
self.w = 3.5
260+
261+
# Current alpha
262+
self.alpha = 0
263+
264+
def constructgraph(self):
265+
nodegrid = Nodegrid(self.img.shape[0], self.img.shape[1])
266+
return nodegrid
267+
268+
def edge(self, node_i, node_j, graph):
269+
"""
270+
Callback for pairwise energy.
271+
"""
272+
273+
# Pixel coordinates.
274+
i = [node_i.y, node_i.x]
275+
j = [node_j.y, node_j.x]
276+
277+
# Current label.
278+
i_label = self.y[i[0], i[1]]
279+
j_label = self.y[j[0], j[1]]
280+
281+
# Pixel values
282+
xi = self.img[i[0], i[1]]
283+
xj = self.img[j[0], j[1]]
284+
285+
# Only for nodes that are not alpha.
286+
if i_label == self.alpha:
287+
return
288+
289+
sourceenergy = self.pairwiseenergy(i_label, j_label, xi, xj)
290+
graph.add_tedge(node_i.nodeid, sourceenergy, 0)
291+
292+
if j_label == self.alpha:
293+
return
294+
295+
energy = self.pairwiseenergy(self.alpha, j_label, xi, xj)
296+
energy += self.pairwiseenergy(i_label, self.alpha, xi, xj)
297+
energy -= self.pairwiseenergy(i_label, j_label, xi, xj)
298+
299+
graph.add_edge(node_i.nodeid, node_j.nodeid, energy, energy)
300+
301+
# graph.add_tedge(node_i.nodeid, C, A)
302+
# graph.add_tedge(node_j.nodeid, D, C)
303+
304+
def node_assign(self, node_i, graph):
305+
"""
306+
Callback for assigning unary energy.
307+
"""
308+
309+
# Pixel
310+
y = node_i.y
311+
x = node_i.x
312+
313+
# Label of pixel
314+
label = self.y[y, x]
315+
316+
# Just for nodes that are not alpha.
317+
if label == self.alpha:
318+
return
319+
320+
# Get unary for assigned label.
321+
source = self.unaries[y, x, label]
322+
323+
# Get unary for alpha.
324+
sink = self.unaries[y, x, self.alpha]
325+
326+
graph.add_tedge(node_i.nodeid, source, sink)
327+
328+
def node_segment(self, node_i, graph):
329+
"""
330+
Callback for segmentation.
331+
"""
332+
333+
# Pixel
334+
y = node_i.y
335+
x = node_i.x
336+
337+
# Label of pixel
338+
label = self.y[y, x]
339+
340+
# Just for nodes that are not alpha.
341+
if label == self.alpha:
342+
return
343+
344+
if graph.get_segment(node_i.nodeid) == 0: # Change to alpha
345+
self.y[y, x] = self.alpha
346+
347+
def pairwiseenergy(self, y1, y2, x1, x2):
348+
"""
349+
Returns pairwise energy between node i and node j using the Potts model.
350+
351+
:param y1: Label of i node.
352+
:param y2: Label of j node.
353+
:param x1: Pixel value at node i.
354+
:param x2: Pixel value at node j.
355+
:return: Pairwise energy.
356+
"""
357+
if y1 == y2:
358+
return 0.0
255359

360+
# Not same label
361+
# np.sum(np.power(x1 - x2, 2), 0)
362+
energy = self.w * np.exp(-self.l * np.power(np.linalg.norm(x1 - x2, 2), 2))
363+
return energy
364+
365+
def segment(self):
366+
for i in range(1):
367+
# For each label: Change current label to alpha?
368+
for alpha in self.label:
369+
self.alpha = alpha
370+
371+
logging.info("Alpha: " + str(alpha))
372+
373+
# Get graph with all nodes.
374+
nodegrid = self.constructgraph()
375+
376+
nodegrid.loop(self.edge, self.node_assign)
377+
378+
nodegrid.maxflow()
379+
380+
# Sets label to alpha if it should change.
381+
nodegrid.loopnodes(self.node_segment)
382+
383+
# Assign color.
384+
colors = []
385+
for i in self.label:
386+
colors.append([random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)])
387+
388+
for y in range(self.img.shape[0]):
389+
for x in range(self.img.shape[1]):
390+
self.img[y, x] = colors[int(self.y[y, x])]
391+
392+
def getimg(self):
393+
return self.img
394+
395+
396+
def binseg():
256397
logging.info("Read image.")
257398
img = misc.imread("banana3.png")
258399
img = np.array(img, dtype=np.float64) / 255
@@ -274,5 +415,38 @@ def main():
274415
plt.imsave("banana_out", img)
275416

276417

418+
def alphaexp():
419+
logging.info("Read image.")
420+
img = misc.imread("banana3.png")
421+
img = np.array(img, dtype=np.float64) / 255
422+
423+
if not os.path.exists("unary.npy"):
424+
generateunaries(img)
425+
426+
logging.info("Load unaries.")
427+
unaries = np.load("unary.npy")
428+
429+
binseg = BinsegAlphaexp(img, unaries, 2)
430+
binseg.segment()
431+
432+
logging.info("Save image.")
433+
img = binseg.getimg().astype(np.uint8)
434+
435+
plt.imshow(img)
436+
plt.show()
437+
plt.imsave("banana_out", img)
438+
439+
440+
@click.command()
441+
@click.option('--usealphaexp', is_flag=True)
442+
def main(usealphaexp):
443+
logging.basicConfig(level=logging.INFO)
444+
445+
if usealphaexp:
446+
alphaexp()
447+
else:
448+
binseg()
449+
450+
277451
if __name__ == '__main__':
278452
main()

0 commit comments

Comments
 (0)