6
6
from prettytable import PrettyTable
7
7
import progressbar
8
8
import maxflow
9
- from maxflow . examples_utils import plot_graph_2d
9
+ from examples_utils import plot_graph_2d
10
10
11
11
12
12
class GMM :
13
13
def __init__ (self , prob ):
14
+ """
15
+ Arguments:
16
+ prob: List with mu, sigma and mix.
17
+ """
14
18
self .mu = prob [0 ]
15
19
self .sigma = prob [1 ]
16
20
self .mix = prob [2 ]
@@ -29,15 +33,29 @@ def prob(self, x):
29
33
30
34
31
35
def unaryenergy (fg , bg , img ):
36
+ """
37
+ Arguments:
38
+ fg: Foreground GMM
39
+ bg: Background GMM
40
+ img: Image array (RGB)
41
+ """
32
42
logging .info ("Calculate unary energy functions." )
33
43
ysize , xsize , _ = img .shape
34
- unary = np .empty ((ysize , xsize , 2 ))
44
+ unary = np .empty ((ysize , xsize , 3 ))
35
45
36
46
with progressbar .ProgressBar (max_value = xsize , redirect_stdout = True ) as progress :
37
47
for x in range (xsize ):
38
48
for y in range (ysize ):
49
+ # Background
39
50
unary [y , x , 0 ] = - np .log (bg .prob (img [y , x ]))
51
+ # Foreground
40
52
unary [y , x , 1 ] = - np .log (fg .prob (img [y , x ]))
53
+
54
+ # Assign labels
55
+ if unary [y , x , 0 ] < unary [y , x , 1 ]:
56
+ unary [y , x , 2 ] = 0 # Background
57
+ else :
58
+ unary [y , x , 2 ] = 1 # Foreground
41
59
progress .update (x )
42
60
43
61
return unary
@@ -55,9 +73,9 @@ def pairwiseenergy(unaries, img):
55
73
for x in range (xsize ):
56
74
for y in range (ysize ):
57
75
pass
58
- #if unaries[y, x, 3] ==
76
+ # if unaries[y, x, 3] ==
59
77
# np.exp(-l * np.linalg.norm())
60
- #else:
78
+ # else:
61
79
# pass
62
80
progress .update (x )
63
81
@@ -87,22 +105,89 @@ def main():
87
105
img = misc .imread ("banana3.png" )
88
106
img = np .array (img , dtype = np .float64 )
89
107
90
- unaries = unaryenergy (fg , bg , img )
108
+ # unaries = unaryenergy(fg, bg, img)
109
+ unaries = np .load ("unary.npy" )
110
+ # np.save("unary", unaries)
91
111
92
112
ysize , xsize , _ = img .shape
93
113
94
- g = maxflow .GraphFloat ()
95
- nodeids = g .add_grid_nodes ((5 , 5 ))
96
- structure = np .array ([[0 , 1 , 0 ],
97
- [1 , 0 , 1 ],
98
- [0 , 1 , 0 ]])
114
+ def pairwiseenergy (y1 , x1 , y2 , x2 ):
115
+ nonlocal unaries
116
+ nonlocal img
117
+
118
+ l = 0.005
119
+ w = 110
99
120
100
- g .add_grid_edges (nodeids , structure = structure , symmetric = True )
121
+ if unaries [y1 , x1 , 2 ] != unaries [y2 , x2 , 2 ]:
122
+ delta = 1
123
+ else :
124
+ delta = 0
101
125
102
- # Source and sink
103
- g .add_grid_tedges (nodeids , 0 , 2 )
126
+ # Not same label
127
+ energy = w * np .exp (- l * np .power (np .linalg .norm (img [y1 , x1 ] - img [y2 , x2 ], ord = 2 ), 2 )) * delta
128
+ #a = (np.linalg.norm(img[y1, x1] - img[y2, x2], ord=2))
129
+ return energy
104
130
105
- plot_graph_2d (g , nodeids .shape )
131
+ g = maxflow .GraphFloat ()
132
+ nodeids = g .add_grid_nodes ((ysize , xsize ))
133
+
134
+ for y in range (ysize - 1 ):
135
+ for x in range (xsize - 1 ):
136
+ e_right = pairwiseenergy (y , x , y , x + 1 )
137
+ if e_right >= 10.0 :
138
+ a = 2
139
+ g .add_edge (nodeids [y , x ], nodeids [y , x + 1 ], e_right , e_right )
140
+
141
+ e_down = pairwiseenergy (y , x , y + 1 , x )
142
+ g .add_edge (nodeids [y , x ], nodeids [y + 1 , x ], e_down , e_down )
143
+
144
+ # Source, sink
145
+ g .add_tedge (nodeids [y , x ], unaries [y , x , 0 ], unaries [y , x , 1 ])
146
+
147
+ for y in range (ysize - 1 ):
148
+ e_down = pairwiseenergy (y , xsize - 1 , y + 1 , xsize - 1 )
149
+ g .add_edge (nodeids [y , xsize - 1 ], nodeids [y + 1 , xsize - 1 ], e_down , e_down )
150
+
151
+ g .add_tedge (nodeids [y , xsize - 1 ], unaries [y , xsize - 1 , 0 ], unaries [y , xsize - 1 , 1 ])
152
+
153
+ for x in range (xsize - 1 ):
154
+ e_right = pairwiseenergy (ysize - 1 , x , ysize - 1 , x + 1 )
155
+
156
+ g .add_edge (nodeids [ysize - 1 , x ], nodeids [ysize - 1 , x + 1 ], e_right , e_right )
157
+ g .add_tedge (nodeids [ysize - 1 , x ], unaries [ysize - 1 , x + 1 , 0 ], unaries [ysize - 1 , x , 1 ])
158
+ g .add_tedge (nodeids [ysize - 1 , xsize - 1 ], unaries [ysize - 1 , xsize - 1 , 0 ], unaries [ysize - 1 , xsize - 1 , 1 ])
159
+
160
+ # g = maxflow.Graph[float]()
161
+ # nodeids = g.add_grid_nodes((5, 5))
162
+ #
163
+ # # Edges pointing backwards (left, left up and left down) with infinite
164
+ # # capacity
165
+ # structure = np.array([[0, 0, 0],
166
+ # [0, 0, 0],
167
+ # [0, 0, 0]
168
+ # ])
169
+ # g.add_grid_edges(nodeids, structure=structure, symmetric=False)
170
+ #
171
+ # # Set a few arbitrary weights
172
+ # weights = np.array([[100, 110, 120, 130, 140]]).T + np.array([0, 2, 4, 6, 8])
173
+ #
174
+ # print(weights)
175
+ #
176
+ # structure = np.zeros((3, 3))
177
+ # structure[1, 2] = 1
178
+ # g.add_grid_edges(nodeids, structure=structure, weights=weights, symmetric=False)
179
+
180
+ # plot_graph_2d(g, nodeids.shape)
181
+
182
+ logging .info ("Calculate max flow." )
183
+ g .maxflow ()
184
+
185
+ for y in range (ysize ):
186
+ for x in range (xsize ):
187
+ if g .get_segment (nodeids [y , x ]):
188
+ img [y , x ] = np .array ([0 , 0 , 0 ])
189
+ else :
190
+ img [y , x ] = np .array ([255 , 255 , 0 ])
106
191
107
192
logging .info ("Save image." )
108
193
img = img .astype (np .uint8 )
0 commit comments