From 996bd56608805bc35955950c13e36b286ce0302e Mon Sep 17 00:00:00 2001 From: Kris Rowe Date: Wed, 13 Dec 2023 07:16:30 +0000 Subject: [PATCH] Insert subgroup size attribute in SYCL kernel lambdas. --- src/occa/internal/lang/attribute.cpp | 5 ++ src/occa/internal/lang/attribute.hpp | 2 + src/occa/internal/lang/modes/dpcpp.cpp | 92 +++++++++++++++++++------- 3 files changed, 74 insertions(+), 25 deletions(-) diff --git a/src/occa/internal/lang/attribute.cpp b/src/occa/internal/lang/attribute.cpp index 85e6750ab..4ce0ee659 100644 --- a/src/occa/internal/lang/attribute.cpp +++ b/src/occa/internal/lang/attribute.cpp @@ -54,6 +54,11 @@ namespace occa { bool attributeArg_t::exists() const { return expr; } + + bool attributeArg_t::canEvaluate() const { + if (!expr) return false; + return expr->canEvaluate(); + } //================================== //---[ Attribute ]------------------ diff --git a/src/occa/internal/lang/attribute.hpp b/src/occa/internal/lang/attribute.hpp index 0bdb0bb1d..9a8066a9e 100644 --- a/src/occa/internal/lang/attribute.hpp +++ b/src/occa/internal/lang/attribute.hpp @@ -63,6 +63,8 @@ namespace occa { void clear(); bool exists() const; + + bool canEvaluate() const; }; //================================== diff --git a/src/occa/internal/lang/modes/dpcpp.cpp b/src/occa/internal/lang/modes/dpcpp.cpp index 48a73409c..eea56cd79 100644 --- a/src/occa/internal/lang/modes/dpcpp.cpp +++ b/src/occa/internal/lang/modes/dpcpp.cpp @@ -6,20 +6,76 @@ #include #include #include -// #include namespace { -class SubgroupSize : public occa::lang::attribute_t { -public: - SubgroupSize() = default; - const std::string& name() const override { - static const std::string name_ = "intel::reqd_sub_group_size"; - return name_; +class dpcppLambda_t : public occa::lang::lambda_t { +public: + int simd_length{-1}; + + dpcppLambda_t(occa::lang::capture_t capture_, int simd_length_) + : lambda_t(capture_), simd_length(simd_length_) {} + + dpcppLambda_t(const dpcppLambda_t& other) + : lambda_t(other), simd_length(other.simd_length) {} + + ~dpcppLambda_t() = default; + + bool equals(const type_t &other) const override { + const dpcppLambda_t &other_ = other.to(); + if (simd_length != other_.simd_length) return false; + return lambda_t::equals(other); + } + + void printDeclaration(occa::lang::printer &pout) const override { + pout << "["; + + switch (this->capture) { + case occa::lang::capture_t::byValue: + pout << "="; + break; + case occa::lang::capture_t::byReference: + pout << "&"; + break; + default: + pout << "???"; + break; + } + + pout << "]("; + + if (!args.empty()) { + const std::string argIndent = pout.indentFromNewline(); + args[0]->printDeclaration(pout); + for (std::size_t i = 1; i < args.size(); ++i) { + pout << ",\n" << argIndent; + args[i]->printDeclaration(pout); + } + } + pout << ") "; + + if (0 < simd_length) { + pout << "[[intel::reqd_sub_group_size("; + pout.print(simd_length); + pout << ")]]"; + } + + pout << " {"; + + pout.printNewline(); + pout.pushInlined(false); + pout.addIndentation(); + + body->print(pout); + + pout.removeIndentation(); + pout.popInlined(); + pout.printNewline(); + pout.printIndentation(); + pout << "}\n"; } - bool forStatementType(const int sType) const override { return false;} - bool isValid(const occa::lang::attributeToken_t &attr) const override {return true;} }; + } namespace occa @@ -196,8 +252,6 @@ namespace occa lambda_t &cg_function = *(new lambda_t(capture_t::byReference)); cg_function.addArgument(sycl_handler); - lambda_t &sycl_kernel = *(new lambda_t(capture_t::byValue)); - sycl_kernel.addArgument(sycl_nditem); int simd_length = simd_length_default; if (k.hasAttribute("simd_length")) { @@ -205,20 +259,8 @@ namespace occa simd_length = attr.args[0].expr->evaluate(); } - if(0 < simd_length) { - attributeArg_t subgroup_size_arg( - new identifierNode( - new identifierToken(originSource::builtin, "subgroup_size_arg"), - std::to_string(simd_length))); - - attributeToken_t subgroup_size_token(new SubgroupSize(), - new identifierToken(originSource::builtin, "subgroup_size"), - occa::lang::attribute_style::cpp); - subgroup_size_token.args.push_back(subgroup_size_arg); - - sycl_kernel.attributes["subgroup_size"] = subgroup_size_token; - } - + dpcppLambda_t& sycl_kernel = *(new dpcppLambda_t(capture_t::byValue, simd_length)); + sycl_kernel.addArgument(sycl_nditem); sycl_kernel.body->swap(k); lambdaNode sycl_kernel_node(sycl_kernel.source, sycl_kernel);