Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Loops] Adds occa::forLoop #465

Merged
merged 13 commits into from
Jan 20, 2021
Prev Previous commit
Next Next commit
[Lang] Handle OKL as a function/macro
  • Loading branch information
dmed256 committed Jan 20, 2021
commit 92af3805e7ead9910db44de667765bde729519d7
2 changes: 0 additions & 2 deletions include/occa/defines/okl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
# undef OCCA_JIT
#endif

#define OKL(SOURCE_CODE)

#define OCCA_JIT(OKL_SCOPE, OKL_SOURCE) \
do { \
static ::occa::kernelBuilder _occaJitKernelBuilder( \
Expand Down
7 changes: 7 additions & 0 deletions include/occa/functional/baseFunction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
#include <occa/utils/hash.hpp>
#include <occa/types/typedefs.hpp>

// Macro-like function used to expand source code in captured lambdas
// - Host: No-op function
// - OKL : Macro to expand source code
inline void OKL(const std::string &sourceCode) {}

namespace occa {
class baseFunction {
public:
Expand All @@ -16,6 +21,8 @@ namespace occa {

functionDefinition& definition();

const functionDefinition& definition() const;

virtual int argumentCount() const = 0;

hash_t hash() const;
Expand Down
16 changes: 8 additions & 8 deletions include/occa/functional/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@
#define OCCA_FUNCTION_1(lambda) \
::occa::functional::inferFunction({}, lambda, #lambda)

#define OCCA_FUNCTION(...) \
OCCA_FUNCTION_EXPAND_1(OCCA_ARG_COUNT(__VA_ARGS__), __VA_ARGS__)
#define OCCA_FUNCTION(...) \
OCCA_FUNCTION_EXPAND_NAME_1(OCCA_ARG_COUNT(__VA_ARGS__)) (__VA_ARGS__)

#define OCCA_FUNCTION_EXPAND_1(ARG_COUNT, ...) \
OCCA_FUNCTION_EXPAND_2(ARG_COUNT, __VA_ARGS__)
#define OCCA_FUNCTION_EXPAND_NAME_1(ARG_COUNT) \
OCCA_FUNCTION_EXPAND_NAME_2(ARG_COUNT)

#define OCCA_FUNCTION_EXPAND_2(ARG_COUNT, ...) \
OCCA_FUNCTION_EXPAND_3(ARG_COUNT, __VA_ARGS__)
#define OCCA_FUNCTION_EXPAND_NAME_2(ARG_COUNT) \
OCCA_FUNCTION_EXPAND_NAME_3(ARG_COUNT)

#define OCCA_FUNCTION_EXPAND_3(ARG_COUNT, ...) \
OCCA_FUNCTION_ ## ARG_COUNT (__VA_ARGS__)
#define OCCA_FUNCTION_EXPAND_NAME_3(ARG_COUNT) \
OCCA_FUNCTION_ ## ARG_COUNT

namespace occa {
template <class Function>
Expand Down
5 changes: 5 additions & 0 deletions src/functional/baseFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ namespace occa {
return *functionStore.get(hash_);
}

const functionDefinition& baseFunction::definition() const {
// Should be initialized at this point
return *functionStore.get(hash_);
}

hash_t baseFunction::hash() const {
return hash_;
}
Expand Down
18 changes: 2 additions & 16 deletions src/loops/typelessForLoop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace occa {
OCCA_LOOP_INIT_OUTER_INDEX(outerIndex)
OCCA_LOOP_INIT_INNER_INDEX(innerIndex)

OCCA_LOOP_FUNCTION(outerIndex, innerIndex);
OCCA_LOOP_FUNCTION

OCCA_LOOP_END_INNER_LOOPS
OCCA_LOOP_END_OUTER_LOOPS
Expand All @@ -37,11 +37,8 @@ namespace occa {
loopScope.device = device;
}

// Temporary until issue #175 is resolved
loopScope.props["okl/validate"] = false;

// Inject the function
loopScope.props["functions/occa_loop_function"] = fn;
loopScope.props["defines/OCCA_LOOP_FUNCTION"] = fn.definition().bodySource;

// Setup @outer loops
std::string outerForLoopsStart, outerForLoopsEnd;
Expand Down Expand Up @@ -76,17 +73,6 @@ namespace occa {
loopScope.props["defines/OCCA_LOOP_INIT_INNER_INDEX(INDEX)"] = "";
}

// Define function call
strVector argumentValues;
if (innerIterationCount) {
argumentValues = {"OUTER_INDEX", "INNER_INDEX"};
} else {
argumentValues = {"OUTER_INDEX"};
}
loopScope.props["defines/OCCA_LOOP_FUNCTION(OUTER_INDEX, INNER_INDEX)"] = (
fn.buildFunctionCall("occa_loop_function", argumentValues)
);

return loopScope;
}

Expand Down
19 changes: 19 additions & 0 deletions src/occa/internal/lang/specialMacros.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,27 @@ namespace occa {
return;
}

// Check if the next argument is an empty semicolon to remove
//
// Example:
// OKL("@inner");
// for(...)
// ->
// @inner
// for(...)
//
token_t *nextToken = pp.getSourceToken();
if (nextToken) {
if (token_t::safeOperatorType(nextToken) == operatorType::semicolon) {
delete nextToken;
} else {
pp.pushInput(nextToken);
}
}

const std::string &content = token->to<stringToken>().value;
pp.injectSourceCode(*token, strip(content));

}
}
}
3 changes: 3 additions & 0 deletions tests/src/internal/lang/preprocessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,7 @@ void testSpecialMacros() {
"__COUNTER__\n"
"__DATE__ __TIME__\n"
"OKL(\"123\")\n"
"OKL(\"\\\"456\\\"\");\n"
);

// __COUNTER__
Expand Down Expand Up @@ -634,6 +635,8 @@ void testSpecialMacros() {
// OKL
ASSERT_EQ_BINARY(123,
(int) nextTokenPrimitiveValue());
ASSERT_EQ_BINARY("456",
nextTokenStringValue());

while(!tokenStream.isEmpty()) {
getToken();
Expand Down
72 changes: 36 additions & 36 deletions tests/src/loops/forLoop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,58 +41,58 @@ void testOuterForLoops(occa::device device) {

occa::scope scope({
{"output", output}
}, {
}, {
{"defines/length", length}
});

occa::forLoop()
.outer(length)
.run(OCCA_FUNCTION(scope, [=](const int outerIndex) -> void {
OKL("@inner")
for (int i = 0; i < 2; ++i) {
const int globalIndex = i + (2 * outerIndex);
output[globalIndex] = globalIndex;
}
OKL("@inner");
for (int i = 0; i < 2; ++i) {
const int globalIndex = i + (2 * outerIndex);
output[globalIndex] = globalIndex;
}
}));

ASSERT_EQ(0, output[0]);
ASSERT_EQ((2 * length) - 1,
ASSERT_EQ((float) 0, output[0]);
ASSERT_EQ((float) (2 * length) - 1,
output[(2 * length) - 1]);
ASSERT_EQ(-1,
ASSERT_EQ((float) -1,
output[2 * length]);

occa::forLoop()
.outer(length, occa::range(length))
.run(OCCA_FUNCTION(scope, [=](const int2 outerIndex) -> void {
OKL("@inner")
for (int i = 0; i < 2; ++i) {
const int globalIndex = (
i + (2 * (outerIndex.y + length * outerIndex.x))
);
output[globalIndex] = -globalIndex;
}
OKL("@inner");
for (int i = 0; i < 2; ++i) {
const int globalIndex = (
i + (2 * (outerIndex.y + length * outerIndex.x))
);
output[globalIndex] = -globalIndex;
}
}));

ASSERT_EQ(0, output[0]);
ASSERT_EQ(-((2 * length * length) - 1),
ASSERT_EQ((float) 0, output[0]);
ASSERT_EQ((float) -((2 * length * length) - 1),
output[2 * length * length - 1]);
ASSERT_EQ(-1,
ASSERT_EQ((float) -1,
output[2 * length * length]);

occa::forLoop()
.outer(length, occa::range(length), indexArray)
.run(OCCA_FUNCTION(scope, [=](const int3 outerIndex) -> void {
OKL("@inner")
for (int i = 0; i < 2; ++i) {
const int globalIndex = (
i + (2 * (outerIndex.z + length * (outerIndex.y + length * outerIndex.x)))
);
output[globalIndex] = globalIndex;
}
OKL("@inner");
for (int i = 0; i < 2; ++i) {
const int globalIndex = (
i + (2 * (outerIndex.z + length * (outerIndex.y + length * outerIndex.x)))
);
output[globalIndex] = globalIndex;
}
}));

ASSERT_EQ(0, output[0]);
ASSERT_EQ((2 * length * length * length) - 1,
ASSERT_EQ((float) 0, output[0]);
ASSERT_EQ((float) (2 * length * length * length) - 1,
output[(2 * length * length * length) - 1]);
}

Expand All @@ -117,10 +117,10 @@ void testFullForLoops(occa::device device) {
output[globalIndex] = globalIndex;
}));

ASSERT_EQ(0, output[0]);
ASSERT_EQ((2 * length) - 1,
ASSERT_EQ((float) 0, output[0]);
ASSERT_EQ((float) (2 * length) - 1,
output[(2 * length) - 1]);
ASSERT_EQ(-1,
ASSERT_EQ((float) -1,
output[2 * length]);

occa::forLoop()
Expand All @@ -133,10 +133,10 @@ void testFullForLoops(occa::device device) {
output[globalIndex] = -globalIndex;
}));

ASSERT_EQ(0, output[0]);
ASSERT_EQ(-((2 * length * length) - 1),
ASSERT_EQ((float) 0, output[0]);
ASSERT_EQ((float) -((2 * length * length) - 1),
output[2 * length * length - 1]);
ASSERT_EQ(-1,
ASSERT_EQ((float) -1,
output[2 * length * length]);

occa::forLoop()
Expand All @@ -149,7 +149,7 @@ void testFullForLoops(occa::device device) {
output[globalIndex] = globalIndex;
}));

ASSERT_EQ(0, output[0]);
ASSERT_EQ((2 * length * length * length) - 1,
ASSERT_EQ((float) 0, output[0]);
ASSERT_EQ((float) (2 * length * length * length) - 1,
output[(2 * length * length * length) - 1]);
}