Skip to content

[SYCL][SPIR-V] Change the LLVM type name of SPIR-V matrix types. #6535

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions clang/lib/CodeGen/CodeGenTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,54 @@ void CodeGenTypes::addRecordTypeName(const RecordDecl *RD,
StringRef suffix) {
SmallString<256> TypeName;
llvm::raw_svector_ostream OS(TypeName);
// If RD is spirv_JointMatrixINTEL type, mangle differently.
if (CGM.getTriple().isSPIRV() || CGM.getTriple().isSPIR()) {
if (RD->getQualifiedNameAsString() == "__spv::__spirv_JointMatrixINTEL") {
if (auto TemplateDecl = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
ArrayRef<TemplateArgument> TemplateArgs =
TemplateDecl->getTemplateArgs().asArray();
OS << "spirv.JointMatrixINTEL.";
for (auto &TemplateArg : TemplateArgs) {
OS << "_";
if (TemplateArg.getKind() == TemplateArgument::Type) {
llvm::Type *TTy = ConvertType(TemplateArg.getAsType());
if (TTy->isIntegerTy()) {
switch (TTy->getIntegerBitWidth()) {
case 8:
OS << "char";
break;
case 16:
OS << "short";
break;
case 32:
OS << "int";
break;
case 64:
OS << "long";
break;
default:
OS << "i" << TTy->getIntegerBitWidth();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__int128 is not legal on SPIR-V backend, and I'm not sure they are legal types for the JointMatrixINTEL type in any case.

Should this default case be changed to an assertion then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little wary putting an assert this deep in the backend when the frontend might allow types to typecheck correctly.

I did find a way to get an i128 type in the test, though not for __int128--I used _BitInt(128) instead.

break;
}
} else if (TTy->isBFloatTy())
OS << "bfloat16";
else if (TTy->isStructTy()) {
StringRef LlvmTyName = TTy->getStructName();
// Emit half/bfloat16 for sycl[::*]::{half,bfloat16}
if (LlvmTyName.startswith("class.sycl::") ||
LlvmTyName.startswith("class.__sycl_internal::"))
LlvmTyName = LlvmTyName.rsplit("::").second;
OS << LlvmTyName;
} else
TTy->print(OS, false, true);
} else if (TemplateArg.getKind() == TemplateArgument::Integral)
OS << TemplateArg.getAsIntegral();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and here

}
Ty->setName(OS.str());
return;
}
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly these two as well

OS << RD->getKindName() << '.';

// FIXME: We probably want to make more tweaks to the printing policy. For
Expand Down
34 changes: 34 additions & 0 deletions clang/test/CodeGenSYCL/matrix.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// RUN: %clang_cc1 -triple spir64-unknown-unknown -disable-llvm-passes -emit-llvm %s -o - -no-opaque-pointers | FileCheck %s
// Test that SPIR-V codegen generates the expected LLVM struct name for the
// JointMatrixINTEL type.
#include <stddef.h>
#include <stdint.h>

namespace __spv {
template <typename T, size_t R, size_t C, uint32_t U, uint32_t S>
struct __spirv_JointMatrixINTEL;
}

// CHECK: @_Z2f1{{.*}}(%spirv.JointMatrixINTEL._float_5_10_0_1
void f1(__spv::__spirv_JointMatrixINTEL<float, 5, 10, 0, 1> *matrix) {}

// CHECK: @_Z2f2{{.*}}(%spirv.JointMatrixINTEL._long_10_2_0_0
void f2(__spv::__spirv_JointMatrixINTEL<uint64_t, 10, 2, 0, 0> *matrix) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test case for the larger that 64 int sizes too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__int128 is not legal on SPIR-V backend, and I'm not sure they are legal types for the JointMatrixINTEL type in any case.

Copy link
Contributor

@MrSidims MrSidims Aug 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's legal to use ExtInt (as well as int128) in SPIR-V with an arbitrary width under ArbitraryPrecisionIntegersINTEL capability. Also there are no restrictions about integer width for element type of the JointMatrix.
So may be it worth it add a case, like:
void f2(__spv::__spirv_JointMatrixINTEL<_ExtInt(42), 10, 2, 0, 0> *matrix) {}
to see if we get ._i42_10_2_0_0 postfix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did find that it was possible to use _ExtInt in a test and added it in a later patch, albeit for _ExtInt(128).


// CHECK: @_Z2f3{{.*}}(%spirv.JointMatrixINTEL._char_10_2_0_0
void f3(__spv::__spirv_JointMatrixINTEL<char, 10, 2, 0, 0> *matrix) {}

namespace sycl {
class half {};
class bfloat16 {};
}
typedef sycl::half my_half;

// CHECK: @_Z2f4{{.*}}(%spirv.JointMatrixINTEL._half_10_2_0_0
void f4(__spv::__spirv_JointMatrixINTEL<my_half, 10, 2, 0, 0> *matrix) {}

// CHECK: @_Z2f5{{.*}}(%spirv.JointMatrixINTEL._bfloat16_10_2_0_0
void f5(__spv::__spirv_JointMatrixINTEL<sycl::bfloat16, 10, 2, 0, 0> *matrix) {}

// CHECK: @_Z2f6{{.*}}(%spirv.JointMatrixINTEL._i128_10_2_0_0
void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0> *matrix) {}
6 changes: 3 additions & 3 deletions sycl/test/matrix/matrix-int8-test.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// RUN: %clangxx -fsycl -fsycl-device-only -O2 -S -emit-llvm -o - %s | FileCheck %s

// CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL" = type { [12 x [48 x [1 x [4 x i8]]]] addrspace(4)* }
// CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL.[[#]]" = type { [12 x [12 x [1 x [4 x i32]]]] addrspace(4)* }
// CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL.[[#]]" = type { [48 x [12 x [4 x [4 x i8]]]] addrspace(4)* }
// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_0_3 = type { [12 x [48 x [1 x [4 x i8]]]] addrspace(4)* }
// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_0_3 = type { [12 x [12 x [1 x [4 x i32]]]] addrspace(4)* }
// CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_3_3 = type { [48 x [12 x [4 x [4 x i8]]]] addrspace(4)* }
Comment on lines +3 to +5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment would be helpful here. Thanks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a quick look through these tests, and I'm not sure I understand them well enough to put accurate comments as to what they are doing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neither do I. I think it would be unfair to gate this PR because of this - the change here looks local enough and most probably harmless.

I'd still prefer the original author of this test ( @yubingex007-a11y ) to follow up with a PR adding comments explaining what this one (and other added together with it) are supposed to verify.


#include <sycl/sycl.hpp>
#if (SYCL_EXT_ONEAPI_MATRIX == 2)
Expand Down