Skip to content

Commit a9812b5

Browse files
author
JeGa
committed
Graph building ...
1 parent 022efce commit a9812b5

File tree

1 file changed

+99
-14
lines changed

1 file changed

+99
-14
lines changed

binseg.py

Lines changed: 99 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,15 @@
66
from prettytable import PrettyTable
77
import progressbar
88
import maxflow
9-
from maxflow.examples_utils import plot_graph_2d
9+
from examples_utils import plot_graph_2d
1010

1111

1212
class GMM:
1313
def __init__(self, prob):
14+
"""
15+
Arguments:
16+
prob: List with mu, sigma and mix.
17+
"""
1418
self.mu = prob[0]
1519
self.sigma = prob[1]
1620
self.mix = prob[2]
@@ -29,15 +33,29 @@ def prob(self, x):
2933

3034

3135
def unaryenergy(fg, bg, img):
36+
"""
37+
Arguments:
38+
fg: Foreground GMM
39+
bg: Background GMM
40+
img: Image array (RGB)
41+
"""
3242
logging.info("Calculate unary energy functions.")
3343
ysize, xsize, _ = img.shape
34-
unary = np.empty((ysize, xsize, 2))
44+
unary = np.empty((ysize, xsize, 3))
3545

3646
with progressbar.ProgressBar(max_value=xsize, redirect_stdout=True) as progress:
3747
for x in range(xsize):
3848
for y in range(ysize):
49+
# Background
3950
unary[y, x, 0] = -np.log(bg.prob(img[y, x]))
51+
# Foreground
4052
unary[y, x, 1] = -np.log(fg.prob(img[y, x]))
53+
54+
# Assign labels
55+
if unary[y, x, 0] < unary[y, x, 1]:
56+
unary[y, x, 2] = 0 # Background
57+
else:
58+
unary[y, x, 2] = 1 # Foreground
4159
progress.update(x)
4260

4361
return unary
@@ -55,9 +73,9 @@ def pairwiseenergy(unaries, img):
5573
for x in range(xsize):
5674
for y in range(ysize):
5775
pass
58-
#if unaries[y, x, 3] ==
76+
# if unaries[y, x, 3] ==
5977
# np.exp(-l * np.linalg.norm())
60-
#else:
78+
# else:
6179
# pass
6280
progress.update(x)
6381

@@ -87,22 +105,89 @@ def main():
87105
img = misc.imread("banana3.png")
88106
img = np.array(img, dtype=np.float64)
89107

90-
unaries = unaryenergy(fg, bg, img)
108+
# unaries = unaryenergy(fg, bg, img)
109+
unaries = np.load("unary.npy")
110+
# np.save("unary", unaries)
91111

92112
ysize, xsize, _ = img.shape
93113

94-
g = maxflow.GraphFloat()
95-
nodeids = g.add_grid_nodes((5, 5))
96-
structure = np.array([[0, 1, 0],
97-
[1, 0, 1],
98-
[0, 1, 0]])
114+
def pairwiseenergy(y1, x1, y2, x2):
115+
nonlocal unaries
116+
nonlocal img
117+
118+
l = 0.005
119+
w = 110
99120

100-
g.add_grid_edges(nodeids, structure=structure, symmetric=True)
121+
if unaries[y1, x1, 2] != unaries[y2, x2, 2]:
122+
delta = 1
123+
else:
124+
delta = 0
101125

102-
# Source and sink
103-
g.add_grid_tedges(nodeids, 0, 2)
126+
# Not same label
127+
energy = w * np.exp(-l * np.power(np.linalg.norm(img[y1, x1] - img[y2, x2], ord=2), 2)) * delta
128+
#a = (np.linalg.norm(img[y1, x1] - img[y2, x2], ord=2))
129+
return energy
104130

105-
plot_graph_2d(g, nodeids.shape)
131+
g = maxflow.GraphFloat()
132+
nodeids = g.add_grid_nodes((ysize, xsize))
133+
134+
for y in range(ysize - 1):
135+
for x in range(xsize - 1):
136+
e_right = pairwiseenergy(y, x, y, x + 1)
137+
if e_right >= 10.0:
138+
a = 2
139+
g.add_edge(nodeids[y, x], nodeids[y, x + 1], e_right, e_right)
140+
141+
e_down = pairwiseenergy(y, x, y + 1, x)
142+
g.add_edge(nodeids[y, x], nodeids[y + 1, x], e_down, e_down)
143+
144+
# Source, sink
145+
g.add_tedge(nodeids[y, x], unaries[y, x, 0], unaries[y, x, 1])
146+
147+
for y in range(ysize - 1):
148+
e_down = pairwiseenergy(y, xsize - 1, y + 1, xsize - 1)
149+
g.add_edge(nodeids[y, xsize - 1], nodeids[y + 1, xsize - 1], e_down, e_down)
150+
151+
g.add_tedge(nodeids[y, xsize - 1], unaries[y, xsize - 1, 0], unaries[y, xsize - 1, 1])
152+
153+
for x in range(xsize - 1):
154+
e_right = pairwiseenergy(ysize - 1, x, ysize - 1, x + 1)
155+
156+
g.add_edge(nodeids[ysize - 1, x], nodeids[ysize - 1, x + 1], e_right, e_right)
157+
g.add_tedge(nodeids[ysize - 1, x], unaries[ysize - 1, x + 1, 0], unaries[ysize - 1, x, 1])
158+
g.add_tedge(nodeids[ysize - 1, xsize - 1], unaries[ysize - 1, xsize - 1, 0], unaries[ysize - 1, xsize - 1, 1])
159+
160+
# g = maxflow.Graph[float]()
161+
# nodeids = g.add_grid_nodes((5, 5))
162+
#
163+
# # Edges pointing backwards (left, left up and left down) with infinite
164+
# # capacity
165+
# structure = np.array([[0, 0, 0],
166+
# [0, 0, 0],
167+
# [0, 0, 0]
168+
# ])
169+
# g.add_grid_edges(nodeids, structure=structure, symmetric=False)
170+
#
171+
# # Set a few arbitrary weights
172+
# weights = np.array([[100, 110, 120, 130, 140]]).T + np.array([0, 2, 4, 6, 8])
173+
#
174+
# print(weights)
175+
#
176+
# structure = np.zeros((3, 3))
177+
# structure[1, 2] = 1
178+
# g.add_grid_edges(nodeids, structure=structure, weights=weights, symmetric=False)
179+
180+
# plot_graph_2d(g, nodeids.shape)
181+
182+
logging.info("Calculate max flow.")
183+
g.maxflow()
184+
185+
for y in range(ysize):
186+
for x in range(xsize):
187+
if g.get_segment(nodeids[y, x]):
188+
img[y, x] = np.array([0, 0, 0])
189+
else:
190+
img[y, x] = np.array([255, 255, 0])
106191

107192
logging.info("Save image.")
108193
img = img.astype(np.uint8)

0 commit comments

Comments
 (0)