Skip to content

Commit

Permalink
Working on ImgMch, 5
Browse files Browse the repository at this point in the history
  • Loading branch information
osmr committed Jul 5, 2019
1 parent 70c0284 commit 62bca0d
Showing 1 changed file with 81 additions and 37 deletions.
118 changes: 81 additions & 37 deletions gluon/gluoncv2/models/superpointnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class SPDetector(HybridBlock):
Confidence threshold.
nms_dist : int, default 4
NMS distance.
use_batch_box_nms : bool, default True
Whether allow to hybridize this block.
hybridizable : bool, default True
Whether allow to hybridize this block.
batch_size : int, default 1
Expand All @@ -77,6 +79,7 @@ def __init__(self,
mid_channels,
conf_thresh=0.015,
nms_dist=4,
use_batch_box_nms=True,
hybridizable=True,
batch_size=1,
in_size=(224, 224),
Expand All @@ -87,6 +90,7 @@ def __init__(self,
assert ((in_size is not None) or not hybridizable)
self.conf_thresh = conf_thresh
self.nms_dist = nms_dist
self.use_batch_box_nms = use_batch_box_nms
self.hybridizable = hybridizable
self.batch_size = batch_size
self.in_size = in_size
Expand All @@ -108,46 +112,86 @@ def hybrid_forward(self, F, x):
heatmap = nodust.transpose(axes=(0, 2, 3, 1))
heatmap = heatmap.reshape(shape=(0, 0, 0, self.reduction, self.reduction))
heatmap = heatmap.transpose(axes=(0, 1, 3, 2, 4))
heatmap = heatmap.reshape(shape=(0, -1))

in_size = self.in_size if self.in_size is not None else (x.shape[2] * self.reduction,
x.shape[3] * self.reduction)
batch_size = self.batch_size if self.batch_size is not None else x.shape[0]

in_nms = F.stack(
heatmap,
F.arange(in_size[0], repeat=in_size[1]).tile((batch_size, 1)),
F.arange(in_size[1]).tile((batch_size, in_size[0])),
F.zeros_like(heatmap) + self.nms_dist,
F.zeros_like(heatmap) + self.nms_dist,
axis=2)
out_nms = F.contrib.box_nms(
data=in_nms,
overlap_thresh=1e-3,
valid_thresh=self.conf_thresh,
coord_start=1,
score_index=0,
id_index=-1,
force_suppress=False,
in_format="center",
out_format="center")

confs = out_nms.slice_axis(axis=2, begin=0, end=1).reshape(shape=(0, -1))
pts = out_nms.slice_axis(axis=2, begin=1, end=3)

if self.hybridizable:
return pts, confs

counts = (confs > 0).sum(axis=1)
confs_list = []
pts_list = []
for i in range(batch_size):
count_i = int(counts[i].asscalar())
confs_i = confs[i].slice_axis(axis=0, begin=0, end=count_i)
pts_i = pts[i].slice_axis(axis=0, begin=0, end=count_i)
confs_list.append(confs_i)
pts_list.append(pts_i)
return pts_list, confs_list
if self.use_batch_box_nms:
heatmap = heatmap.reshape(shape=(0, -1))

in_nms = F.stack(
heatmap,
F.arange(in_size[0], repeat=in_size[1]).tile((batch_size, 1)),
F.arange(in_size[1]).tile((batch_size, in_size[0])),
F.zeros_like(heatmap) + self.nms_dist,
F.zeros_like(heatmap) + self.nms_dist,
axis=2)
out_nms = F.contrib.box_nms(
data=in_nms,
overlap_thresh=1e-3,
valid_thresh=self.conf_thresh,
coord_start=1,
score_index=0,
id_index=-1,
force_suppress=False,
in_format="center",
out_format="center")

confs = out_nms.slice_axis(axis=2, begin=0, end=1).reshape(shape=(0, -1))
pts = out_nms.slice_axis(axis=2, begin=1, end=3)

if self.hybridizable:
return pts, confs

confs_list = []
pts_list = []
counts = (confs > 0).sum(axis=1)
for i in range(batch_size):
count_i = int(counts[i].asscalar())
confs_i = confs[i].slice_axis(axis=0, begin=0, end=count_i)
pts_i = pts[i].slice_axis(axis=0, begin=0, end=count_i)
confs_list.append(confs_i)
pts_list.append(pts_i)
return pts_list, confs_list

else:
heatmap = heatmap.reshape(shape=(0, -3, -3))

in_nms = F.stack(
heatmap,
F.arange(in_size[0], repeat=in_size[1]).tile((batch_size, 1)),
F.arange(in_size[1]).tile((batch_size, in_size[0])),
F.zeros_like(heatmap) + self.nms_dist,
F.zeros_like(heatmap) + self.nms_dist,
axis=2)
out_nms = F.contrib.box_nms(
data=in_nms,
overlap_thresh=1e-3,
valid_thresh=self.conf_thresh,
coord_start=1,
score_index=0,
id_index=-1,
force_suppress=False,
in_format="center",
out_format="center")

confs = out_nms.slice_axis(axis=2, begin=0, end=1).reshape(shape=(0, -1))
pts = out_nms.slice_axis(axis=2, begin=1, end=3)

if self.hybridizable:
return pts, confs

confs_list = []
pts_list = []
counts = (confs > 0).sum(axis=1)
for i in range(batch_size):
count_i = int(counts[i].asscalar())
confs_i = confs[i].slice_axis(axis=0, begin=0, end=count_i)
pts_i = pts[i].slice_axis(axis=0, begin=0, end=count_i)
confs_list.append(confs_i)
pts_list.append(pts_i)
return pts_list, confs_list


class SPDescriptor(HybridBlock):
Expand Down Expand Up @@ -373,7 +417,7 @@ def _test():
hybridizable = False
batch_size = 1
# in_size = (224, 224)
in_size = (200, 400)
in_size = (2000, 4000)

models = [
superpointnet,
Expand All @@ -383,7 +427,7 @@ def _test():

net = model(pretrained=pretrained, hybridizable=hybridizable, batch_size=batch_size, in_size=in_size)

ctx = mx.cpu()
ctx = mx.gpu(0)
if not pretrained:
net.initialize(ctx=ctx)

Expand Down

0 comments on commit 62bca0d

Please sign in to comment.