|
7 | 7 | from tvm.relay.testing import ctx_list
|
8 | 8 | import topi.testing
|
9 | 9 |
|
10 |
| - |
11 | 10 | def test_resize_infer_type():
|
12 | 11 | n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
|
13 | 12 | x = relay.var("x", relay.TensorType((n, c, h, w), "int8"))
|
@@ -307,10 +306,46 @@ def verify_roi_align(data_shape, rois_shape, pooled_size, spatial_scale, sample_
|
307 | 306 | verify_roi_align((4, 4, 16, 16), (32, 5), pooled_size=7, spatial_scale=0.5, sample_ratio=2)
|
308 | 307 |
|
309 | 308 |
|
| 309 | +def test_yolo_reorg_infer_shape(): |
| 310 | + def verify_yolo_reorg(shape, stride, out_shape): |
| 311 | + x = relay.var("x", relay.TensorType(shape, "float32")) |
| 312 | + z = relay.vision.yolo_reorg(x, stride=stride) |
| 313 | + zz = relay.ir_pass.infer_type(z) |
| 314 | + assert "stride=" in z.astext() |
| 315 | + assert zz.checked_type == relay.ty.TensorType(out_shape, "float32") |
| 316 | + |
| 317 | + n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") |
| 318 | + verify_yolo_reorg((n, c, 20, 20), 10, (n, c*10*10, 2, 2)) |
| 319 | + verify_yolo_reorg((n, c, h, w), 2, (n, c*2*2, h/2, w/2)) |
| 320 | + |
| 321 | +def test_yolo_reorg(): |
| 322 | + def verify_yolo_reorg(shape, stride): |
| 323 | + x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") |
| 324 | + ref_res = topi.testing.reorg_python(x_data, stride) |
| 325 | + |
| 326 | + x = relay.var("x", relay.TensorType(shape, "float32")) |
| 327 | + z = relay.vision.yolo_reorg(x, stride=stride) |
| 328 | + zz = relay.ir_pass.infer_type(z) |
| 329 | + assert "stride=" in z.astext() |
| 330 | + assert zz.checked_type == relay.ty.TensorType(ref_res.shape, "float32") |
| 331 | + |
| 332 | + func = relay.Function([x], z) |
| 333 | + |
| 334 | + for target, ctx in ctx_list(): |
| 335 | + for kind in ["graph", "debug"]: |
| 336 | + intrp = relay.create_executor(kind, ctx=ctx, target=target) |
| 337 | + op_res = intrp.evaluate(func)(x_data) |
| 338 | + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) |
| 339 | + |
| 340 | + verify_yolo_reorg((1, 100, 20, 20), 10) |
| 341 | + verify_yolo_reorg((1, 4, 6, 6), 2) |
| 342 | + |
310 | 343 | if __name__ == "__main__":
|
311 | 344 | test_resize_infer_type()
|
312 | 345 | test_resize()
|
313 | 346 | test_multibox_prior()
|
314 | 347 | test_multibox_transform_loc()
|
315 | 348 | test_nms()
|
316 | 349 | test_roi_align()
|
| 350 | + test_yolo_reorg_infer_shape() |
| 351 | + test_yolo_reorg() |
0 commit comments