Skip to content

Commit 32bcbae

Browse files
committed
refactor grid sample
1 parent 7f30d9d commit 32bcbae

File tree

1 file changed

+17
-35
lines changed

1 file changed

+17
-35
lines changed

examples/Utility.hpp

Lines changed: 17 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -236,53 +236,35 @@ inline cv::Mat bilinearGridSample(const cv::Mat& input, const cv::Mat& grid, boo
236236
cv::Mat x1Mat = x0Mat + 1;
237237
cv::Mat y1Mat = x0Mat + 1;
238238

239-
cv::Mat wa = (x1Mat - xMat).mul(y1Mat - yMat);
240-
cv::Mat wb = (x1Mat - xMat).mul(yMat - y0Mat);
241-
cv::Mat wc = (xMat - x0Mat).mul(y1Mat - yMat);
242-
cv::Mat wd = (xMat - x0Mat).mul(yMat - y0Mat);
239+
std::vector<cv::Mat> weights = {(x1Mat - xMat).mul(y1Mat - yMat), (x1Mat - xMat).mul(yMat - y0Mat),
240+
(xMat - x0Mat).mul(y1Mat - yMat), (xMat - x0Mat).mul(yMat - y0Mat)};
243241

244-
std::vector<int> newSize{batch, channel, grid.size[1] * grid.size[2]};
245-
cv::Mat result = cv::Mat::zeros(3, newSize.data(), CV_32F);
242+
cv::Mat result = cv::Mat::zeros(3, std::vector<int>{batch, channel, grid.size[1] * grid.size[2]}.data(), CV_32F);
246243

247244
auto isCoordSafe = [](int size, int maxSize) -> bool { return size > 0 && size < maxSize; };
245+
248246
for (int b = 0; b < batch; ++b) {
249247
for (int i = 0; i < grid.size[1] * grid.size[2]; ++i) {
250248
int x0 = x0Mat.at<float>(b, i);
251249
int y0 = y0Mat.at<float>(b, i);
252250
int x1 = x1Mat.at<float>(b, i);
253251
int y1 = y1Mat.at<float>(b, i);
254252

255-
cv::Mat Ia = cv::Mat::zeros(channel, 1, CV_32F);
256-
cv::Mat Ib = cv::Mat::zeros(channel, 1, CV_32F);
257-
cv::Mat Ic = cv::Mat::zeros(channel, 1, CV_32F);
258-
cv::Mat Id = cv::Mat::zeros(channel, 1, CV_32F);
259-
260-
if (isCoordSafe(x0, width) && isCoordSafe(y0, height)) {
261-
Ia = input({cv::Range(b, b + 1), cv::Range::all(), cv::Range(y0, y0 + 1), cv::Range(x0, x0 + 1)})
262-
.clone()
263-
.reshape(1, channel);
264-
}
265-
266-
if (isCoordSafe(x0, width) && isCoordSafe(y1, height)) {
267-
Ib = input({cv::Range(b, b + 1), cv::Range::all(), cv::Range(y1, y1 + 1), cv::Range(x0, x0 + 1)})
268-
.clone()
269-
.reshape(1, channel);
270-
}
271-
272-
if (isCoordSafe(x1, width) && isCoordSafe(y0, height)) {
273-
Ic = input({cv::Range(b, b + 1), cv::Range::all(), cv::Range(y0, y0 + 1), cv::Range(x1, x1 + 1)})
274-
.clone()
275-
.reshape(1, channel);
276-
}
277-
278-
if (isCoordSafe(x1, width) && isCoordSafe(y1, height)) {
279-
Id = input({cv::Range(b, b + 1), cv::Range::all(), cv::Range(y1, y1 + 1), cv::Range(x1, x1 + 1)})
280-
.clone()
281-
.reshape(1, channel);
253+
std::vector<std::pair<int, int>> pairs = {{x0, y0}, {x0, y1}, {x1, y0}, {x1, y1}};
254+
std::vector<cv::Mat> Is(4, cv::Mat::zeros(channel, 1, CV_32F));
255+
256+
for (int k = 0; k < 4; ++k) {
257+
if (isCoordSafe(pairs[k].first, width) && isCoordSafe(pairs[k].second, height)) {
258+
Is[k] =
259+
input({cv::Range(b, b + 1), cv::Range::all(), cv::Range(pairs[k].second, pairs[k].second + 1),
260+
cv::Range(pairs[k].first, pairs[k].first + 1)})
261+
.clone()
262+
.reshape(1, channel);
263+
}
282264
}
283265

284-
cv::Mat curDescriptor =
285-
Ia * wa.at<float>(i) + Ib * wb.at<float>(i) + Ic * wc.at<float>(i) + Id * wd.at<float>(i);
266+
cv::Mat curDescriptor = Is[0] * weights[0].at<float>(i) + Is[1] * weights[1].at<float>(i) +
267+
Is[2] * weights[2].at<float>(i) + Is[3] * weights[3].at<float>(i);
286268

287269
for (int c = 0; c < channel; ++c) {
288270
result.at<float>(b, c, i) = curDescriptor.at<float>(c);

0 commit comments

Comments
 (0)