Skip to content

Commit

Permalink
[Parser] Added @ restrict attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
dmed256 committed Apr 4, 2019
1 parent d2cddcc commit bf1dd16
Show file tree
Hide file tree
Showing 11 changed files with 145 additions and 22 deletions.
1 change: 0 additions & 1 deletion include/occa/lang/mode/opencl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ namespace occa {
openclParser(const occa::properties &settings_ = occa::properties());

virtual void onClear();

virtual void beforePreprocessing();

virtual void beforeKernelSplit();
Expand Down
4 changes: 4 additions & 0 deletions include/occa/lang/parser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ namespace occa {

//---[ Misc ]---------------------
occa::properties settings;
qualifier_t *restrictQualifier;
//================================

parser_t(const occa::properties &settings_ = occa::properties());
Expand Down Expand Up @@ -95,10 +96,13 @@ namespace occa {

void setSource(const std::string &source,
const bool isFile);
void setupLoadTokens();
void loadTokens();
void parseTokens();

keyword_t& getKeyword(token_t *token);
keyword_t& getKeyword(const std::string &name);

opType_t getOperatorType(token_t *token);
//================================

Expand Down
2 changes: 2 additions & 0 deletions include/occa/lang/statement/statement.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ namespace occa {
extern const int return_;

extern const int attribute;

extern const int blockStatements;
}

class statement_t {
Expand Down
1 change: 1 addition & 0 deletions include/occa/lang/transforms/builtins.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <occa/lang/transforms/builtins/dim.hpp>
#include <occa/lang/transforms/builtins/finders.hpp>
#include <occa/lang/transforms/builtins/restrict.hpp>
#include <occa/lang/transforms/builtins/tile.hpp>

#endif
26 changes: 26 additions & 0 deletions include/occa/lang/transforms/builtins/restrict.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef OCCA_LANG_TRANSFORMS_BUILTINS_RESTRICT_HEADER
#define OCCA_LANG_TRANSFORMS_BUILTINS_RESTRICT_HEADER

#include <occa/lang/transforms/statementTransform.hpp>

namespace occa {
namespace lang {
class qualifier_t;

namespace transforms {
class restrict : public statementTransform {
public:
const qualifier_t &restrictQualifier;

restrict(const qualifier_t &restrictQualifier_);

virtual statement_t* transformStatement(statement_t &smnt);
};

bool applyRestrictTransforms(statement_t &smnt,
const qualifier_t &restrictQualifier);
}
}
}

#endif
1 change: 0 additions & 1 deletion src/lang/mode/okl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,6 @@ namespace occa {
parser.addAttribute<attributes::inner>();
parser.addAttribute<attributes::kernel>();
parser.addAttribute<attributes::outer>();
parser.addAttribute<attributes::restrict>();
parser.addAttribute<attributes::shared>();
}

Expand Down
15 changes: 5 additions & 10 deletions src/lang/mode/opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ namespace occa {

okl::addAttributes(*this);

if (!settings.has("options/restrict")) {
settings["options/restrict"] = "restrict";
}
settings["opencl/extensions/cl_khr_fp64"] = true;
}

Expand Down Expand Up @@ -236,16 +239,8 @@ namespace occa {

void openclParser::addStructQualifiers() {
statementPtrVector statements;
findStatements(statementType::block |
statementType::elif_ |
statementType::else_ |
statementType::for_ |
statementType::function |
statementType::functionDecl |
statementType::if_ |
statementType::namespace_ |
statementType::switch_ |
statementType::while_,
findStatements(statementType::blockStatements |
statementType::function,
root,
updateScopeStructVariables,
statements);
Expand Down
52 changes: 42 additions & 10 deletions src/lang/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ namespace occa {
root(NULL, NULL),
up(&root),
success(true),
settings(settings_) {
settings(settings_),
restrictQualifier(NULL) {
// Properly implement `identifier-nondigit` for identifiers
// Meanwhile, we use the unknownFilter
stream = (tokenizer
Expand Down Expand Up @@ -77,12 +78,14 @@ namespace occa {
addAttribute<attributes::dim>();
addAttribute<attributes::dimOrder>();
addAttribute<attributes::tile>();
addAttribute<attributes::restrict>();
}

parser_t::~parser_t() {
clear();

freeKeywords(keywords);
delete restrictQualifier;

nameToAttributeMap::iterator it = attributeMap.begin();
while (it != attributeMap.end()) {
Expand Down Expand Up @@ -231,6 +234,7 @@ namespace occa {
tokenizer.set(source.c_str());
}

setupLoadTokens();
loadTokens();

delete root.source;
Expand All @@ -241,9 +245,23 @@ namespace occa {
);
}

void parser_t::loadTokens() {
void parser_t::setupLoadTokens() {
beforePreprocessing();

// Setup @restrict
const std::string restrictStr = (
settings.get<std::string>("options/restrict",
"__restrict__")
);

if (restrictStr != "disabled") {
restrictQualifier = new qualifier_t(restrictStr,
qualifierType::custom);
addKeyword(keywords, new qualifierKeyword(*restrictQualifier));
}
}

void parser_t::loadTokens() {
token_t *token;
while (!stream.isEmpty()) {
stream >> token;
Expand All @@ -263,12 +281,22 @@ namespace occa {
void parser_t::parseTokens() {
beforeParsing();
if (!success) return;

loadAllStatements();
if (!success) return;

if (restrictQualifier) {
success = transforms::applyRestrictTransforms(root,
*restrictQualifier);
if (!success) return;
}

success = transforms::applyDimTransforms(root);
if (!success) return;

success = transforms::applyTileTransforms(root);
if (!success) return;

afterParsing();
}

Expand All @@ -287,28 +315,32 @@ namespace occa {
return noKeyword;
}

std::string identifier;
std::string name;
if (tType & tokenType::identifier) {
identifier = token->to<identifierToken>().value;
name = token->to<identifierToken>().value;
}
else if (tType & tokenType::qualifier) {
identifier = token->to<qualifierToken>().qualifier.name;
name = token->to<qualifierToken>().qualifier.name;
}
else if (tType & tokenType::type) {
identifier = token->to<typeToken>().value.name();
name = token->to<typeToken>().value.name();
}
else if (tType & tokenType::variable) {
identifier = token->to<variableToken>().value.name();
name = token->to<variableToken>().value.name();
}
else if (tType & tokenType::function) {
identifier = token->to<functionToken>().value.name();
name = token->to<functionToken>().value.name();
}

keywordMapIterator it = keywords.find(identifier);
return getKeyword(name);
}

keyword_t& parser_t::getKeyword(const std::string &name) {
keywordMapIterator it = keywords.find(name);
if (it != keywords.end()) {
return *(it->second);
}
return up->getScopeKeyword(identifier);
return up->getScopeKeyword(name);
}

opType_t parser_t::getOperatorType(token_t *token) {
Expand Down
12 changes: 12 additions & 0 deletions src/lang/statement/statement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ namespace occa {
const int return_ = (1 << 27);

const int attribute = (1 << 28);

const int blockStatements = (
block |
elif_ |
else_ |
for_ |
functionDecl |
if_ |
namespace_ |
switch_ |
while_
);
}

statement_t::statement_t(blockStatement *up_,
Expand Down
49 changes: 49 additions & 0 deletions src/lang/transforms/builtins/restrict.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#include <occa/lang/statement.hpp>
#include <occa/lang/variable.hpp>
#include <occa/lang/qualifier.hpp>
#include <occa/lang/builtins/types.hpp>
#include <occa/lang/transforms/builtins/restrict.hpp>

namespace occa {
namespace lang {
namespace transforms {
restrict::restrict(const qualifier_t &restrictQualifier_) :
restrictQualifier(restrictQualifier_) {
validStatementTypes = (statementType::functionDecl |
statementType::function);
}

statement_t* restrict::transformStatement(statement_t &smnt) {
function_t &func = (
(smnt.type() & statementType::function)
? smnt.to<functionStatement>().function
: smnt.to<functionDeclStatement>().function
);

const int argc = (int) func.args.size();
for (int i = 0; i < argc; ++i) {
variable_t *arg = func.args[i];
if (arg && arg->hasAttribute("restrict")) {
const int pointerCount = (int) arg->vartype.pointers.size();
if (pointerCount) {
arg->vartype.pointers[pointerCount - 1] += restrictQualifier;
} else {
arg->attributes["restrict"].printError(
"[@restrict] can only be applied to pointer function arguments"
);
return NULL;
}
}
}

return &smnt;
}

bool applyRestrictTransforms(statement_t &smnt,
const qualifier_t &restrictQualifier) {
restrict restrictTransform(restrictQualifier);
return restrictTransform.statementTransform::apply(smnt);
}
}
}
}
4 changes: 4 additions & 0 deletions tests/src/lang/parser/statementLoading.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,10 @@ void testAttributeLoading() {
parseAndPrintSource("@dim(1,2,3) @dimOrder(2,1,0) int *x; x(1,2,3);");
std::cerr << "==============================================\n";

std::cerr << "\n---[ @restrict Transformations ]------------\n";
parseAndPrintSource("void foo(@restrict int *a) {}");
std::cerr << "==============================================\n";

std::cerr << "\n---[ @tile Transformations ]--------------------\n";
parseAndPrintSource("for (int i = 0; i < (1 + 2 + N + 6); ++i; @tile(16, @outer, @inner, check=false)) {"
" int x;"
Expand Down

0 comments on commit bf1dd16

Please sign in to comment.