3
3
import scipy .stats
4
4
import logging
5
5
import matplotlib .pyplot as plt
6
- from prettytable import PrettyTable
7
6
import progressbar
8
7
import maxflow
9
8
import os
10
-
9
+ import click
10
+ import random
11
11
12
12
class GMM :
13
13
def __init__ (self , prob ):
@@ -40,14 +40,13 @@ def unaryenergy(fg, bg, img):
40
40
img: Image array (RGB)
41
41
42
42
Returns:
43
- Array A (ysize, xsize, 3 ) with
43
+ Array A (ysize, xsize, 2 ) with
44
44
A[y, x, 0] = energy background
45
45
A[y, x, 1] = energy foreground
46
- A[y, x, 2] = label
47
46
"""
48
47
logging .info ("Calculate unary energy functions." )
49
48
ysize , xsize , _ = img .shape
50
- unary = np .empty ((ysize , xsize , 3 ))
49
+ unary = np .empty ((ysize , xsize , 2 ))
51
50
52
51
with progressbar .ProgressBar (max_value = xsize , redirect_stdout = True ) as progress :
53
52
for x in range (xsize ):
@@ -56,12 +55,6 @@ def unaryenergy(fg, bg, img):
56
55
unary [y , x , 0 ] = - np .log (bg .prob (img [y , x ]))
57
56
# Foreground
58
57
unary [y , x , 1 ] = - np .log (fg .prob (img [y , x ]))
59
-
60
- # Assign labels
61
- if unary [y , x , 0 ] < unary [y , x , 1 ]:
62
- unary [y , x , 2 ] = 0 # Background
63
- else :
64
- unary [y , x , 2 ] = 1 # Foreground
65
58
progress .update (x )
66
59
67
60
return unary
@@ -155,6 +148,7 @@ def loop(self, edgecallback, nodecallback):
155
148
nodecallback (self .getNode (self .ysize - 1 , self .xsize - 1 ), self .g )
156
149
157
150
def loopnodes (self , callback ):
151
+ logging .info ("Iterate through nodes." )
158
152
for y in range (self .ysize ):
159
153
for x in range (self .xsize ):
160
154
callback (self .getNode (y , x ), self .g )
@@ -193,7 +187,7 @@ def edge(self, node_i, node_j, graph):
193
187
C = self .pairwiseenergy (1 , 0 , xi , xj )
194
188
D = self .pairwiseenergy (1 , 1 , xi , xj )
195
189
196
- #energy = self.pairwiseenergy(self.unaries[i[0], i[1], 2],
190
+ # energy = self.pairwiseenergy(self.unaries[i[0], i[1], 2],
197
191
# self.unaries[j[0], j[1], 2],
198
192
# xi, xj)
199
193
@@ -250,9 +244,156 @@ def getimg(self):
250
244
return self .img
251
245
252
246
253
- def main ():
254
- logging .basicConfig (level = logging .INFO )
247
+ class BinsegAlphaexp :
248
+ def __init__ (self , img , unaries , numlabel ):
249
+ self .img = img
250
+ self .unaries = unaries
251
+
252
+ # Available labels.
253
+ self .label = range (numlabel )
254
+
255
+ # Initial labeling. All = 0
256
+ self .y = np .empty ((img .shape [0 ], img .shape [1 ]))
257
+
258
+ self .l = 0.5
259
+ self .w = 3.5
260
+
261
+ # Current alpha
262
+ self .alpha = 0
263
+
264
+ def constructgraph (self ):
265
+ nodegrid = Nodegrid (self .img .shape [0 ], self .img .shape [1 ])
266
+ return nodegrid
267
+
268
+ def edge (self , node_i , node_j , graph ):
269
+ """
270
+ Callback for pairwise energy.
271
+ """
272
+
273
+ # Pixel coordinates.
274
+ i = [node_i .y , node_i .x ]
275
+ j = [node_j .y , node_j .x ]
276
+
277
+ # Current label.
278
+ i_label = self .y [i [0 ], i [1 ]]
279
+ j_label = self .y [j [0 ], j [1 ]]
280
+
281
+ # Pixel values
282
+ xi = self .img [i [0 ], i [1 ]]
283
+ xj = self .img [j [0 ], j [1 ]]
284
+
285
+ # Only for nodes that are not alpha.
286
+ if i_label == self .alpha :
287
+ return
288
+
289
+ sourceenergy = self .pairwiseenergy (i_label , j_label , xi , xj )
290
+ graph .add_tedge (node_i .nodeid , sourceenergy , 0 )
291
+
292
+ if j_label == self .alpha :
293
+ return
294
+
295
+ energy = self .pairwiseenergy (self .alpha , j_label , xi , xj )
296
+ energy += self .pairwiseenergy (i_label , self .alpha , xi , xj )
297
+ energy -= self .pairwiseenergy (i_label , j_label , xi , xj )
298
+
299
+ graph .add_edge (node_i .nodeid , node_j .nodeid , energy , energy )
300
+
301
+ # graph.add_tedge(node_i.nodeid, C, A)
302
+ # graph.add_tedge(node_j.nodeid, D, C)
303
+
304
+ def node_assign (self , node_i , graph ):
305
+ """
306
+ Callback for assigning unary energy.
307
+ """
308
+
309
+ # Pixel
310
+ y = node_i .y
311
+ x = node_i .x
312
+
313
+ # Label of pixel
314
+ label = self .y [y , x ]
315
+
316
+ # Just for nodes that are not alpha.
317
+ if label == self .alpha :
318
+ return
319
+
320
+ # Get unary for assigned label.
321
+ source = self .unaries [y , x , label ]
322
+
323
+ # Get unary for alpha.
324
+ sink = self .unaries [y , x , self .alpha ]
325
+
326
+ graph .add_tedge (node_i .nodeid , source , sink )
327
+
328
+ def node_segment (self , node_i , graph ):
329
+ """
330
+ Callback for segmentation.
331
+ """
332
+
333
+ # Pixel
334
+ y = node_i .y
335
+ x = node_i .x
336
+
337
+ # Label of pixel
338
+ label = self .y [y , x ]
339
+
340
+ # Just for nodes that are not alpha.
341
+ if label == self .alpha :
342
+ return
343
+
344
+ if graph .get_segment (node_i .nodeid ) == 0 : # Change to alpha
345
+ self .y [y , x ] = self .alpha
346
+
347
+ def pairwiseenergy (self , y1 , y2 , x1 , x2 ):
348
+ """
349
+ Returns pairwise energy between node i and node j using the Potts model.
350
+
351
+ :param y1: Label of i node.
352
+ :param y2: Label of j node.
353
+ :param x1: Pixel value at node i.
354
+ :param x2: Pixel value at node j.
355
+ :return: Pairwise energy.
356
+ """
357
+ if y1 == y2 :
358
+ return 0.0
255
359
360
+ # Not same label
361
+ # np.sum(np.power(x1 - x2, 2), 0)
362
+ energy = self .w * np .exp (- self .l * np .power (np .linalg .norm (x1 - x2 , 2 ), 2 ))
363
+ return energy
364
+
365
+ def segment (self ):
366
+ for i in range (1 ):
367
+ # For each label: Change current label to alpha?
368
+ for alpha in self .label :
369
+ self .alpha = alpha
370
+
371
+ logging .info ("Alpha: " + str (alpha ))
372
+
373
+ # Get graph with all nodes.
374
+ nodegrid = self .constructgraph ()
375
+
376
+ nodegrid .loop (self .edge , self .node_assign )
377
+
378
+ nodegrid .maxflow ()
379
+
380
+ # Sets label to alpha if it should change.
381
+ nodegrid .loopnodes (self .node_segment )
382
+
383
+ # Assign color.
384
+ colors = []
385
+ for i in self .label :
386
+ colors .append ([random .randint (0 , 255 ), random .randint (0 , 255 ), random .randint (0 , 255 )])
387
+
388
+ for y in range (self .img .shape [0 ]):
389
+ for x in range (self .img .shape [1 ]):
390
+ self .img [y , x ] = colors [int (self .y [y , x ])]
391
+
392
+ def getimg (self ):
393
+ return self .img
394
+
395
+
396
+ def binseg ():
256
397
logging .info ("Read image." )
257
398
img = misc .imread ("banana3.png" )
258
399
img = np .array (img , dtype = np .float64 ) / 255
@@ -274,5 +415,38 @@ def main():
274
415
plt .imsave ("banana_out" , img )
275
416
276
417
418
+ def alphaexp ():
419
+ logging .info ("Read image." )
420
+ img = misc .imread ("banana3.png" )
421
+ img = np .array (img , dtype = np .float64 ) / 255
422
+
423
+ if not os .path .exists ("unary.npy" ):
424
+ generateunaries (img )
425
+
426
+ logging .info ("Load unaries." )
427
+ unaries = np .load ("unary.npy" )
428
+
429
+ binseg = BinsegAlphaexp (img , unaries , 2 )
430
+ binseg .segment ()
431
+
432
+ logging .info ("Save image." )
433
+ img = binseg .getimg ().astype (np .uint8 )
434
+
435
+ plt .imshow (img )
436
+ plt .show ()
437
+ plt .imsave ("banana_out" , img )
438
+
439
+
440
+ @click .command ()
441
+ @click .option ('--usealphaexp' , is_flag = True )
442
+ def main (usealphaexp ):
443
+ logging .basicConfig (level = logging .INFO )
444
+
445
+ if usealphaexp :
446
+ alphaexp ()
447
+ else :
448
+ binseg ()
449
+
450
+
277
451
if __name__ == '__main__' :
278
452
main ()
0 commit comments