From bf1dd169dfba475ab7ac35a4529fe4c6890d38cd Mon Sep 17 00:00:00 2001 From: dmed256 Date: Thu, 4 Apr 2019 03:06:34 -0700 Subject: [PATCH] [Parser] Added @ restrict attribute --- include/occa/lang/mode/opencl.hpp | 1 - include/occa/lang/parser.hpp | 4 ++ include/occa/lang/statement/statement.hpp | 2 + include/occa/lang/transforms/builtins.hpp | 1 + .../lang/transforms/builtins/restrict.hpp | 26 ++++++++++ src/lang/mode/okl.cpp | 1 - src/lang/mode/opencl.cpp | 15 ++---- src/lang/parser.cpp | 52 +++++++++++++++---- src/lang/statement/statement.cpp | 12 +++++ src/lang/transforms/builtins/restrict.cpp | 49 +++++++++++++++++ tests/src/lang/parser/statementLoading.cpp | 4 ++ 11 files changed, 145 insertions(+), 22 deletions(-) create mode 100644 include/occa/lang/transforms/builtins/restrict.hpp create mode 100644 src/lang/transforms/builtins/restrict.cpp diff --git a/include/occa/lang/mode/opencl.hpp b/include/occa/lang/mode/opencl.hpp index b724f9fe4..c5ac3e75d 100644 --- a/include/occa/lang/mode/opencl.hpp +++ b/include/occa/lang/mode/opencl.hpp @@ -16,7 +16,6 @@ namespace occa { openclParser(const occa::properties &settings_ = occa::properties()); virtual void onClear(); - virtual void beforePreprocessing(); virtual void beforeKernelSplit(); diff --git a/include/occa/lang/parser.hpp b/include/occa/lang/parser.hpp index 9429e715a..c73ed1aa4 100644 --- a/include/occa/lang/parser.hpp +++ b/include/occa/lang/parser.hpp @@ -64,6 +64,7 @@ namespace occa { //---[ Misc ]--------------------- occa::properties settings; + qualifier_t *restrictQualifier; //================================ parser_t(const occa::properties &settings_ = occa::properties()); @@ -95,10 +96,13 @@ namespace occa { void setSource(const std::string &source, const bool isFile); + void setupLoadTokens(); void loadTokens(); void parseTokens(); keyword_t& getKeyword(token_t *token); + keyword_t& getKeyword(const std::string &name); + opType_t getOperatorType(token_t *token); //================================ diff --git a/include/occa/lang/statement/statement.hpp b/include/occa/lang/statement/statement.hpp index 313eb9480..2d0b25275 100644 --- a/include/occa/lang/statement/statement.hpp +++ b/include/occa/lang/statement/statement.hpp @@ -63,6 +63,8 @@ namespace occa { extern const int return_; extern const int attribute; + + extern const int blockStatements; } class statement_t { diff --git a/include/occa/lang/transforms/builtins.hpp b/include/occa/lang/transforms/builtins.hpp index 6f986f49c..24d1f983d 100644 --- a/include/occa/lang/transforms/builtins.hpp +++ b/include/occa/lang/transforms/builtins.hpp @@ -3,6 +3,7 @@ #include #include +#include #include #endif diff --git a/include/occa/lang/transforms/builtins/restrict.hpp b/include/occa/lang/transforms/builtins/restrict.hpp new file mode 100644 index 000000000..ac97c4b69 --- /dev/null +++ b/include/occa/lang/transforms/builtins/restrict.hpp @@ -0,0 +1,26 @@ +#ifndef OCCA_LANG_TRANSFORMS_BUILTINS_RESTRICT_HEADER +#define OCCA_LANG_TRANSFORMS_BUILTINS_RESTRICT_HEADER + +#include + +namespace occa { + namespace lang { + class qualifier_t; + + namespace transforms { + class restrict : public statementTransform { + public: + const qualifier_t &restrictQualifier; + + restrict(const qualifier_t &restrictQualifier_); + + virtual statement_t* transformStatement(statement_t &smnt); + }; + + bool applyRestrictTransforms(statement_t &smnt, + const qualifier_t &restrictQualifier); + } + } +} + +#endif diff --git a/src/lang/mode/okl.cpp b/src/lang/mode/okl.cpp index a1bb54533..730fb8caf 100644 --- a/src/lang/mode/okl.cpp +++ b/src/lang/mode/okl.cpp @@ -388,7 +388,6 @@ namespace occa { parser.addAttribute(); parser.addAttribute(); parser.addAttribute(); - parser.addAttribute(); parser.addAttribute(); } diff --git a/src/lang/mode/opencl.cpp b/src/lang/mode/opencl.cpp index e4823141c..20963c644 100644 --- a/src/lang/mode/opencl.cpp +++ b/src/lang/mode/opencl.cpp @@ -17,6 +17,9 @@ namespace occa { okl::addAttributes(*this); + if (!settings.has("options/restrict")) { + settings["options/restrict"] = "restrict"; + } settings["opencl/extensions/cl_khr_fp64"] = true; } @@ -236,16 +239,8 @@ namespace occa { void openclParser::addStructQualifiers() { statementPtrVector statements; - findStatements(statementType::block | - statementType::elif_ | - statementType::else_ | - statementType::for_ | - statementType::function | - statementType::functionDecl | - statementType::if_ | - statementType::namespace_ | - statementType::switch_ | - statementType::while_, + findStatements(statementType::blockStatements | + statementType::function, root, updateScopeStructVariables, statements); diff --git a/src/lang/parser.cpp b/src/lang/parser.cpp index a3913ad5e..588214b42 100644 --- a/src/lang/parser.cpp +++ b/src/lang/parser.cpp @@ -19,7 +19,8 @@ namespace occa { root(NULL, NULL), up(&root), success(true), - settings(settings_) { + settings(settings_), + restrictQualifier(NULL) { // Properly implement `identifier-nondigit` for identifiers // Meanwhile, we use the unknownFilter stream = (tokenizer @@ -77,12 +78,14 @@ namespace occa { addAttribute(); addAttribute(); addAttribute(); + addAttribute(); } parser_t::~parser_t() { clear(); freeKeywords(keywords); + delete restrictQualifier; nameToAttributeMap::iterator it = attributeMap.begin(); while (it != attributeMap.end()) { @@ -231,6 +234,7 @@ namespace occa { tokenizer.set(source.c_str()); } + setupLoadTokens(); loadTokens(); delete root.source; @@ -241,9 +245,23 @@ namespace occa { ); } - void parser_t::loadTokens() { + void parser_t::setupLoadTokens() { beforePreprocessing(); + // Setup @restrict + const std::string restrictStr = ( + settings.get("options/restrict", + "__restrict__") + ); + + if (restrictStr != "disabled") { + restrictQualifier = new qualifier_t(restrictStr, + qualifierType::custom); + addKeyword(keywords, new qualifierKeyword(*restrictQualifier)); + } + } + + void parser_t::loadTokens() { token_t *token; while (!stream.isEmpty()) { stream >> token; @@ -263,12 +281,22 @@ namespace occa { void parser_t::parseTokens() { beforeParsing(); if (!success) return; + loadAllStatements(); if (!success) return; + + if (restrictQualifier) { + success = transforms::applyRestrictTransforms(root, + *restrictQualifier); + if (!success) return; + } + success = transforms::applyDimTransforms(root); if (!success) return; + success = transforms::applyTileTransforms(root); if (!success) return; + afterParsing(); } @@ -287,28 +315,32 @@ namespace occa { return noKeyword; } - std::string identifier; + std::string name; if (tType & tokenType::identifier) { - identifier = token->to().value; + name = token->to().value; } else if (tType & tokenType::qualifier) { - identifier = token->to().qualifier.name; + name = token->to().qualifier.name; } else if (tType & tokenType::type) { - identifier = token->to().value.name(); + name = token->to().value.name(); } else if (tType & tokenType::variable) { - identifier = token->to().value.name(); + name = token->to().value.name(); } else if (tType & tokenType::function) { - identifier = token->to().value.name(); + name = token->to().value.name(); } - keywordMapIterator it = keywords.find(identifier); + return getKeyword(name); + } + + keyword_t& parser_t::getKeyword(const std::string &name) { + keywordMapIterator it = keywords.find(name); if (it != keywords.end()) { return *(it->second); } - return up->getScopeKeyword(identifier); + return up->getScopeKeyword(name); } opType_t parser_t::getOperatorType(token_t *token) { diff --git a/src/lang/statement/statement.cpp b/src/lang/statement/statement.cpp index 7da201da1..153b4f143 100644 --- a/src/lang/statement/statement.cpp +++ b/src/lang/statement/statement.cpp @@ -46,6 +46,18 @@ namespace occa { const int return_ = (1 << 27); const int attribute = (1 << 28); + + const int blockStatements = ( + block | + elif_ | + else_ | + for_ | + functionDecl | + if_ | + namespace_ | + switch_ | + while_ + ); } statement_t::statement_t(blockStatement *up_, diff --git a/src/lang/transforms/builtins/restrict.cpp b/src/lang/transforms/builtins/restrict.cpp new file mode 100644 index 000000000..09d762880 --- /dev/null +++ b/src/lang/transforms/builtins/restrict.cpp @@ -0,0 +1,49 @@ +#include +#include +#include +#include +#include + +namespace occa { + namespace lang { + namespace transforms { + restrict::restrict(const qualifier_t &restrictQualifier_) : + restrictQualifier(restrictQualifier_) { + validStatementTypes = (statementType::functionDecl | + statementType::function); + } + + statement_t* restrict::transformStatement(statement_t &smnt) { + function_t &func = ( + (smnt.type() & statementType::function) + ? smnt.to().function + : smnt.to().function + ); + + const int argc = (int) func.args.size(); + for (int i = 0; i < argc; ++i) { + variable_t *arg = func.args[i]; + if (arg && arg->hasAttribute("restrict")) { + const int pointerCount = (int) arg->vartype.pointers.size(); + if (pointerCount) { + arg->vartype.pointers[pointerCount - 1] += restrictQualifier; + } else { + arg->attributes["restrict"].printError( + "[@restrict] can only be applied to pointer function arguments" + ); + return NULL; + } + } + } + + return &smnt; + } + + bool applyRestrictTransforms(statement_t &smnt, + const qualifier_t &restrictQualifier) { + restrict restrictTransform(restrictQualifier); + return restrictTransform.statementTransform::apply(smnt); + } + } + } +} diff --git a/tests/src/lang/parser/statementLoading.cpp b/tests/src/lang/parser/statementLoading.cpp index 9d611d876..8b6cc6536 100644 --- a/tests/src/lang/parser/statementLoading.cpp +++ b/tests/src/lang/parser/statementLoading.cpp @@ -706,6 +706,10 @@ void testAttributeLoading() { parseAndPrintSource("@dim(1,2,3) @dimOrder(2,1,0) int *x; x(1,2,3);"); std::cerr << "==============================================\n"; + std::cerr << "\n---[ @restrict Transformations ]------------\n"; + parseAndPrintSource("void foo(@restrict int *a) {}"); + std::cerr << "==============================================\n"; + std::cerr << "\n---[ @tile Transformations ]--------------------\n"; parseAndPrintSource("for (int i = 0; i < (1 + 2 + N + 6); ++i; @tile(16, @outer, @inner, check=false)) {" " int x;"