Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set SYCL subgroup size via kernel property or @simd_length attribute. #726

Merged
merged 7 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/occa/internal/lang/attribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ namespace occa {
bool attributeArg_t::exists() const {
return expr;
}

bool attributeArg_t::canEvaluate() const {
if (!expr) return false;
return expr->canEvaluate();
}
//==================================

//---[ Attribute ]------------------
Expand Down
2 changes: 2 additions & 0 deletions src/occa/internal/lang/attribute.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ namespace occa {
void clear();

bool exists() const;

bool canEvaluate() const;
};
//==================================

Expand Down
1 change: 1 addition & 0 deletions src/occa/internal/lang/builtins/attributes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <occa/internal/lang/builtins/attributes/outer.hpp>
#include <occa/internal/lang/builtins/attributes/restrict.hpp>
#include <occa/internal/lang/builtins/attributes/shared.hpp>
#include <occa/internal/lang/builtins/attributes/simdLength.hpp>
#include <occa/internal/lang/builtins/attributes/tile.hpp>

#endif
50 changes: 50 additions & 0 deletions src/occa/internal/lang/builtins/attributes/simdLength.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include <occa/internal/lang/expr.hpp>
#include <occa/internal/lang/parser.hpp>
#include <occa/internal/lang/statement.hpp>
#include <occa/internal/lang/variable.hpp>
#include <occa/internal/lang/builtins/attributes/simdLength.hpp>

namespace occa {
namespace lang {
namespace attributes {

const std::string& simdLength::name() const { return name_;}

bool simdLength::forStatementType(const int sType) const {
return (sType & statementType::for_);
}

bool simdLength::isValid(const attributeToken_t &attr) const {
if (attr.kwargs.size()) {
attr.printError(name_ + " does not take kwargs");
return false;

Check warning on line 20 in src/occa/internal/lang/builtins/attributes/simdLength.cpp

View check run for this annotation

Codecov / codecov/patch

src/occa/internal/lang/builtins/attributes/simdLength.cpp#L19-L20

Added lines #L19 - L20 were not covered by tests
}

if (1 != attr.args.size()) {
attr.printError(name_ + " takes one argument");
return false;

Check warning on line 25 in src/occa/internal/lang/builtins/attributes/simdLength.cpp

View check run for this annotation

Codecov / codecov/patch

src/occa/internal/lang/builtins/attributes/simdLength.cpp#L24-L25

Added lines #L24 - L25 were not covered by tests
}

const auto& attr_arg = attr.args[0];
if (!attr_arg.canEvaluate()) {
attr.printError(name_ + " cannot evaluate argument");
return false;

Check warning on line 31 in src/occa/internal/lang/builtins/attributes/simdLength.cpp

View check run for this annotation

Codecov / codecov/patch

src/occa/internal/lang/builtins/attributes/simdLength.cpp#L30-L31

Added lines #L30 - L31 were not covered by tests
}

primitive value = attr_arg.expr->evaluate();
if (!value.isInteger()) {
attr.printError(name_ + " take an integer argument");
return false;

Check warning on line 37 in src/occa/internal/lang/builtins/attributes/simdLength.cpp

View check run for this annotation

Codecov / codecov/patch

src/occa/internal/lang/builtins/attributes/simdLength.cpp#L36-L37

Added lines #L36 - L37 were not covered by tests
}

if(0 > value.to<int>()) {
attr.printError(name_ + " arguments must be postive!");
return false;

Check warning on line 42 in src/occa/internal/lang/builtins/attributes/simdLength.cpp

View check run for this annotation

Codecov / codecov/patch

src/occa/internal/lang/builtins/attributes/simdLength.cpp#L41-L42

Added lines #L41 - L42 were not covered by tests
}

return true;
}

}
}
}
24 changes: 24 additions & 0 deletions src/occa/internal/lang/builtins/attributes/simdLength.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef OCCA_INTERNAL_LANG_BUILTINS_ATTRIBUTES_SIMD_LENGTH_HEADER
#define OCCA_INTERNAL_LANG_BUILTINS_ATTRIBUTES_SIMD_LENGTH_HEADER

#include <occa/internal/lang/attribute.hpp>

namespace occa {
namespace lang {
namespace attributes {

class simdLength : public attribute_t {
public:
simdLength() = default;
const std::string& name() const override;
bool forStatementType(const int sType) const override;
bool isValid(const attributeToken_t &attr) const override;
private:
static const inline std::string name_{"simd_length"};
};

}
}
}

#endif
94 changes: 82 additions & 12 deletions src/occa/internal/lang/modes/dpcpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,78 @@
#include <occa/internal/lang/builtins/attributes.hpp>
#include <occa/internal/lang/builtins/types.hpp>
#include <occa/internal/lang/expr.hpp>
// #include <stringstream>
#include <occa/internal/lang/attribute.hpp>

namespace {

class dpcppLambda_t : public occa::lang::lambda_t {
public:
int simd_length{-1};

dpcppLambda_t(occa::lang::capture_t capture_, int simd_length_)
: lambda_t(capture_), simd_length(simd_length_) {}

dpcppLambda_t(const dpcppLambda_t& other)
: lambda_t(other), simd_length(other.simd_length) {}

~dpcppLambda_t() = default;

Check warning on line 22 in src/occa/internal/lang/modes/dpcpp.cpp

View check run for this annotation

Codecov / codecov/patch

src/occa/internal/lang/modes/dpcpp.cpp#L22

Added line #L22 was not covered by tests

bool equals(const type_t &other) const override {
const dpcppLambda_t &other_ = other.to<dpcppLambda_t>();
if (simd_length != other_.simd_length) return false;
return lambda_t::equals(other);

Check warning on line 27 in src/occa/internal/lang/modes/dpcpp.cpp

View check run for this annotation

Codecov / codecov/patch

src/occa/internal/lang/modes/dpcpp.cpp#L24-L27

Added lines #L24 - L27 were not covered by tests
}

void printDeclaration(occa::lang::printer &pout) const override {
pout << "[";

switch (this->capture) {
case occa::lang::capture_t::byValue:
pout << "=";
break;
case occa::lang::capture_t::byReference:
pout << "&";
break;

Check warning on line 39 in src/occa/internal/lang/modes/dpcpp.cpp

View check run for this annotation

Codecov / codecov/patch

src/occa/internal/lang/modes/dpcpp.cpp#L38-L39

Added lines #L38 - L39 were not covered by tests
default:
pout << "???";
break;

Check warning on line 42 in src/occa/internal/lang/modes/dpcpp.cpp

View check run for this annotation

Codecov / codecov/patch

src/occa/internal/lang/modes/dpcpp.cpp#L41-L42

Added lines #L41 - L42 were not covered by tests
}

pout << "](";

if (!args.empty()) {
const std::string argIndent = pout.indentFromNewline();
args[0]->printDeclaration(pout);
for (std::size_t i = 1; i < args.size(); ++i) {
pout << ",\n" << argIndent;
args[i]->printDeclaration(pout);

Check warning on line 52 in src/occa/internal/lang/modes/dpcpp.cpp

View check run for this annotation

Codecov / codecov/patch

src/occa/internal/lang/modes/dpcpp.cpp#L51-L52

Added lines #L51 - L52 were not covered by tests
}
}
pout << ") ";

if (0 < simd_length) {
pout << "[[intel::reqd_sub_group_size(";
pout.print(simd_length);
pout << ")]]";
}

pout << " {";

pout.printNewline();
pout.pushInlined(false);
pout.addIndentation();

body->print(pout);

pout.removeIndentation();
pout.popInlined();
pout.printNewline();
pout.printIndentation();
pout << "}\n";
}
};

}

namespace occa
{
Expand All @@ -20,6 +91,7 @@
shared("auto", qualifierType::custom)
{
okl::addOklAttributes(*this);
simd_length_default = settings_.get("simd_length",-1);
}

void dpcppParser::onClear()
Expand Down Expand Up @@ -79,15 +151,7 @@

std::string dpcppParser::launchBoundsAttribute(const int innerDims[3])
{
std::stringstream ss;
ss << "[[sycl::reqd_work_group_size("
<< innerDims[2]
<< ","
<< innerDims[1]
<< ","
<< innerDims[0]
<< ")]]\n";
return ss.str();
return "";
}

// @note: As of SYCL 2020 this will need to change from `CL/sycl.hpp` to `sycl.hpp`
Expand Down Expand Up @@ -188,9 +252,15 @@
lambda_t &cg_function = *(new lambda_t(capture_t::byReference));
cg_function.addArgument(sycl_handler);

lambda_t &sycl_kernel = *(new lambda_t(capture_t::byValue));
sycl_kernel.addArgument(sycl_nditem);

int simd_length = simd_length_default;
if (k.hasAttribute("simd_length")) {
const attributeToken_t& attr = k.attributes["simd_length"];
simd_length = attr.args[0].expr->evaluate();
}

dpcppLambda_t& sycl_kernel = *(new dpcppLambda_t(capture_t::byValue, simd_length));
sycl_kernel.addArgument(sycl_nditem);
sycl_kernel.body->swap(k);

lambdaNode sycl_kernel_node(sycl_kernel.source, sycl_kernel);
Expand Down
3 changes: 2 additions & 1 deletion src/occa/internal/lang/modes/dpcpp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,14 @@ namespace occa
void setSharedQualifiers();
void setKernelQualifiers(function_t &function);
void migrateLocalDecls(functionDeclStatement &kernelSmnt);
void setLaunchBounds();

void setupAtomics();
static bool transformAtomicBlockStatement(blockStatement &blockSmnt);
static bool transformAtomicBasicExpressionStatement(expressionStatement &exprSmnt);

private:
int simd_length_default;

inline int dpcppDimensionOrder(const int index) { return 2 - index; }
};
} // namespace okl
Expand Down
1 change: 1 addition & 0 deletions src/occa/internal/lang/modes/okl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ namespace occa {
parser.addAttribute<attributes::shared>();
parser.addAttribute<attributes::maxInnerDims>();
parser.addAttribute<attributes::noBarrier>();
parser.addAttribute<attributes::simdLength>();
}

