Skip to content

[SYCL][SPIR-V] Drop Unnecessary Matrix Use parameter #6923

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions clang/lib/CodeGen/CodeGenTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,15 @@ void CodeGenTypes::addRecordTypeName(const RecordDecl *RD,
if (auto TemplateDecl = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
ArrayRef<TemplateArgument> TemplateArgs =
TemplateDecl->getTemplateArgs().asArray();
constexpr size_t NumOfMatrixParameters = 6;
const size_t TemplateArgsSize = TemplateArgs.size();
assert(TemplateArgsSize == NumOfMatrixParameters &&
"Incorrect number of template parameters for JointMatrixINTEL");
OS << "spirv.JointMatrixINTEL.";
for (auto &TemplateArg : TemplateArgs) {
OS << "_";
if (TemplateArg.getKind() == TemplateArgument::Type) {
llvm::Type *TTy = ConvertType(TemplateArg.getAsType());
for (size_t I = 0; I != TemplateArgsSize; ++I) {
if (TemplateArgs[I].getKind() == TemplateArgument::Type) {
OS << "_";
llvm::Type *TTy = ConvertType(TemplateArgs[I].getAsType());
if (TTy->isIntegerTy()) {
switch (TTy->getIntegerBitWidth()) {
case 8:
Expand Down Expand Up @@ -91,8 +95,16 @@ void CodeGenTypes::addRecordTypeName(const RecordDecl *RD,
OS << LlvmTyName;
} else
TTy->print(OS, false, true);
} else if (TemplateArg.getKind() == TemplateArgument::Integral)
OS << TemplateArg.getAsIntegral();
} else if (TemplateArgs[I].getKind() == TemplateArgument::Integral) {
const auto IntTemplateParam = TemplateArgs[I].getAsIntegral();
// Last template parameter of __spirv_JointMatrixINTEL 'Use' is
// optional in SPIR-V, so If it has 'Unnecessary' value - skip it.
// MatrixUse::Unnecessary defined as '3' in spirv_types.hpp.
constexpr size_t Unnecessary = 3;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding can you explain this? What is the significance of '3' here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the significance of '3' here?

comment would be helpful. Thank you

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the comment

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still a little confused. My question was why is "3" = Unecessary. Is this defined by the spec? What are possible legal values for Use parameter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, there is no Unnecessary in both SPIR-V and SYCL spec. Legal values on user-visible API level are 0, 1 and 2 (same for SPIR-V). Yet Unnecessary is defined in intermediate level of SYCL headers to keep compatibility with previous version of the extension, where we didn't have Use parameter making to be a default Use value.

While answering your question another idea came up in my mind, which might be cleaner. What if we declare struct __spirv_JointMatrixINTEL depending on SYCL_EXT_ONEAPI_MATRIX , WDYT @yubingex007-a11y @dkhaldi ? Like this:

#if SYCL_EXT_ONEAPI_MATRIX == 1
template <typename T, std::size_t R, std::size_t C, MatrixLayout L,
          Scope::Flag S = Scope::Flag::Subgroup>
struct __spirv_JointMatrixINTEL {
  T(*Value)
  [R][C][static_cast<size_t>(L) + 1][static_cast<size_t>(S) + 1];
};
#else
template <typename T, std::size_t R, std::size_t C, MatrixLayout L,
          Scope::Flag S = Scope::Flag::Subgroup,
          MatrixUse U = MatrixUse::Unnecessary>
struct __spirv_JointMatrixINTEL {
  T(*Value)
  [R][C][static_cast<size_t>(L) + 1][static_cast<size_t>(S) + 1]
     [static_cast<size_t>(U) + 1];
};
#endif // SYCL_EXT_ONEAPI_MATRIX

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this idea as long this does not break current SPIRV (use is still optional).

Copy link
Contributor Author

@MrSidims MrSidims Oct 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eh, there is quite a lot of things to change in the headers with a lot of code duplication to make it working. @elizabethandrews are we good to have the patch in clang or additional code commentary is required?
Basically every code which contains __spirv_JointMatrixINTEL should be duplicated and it includes all SPIR-V matrix instructions (10 of them) declarations with their appropriate calls in high-level SYCL API. So it does look for me that this clang modification is easier to support then two copies of the matrix-related code in the headers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do something like:

template <typename T, std::size_t R, std::size_t C, MatrixLayout L,
          Scope::Flag S = Scope::Flag::Subgroup
#if SYCL_EXT_ONEAPI_MATRIX == 1
          ,MatrixUse U = MatrixUse::Unnecessary>
#else
>
#endif
struct __spirv_JointMatrixINTEL {
  T(*Value)
  [R][C][static_cast<size_t>(L) + 1][static_cast<size_t>(S) + 1]
#if SYCL_EXT_ONEAPI_MATRIX == 1
     [static_cast<size_t>(U) + 1];
#else
  ;
#endif // SYCL_EXT_ONEAPI_MATRIX
};

Also, I see that Unnecessary template parameter is a default value, why can't we just declare it? How it will be incompatible with the previous version of the extension?

if (!(I == NumOfMatrixParameters &&
IntTemplateParam == Unnecessary))
OS << "_" << IntTemplateParam;
}
}
Ty->setName(OS.str());
return;
Expand Down
17 changes: 10 additions & 7 deletions clang/test/CodeGenSYCL/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@
#include <stdint.h>

namespace __spv {
template <typename T, size_t R, size_t C, uint32_t U, uint32_t S>
template <typename T, size_t R, size_t C, uint32_t L, uint32_t S, uint32_t U>
struct __spirv_JointMatrixINTEL;
}

// CHECK: @_Z2f1{{.*}}(%spirv.JointMatrixINTEL._float_5_10_0_1
void f1(__spv::__spirv_JointMatrixINTEL<float, 5, 10, 0, 1> *matrix) {}
void f1(__spv::__spirv_JointMatrixINTEL<float, 5, 10, 0, 1, 3> *matrix) {}

// CHECK: @_Z2f2{{.*}}(%spirv.JointMatrixINTEL._long_10_2_0_0
void f2(__spv::__spirv_JointMatrixINTEL<uint64_t, 10, 2, 0, 0> *matrix) {}
void f2(__spv::__spirv_JointMatrixINTEL<uint64_t, 10, 2, 0, 0, 3> *matrix) {}

// CHECK: @_Z2f3{{.*}}(%spirv.JointMatrixINTEL._char_10_2_0_0
void f3(__spv::__spirv_JointMatrixINTEL<char, 10, 2, 0, 0> *matrix) {}
void f3(__spv::__spirv_JointMatrixINTEL<char, 10, 2, 0, 0, 3> *matrix) {}

namespace sycl {
class half {};
Expand All @@ -25,10 +25,13 @@ namespace sycl {
typedef sycl::half my_half;

// CHECK: @_Z2f4{{.*}}(%spirv.JointMatrixINTEL._half_10_2_0_0
void f4(__spv::__spirv_JointMatrixINTEL<my_half, 10, 2, 0, 0> *matrix) {}
void f4(__spv::__spirv_JointMatrixINTEL<my_half, 10, 2, 0, 0, 3> *matrix) {}

// CHECK: @_Z2f5{{.*}}(%spirv.JointMatrixINTEL._bfloat16_10_2_0_0
void f5(__spv::__spirv_JointMatrixINTEL<sycl::bfloat16, 10, 2, 0, 0> *matrix) {}
void f5(__spv::__spirv_JointMatrixINTEL<sycl::bfloat16, 10, 2, 0, 0, 3> *matrix) {}

// CHECK: @_Z2f6{{.*}}(%spirv.JointMatrixINTEL._i128_10_2_0_0
void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0> *matrix) {}
void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0, 3> *matrix) {}

// CHECK: @_Z2f7{{.*}}(%spirv.JointMatrixINTEL._char_10_2_0_0_1
void f7(__spv::__spirv_JointMatrixINTEL<char, 10, 2, 0, 0, 1> *matrix) {}