Skip to content

Commit

Permalink
Fixed remapping bug
Browse files Browse the repository at this point in the history
  • Loading branch information
clementfarabet committed Jun 18, 2012
1 parent 6f23498 commit bc1a3d0
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 14 deletions.
26 changes: 18 additions & 8 deletions generic/imgraph.c
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,8 @@ static int imgraph_(segmentmst)(lua_State *L) {
THTensor *src = (THTensor *)luaT_checkudata(L, 2, torch_(Tensor_id));
real thres = lua_tonumber(L, 3);
int minsize = lua_tonumber(L, 4);
int color = lua_toboolean(L, 5);
int adaptivethres = lua_toboolean(L, 5);
int color = lua_toboolean(L, 6);

// dims
long nmaps = src->size[0];
Expand Down Expand Up @@ -397,7 +398,11 @@ static int imgraph_(segmentmst)(lua_State *L) {
if ((edges[i].w <= threshold[a]) && (edges[i].w <= threshold[b])) {
set_join(set, a, b);
a = set_find(set, a);
threshold[a] = edges[i].w + thres/set->elts[a].surface;
if (adaptivethres) {
threshold[a] = edges[i].w + thres/set->elts[a].surface;
} else {
threshold[a] = edges[i].w;
}
}
}
}
Expand Down Expand Up @@ -461,7 +466,8 @@ static int imgraph_(segmentmstsparse)(lua_State *L) {
THTensor *src = (THTensor *)luaT_checkudata(L, 2, torch_(Tensor_id));
real thres = lua_tonumber(L, 3);
int minsize = lua_tonumber(L, 4);
int color = lua_toboolean(L, 5);
int adaptivethres = lua_toboolean(L, 5);
int color = lua_toboolean(L, 6);

// dims
long nedges = src->size[0];
Expand All @@ -476,8 +482,8 @@ static int imgraph_(segmentmstsparse)(lua_State *L) {
edges = (Edge *)calloc(nedges, sizeof(Edge));
int i;
for (i = 0; i < nedges; i++) {
edges[i].a = src_data[3*i + 0];
edges[i].b = src_data[3*i + 1];
edges[i].a = src_data[3*i + 0] - 1;
edges[i].b = src_data[3*i + 1] - 1;
edges[i].w = src_data[3*i + 2];
if (src_data[3*i + 0] > nnodes) nnodes = src_data[3*i + 0];
if (src_data[3*i + 1] > nnodes) nnodes = src_data[3*i + 1];
Expand All @@ -503,7 +509,11 @@ static int imgraph_(segmentmstsparse)(lua_State *L) {
if ((edges[i].w <= threshold[a]) && (edges[i].w <= threshold[b])) {
set_join(set, a, b);
a = set_find(set, a);
threshold[a] = edges[i].w + thres/set->elts[a].surface;
if (adaptivethres) {
threshold[a] = edges[i].w + thres/set->elts[a].surface;
} else {
threshold[a] = edges[i].w;
}
}
}
}
Expand All @@ -522,7 +532,7 @@ static int imgraph_(segmentmstsparse)(lua_State *L) {
THTensor_(fill)(colormap, -1);
THTensor_(resize2d)(dst, nnodes, 3);
for (i = 0; i < nnodes; i++) {
int comp = set_find(set, (i+1));
int comp = set_find(set, i);
real check = THTensor_(get2d)(colormap, comp, 0);
if (check == -1) {
THTensor_(set2d)(colormap, comp, 0, rand0to1());
Expand All @@ -540,7 +550,7 @@ static int imgraph_(segmentmstsparse)(lua_State *L) {
THTensor_(resize1d)(dst, nnodes);
real *dst_data = THTensor_(data)(dst);
for (i = 0; i < nnodes; i++) {
dst_data[i] = set_find(set, (i+1));
dst_data[i] = set_find(set, i) + 1;
}
}

Expand Down
17 changes: 11 additions & 6 deletions init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -530,42 +530,47 @@ end
function imgraph.segmentmst(...)
--get args
local args = {...}
local dest, graph, thres, minsize, colorize
local dest, graph, thres, minsize, colorize, adaptive
local arg2 = torch.typename(args[2])
if arg2 and arg2:find('Tensor') then
dest = args[1]
graph = args[2]
thres = args[3]
minsize = args[4]
colorize = args[5]
adaptive = args[6]
else
graph = args[1]
thres = args[2]
minsize = args[3]
colorize = args[4]
adaptive = args[5]
end

-- defaults
thres = thres or 3
minsize = minsize or 20
colorize = colorize or false
if adaptive == nil then adaptive = true end

-- usage
if not graph then
print(xlua.usage('imgraph.segmentmst',
'segment an edge-weighted graph, using a surface adaptive criterion\n'
.. 'on the min-spanning tree of the graph (see Felzenszwalb et al. 2004)',
'segment an edge-weighted graph, by thresholding its mininum spanning tree\n'
..'(an adaptive threshold is used by default, as in Felzenszwalb et al.)',
nil,
{type='torch.Tensor', help='input graph', req=true},
{type='number', help='base threshold for merging', default=3},
{type='number', help='min size: merge components of smaller size', default=20},
{type='boolean', help='replace components id by random colors', default=false},
{type='boolean', help='use adaptive threshold (Felzenszwalb trick)', default=true},
"",
{type='torch.Tensor', help='destination tensor', req=true},
{type='torch.Tensor', help='input graph', req=true},
{type='number', help='base threshold for merging', default=3},
{type='number', help='min size: merge components of smaller size', default=20},
{type='boolean', help='replace components id by random colors', default=false}))
{type='boolean', help='replace components id by random colors', default=false},
{type='boolean', help='use adaptive threshold (Felzenszwalb trick)', default=true}))
xlua.error('incorrect arguments', 'imgraph.segmentmst')
end

Expand All @@ -574,10 +579,10 @@ function imgraph.segmentmst(...)
local nelts
if graph:nDimension() == 3 then
-- dense image graph (input is a KxHxW graph, K=1/2 connexity, nnodes=H*W)
nelts = graph.imgraph.segmentmst(dest, graph, thres, minsize, colorize)
nelts = graph.imgraph.segmentmst(dest, graph, thres, minsize, adaptive, colorize)
else
-- sparse graph (input is a Nx3 graph, nnodes=N, each entry input[i] is an edge: {node1, node2, weight})
nelts = graph.imgraph.segmentmstsparse(dest, graph, thres, minsize, colorize)
nelts = graph.imgraph.segmentmstsparse(dest, graph, thres, minsize, adaptive, colorize)
end

-- return image
Expand Down

0 comments on commit bc1a3d0

Please sign in to comment.