Skip to content

Commit

Permalink
Insert subgroup size attribute in SYCL kernel lambdas.
Browse files Browse the repository at this point in the history
  • Loading branch information
kris-rowe committed Dec 15, 2023
1 parent 8747afb commit 996bd56
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 25 deletions.
5 changes: 5 additions & 0 deletions src/occa/internal/lang/attribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ namespace occa {
bool attributeArg_t::exists() const {
return expr;
}

bool attributeArg_t::canEvaluate() const {
if (!expr) return false;
return expr->canEvaluate();
}
//==================================

//---[ Attribute ]------------------
Expand Down
2 changes: 2 additions & 0 deletions src/occa/internal/lang/attribute.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ namespace occa {
void clear();

bool exists() const;

bool canEvaluate() const;
};
//==================================

Expand Down
92 changes: 67 additions & 25 deletions src/occa/internal/lang/modes/dpcpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,76 @@
#include <occa/internal/lang/builtins/types.hpp>
#include <occa/internal/lang/expr.hpp>
#include <occa/internal/lang/attribute.hpp>
// #include <stringstream>

namespace {
class SubgroupSize : public occa::lang::attribute_t {
public:
SubgroupSize() = default;

const std::string& name() const override {
static const std::string name_ = "intel::reqd_sub_group_size";
return name_;
class dpcppLambda_t : public occa::lang::lambda_t {
public:
int simd_length{-1};

dpcppLambda_t(occa::lang::capture_t capture_, int simd_length_)
: lambda_t(capture_), simd_length(simd_length_) {}

dpcppLambda_t(const dpcppLambda_t& other)
: lambda_t(other), simd_length(other.simd_length) {}

~dpcppLambda_t() = default;

Check warning on line 22 in src/occa/internal/lang/modes/dpcpp.cpp

View check run for this annotation

Codecov / codecov/patch

src/occa/internal/lang/modes/dpcpp.cpp#L22

Added line #L22 was not covered by tests

bool equals(const type_t &other) const override {
const dpcppLambda_t &other_ = other.to<dpcppLambda_t>();
if (simd_length != other_.simd_length) return false;
return lambda_t::equals(other);

Check warning on line 27 in src/occa/internal/lang/modes/dpcpp.cpp

View check run for this annotation

Codecov / codecov/patch

src/occa/internal/lang/modes/dpcpp.cpp#L24-L27

Added lines #L24 - L27 were not covered by tests
}

void printDeclaration(occa::lang::printer &pout) const override {
pout << "[";

switch (this->capture) {
case occa::lang::capture_t::byValue:
pout << "=";
break;
case occa::lang::capture_t::byReference:
pout << "&";
break;

Check warning on line 39 in src/occa/internal/lang/modes/dpcpp.cpp

View check run for this annotation

Codecov / codecov/patch

src/occa/internal/lang/modes/dpcpp.cpp#L38-L39

Added lines #L38 - L39 were not covered by tests
default:
pout << "???";
break;

Check warning on line 42 in src/occa/internal/lang/modes/dpcpp.cpp

View check run for this annotation

Codecov / codecov/patch

src/occa/internal/lang/modes/dpcpp.cpp#L41-L42

Added lines #L41 - L42 were not covered by tests
}

pout << "](";

if (!args.empty()) {
const std::string argIndent = pout.indentFromNewline();
args[0]->printDeclaration(pout);
for (std::size_t i = 1; i < args.size(); ++i) {
pout << ",\n" << argIndent;
args[i]->printDeclaration(pout);

Check warning on line 52 in src/occa/internal/lang/modes/dpcpp.cpp

View check run for this annotation

Codecov / codecov/patch

src/occa/internal/lang/modes/dpcpp.cpp#L51-L52

Added lines #L51 - L52 were not covered by tests
}
}
pout << ") ";

if (0 < simd_length) {
pout << "[[intel::reqd_sub_group_size(";
pout.print(simd_length);
pout << ")]]";
}

pout << " {";

pout.printNewline();
pout.pushInlined(false);
pout.addIndentation();

body->print(pout);

pout.removeIndentation();
pout.popInlined();
pout.printNewline();
pout.printIndentation();
pout << "}\n";
}
bool forStatementType(const int sType) const override { return false;}
bool isValid(const occa::lang::attributeToken_t &attr) const override {return true;}
};

}

namespace occa
Expand Down Expand Up @@ -196,29 +252,15 @@ namespace occa
lambda_t &cg_function = *(new lambda_t(capture_t::byReference));
cg_function.addArgument(sycl_handler);

lambda_t &sycl_kernel = *(new lambda_t(capture_t::byValue));
sycl_kernel.addArgument(sycl_nditem);

int simd_length = simd_length_default;
if (k.hasAttribute("simd_length")) {
const attributeToken_t& attr = k.attributes["simd_length"];
simd_length = attr.args[0].expr->evaluate();
}

if(0 < simd_length) {
attributeArg_t subgroup_size_arg(
new identifierNode(
new identifierToken(originSource::builtin, "subgroup_size_arg"),
std::to_string(simd_length)));

attributeToken_t subgroup_size_token(new SubgroupSize(),
new identifierToken(originSource::builtin, "subgroup_size"),
occa::lang::attribute_style::cpp);
subgroup_size_token.args.push_back(subgroup_size_arg);

sycl_kernel.attributes["subgroup_size"] = subgroup_size_token;
}

dpcppLambda_t& sycl_kernel = *(new dpcppLambda_t(capture_t::byValue, simd_length));
sycl_kernel.addArgument(sycl_nditem);
sycl_kernel.body->swap(k);

lambdaNode sycl_kernel_node(sycl_kernel.source, sycl_kernel);
Expand Down

0 comments on commit 996bd56

Please sign in to comment.