Skip to content

Commit 38c9eb1

Browse files
codeislife99Ubuntu
andauthored
Fix Bug in Bilinear Interpolation and Add Deform Conv to PT FrontEnd (#7397)
* Fix Bug in Bilinear Interpolation * Add NHWC Tests * clean * Fix Bug and Add Deformable Conv PyTorch for completeness * Add Tensor Utils * Remove stuff * Include vector * PR Comments * Empty Commit for CI Co-authored-by: Ubuntu <ubuntu@ip-172-31-42-251.us-east-2.compute.internal>
1 parent c118b08 commit 38c9eb1

File tree

7 files changed

+257
-88
lines changed

7 files changed

+257
-88
lines changed

include/tvm/topi/detail/tensor_utils.h

Lines changed: 55 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
#include <tvm/te/operation.h>
2828

29+
#include <vector>
2930
namespace tvm {
3031
namespace topi {
3132
namespace detail {
@@ -64,29 +65,36 @@ inline bool is_empty_shape(const Array<PrimExpr>& x) {
6465
*/
6566
inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array<PrimExpr>& indices,
6667
const PrimExpr max_y, const PrimExpr max_x) {
68+
auto batch_id = indices[0];
69+
auto channel_id = indices[1];
6770
auto in_y = indices[2];
68-
auto yf = tvm::floor(in_y);
69-
auto yc = tvm::cast(DataType::Int(32), tvm::ceil(in_y));
70-
71-
auto y0 = tvm::cast(DataType::Int(32), tvm::floor(in_y));
72-
auto y1 = tvm::if_then_else((yc > max_y), max_y, yc);
73-
auto y_lerp = in_y - yf;
74-
7571
auto in_x = indices[3];
76-
auto xf = tvm::floor(in_x);
77-
auto xc = tvm::cast(DataType::Int(32), tvm::ceil(in_x));
78-
79-
auto x0 = tvm::cast(DataType::Int(32), tvm::floor(in_x));
80-
auto x1 = tvm::if_then_else((xc > max_x), max_x, xc);
81-
auto x_lerp = in_x - xf;
8272

83-
auto A = input(indices[0], indices[1], y0, x0);
84-
auto B = input(indices[0], indices[1], y0, x1);
85-
auto C = input(indices[0], indices[1], y1, x0);
86-
auto D = input(indices[0], indices[1], y1, x1);
87-
88-
return A * (1 - x_lerp) * (1 - y_lerp) + B * x_lerp * (1 - y_lerp) + C * (1 - x_lerp) * y_lerp +
89-
D * x_lerp * y_lerp;
73+
auto y_low = tvm::cast(DataType::Int(32), tvm::floor(in_y));
74+
auto y_high = y_low + 1;
75+
76+
auto x_low = tvm::cast(DataType::Int(32), tvm::floor(in_x));
77+
auto x_high = x_low + 1;
78+
79+
auto wy_h = in_y - y_low;
80+
auto wx_h = in_x - x_low;
81+
auto wy_l = 1 - wy_h;
82+
auto wx_l = 1 - wx_h;
83+
84+
PrimExpr val = 0;
85+
std::vector<std::vector<PrimExpr>> wx_xp{{wx_l, x_low}, {wx_h, x_high}};
86+
std::vector<std::vector<PrimExpr>> wy_yp{{wy_l, y_low}, {wy_h, y_high}};
87+
for (auto wx_xp_ele : wx_xp) {
88+
for (auto wy_yp_ele : wy_yp) {
89+
auto wx = wx_xp_ele[0];
90+
auto xp = wx_xp_ele[1];
91+
auto wy = wy_yp_ele[0];
92+
auto yp = wy_yp_ele[1];
93+
val += tvm::if_then_else(0 <= yp && yp <= max_y && 0 <= xp && xp <= max_x,
94+
wx * wy * input(batch_id, channel_id, yp, xp), 0);
95+
}
96+
}
97+
return val;
9098
}
9199

92100
/*!
@@ -101,29 +109,36 @@ inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array<PrimExpr>&
101109
*/
102110
inline PrimExpr bilinear_sample_nhwc(const Tensor& input, const Array<PrimExpr>& indices,
103111
const PrimExpr max_y, const PrimExpr max_x) {
112+
auto batch_id = indices[0];
113+
auto channel_id = indices[3];
104114
auto in_y = indices[1];
105-
auto yf = tvm::floor(in_y);
106-
auto yc = tvm::cast(DataType::Int(32), tvm::ceil(in_y));
107-
108-
auto y0 = tvm::cast(DataType::Int(32), tvm::floor(in_y));
109-
auto y1 = tvm::if_then_else((yc > max_y), max_y, yc);
110-
auto y_lerp = in_y - yf;
111-
112115
auto in_x = indices[2];
113-
auto xf = tvm::floor(in_x);
114-
auto xc = tvm::cast(DataType::Int(32), tvm::ceil(in_x));
115-
116-
auto x0 = tvm::cast(DataType::Int(32), tvm::floor(in_x));
117-
auto x1 = tvm::if_then_else((xc > max_x), max_x, xc);
118-
auto x_lerp = in_x - xf;
119116

120-
auto A = input(indices[0], y0, x0, indices[3]);
121-
auto B = input(indices[0], y0, x1, indices[3]);
122-
auto C = input(indices[0], y1, x0, indices[3]);
123-
auto D = input(indices[0], y1, x1, indices[3]);
124-
125-
return A * (1 - x_lerp) * (1 - y_lerp) + B * x_lerp * (1 - y_lerp) + C * (1 - x_lerp) * y_lerp +
126-
D * x_lerp * y_lerp;
117+
auto y_low = tvm::cast(DataType::Int(32), tvm::floor(in_y));
118+
auto y_high = y_low + 1;
119+
120+
auto x_low = tvm::cast(DataType::Int(32), tvm::floor(in_x));
121+
auto x_high = x_low + 1;
122+
123+
auto wy_h = in_y - y_low;
124+
auto wx_h = in_x - x_low;
125+
auto wy_l = 1 - wy_h;
126+
auto wx_l = 1 - wx_h;
127+
128+
PrimExpr val = 0;
129+
std::vector<std::vector<PrimExpr>> wx_xp{{wx_l, x_low}, {wx_h, x_high}};
130+
std::vector<std::vector<PrimExpr>> wy_yp{{wy_l, y_low}, {wy_h, y_high}};
131+
for (auto wx_xp_ele : wx_xp) {
132+
for (auto wy_yp_ele : wy_yp) {
133+
auto wx = wx_xp_ele[0];
134+
auto xp = wx_xp_ele[1];
135+
auto wy = wy_yp_ele[0];
136+
auto yp = wy_yp_ele[1];
137+
val += tvm::if_then_else(0 <= yp && yp <= max_y && 0 <= xp && xp <= max_x,
138+
wx * wy * input(batch_id, yp, xp, channel_id), 0);
139+
}
140+
}
141+
return val;
127142
}
128143

129144
} // namespace detail

python/tvm/relay/frontend/pytorch.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1928,6 +1928,32 @@ def roi_align(self, inputs, input_types):
19281928

19291929
return _op.vision.roi_align(data, boxes, output_size, spatial_scale, sample_ratio)
19301930

1931+
def deform_conv2d(self, inputs, input_types):
1932+
data = inputs[0]
1933+
weight = inputs[1]
1934+
offset = inputs[2]
1935+
strides = (inputs[4], inputs[5])
1936+
padding = (inputs[6], inputs[7])
1937+
dilation = (inputs[8], inputs[9])
1938+
groups = inputs[10]
1939+
deformable_groups = inputs[11]
1940+
weight_shape = self.infer_shape(weight)
1941+
output_channels = weight_shape[0]
1942+
kernel_size = (weight_shape[2], weight_shape[3])
1943+
1944+
return _op.nn.deformable_conv2d(
1945+
data,
1946+
offset,
1947+
weight,
1948+
strides,
1949+
padding,
1950+
dilation,
1951+
deformable_groups,
1952+
groups,
1953+
output_channels,
1954+
kernel_size,
1955+
)
1956+
19311957
def unbind(self, inputs, input_types):
19321958
data = inputs[0]
19331959
dim = int(inputs[1])
@@ -2292,6 +2318,7 @@ def create_convert_map(self):
22922318
"torchvision::nms": self.nms,
22932319
"aten::logsumexp": self.logsumexp,
22942320
"torchvision::roi_align": self.roi_align,
2321+
"torchvision::deform_conv2d": self.deform_conv2d,
22952322
"aten::unbind": self.unbind,
22962323
"aten::__and__": self.logical_and,
22972324
"aten::logical_and": self.logical_and,

python/tvm/topi/testing/deformable_conv2d_python.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# pylint: disable=invalid-name, too-many-locals, too-many-arguments
1818
"""Deformable convolution in python"""
1919
import itertools
20+
import math
2021
import numpy as np
2122
from tvm.topi.nn.utils import get_pad_tuple
2223

@@ -80,15 +81,22 @@ def deformable_conv2d_nchw_python(
8081
dilation_h, dilation_w = dilation
8182

8283
def _bilinear(n, c, h, w):
83-
low_h, low_w = int(h), int(w)
84-
high_h = min(low_h + 1, in_height - 1)
85-
high_w = min(low_w + 1, in_width - 1)
86-
y_lerp = h - low_h
87-
x_lerp = w - low_w
88-
89-
bottom = (1 - x_lerp) * a_np[n, c, low_h, low_w] + x_lerp * a_np[n, c, low_h, high_w]
90-
top = (1 - x_lerp) * a_np[n, c, high_h, low_w] + x_lerp * a_np[n, c, high_h, high_w]
91-
return (1 - y_lerp) * bottom + y_lerp * top
84+
y_low = int(math.floor(h))
85+
x_low = int(math.floor(w))
86+
y_high = y_low + 1
87+
x_high = x_low + 1
88+
89+
wy_h = h - y_low
90+
wx_h = w - x_low
91+
wy_l = 1 - wy_h
92+
wx_l = 1 - wx_h
93+
94+
val = 0
95+
for wx, xp in zip((wx_l, wx_h), (x_low, x_high)):
96+
for wy, yp in zip((wy_l, wy_h), (y_low, y_high)):
97+
if 0 <= yp < in_height and 0 <= xp < in_width:
98+
val += wx * wy * a_np[n, c, yp, xp]
99+
return val
92100

93101
a_deform = np.zeros((batch, in_channel, out_height, out_width, kernel_h, kernel_w), dtype=dtype)
94102
for n, h, w in itertools.product(range(batch), range(out_height), range(out_width)):

python/tvm/topi/testing/roi_align_python.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,29 @@ def roi_align_nchw_python(a_np, rois_np, pooled_size, spatial_scale, sample_rati
3131
else:
3232
pooled_size_h, pooled_size_w = pooled_size
3333

34-
def _bilinear(b, c, y, x):
34+
def _bilinear(n, c, y, x):
3535
if y < -1 or y > height or x < -1 or x > width:
3636
return 0
37-
y = max(y, 0.0)
38-
x = max(x, 0.0)
39-
y_low = int(y)
40-
x_low = int(x)
4137

42-
y_high = min(y_low + 1, height - 1)
43-
x_high = min(x_low + 1, width - 1)
38+
y = min(max(y, 0), height - 1)
39+
x = min(max(x, 0), width - 1)
4440

45-
ly = y - y_low
46-
lx = x - x_low
47-
return (
48-
(1 - ly) * (1 - lx) * a_np[b, c, y_low, x_low]
49-
+ (1 - ly) * lx * a_np[b, c, y_low, x_high]
50-
+ ly * (1 - lx) * a_np[b, c, y_high, x_low]
51-
+ ly * lx * a_np[b, c, y_high, x_high]
52-
)
41+
y_low = int(math.floor(y))
42+
x_low = int(math.floor(x))
43+
y_high = y_low + 1
44+
x_high = x_low + 1
45+
46+
wy_h = y - y_low
47+
wx_h = x - x_low
48+
wy_l = 1 - wy_h
49+
wx_l = 1 - wx_h
50+
51+
val = 0
52+
for wx, xp in zip((wx_l, wx_h), (x_low, x_high)):
53+
for wy, yp in zip((wy_l, wy_h), (y_low, y_high)):
54+
if 0 <= yp < height and 0 <= xp < width:
55+
val += wx * wy * a_np[n, c, yp, xp]
56+
return val
5357

5458
for i in range(num_roi):
5559
roi = rois_np[i]

python/tvm/topi/vision/rcnn/roi_align.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1):
6060

6161
def _bilinear(i, c, y, x):
6262
outside = tvm.tir.any(y < -1.0, x < -1.0, y > height, x > width)
63-
y = tvm.te.max(y, 0.0)
64-
x = tvm.te.max(x, 0.0)
63+
y = tvm.te.min(tvm.te.max(y, 0.0), height - 1)
64+
x = tvm.te.min(tvm.te.max(x, 0.0), width - 1)
6565
val = bilinear_sample_nchw(data, (i, c, y, x), height - 1, width - 1)
6666
return tvm.tir.if_then_else(outside, 0.0, val)
6767

tests/python/frontend/pytorch/test_forward.py

Lines changed: 83 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,6 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at
216216

217217
assert_shapes_match(baseline_output, compiled_output)
218218
tvm.testing.assert_allclose(baseline_output, compiled_output, rtol=rtol, atol=atol)
219-
220219
del model_name
221220
del baseline_model
222221
torch.cuda.empty_cache()
@@ -924,6 +923,85 @@ def test_forward_conv_transpose():
924923
verify_model(torch.nn.ConvTranspose1d(3, 12, 3, bias=False), input_data=conv1d_input_data)
925924

926925

926+
def test_forward_deform_conv():
927+
torch.set_grad_enabled(False)
928+
929+
def test_run(
930+
batch_size,
931+
in_channels,
932+
out_channels,
933+
in_height,
934+
in_width,
935+
out_height,
936+
out_width,
937+
offset_groups,
938+
kh,
939+
kw,
940+
groups,
941+
):
942+
input_shape = [batch_size, in_channels, in_height, in_width]
943+
offset_shape = [batch_size, 2 * offset_groups * kh * kw, out_height, out_width]
944+
weight_shape = [out_channels, in_channels // groups, kh, kw]
945+
input_data = torch.rand(input_shape)
946+
offset_data = torch.rand(offset_shape)
947+
weight_data = torch.rand(weight_shape)
948+
949+
class DeformConv2D(Module):
950+
def forward(self, *args):
951+
return torchvision.ops.deform_conv2d(args[0], args[1], args[2])
952+
953+
verify_model(
954+
DeformConv2D().float().eval(),
955+
input_data=[input_data, offset_data, weight_data],
956+
rtol=1e-4,
957+
atol=1e-4,
958+
)
959+
960+
batch_size = 4
961+
in_channels, out_channels = 4, 6
962+
in_height, in_width = 10, 10
963+
out_height, out_width = 8, 8
964+
offset_groups = 2
965+
kh, kw = 3, 3
966+
groups = 1
967+
968+
test_run(
969+
batch_size,
970+
in_channels,
971+
out_channels,
972+
in_height,
973+
in_width,
974+
out_height,
975+
out_width,
976+
offset_groups,
977+
kh,
978+
kw,
979+
groups,
980+
)
981+
982+
batch_size = 5
983+
in_channels, out_channels = 4, 6
984+
in_height, in_width = 10, 10
985+
out_height, out_width = 8, 8
986+
offset_groups = 1
987+
kh, kw = 3, 3
988+
groups = 1
989+
990+
test_run(
991+
batch_size,
992+
in_channels,
993+
out_channels,
994+
in_height,
995+
in_width,
996+
out_height,
997+
out_width,
998+
offset_groups,
999+
kh,
1000+
kw,
1001+
groups,
1002+
)
1003+
1004+
9271005
@tvm.testing.uses_gpu
9281006
def test_forward_threshold():
9291007
torch.set_grad_enabled(False)
@@ -1700,7 +1778,7 @@ def test_forward_roi_align():
17001778
"""ROI align"""
17011779
torch.set_grad_enabled(False)
17021780

1703-
class ROIAlgin(Module):
1781+
class ROIAlign(Module):
17041782
def __init__(self, output_sizes, spatial_scale=1.0, sampling_ratio=-1):
17051783
super().__init__()
17061784
self.spatial_scale = spatial_scale
@@ -1721,9 +1799,9 @@ def forward(self, *args):
17211799
in_batch = torch.zeros((35, 1), dtype=torch.float)
17221800
in_boxes = torch.cat([in_batch, in_boxes], dim=1)
17231801

1724-
verify_model(ROIAlgin(7), [in_data, in_boxes])
1725-
verify_model(ROIAlgin((10, 10), 0.7, 5), [in_data, in_boxes])
1726-
verify_model(ROIAlgin(15, 0.9, 3), [in_data, in_boxes])
1802+
verify_model(ROIAlign(7), [in_data, in_boxes])
1803+
verify_model(ROIAlign((10, 10), 0.7, 5), [in_data, in_boxes])
1804+
verify_model(ROIAlign(15, 0.9, 3), [in_data, in_boxes])
17271805

17281806

17291807
@tvm.testing.uses_gpu

0 commit comments

Comments
 (0)