void setOklLoopIndices(functionDeclStatement &kernelSmnt) {
Expand Down
4 changes: 4 additions & 0 deletions src/occa/internal/lang/modes/withLauncher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,10 @@ namespace occa {
forStatement &newForSmnt = (forStatement&) forSmnt.clone();
newKernelSmnt.set(newForSmnt);

if (newForSmnt.hasAttribute("simd_length")) {
newKernelSmnt.addAttribute(newForSmnt.attributes["simd_length"]);
}

bool addLaunchBoundsAttribute{true};
int kernelInnerDims[3] = {1,1,1};
if (newForSmnt.hasAttribute("max_inner_dims")) {
Expand Down
56 changes: 56 additions & 0 deletions tests/src/internal/lang/modes/dpcpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ void testSharedAnnotation();
void testBarriers();
void testAtomic();
void testSource();
void testSimdLength();

int main(const int argc, const char **argv) {
parser.settings["okl/validate"] = true;
Expand All @@ -38,6 +39,7 @@ int main(const int argc, const char **argv) {
testSharedAnnotation();
testBarriers();
testSource();
testSimdLength();

return 0;
}
Expand Down Expand Up @@ -163,3 +165,57 @@ void testSource() {
"}\n"
);
}

void testSimdLengthAttribute() {
const std::string kernel_source = R"(
@kernel void f() {
@outer @simd_length(16)
for (int o = 0; o < 1; ++o) {
@inner for (int i = 0; i < 32; ++i) {
int j = i + o;
}
}
}
)";

parser.parseSource(kernel_source);
ASSERT_TRUE(parser.success);

printer pout;
parser.root.print(pout);
const std::string translated_source = pout.str();

auto pos = translated_source.find("[[intel::reqd_sub_group_size(16)]]");
ASSERT_TRUE(std::string::npos != pos);
}

void testSimdLengthProperty() {
const std::string kernel_source = R"(
@kernel void f() {
@outer for (int o = 0; o < 1; ++o) {
@inner for (int i = 0; i < 32; ++i) {
int j = i + o;
}
}
}
)";

occa::json properties;
properties["simd_length"] = 16;
occa::lang::okl::dpcppParser dpcpp_parser(properties);

dpcpp_parser.parseSource(kernel_source);
ASSERT_TRUE(parser.success);

printer pout;
dpcpp_parser.root.print(pout);
const std::string translated_source = pout.str();

auto pos = translated_source.find("[[intel::reqd_sub_group_size(16)]]");
ASSERT_TRUE(std::string::npos != pos);
}

void testSimdLength() {
testSimdLengthAttribute();
testSimdLengthProperty();
}