Skip to content
This repository was archived by the owner on Mar 1, 2025. It is now read-only.

Commit 70486f2

Browse files
committed
Multithreaded rulebook
1 parent 1171aae commit 70486f2

File tree

1 file changed

+50
-11
lines changed

1 file changed

+50
-11
lines changed

sparseconvnet/SCN/Metadata/SubmanifoldConvolutionRules.h

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#ifndef SUBMANIFOLDCONVOLUTIONRULES_H
88
#define SUBMANIFOLDCONVOLUTIONRULES_H
99

10+
#include <algorithm>
11+
1012
// Full input region for an output point
1113
template <Int dimension>
1214
RectangularRegion<dimension>
@@ -27,20 +29,57 @@ template <Int dimension>
2729
double SubmanifoldConvolution_SgToRules(SparseGrid<dimension> &grid,
2830
RuleBook &rules, long *size) {
2931
double countActiveInputs = 0;
30-
for (auto const &outputIter : grid.mp) {
31-
auto inRegion =
32-
InputRegionCalculator_Submanifold<dimension>(outputIter.first, size);
33-
Int rulesOffset = 0;
34-
for (auto inputPoint : inRegion) {
35-
auto inputIter = grid.mp.find(inputPoint);
36-
if (inputIter != grid.mp.end()) {
37-
rules[rulesOffset].push_back(inputIter->second + grid.ctr);
38-
rules[rulesOffset].push_back(outputIter.second + grid.ctr);
39-
countActiveInputs++;
32+
const Int threadCount = 4;
33+
std::vector<std::thread> threads;
34+
std::array<int, threadCount> activeInputs = {};
35+
std::vector<RuleBook> rulebooks;
36+
for (Int t = 0; t < threadCount; ++t) {
37+
rulebooks.push_back(RuleBook(rules.size()));
38+
}
39+
40+
auto func = [&](const int order) {
41+
auto outputIter = grid.mp.begin();
42+
auto &rb = rulebooks[order];
43+
int rem = grid.mp.size();
44+
int aciveInputCount = 0;
45+
46+
if (rem > order) {
47+
std::advance(outputIter, order);
48+
rem -= order;
49+
50+
for (; outputIter != grid.mp.end();
51+
std::advance(outputIter, std::min(threadCount, rem)),
52+
rem -= threadCount) {
53+
auto inRegion = InputRegionCalculator_Submanifold<dimension>(
54+
outputIter->first, size);
55+
Int rulesOffset = 0;
56+
for (auto inputPoint : inRegion) {
57+
auto inputIter = grid.mp.find(inputPoint);
58+
if (inputIter != grid.mp.end()) {
59+
aciveInputCount++;
60+
rb[rulesOffset].push_back(inputIter->second + grid.ctr);
61+
rb[rulesOffset].push_back(outputIter->second + grid.ctr);
62+
}
63+
rulesOffset++;
64+
}
4065
}
41-
rulesOffset++;
66+
activeInputs[order] = aciveInputCount;
4267
}
68+
};
69+
70+
for (Int t = 0; t < threadCount; ++t) {
71+
threads.push_back(std::thread(func, t));
4372
}
73+
74+
for (Int t = 0; t < threadCount; ++t) {
75+
threads[t].join();
76+
countActiveInputs += activeInputs[t];
77+
for (std::size_t i = 0; i < rulebooks[t].size(); ++i) {
78+
rules[i].insert(rules[i].end(), rulebooks[t][i].begin(),
79+
rulebooks[t][i].end());
80+
}
81+
}
82+
4483
return countActiveInputs;
4584
}
4685

0 commit comments

Comments
 (0)