Skip to content

Commit 1eacad4

Browse files
author
JeGa
committed
Improved code structure. Todo: Pairwise energy.
1 parent a9812b5 commit 1eacad4

File tree

1 file changed

+139
-107
lines changed

1 file changed

+139
-107
lines changed

binseg.py

Lines changed: 139 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from prettytable import PrettyTable
77
import progressbar
88
import maxflow
9-
from examples_utils import plot_graph_2d
9+
import os
1010

1111

1212
class GMM:
@@ -38,6 +38,12 @@ def unaryenergy(fg, bg, img):
3838
fg: Foreground GMM
3939
bg: Background GMM
4040
img: Image array (RGB)
41+
42+
Returns:
43+
Array A (ysize, xsize, 3) with
44+
A[y, x, 0] = energy background
45+
A[y, x, 1] = energy foreground
46+
A[y, x, 2] = label
4147
"""
4248
logging.info("Calculate unary energy functions.")
4349
ysize, xsize, _ = img.shape
@@ -61,27 +67,6 @@ def unaryenergy(fg, bg, img):
6167
return unary
6268

6369

64-
def pairwiseenergy(unaries, img):
65-
logging.info("Calculate pairwise energy functions.")
66-
ysize, xsize, _ = img.shape
67-
pairwise = np.array()
68-
69-
l = 0.5
70-
w = 1.0
71-
72-
with progressbar.ProgressBar(max_value=xsize, redirect_stdout=True) as progress:
73-
for x in range(xsize):
74-
for y in range(ysize):
75-
pass
76-
# if unaries[y, x, 3] ==
77-
# np.exp(-l * np.linalg.norm())
78-
# else:
79-
# pass
80-
progress.update(x)
81-
82-
return pairwise
83-
84-
8570
def readprobfile(filename):
8671
file = np.load(filename)
8772
mu = file["arr_0"]
@@ -91,106 +76,153 @@ def readprobfile(filename):
9176
return [mu, sigma, mix]
9277

9378

94-
def main():
95-
logging.basicConfig(level=logging.INFO)
96-
79+
def generateunaries(img):
9780
logging.info("Read GMM for unaries.")
9881
probf = readprobfile("prob_foreground.npz")
9982
probb = readprobfile("prob_background.npz")
10083

10184
fg = GMM(probf)
10285
bg = GMM(probb)
10386

104-
logging.info("Read image.")
105-
img = misc.imread("banana3.png")
106-
img = np.array(img, dtype=np.float64)
87+
logging.info("Generate unaries.")
88+
unaries = unaryenergy(fg, bg, img)
89+
np.save("unary", unaries)
10790

108-
# unaries = unaryenergy(fg, bg, img)
109-
unaries = np.load("unary.npy")
110-
# np.save("unary", unaries)
11191

112-
ysize, xsize, _ = img.shape
92+
class Nodegrid:
93+
def __init__(self, ysize, xsize):
94+
self.g = maxflow.GraphFloat()
95+
self.nodeids = self.g.add_grid_nodes((ysize, xsize))
11396

114-
def pairwiseenergy(y1, x1, y2, x2):
115-
nonlocal unaries
116-
nonlocal img
97+
self.ysize = ysize
98+
self.xsize = xsize
99+
100+
def loop(self, edgecallback, nodecallback):
101+
"""
102+
Loops over the grid of nodes. Two callback functions are required:
103+
104+
:param edgecallback: Called for every edge with node id (i, j)
105+
as parameter.
106+
:param nodecallback: called for every node with node id as parameter.
107+
"""
108+
logging.info("Iterate through graph.")
109+
110+
for y in range(self.ysize - 1):
111+
for x in range(self.xsize - 1):
112+
# Right edge
113+
edgecallback(self.nodeids[y, x],
114+
self.nodeids[y, x + 1], self.g)
115+
116+
# Down edge
117+
edgecallback(self.nodeids[y, x],
118+
self.nodeids[y + 1, x], self.g)
119+
120+
# Node
121+
nodecallback(self.nodeids[y, x], y, x, self.g)
122+
123+
# Last column
124+
for y in range(self.ysize - 1):
125+
# Down edge
126+
edgecallback(self.nodeids[y, self.xsize - 1],
127+
self.nodeids[y + 1, self.xsize - 1], self.g)
128+
129+
# Node
130+
nodecallback(self.nodeids[y, self.xsize - 1], y, x, self.g)
131+
132+
# Last row
133+
for x in range(self.xsize - 1):
134+
# Right edge
135+
edgecallback(self.nodeids[self.ysize - 1, x],
136+
self.nodeids[self.ysize - 1, x + 1], self.g)
137+
138+
# Node
139+
nodecallback(self.nodeids[self.ysize - 1, x], y, x, self.g)
140+
141+
# Last node
142+
nodecallback(self.nodeids[self.ysize - 1, self.xsize - 1], y, x, self.g)
117143

118-
l = 0.005
119-
w = 110
144+
def loopnodes(self, callback):
145+
for y in range(self.ysize):
146+
for x in range(self.xsize):
147+
callback(self.nodeids[y, x], y, x, self.g)
120148

121-
if unaries[y1, x1, 2] != unaries[y2, x2, 2]:
122-
delta = 1
149+
def maxflow(self):
150+
logging.info("Calculate max flow.")
151+
self.g.maxflow()
152+
153+
154+
class Binseg:
155+
def __init__(self, img, unaries):
156+
self.img = img
157+
self.unaries = unaries
158+
159+
self.nodegrid = Nodegrid(img.shape[0], img.shape[1])
160+
161+
self.l = 0.005
162+
self.w = 100
163+
164+
def edge(self, nodeid_i, nodeid_j, graph):
165+
"""
166+
Callback for pairwise energy.
167+
"""
168+
pass
169+
170+
def node_assign(self, nodeid_i, y, x, graph):
171+
"""
172+
Callback for assigning unary energy.
173+
"""
174+
graph.add_tedge(nodeid_i, self.unaries[y, x, 1], self.unaries[y, x, 0])
175+
176+
def node_segment(self, nodeid_i, y, x, graph):
177+
"""
178+
Callback for segmentation.
179+
"""
180+
if graph.get_segment(nodeid_i) == 0:
181+
self.img[y, x] = np.array([0, 0, 0])
123182
else:
124-
delta = 0
125-
126-
# Not same label
127-
energy = w * np.exp(-l * np.power(np.linalg.norm(img[y1, x1] - img[y2, x2], ord=2), 2)) * delta
128-
#a = (np.linalg.norm(img[y1, x1] - img[y2, x2], ord=2))
129-
return energy
130-
131-
g = maxflow.GraphFloat()
132-
nodeids = g.add_grid_nodes((ysize, xsize))
133-
134-
for y in range(ysize - 1):
135-
for x in range(xsize - 1):
136-
e_right = pairwiseenergy(y, x, y, x + 1)
137-
if e_right >= 10.0:
138-
a = 2
139-
g.add_edge(nodeids[y, x], nodeids[y, x + 1], e_right, e_right)
140-
141-
e_down = pairwiseenergy(y, x, y + 1, x)
142-
g.add_edge(nodeids[y, x], nodeids[y + 1, x], e_down, e_down)
143-
144-
# Source, sink
145-
g.add_tedge(nodeids[y, x], unaries[y, x, 0], unaries[y, x, 1])
146-
147-
for y in range(ysize - 1):
148-
e_down = pairwiseenergy(y, xsize - 1, y + 1, xsize - 1)
149-
g.add_edge(nodeids[y, xsize - 1], nodeids[y + 1, xsize - 1], e_down, e_down)
150-
151-
g.add_tedge(nodeids[y, xsize - 1], unaries[y, xsize - 1, 0], unaries[y, xsize - 1, 1])
152-
153-
for x in range(xsize - 1):
154-
e_right = pairwiseenergy(ysize - 1, x, ysize - 1, x + 1)
155-
156-
g.add_edge(nodeids[ysize - 1, x], nodeids[ysize - 1, x + 1], e_right, e_right)
157-
g.add_tedge(nodeids[ysize - 1, x], unaries[ysize - 1, x + 1, 0], unaries[ysize - 1, x, 1])
158-
g.add_tedge(nodeids[ysize - 1, xsize - 1], unaries[ysize - 1, xsize - 1, 0], unaries[ysize - 1, xsize - 1, 1])
159-
160-
# g = maxflow.Graph[float]()
161-
# nodeids = g.add_grid_nodes((5, 5))
162-
#
163-
# # Edges pointing backwards (left, left up and left down) with infinite
164-
# # capacity
165-
# structure = np.array([[0, 0, 0],
166-
# [0, 0, 0],
167-
# [0, 0, 0]
168-
# ])
169-
# g.add_grid_edges(nodeids, structure=structure, symmetric=False)
170-
#
171-
# # Set a few arbitrary weights
172-
# weights = np.array([[100, 110, 120, 130, 140]]).T + np.array([0, 2, 4, 6, 8])
173-
#
174-
# print(weights)
175-
#
176-
# structure = np.zeros((3, 3))
177-
# structure[1, 2] = 1
178-
# g.add_grid_edges(nodeids, structure=structure, weights=weights, symmetric=False)
179-
180-
# plot_graph_2d(g, nodeids.shape)
181-
182-
logging.info("Calculate max flow.")
183-
g.maxflow()
184-
185-
for y in range(ysize):
186-
for x in range(xsize):
187-
if g.get_segment(nodeids[y, x]):
188-
img[y, x] = np.array([0, 0, 0])
189-
else:
190-
img[y, x] = np.array([255, 255, 0])
183+
self.img[y, x] = np.array([255, 255, 0])
184+
185+
def pairwiseenergy(y1, x1, y2, x2):
186+
pass
187+
# if unaries[y1, x1, 2] != unaries[y2, x2, 2]:
188+
# delta = 1
189+
# else:
190+
# delta = 0
191+
#
192+
# # Not same label
193+
# energy = w * np.exp(-l * np.power(np.linalg.norm(img[y1, x1] - img[y2, x2], ord=2), 2)) * delta
194+
# # a = (np.linalg.norm(img[y1, x1] - img[y2, x2], ord=2))
195+
# return energy
196+
197+
def segment(self):
198+
self.nodegrid.loop(self.edge, self.node_assign)
199+
200+
self.nodegrid.maxflow()
201+
202+
self.nodegrid.loopnodes(self.node_segment)
203+
204+
def getimg(self):
205+
return self.img
206+
207+
208+
def main():
209+
logging.basicConfig(level=logging.INFO)
210+
211+
logging.info("Read image.")
212+
img = misc.imread("banana3.png")
213+
img = np.array(img, dtype=np.float64)
214+
215+
if not os.path.exists("unary.npy"):
216+
generateunaries(img)
217+
218+
logging.info("Load unaries.")
219+
unaries = np.load("unary.npy")
220+
221+
binseg = Binseg(img, unaries)
222+
binseg.segment()
191223

192224
logging.info("Save image.")
193-
img = img.astype(np.uint8)
225+
img = binseg.getimg().astype(np.uint8)
194226

195227
plt.imshow(img)
196228
plt.show()

0 commit comments

Comments
 (0)