26
26
27
27
#include < tvm/te/operation.h>
28
28
29
+ #include < vector>
29
30
namespace tvm {
30
31
namespace topi {
31
32
namespace detail {
@@ -64,29 +65,36 @@ inline bool is_empty_shape(const Array<PrimExpr>& x) {
64
65
*/
65
66
inline PrimExpr bilinear_sample_nchw (const Tensor& input, const Array<PrimExpr>& indices,
66
67
const PrimExpr max_y, const PrimExpr max_x) {
68
+ auto batch_id = indices[0 ];
69
+ auto channel_id = indices[1 ];
67
70
auto in_y = indices[2 ];
68
- auto yf = tvm::floor (in_y);
69
- auto yc = tvm::cast (DataType::Int (32 ), tvm::ceil (in_y));
70
-
71
- auto y0 = tvm::cast (DataType::Int (32 ), tvm::floor (in_y));
72
- auto y1 = tvm::if_then_else ((yc > max_y), max_y, yc);
73
- auto y_lerp = in_y - yf;
74
-
75
71
auto in_x = indices[3 ];
76
- auto xf = tvm::floor (in_x);
77
- auto xc = tvm::cast (DataType::Int (32 ), tvm::ceil (in_x));
78
-
79
- auto x0 = tvm::cast (DataType::Int (32 ), tvm::floor (in_x));
80
- auto x1 = tvm::if_then_else ((xc > max_x), max_x, xc);
81
- auto x_lerp = in_x - xf;
82
72
83
- auto A = input (indices[0 ], indices[1 ], y0, x0);
84
- auto B = input (indices[0 ], indices[1 ], y0, x1);
85
- auto C = input (indices[0 ], indices[1 ], y1, x0);
86
- auto D = input (indices[0 ], indices[1 ], y1, x1);
87
-
88
- return A * (1 - x_lerp) * (1 - y_lerp) + B * x_lerp * (1 - y_lerp) + C * (1 - x_lerp) * y_lerp +
89
- D * x_lerp * y_lerp;
73
+ auto y_low = tvm::cast (DataType::Int (32 ), tvm::floor (in_y));
74
+ auto y_high = y_low + 1 ;
75
+
76
+ auto x_low = tvm::cast (DataType::Int (32 ), tvm::floor (in_x));
77
+ auto x_high = x_low + 1 ;
78
+
79
+ auto wy_h = in_y - y_low;
80
+ auto wx_h = in_x - x_low;
81
+ auto wy_l = 1 - wy_h;
82
+ auto wx_l = 1 - wx_h;
83
+
84
+ PrimExpr val = 0 ;
85
+ std::vector<std::vector<PrimExpr>> wx_xp{{wx_l, x_low}, {wx_h, x_high}};
86
+ std::vector<std::vector<PrimExpr>> wy_yp{{wy_l, y_low}, {wy_h, y_high}};
87
+ for (auto wx_xp_ele : wx_xp) {
88
+ for (auto wy_yp_ele : wy_yp) {
89
+ auto wx = wx_xp_ele[0 ];
90
+ auto xp = wx_xp_ele[1 ];
91
+ auto wy = wy_yp_ele[0 ];
92
+ auto yp = wy_yp_ele[1 ];
93
+ val += tvm::if_then_else (0 <= yp && yp <= max_y && 0 <= xp && xp <= max_x,
94
+ wx * wy * input (batch_id, channel_id, yp, xp), 0 );
95
+ }
96
+ }
97
+ return val;
90
98
}
91
99
92
100
/* !
@@ -101,29 +109,36 @@ inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array<PrimExpr>&
101
109
*/
102
110
inline PrimExpr bilinear_sample_nhwc (const Tensor& input, const Array<PrimExpr>& indices,
103
111
const PrimExpr max_y, const PrimExpr max_x) {
112
+ auto batch_id = indices[0 ];
113
+ auto channel_id = indices[3 ];
104
114
auto in_y = indices[1 ];
105
- auto yf = tvm::floor (in_y);
106
- auto yc = tvm::cast (DataType::Int (32 ), tvm::ceil (in_y));
107
-
108
- auto y0 = tvm::cast (DataType::Int (32 ), tvm::floor (in_y));
109
- auto y1 = tvm::if_then_else ((yc > max_y), max_y, yc);
110
- auto y_lerp = in_y - yf;
111
-
112
115
auto in_x = indices[2 ];
113
- auto xf = tvm::floor (in_x);
114
- auto xc = tvm::cast (DataType::Int (32 ), tvm::ceil (in_x));
115
-
116
- auto x0 = tvm::cast (DataType::Int (32 ), tvm::floor (in_x));
117
- auto x1 = tvm::if_then_else ((xc > max_x), max_x, xc);
118
- auto x_lerp = in_x - xf;
119
116
120
- auto A = input (indices[0 ], y0, x0, indices[3 ]);
121
- auto B = input (indices[0 ], y0, x1, indices[3 ]);
122
- auto C = input (indices[0 ], y1, x0, indices[3 ]);
123
- auto D = input (indices[0 ], y1, x1, indices[3 ]);
124
-
125
- return A * (1 - x_lerp) * (1 - y_lerp) + B * x_lerp * (1 - y_lerp) + C * (1 - x_lerp) * y_lerp +
126
- D * x_lerp * y_lerp;
117
+ auto y_low = tvm::cast (DataType::Int (32 ), tvm::floor (in_y));
118
+ auto y_high = y_low + 1 ;
119
+
120
+ auto x_low = tvm::cast (DataType::Int (32 ), tvm::floor (in_x));
121
+ auto x_high = x_low + 1 ;
122
+
123
+ auto wy_h = in_y - y_low;
124
+ auto wx_h = in_x - x_low;
125
+ auto wy_l = 1 - wy_h;
126
+ auto wx_l = 1 - wx_h;
127
+
128
+ PrimExpr val = 0 ;
129
+ std::vector<std::vector<PrimExpr>> wx_xp{{wx_l, x_low}, {wx_h, x_high}};
130
+ std::vector<std::vector<PrimExpr>> wy_yp{{wy_l, y_low}, {wy_h, y_high}};
131
+ for (auto wx_xp_ele : wx_xp) {
132
+ for (auto wy_yp_ele : wy_yp) {
133
+ auto wx = wx_xp_ele[0 ];
134
+ auto xp = wx_xp_ele[1 ];
135
+ auto wy = wy_yp_ele[0 ];
136
+ auto yp = wy_yp_ele[1 ];
137
+ val += tvm::if_then_else (0 <= yp && yp <= max_y && 0 <= xp && xp <= max_x,
138
+ wx * wy * input (batch_id, yp, xp, channel_id), 0 );
139
+ }
140
+ }
141
+ return val;
127
142
}
128
143
129
144
} // namespace detail
0 commit comments