-
Notifications
You must be signed in to change notification settings - Fork 42
/
Copy pathfpn_psroi_rotatedpooling.py
129 lines (104 loc) · 5.77 KB
/
fpn_psroi_rotatedpooling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# --------------------------------------------------------
# Copyright (c) 2017 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Modified by Haozhi Qi, Yuwen Xiong
# --------------------------------------------------------
import mxnet as mx
import numpy as np
from mxnet.contrib import autograd
import gc
import pdb
class FPNPSROIROTATEDPoolingOperator(mx.operator.CustomOp):
def __init__(self, feat_strides, pooled_height, pooled_width, output_dim):
self.pooled_height = pooled_height
self.pooled_width = pooled_width
self.feat_strides = feat_strides
self.output_dim = output_dim
self.in_grad_hist_list = []
self.num_strides = len(self.feat_strides)
self.roi_pool = [None for _ in range(self.num_strides)]
self.feat_idx = [None for _ in range(self.num_strides)]
def forward(self, is_train, req, in_data, out_data, aux):
rois = in_data[-1].asnumpy()
# w = rois[:, 3] - rois[:, 1] + 1
# h = rois[:, 4] - rois[:, 2] + 1
w = np.maximum(rois[:, 3], 1)
h = np.maximum(rois[:, 4], 1)
# TODO: carefully scale the w, h
feat_id = np.clip(np.floor(2 + np.log2(np.sqrt(w * h) / 224)), 0, len(self.feat_strides) - 1)
pyramid_idx = []
rois_p = [None for _ in range(self.num_strides)]
for i in range(self.num_strides):
self.feat_idx[i] = np.where(feat_id == i)[0]
if len(self.feat_idx[i]) == 0:
# padding dummy roi
rois_p[i] = np.zeros((1, 6))
pyramid_idx.append(-1)
else:
rois_p[i] = rois[self.feat_idx[i]]
pyramid_idx.append(self.feat_idx[i])
rois_idx = np.argsort(np.hstack(pyramid_idx))[-rois.shape[0]:]
# pdb.set_trace()
if is_train:
for i in range(self.num_strides):
self.in_grad_hist_list.append(mx.nd.zeros_like(in_data[i]))
autograd.mark_variables([in_data[i] for i in range(self.num_strides)], self.in_grad_hist_list)
with autograd.train_section():
for i in range(self.num_strides):
# self.roi_pool[i] = mx.nd.contrib.PSROIROTATEDPooling(data=in_data[i], rois=mx.nd.array(rois_p[i], in_data[i].context), group_size=7, pooled_size=7,
# output_dim=10, spatial_scale=1.0 / self.feat_strides[i])
self.roi_pool[i] = mx.contrib.nd.PSROIROTATEDPooling(data=in_data[i],
rois=mx.nd.array(rois_p[i], in_data[i].context),
group_size=7, pooled_size=7,
output_dim=10,
spatial_scale=1.0 / self.feat_strides[i])
roi_pool = mx.nd.concatenate(self.roi_pool, axis=0)
else:
# during testing, there is no need to record variable, thus saving memory
# pdb.set_trace()
roi_pool = [None for _ in range(self.num_strides)]
for i in range(self.num_strides):
roi_pool[i] = mx.contrib.nd.PSROIROTATEDPooling(data=in_data[i],
rois=mx.nd.array(rois_p[i], in_data[i].context),
group_size=7, pooled_size=7,
output_dim=10,
spatial_scale=1.0 / self.feat_strides[i])
roi_pool = mx.nd.concatenate(roi_pool, axis=0)
roi_pool = mx.nd.take(roi_pool, mx.nd.array(rois_idx, roi_pool.context))
self.assign(out_data[0], req[0], roi_pool)
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
for i in range(len(in_grad)):
self.assign(in_grad[i], req[i], 0)
with autograd.train_section():
for i in range(self.num_strides):
if len(self.feat_idx[i] > 0):
autograd.compute_gradient([mx.nd.take(out_grad[0], mx.nd.array(self.feat_idx[i], out_grad[0].context)) * self.roi_pool[i]])
for i in range(0, self.num_strides):
self.assign(in_grad[i], req[i], self.in_grad_hist_list[i])
gc.collect()
@mx.operator.register('fpn_psroi_rotatedpooling')
class FPNPSROIROTATEDPoolingProp(mx.operator.CustomOpProp):
def __init__(self, feat_strides='(4,8,16,32)', pooled_height='7', pooled_width='7', output_dim='10'):
super(FPNPSROIROTATEDPoolingProp, self).__init__(need_top_grad=True)
self.pooled_height = int(pooled_height)
self.pooled_width = int(pooled_width)
self.feat_strides = np.fromstring(feat_strides[1:-1], dtype=int, sep=',')
self.output_dim = int(output_dim)
self.num_strides = len(self.feat_strides)
def list_arguments(self):
args_list = []
for i in range(self.num_strides):
args_list.append('data_p{}'.format(2 + i))
args_list.append('Rrois')
return args_list
def list_outputs(self):
return ['output']
def infer_shape(self, in_shape):
# pdb.set_trace()
# output_feat_shape = [in_shape[-1][0], in_shape[0][1], self.pooled_height, self.pooled_width]
output_feat_shape = [in_shape[-1][0], self.output_dim, self.pooled_height, self.pooled_width]
return in_shape, [output_feat_shape]
def create_operator(self, ctx, shapes, dtypes):
return FPNPSROIROTATEDPoolingOperator(self.feat_strides, self.pooled_height, self.pooled_width, self.output_dim)
def declare_backward_dependency(self, out_grad, in_data, out_data):
return [out_grad[0]]