Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Loops] Adds occa::forLoop #465

Merged
merged 13 commits into from
Jan 20, 2021
Prev Previous commit
Next Next commit
[Loops] Added iterator stubs
  • Loading branch information
dmed256 committed Jan 20, 2021
commit 154ba102e3b6c37187fa77426c6d9d4036db0c1b
10 changes: 7 additions & 3 deletions include/occa/functional/range.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
namespace occa {
class range : public typelessArray {
public:
const dim_t start;
const dim_t end;
const dim_t step;
dim_t start;
dim_t end;
dim_t step;

range(const dim_t end_);

Expand All @@ -31,6 +31,10 @@ namespace occa {
const dim_t end_,
const dim_t step);

range(const range &other);

range& operator = (const range &other);

private:
void setupArrayScopeOverrides(occa::scope &scope) const;

Expand Down
44 changes: 44 additions & 0 deletions include/occa/loops/iteration.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#ifndef OCCA_LOOPS_ITERATION_HEADER
#define OCCA_LOOPS_ITERATION_HEADER

#include <occa/functional/array.hpp>
#include <occa/functional/range.hpp>

namespace occa {
enum class iterationType {
undefined,
range,
indexArray
};

class iteration {
private:
iterationType type;
occa::range range;
occa::array<int> indices;

public:
iteration();

iteration(const int rangeEnd);

iteration(const occa::range &range_);

iteration(const occa::array<int> &indices_);

iteration(const iteration &other);

iteration& operator = (const iteration &other);

std::string buildForLoop(occa::scope &scope,
const std::string &iteratorName) const;

std::string buildRangeForLoop(occa::scope &scope,
const std::string &iteratorName) const;

std::string buildIndexForLoop(occa::scope &scope,
const std::string &iteratorName) const;
};
}

#endif
26 changes: 6 additions & 20 deletions include/occa/loops/typelessForLoop.hpp
Original file line number Diff line number Diff line change
@@ -1,25 +1,9 @@
#ifndef OCCA_LOOPS_TYPELESSFORLOOP_HEADER
#define OCCA_LOOPS_TYPELESSFORLOOP_HEADER

#include <occa/functional/array.hpp>
#include <occa/functional/range.hpp>
#include <occa/loops/iteration.hpp>

namespace occa {
class iteration {
public:
inline iteration() {}

inline iteration(const int) {}

inline iteration(const range &) {}

inline iteration(const array<int> &) {}

inline std::string buildForLoop(const std::string &iterationName) const {
return "";
}
};

class typelessForLoop {
public:
occa::device device;
Expand All @@ -39,12 +23,14 @@ namespace occa {
occa::scope getForLoopScope(const occa::scope &scope,
const baseFunction &fn) const;

std::string buildOuterLoop(const int index) const;
std::string buildOuterLoop(occa::scope &scope,
const int index) const;

std::string buildInnerLoop(const int index) const;
std::string buildInnerLoop(occa::scope &scope,
const int index) const;

std::string buildIndexInitializer(const std::string &indexName,
const int count) const;
const int iterationCount) const;
};
}

Expand Down
16 changes: 16 additions & 0 deletions src/functional/range.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,22 @@ namespace occa {
setupTypelessArray(device__, dtype::get<int>());
}

range::range(const range &other) :
typelessArray(other),
start(other.start),
end(other.end),
step(other.step) {}

range& range::operator = (const range &other) {
typelessArray::operator = (other);

start = other.start;
end = other.end;
step = other.step;

return *this;
}

void range::setupArrayScopeOverrides(occa::scope &scope) const {
// Step compile-time defines on the common cases:
// - Starting at 0
Expand Down
4 changes: 2 additions & 2 deletions src/loops/forLoop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace occa {

outerForLoop2 forLoop::outer(occa::iteration outerIteration0,
occa::iteration outerIteration1) {
outerForLoop1 loop(device);
outerForLoop2 loop(device);

loop.outerIterationCount = 2;

Expand All @@ -31,7 +31,7 @@ namespace occa {
outerForLoop3 forLoop::outer(occa::iteration outerIteration0,
occa::iteration outerIteration1,
occa::iteration outerIteration2) {
outerForLoop1 loop(device);
outerForLoop3 loop(device);

loop.outerIterationCount = 3;

Expand Down
63 changes: 63 additions & 0 deletions src/loops/iteration.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#include <occa/loops/iteration.hpp>

namespace occa {
iteration::iteration() :
type(iterationType::undefined),
range(-1) {}

iteration::iteration(const int rangeEnd) :
type(iterationType::range),
range(rangeEnd) {}

iteration::iteration(const occa::range &range_) :
type(iterationType::range),
range(range_) {}

iteration::iteration(const occa::array<int> &indices_) :
type(iterationType::indexArray),
range(-1),
indices(indices_) {}

iteration::iteration(const iteration &other) :
type(other.type),
range(other.range),
indices(other.indices) {}

iteration& iteration::operator = (const iteration &other) {
type = other.type;
range = other.range;
indices = other.indices;

return *this;
}

std::string iteration::buildForLoop(occa::scope &scope,
const std::string &iteratorName) const {
OCCA_ERROR("Iteration not defined",
type != iterationType::undefined);

if (type == iterationType::range) {
return buildRangeForLoop(scope, iteratorName);
} else {
return buildIndexForLoop(scope, iteratorName);
}
}

std::string iteration::buildRangeForLoop(occa::scope &scope,
const std::string &iteratorName) const {
std::stringstream ss;
ss << "for (int " << iteratorName << " = 0;"
<< " " << iteratorName << " < 10;"
<< " ++" << iteratorName << ") {";
return ss.str();
}

std::string iteration::buildIndexForLoop(occa::scope &scope,
const std::string &iteratorName) const {
std::stringstream ss;
ss << "for (int " << iteratorName << " = 0;"
<< " " << iteratorName << " < 10;"
<< " ++" << iteratorName << ") {";
return ss.str();
}
}
39 changes: 23 additions & 16 deletions src/loops/typelessForLoop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,16 @@ namespace occa {
loopScope.device = device;
}

// Temporary until issue #175 is resolved
loopScope.props["okl/validate"] = false;

// Inject the function
loopScope.props["functions/occa_loop_function"] = fn;

// Setup @outer loops
std::string outerForLoopsStart, outerForLoopsEnd;
for (int i = 0; i < outerIterationCount; ++i) {
outerForLoopsStart += buildOuterLoop(i);
outerForLoopsStart += buildOuterLoop(loopScope, i);
outerForLoopsEnd += "}";
}
loopScope.props["defines/OCCA_LOOP_START_OUTER_LOOPS"] = outerForLoopsStart;
Expand All @@ -57,7 +60,7 @@ namespace occa {
// Setup @inner loops
std::string innerForLoopsStart, innerForLoopsEnd;
for (int i = 0; i < innerIterationCount; ++i) {
innerForLoopsStart += buildInnerLoop(i);
innerForLoopsStart += buildInnerLoop(loopScope, i);
innerForLoopsEnd += "}";
}
loopScope.props["defines/OCCA_LOOP_START_INNER_LOOPS"] = innerForLoopsStart;
Expand Down Expand Up @@ -87,33 +90,37 @@ namespace occa {
return loopScope;
}

std::string typelessForLoop::buildOuterLoop(const int index) const {
return outerIterations[index].buildForLoop(
std::string typelessForLoop::buildOuterLoop(occa::scope &scope,
const int index) const {
return "@outer " + outerIterations[index].buildForLoop(
scope,
"OUTER_INDEX_" + std::to_string(index)
);
}

std::string typelessForLoop::buildInnerLoop(const int index) const {
return innerIterations[index].buildForLoop(
std::string typelessForLoop::buildInnerLoop(occa::scope &scope,
const int index) const {
return "@inner " + innerIterations[index].buildForLoop(
scope,
"INNER_INDEX_" + std::to_string(index)
);
}

std::string typelessForLoop::buildIndexInitializer(const std::string &indexName,
const int count) const {
const int iterationCount) const {
std::stringstream ss;

if (innerIterationCount == 1) {
ss << "const int " << indexName << " = " << indexName << "_1;";
} else if (innerIterationCount == 2) {
if (iterationCount == 1) {
ss << "const int " << indexName << " = " << indexName << "_0;";
} else if (iterationCount == 2) {
ss << "int2 " << indexName << ";\n"
<< "" << indexName << ".x = " << indexName << "_1;"
<< "" << indexName << ".y = " << indexName << "_2;";
} else if (innerIterationCount == 3) {
<< "" << indexName << ".x = " << indexName << "_0;"
<< "" << indexName << ".y = " << indexName << "_1;";
} else if (iterationCount == 3) {
ss << "int3 " << indexName << ";\n"
<< "" << indexName << ".x = " << indexName << "_1;"
<< "" << indexName << ".y = " << indexName << "_2;"
<< "" << indexName << ".z = " << indexName << "_3;";
<< "" << indexName << ".x = " << indexName << "_0;"
<< "" << indexName << ".y = " << indexName << "_1;"
<< "" << indexName << ".z = " << indexName << "_2;";
}

return ss.str();
Expand Down