Skip to content

Commit 18b8d79

Browse files
authored
Fixed ROI Pooling Output Shape to Consider Multiple ROIs (keras-team#2350) (keras-team#2360)
* Fixed indentation and output shape in roi pooling to consider multiple ROIs * Formatted code
1 parent e170bfe commit 18b8d79

File tree

2 files changed

+53
-16
lines changed

2 files changed

+53
-16
lines changed

keras_cv/layers/object_detection/roi_pool.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,12 @@ def _pool_single_sample(self, args):
112112
feature_map: [H, W, C] float Tensor
113113
rois: [N, 4] float Tensor
114114
Returns:
115-
pooled_feature_map: [target_size, C] float Tensor
115+
pooled_feature_map: [N, target_height, target_width, C] float Tensor
116116
"""
117117
feature_map, rois = args
118118
num_rois = rois.get_shape().as_list()[0]
119119
height, width, channel = feature_map.get_shape().as_list()
120+
regions = []
120121
# TODO (consider vectorize it for better performance)
121122
for n in range(num_rois):
122123
# [4]
@@ -127,7 +128,7 @@ def _pool_single_sample(self, args):
127128
region_width = width * (roi[3] - roi[1])
128129
h_step = region_height / self.target_height
129130
w_step = region_width / self.target_width
130-
regions = []
131+
region_steps = []
131132
for i in range(self.target_height):
132133
for j in range(self.target_width):
133134
height_start = y_start + i * h_step
@@ -147,16 +148,18 @@ def _pool_single_sample(self, args):
147148
1, width_end - width_start
148149
)
149150
# [h_step, w_step, C]
150-
region = feature_map[
151+
region_step = feature_map[
151152
height_start:height_end, width_start:width_end, :
152153
]
153154
# target_height * target_width * [C]
154-
regions.append(tf.reduce_max(region, axis=[0, 1]))
155-
regions = tf.reshape(
156-
tf.stack(regions),
157-
[self.target_height, self.target_width, channel],
155+
region_steps.append(tf.reduce_max(region_step, axis=[0, 1]))
156+
regions.append(
157+
tf.reshape(
158+
tf.stack(region_steps),
159+
[self.target_height, self.target_width, channel],
160+
)
158161
)
159-
return regions
162+
return tf.stack(regions)
160163

161164
def get_config(self):
162165
config = {

keras_cv/layers/object_detection/roi_pool_test.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test_no_quantize(self):
4343
# | 56, 57, 58, 59(max) | 60, 61, 62, 63(max) |
4444
# --------------------------------------------
4545
expected_feature_map = tf.reshape(
46-
tf.constant([27, 31, 59, 63]), [1, 2, 2, 1]
46+
tf.constant([27, 31, 59, 63]), [1, 1, 2, 2, 1]
4747
)
4848
self.assertAllClose(expected_feature_map, pooled_feature_map)
4949

@@ -69,7 +69,7 @@ def test_roi_quantize_y(self):
6969
# | 56, 57, 58(max) | 59, 60, 61, 62(max) | 63 (removed)
7070
# --------------------------------------------
7171
expected_feature_map = tf.reshape(
72-
tf.constant([26, 30, 58, 62]), [1, 2, 2, 1]
72+
tf.constant([26, 30, 58, 62]), [1, 1, 2, 2, 1]
7373
)
7474
self.assertAllClose(expected_feature_map, pooled_feature_map)
7575

@@ -94,7 +94,7 @@ def test_roi_quantize_x(self):
9494
# | 48, 49, 50, 51(max) | 52, 53, 54, 55(max) |
9595
# --------------------------------------------
9696
expected_feature_map = tf.reshape(
97-
tf.constant([19, 23, 51, 55]), [1, 2, 2, 1]
97+
tf.constant([19, 23, 51, 55]), [1, 1, 2, 2, 1]
9898
)
9999
self.assertAllClose(expected_feature_map, pooled_feature_map)
100100

@@ -121,7 +121,7 @@ def test_roi_quantize_h(self):
121121
# | 56, 57, 58, 59(max) | 60, 61, 62, 63(max) |
122122
# --------------------------------------------
123123
expected_feature_map = tf.reshape(
124-
tf.constant([11, 15, 35, 39, 59, 63]), [1, 3, 2, 1]
124+
tf.constant([11, 15, 35, 39, 59, 63]), [1, 1, 3, 2, 1]
125125
)
126126
self.assertAllClose(expected_feature_map, pooled_feature_map)
127127

@@ -147,7 +147,7 @@ def test_roi_quantize_w(self):
147147
# | 56, 57(max) | 58, 59, 60(max) | 61, 62, 63(max) |
148148
# --------------------------------------------
149149
expected_feature_map = tf.reshape(
150-
tf.constant([25, 28, 31, 57, 60, 63]), [1, 2, 3, 1]
150+
tf.constant([25, 28, 31, 57, 60, 63]), [1, 1, 2, 3, 1]
151151
)
152152
self.assertAllClose(expected_feature_map, pooled_feature_map)
153153

@@ -168,7 +168,8 @@ def test_roi_feature_map_height_smaller_than_roi(self):
168168
# ------------------repeated----------------------
169169
# | 12, 13(max) | 14, 15(max) |
170170
expected_feature_map = tf.reshape(
171-
tf.constant([1, 3, 1, 3, 5, 7, 9, 11, 9, 11, 13, 15]), [1, 6, 2, 1]
171+
tf.constant([1, 3, 1, 3, 5, 7, 9, 11, 9, 11, 13, 15]),
172+
[1, 1, 6, 2, 1],
172173
)
173174
self.assertAllClose(expected_feature_map, pooled_feature_map)
174175

@@ -189,7 +190,7 @@ def test_roi_feature_map_width_smaller_than_roi(self):
189190
# --------------------------------------------
190191
expected_feature_map = tf.reshape(
191192
tf.constant([4, 4, 5, 6, 6, 7, 12, 12, 13, 14, 14, 15]),
192-
[1, 2, 6, 1],
193+
[1, 1, 2, 6, 1],
193194
)
194195
self.assertAllClose(expected_feature_map, pooled_feature_map)
195196

@@ -203,10 +204,43 @@ def test_roi_empty(self):
203204
rois = tf.reshape(tf.constant([0.0, 0.0, 0.0, 0.0]), [1, 1, 4])
204205
pooled_feature_map = roi_pooler(feature_map, rois)
205206
# all outputs should be top-left pixel
206-
self.assertAllClose(tf.ones([1, 2, 2, 1]), pooled_feature_map)
207+
self.assertAllClose(tf.ones([1, 1, 2, 2, 1]), pooled_feature_map)
207208

208209
def test_invalid_image_shape(self):
209210
with self.assertRaisesRegex(ValueError, "dynamic shape"):
210211
_ = ROIPooler(
211212
"rel_yxyx", target_size=[2, 2], image_shape=[None, 224, 3]
212213
)
214+
215+
def test_multiple_rois(self):
216+
feature_map = tf.expand_dims(
217+
tf.reshape(tf.range(0, 64), [8, 8, 1]), axis=0
218+
)
219+
220+
roi_pooler = ROIPooler(
221+
bounding_box_format="yxyx",
222+
target_size=[2, 2],
223+
image_shape=[224, 224, 3],
224+
)
225+
rois = tf.constant(
226+
[[[0.0, 0.0, 112.0, 112.0], [0.0, 112.0, 224.0, 224.0]]],
227+
)
228+
229+
pooled_feature_map = roi_pooler(feature_map, rois)
230+
# the maximum value would be at bottom-right at each block, roi sharded
231+
# into 2x2 blocks
232+
# | 0, 1, 2, 3 | 4, 5, 6, 7 |
233+
# | 8, 9, 10, 11 | 12, 13, 14, 15 |
234+
# | 16, 17, 18, 19 | 20, 21, 22, 23 |
235+
# | 24, 25, 26, 27(max) | 28, 29, 30, 31(max) |
236+
# --------------------------------------------
237+
# | 32, 33, 34, 35 | 36, 37, 38, 39 |
238+
# | 40, 41, 42, 43 | 44, 45, 46, 47 |
239+
# | 48, 49, 50, 51 | 52, 53, 54, 55 |
240+
# | 56, 57, 58, 59(max) | 60, 61, 62, 63(max) |
241+
# --------------------------------------------
242+
243+
expected_feature_map = tf.reshape(
244+
tf.constant([9, 11, 25, 27, 29, 31, 61, 63]), [1, 2, 2, 2, 1]
245+
)
246+
self.assertAllClose(expected_feature_map, pooled_feature_map)

0 commit comments

Comments
 (0)