6
6
from prettytable import PrettyTable
7
7
import progressbar
8
8
import maxflow
9
- from examples_utils import plot_graph_2d
9
+ import os
10
10
11
11
12
12
class GMM :
@@ -38,6 +38,12 @@ def unaryenergy(fg, bg, img):
38
38
fg: Foreground GMM
39
39
bg: Background GMM
40
40
img: Image array (RGB)
41
+
42
+ Returns:
43
+ Array A (ysize, xsize, 3) with
44
+ A[y, x, 0] = energy background
45
+ A[y, x, 1] = energy foreground
46
+ A[y, x, 2] = label
41
47
"""
42
48
logging .info ("Calculate unary energy functions." )
43
49
ysize , xsize , _ = img .shape
@@ -61,27 +67,6 @@ def unaryenergy(fg, bg, img):
61
67
return unary
62
68
63
69
64
- def pairwiseenergy (unaries , img ):
65
- logging .info ("Calculate pairwise energy functions." )
66
- ysize , xsize , _ = img .shape
67
- pairwise = np .array ()
68
-
69
- l = 0.5
70
- w = 1.0
71
-
72
- with progressbar .ProgressBar (max_value = xsize , redirect_stdout = True ) as progress :
73
- for x in range (xsize ):
74
- for y in range (ysize ):
75
- pass
76
- # if unaries[y, x, 3] ==
77
- # np.exp(-l * np.linalg.norm())
78
- # else:
79
- # pass
80
- progress .update (x )
81
-
82
- return pairwise
83
-
84
-
85
70
def readprobfile (filename ):
86
71
file = np .load (filename )
87
72
mu = file ["arr_0" ]
@@ -91,106 +76,153 @@ def readprobfile(filename):
91
76
return [mu , sigma , mix ]
92
77
93
78
94
- def main ():
95
- logging .basicConfig (level = logging .INFO )
96
-
79
+ def generateunaries (img ):
97
80
logging .info ("Read GMM for unaries." )
98
81
probf = readprobfile ("prob_foreground.npz" )
99
82
probb = readprobfile ("prob_background.npz" )
100
83
101
84
fg = GMM (probf )
102
85
bg = GMM (probb )
103
86
104
- logging .info ("Read image ." )
105
- img = misc . imread ( "banana3.png" )
106
- img = np .array ( img , dtype = np . float64 )
87
+ logging .info ("Generate unaries ." )
88
+ unaries = unaryenergy ( fg , bg , img )
89
+ np .save ( "unary" , unaries )
107
90
108
- # unaries = unaryenergy(fg, bg, img)
109
- unaries = np .load ("unary.npy" )
110
- # np.save("unary", unaries)
111
91
112
- ysize , xsize , _ = img .shape
92
+ class Nodegrid :
93
+ def __init__ (self , ysize , xsize ):
94
+ self .g = maxflow .GraphFloat ()
95
+ self .nodeids = self .g .add_grid_nodes ((ysize , xsize ))
113
96
114
- def pairwiseenergy (y1 , x1 , y2 , x2 ):
115
- nonlocal unaries
116
- nonlocal img
97
+ self .ysize = ysize
98
+ self .xsize = xsize
99
+
100
+ def loop (self , edgecallback , nodecallback ):
101
+ """
102
+ Loops over the grid of nodes. Two callback functions are required:
103
+
104
+ :param edgecallback: Called for every edge with node id (i, j)
105
+ as parameter.
106
+ :param nodecallback: called for every node with node id as parameter.
107
+ """
108
+ logging .info ("Iterate through graph." )
109
+
110
+ for y in range (self .ysize - 1 ):
111
+ for x in range (self .xsize - 1 ):
112
+ # Right edge
113
+ edgecallback (self .nodeids [y , x ],
114
+ self .nodeids [y , x + 1 ], self .g )
115
+
116
+ # Down edge
117
+ edgecallback (self .nodeids [y , x ],
118
+ self .nodeids [y + 1 , x ], self .g )
119
+
120
+ # Node
121
+ nodecallback (self .nodeids [y , x ], y , x , self .g )
122
+
123
+ # Last column
124
+ for y in range (self .ysize - 1 ):
125
+ # Down edge
126
+ edgecallback (self .nodeids [y , self .xsize - 1 ],
127
+ self .nodeids [y + 1 , self .xsize - 1 ], self .g )
128
+
129
+ # Node
130
+ nodecallback (self .nodeids [y , self .xsize - 1 ], y , x , self .g )
131
+
132
+ # Last row
133
+ for x in range (self .xsize - 1 ):
134
+ # Right edge
135
+ edgecallback (self .nodeids [self .ysize - 1 , x ],
136
+ self .nodeids [self .ysize - 1 , x + 1 ], self .g )
137
+
138
+ # Node
139
+ nodecallback (self .nodeids [self .ysize - 1 , x ], y , x , self .g )
140
+
141
+ # Last node
142
+ nodecallback (self .nodeids [self .ysize - 1 , self .xsize - 1 ], y , x , self .g )
117
143
118
- l = 0.005
119
- w = 110
144
+ def loopnodes (self , callback ):
145
+ for y in range (self .ysize ):
146
+ for x in range (self .xsize ):
147
+ callback (self .nodeids [y , x ], y , x , self .g )
120
148
121
- if unaries [y1 , x1 , 2 ] != unaries [y2 , x2 , 2 ]:
122
- delta = 1
149
+ def maxflow (self ):
150
+ logging .info ("Calculate max flow." )
151
+ self .g .maxflow ()
152
+
153
+
154
+ class Binseg :
155
+ def __init__ (self , img , unaries ):
156
+ self .img = img
157
+ self .unaries = unaries
158
+
159
+ self .nodegrid = Nodegrid (img .shape [0 ], img .shape [1 ])
160
+
161
+ self .l = 0.005
162
+ self .w = 100
163
+
164
+ def edge (self , nodeid_i , nodeid_j , graph ):
165
+ """
166
+ Callback for pairwise energy.
167
+ """
168
+ pass
169
+
170
+ def node_assign (self , nodeid_i , y , x , graph ):
171
+ """
172
+ Callback for assigning unary energy.
173
+ """
174
+ graph .add_tedge (nodeid_i , self .unaries [y , x , 1 ], self .unaries [y , x , 0 ])
175
+
176
+ def node_segment (self , nodeid_i , y , x , graph ):
177
+ """
178
+ Callback for segmentation.
179
+ """
180
+ if graph .get_segment (nodeid_i ) == 0 :
181
+ self .img [y , x ] = np .array ([0 , 0 , 0 ])
123
182
else :
124
- delta = 0
125
-
126
- # Not same label
127
- energy = w * np .exp (- l * np .power (np .linalg .norm (img [y1 , x1 ] - img [y2 , x2 ], ord = 2 ), 2 )) * delta
128
- #a = (np.linalg.norm(img[y1, x1] - img[y2, x2], ord=2))
129
- return energy
130
-
131
- g = maxflow .GraphFloat ()
132
- nodeids = g .add_grid_nodes ((ysize , xsize ))
133
-
134
- for y in range (ysize - 1 ):
135
- for x in range (xsize - 1 ):
136
- e_right = pairwiseenergy (y , x , y , x + 1 )
137
- if e_right >= 10.0 :
138
- a = 2
139
- g .add_edge (nodeids [y , x ], nodeids [y , x + 1 ], e_right , e_right )
140
-
141
- e_down = pairwiseenergy (y , x , y + 1 , x )
142
- g .add_edge (nodeids [y , x ], nodeids [y + 1 , x ], e_down , e_down )
143
-
144
- # Source, sink
145
- g .add_tedge (nodeids [y , x ], unaries [y , x , 0 ], unaries [y , x , 1 ])
146
-
147
- for y in range (ysize - 1 ):
148
- e_down = pairwiseenergy (y , xsize - 1 , y + 1 , xsize - 1 )
149
- g .add_edge (nodeids [y , xsize - 1 ], nodeids [y + 1 , xsize - 1 ], e_down , e_down )
150
-
151
- g .add_tedge (nodeids [y , xsize - 1 ], unaries [y , xsize - 1 , 0 ], unaries [y , xsize - 1 , 1 ])
152
-
153
- for x in range (xsize - 1 ):
154
- e_right = pairwiseenergy (ysize - 1 , x , ysize - 1 , x + 1 )
155
-
156
- g .add_edge (nodeids [ysize - 1 , x ], nodeids [ysize - 1 , x + 1 ], e_right , e_right )
157
- g .add_tedge (nodeids [ysize - 1 , x ], unaries [ysize - 1 , x + 1 , 0 ], unaries [ysize - 1 , x , 1 ])
158
- g .add_tedge (nodeids [ysize - 1 , xsize - 1 ], unaries [ysize - 1 , xsize - 1 , 0 ], unaries [ysize - 1 , xsize - 1 , 1 ])
159
-
160
- # g = maxflow.Graph[float]()
161
- # nodeids = g.add_grid_nodes((5, 5))
162
- #
163
- # # Edges pointing backwards (left, left up and left down) with infinite
164
- # # capacity
165
- # structure = np.array([[0, 0, 0],
166
- # [0, 0, 0],
167
- # [0, 0, 0]
168
- # ])
169
- # g.add_grid_edges(nodeids, structure=structure, symmetric=False)
170
- #
171
- # # Set a few arbitrary weights
172
- # weights = np.array([[100, 110, 120, 130, 140]]).T + np.array([0, 2, 4, 6, 8])
173
- #
174
- # print(weights)
175
- #
176
- # structure = np.zeros((3, 3))
177
- # structure[1, 2] = 1
178
- # g.add_grid_edges(nodeids, structure=structure, weights=weights, symmetric=False)
179
-
180
- # plot_graph_2d(g, nodeids.shape)
181
-
182
- logging .info ("Calculate max flow." )
183
- g .maxflow ()
184
-
185
- for y in range (ysize ):
186
- for x in range (xsize ):
187
- if g .get_segment (nodeids [y , x ]):
188
- img [y , x ] = np .array ([0 , 0 , 0 ])
189
- else :
190
- img [y , x ] = np .array ([255 , 255 , 0 ])
183
+ self .img [y , x ] = np .array ([255 , 255 , 0 ])
184
+
185
+ def pairwiseenergy (y1 , x1 , y2 , x2 ):
186
+ pass
187
+ # if unaries[y1, x1, 2] != unaries[y2, x2, 2]:
188
+ # delta = 1
189
+ # else:
190
+ # delta = 0
191
+ #
192
+ # # Not same label
193
+ # energy = w * np.exp(-l * np.power(np.linalg.norm(img[y1, x1] - img[y2, x2], ord=2), 2)) * delta
194
+ # # a = (np.linalg.norm(img[y1, x1] - img[y2, x2], ord=2))
195
+ # return energy
196
+
197
+ def segment (self ):
198
+ self .nodegrid .loop (self .edge , self .node_assign )
199
+
200
+ self .nodegrid .maxflow ()
201
+
202
+ self .nodegrid .loopnodes (self .node_segment )
203
+
204
+ def getimg (self ):
205
+ return self .img
206
+
207
+
208
+ def main ():
209
+ logging .basicConfig (level = logging .INFO )
210
+
211
+ logging .info ("Read image." )
212
+ img = misc .imread ("banana3.png" )
213
+ img = np .array (img , dtype = np .float64 )
214
+
215
+ if not os .path .exists ("unary.npy" ):
216
+ generateunaries (img )
217
+
218
+ logging .info ("Load unaries." )
219
+ unaries = np .load ("unary.npy" )
220
+
221
+ binseg = Binseg (img , unaries )
222
+ binseg .segment ()
191
223
192
224
logging .info ("Save image." )
193
- img = img .astype (np .uint8 )
225
+ img = binseg . getimg () .astype (np .uint8 )
194
226
195
227
plt .imshow (img )
196
228
plt .show ()
0 commit comments