-
Notifications
You must be signed in to change notification settings - Fork 792
[SYCL][SPIR-V] Change the LLVM type name of SPIR-V matrix types. #6535
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6a449f2
da02eb3
53a12bc
290b8d2
a6e689c
cbf7da4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,6 +51,54 @@ void CodeGenTypes::addRecordTypeName(const RecordDecl *RD, | |
StringRef suffix) { | ||
SmallString<256> TypeName; | ||
llvm::raw_svector_ostream OS(TypeName); | ||
// If RD is spirv_JointMatrixINTEL type, mangle differently. | ||
if (CGM.getTriple().isSPIRV() || CGM.getTriple().isSPIR()) { | ||
if (RD->getQualifiedNameAsString() == "__spv::__spirv_JointMatrixINTEL") { | ||
if (auto TemplateDecl = dyn_cast<ClassTemplateSpecializationDecl>(RD)) { | ||
ArrayRef<TemplateArgument> TemplateArgs = | ||
TemplateDecl->getTemplateArgs().asArray(); | ||
OS << "spirv.JointMatrixINTEL."; | ||
for (auto &TemplateArg : TemplateArgs) { | ||
OS << "_"; | ||
if (TemplateArg.getKind() == TemplateArgument::Type) { | ||
llvm::Type *TTy = ConvertType(TemplateArg.getAsType()); | ||
if (TTy->isIntegerTy()) { | ||
switch (TTy->getIntegerBitWidth()) { | ||
case 8: | ||
OS << "char"; | ||
break; | ||
case 16: | ||
OS << "short"; | ||
break; | ||
case 32: | ||
OS << "int"; | ||
break; | ||
case 64: | ||
OS << "long"; | ||
break; | ||
default: | ||
OS << "i" << TTy->getIntegerBitWidth(); | ||
break; | ||
} | ||
} else if (TTy->isBFloatTy()) | ||
OS << "bfloat16"; | ||
else if (TTy->isStructTy()) { | ||
StringRef LlvmTyName = TTy->getStructName(); | ||
// Emit half/bfloat16 for sycl[::*]::{half,bfloat16} | ||
if (LlvmTyName.startswith("class.sycl::") || | ||
LlvmTyName.startswith("class.__sycl_internal::")) | ||
LlvmTyName = LlvmTyName.rsplit("::").second; | ||
OS << LlvmTyName; | ||
} else | ||
TTy->print(OS, false, true); | ||
} else if (TemplateArg.getKind() == TemplateArgument::Integral) | ||
OS << TemplateArg.getAsIntegral(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and here |
||
} | ||
Ty->setName(OS.str()); | ||
return; | ||
} | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Possibly these two as well |
||
OS << RD->getKindName() << '.'; | ||
|
||
// FIXME: We probably want to make more tweaks to the printing policy. For | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
// RUN: %clang_cc1 -triple spir64-unknown-unknown -disable-llvm-passes -emit-llvm %s -o - -no-opaque-pointers | FileCheck %s | ||
// Test that SPIR-V codegen generates the expected LLVM struct name for the | ||
// JointMatrixINTEL type. | ||
#include <stddef.h> | ||
#include <stdint.h> | ||
elizabethandrews marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
namespace __spv { | ||
template <typename T, size_t R, size_t C, uint32_t U, uint32_t S> | ||
struct __spirv_JointMatrixINTEL; | ||
} | ||
|
||
// CHECK: @_Z2f1{{.*}}(%spirv.JointMatrixINTEL._float_5_10_0_1 | ||
void f1(__spv::__spirv_JointMatrixINTEL<float, 5, 10, 0, 1> *matrix) {} | ||
|
||
// CHECK: @_Z2f2{{.*}}(%spirv.JointMatrixINTEL._long_10_2_0_0 | ||
void f2(__spv::__spirv_JointMatrixINTEL<uint64_t, 10, 2, 0, 0> *matrix) {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a test case for the larger that 64 int sizes too? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's legal to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did find that it was possible to use |
||
|
||
// CHECK: @_Z2f3{{.*}}(%spirv.JointMatrixINTEL._char_10_2_0_0 | ||
void f3(__spv::__spirv_JointMatrixINTEL<char, 10, 2, 0, 0> *matrix) {} | ||
|
||
namespace sycl { | ||
class half {}; | ||
class bfloat16 {}; | ||
} | ||
typedef sycl::half my_half; | ||
|
||
// CHECK: @_Z2f4{{.*}}(%spirv.JointMatrixINTEL._half_10_2_0_0 | ||
void f4(__spv::__spirv_JointMatrixINTEL<my_half, 10, 2, 0, 0> *matrix) {} | ||
|
||
// CHECK: @_Z2f5{{.*}}(%spirv.JointMatrixINTEL._bfloat16_10_2_0_0 | ||
void f5(__spv::__spirv_JointMatrixINTEL<sycl::bfloat16, 10, 2, 0, 0> *matrix) {} | ||
|
||
// CHECK: @_Z2f6{{.*}}(%spirv.JointMatrixINTEL._i128_10_2_0_0 | ||
void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0> *matrix) {} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
// RUN: %clangxx -fsycl -fsycl-device-only -O2 -S -emit-llvm -o - %s | FileCheck %s | ||
|
||
// CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL" = type { [12 x [48 x [1 x [4 x i8]]]] addrspace(4)* } | ||
// CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL.[[#]]" = type { [12 x [12 x [1 x [4 x i32]]]] addrspace(4)* } | ||
// CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL.[[#]]" = type { [48 x [12 x [4 x [4 x i8]]]] addrspace(4)* } | ||
// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_0_3 = type { [12 x [48 x [1 x [4 x i8]]]] addrspace(4)* } | ||
// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_0_3 = type { [12 x [12 x [1 x [4 x i32]]]] addrspace(4)* } | ||
// CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_3_3 = type { [48 x [12 x [4 x [4 x i8]]]] addrspace(4)* } | ||
Comment on lines
+3
to
+5
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comment would be helpful here. Thanks There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I took a quick look through these tests, and I'm not sure I understand them well enough to put accurate comments as to what they are doing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Neither do I. I think it would be unfair to gate this PR because of this - the change here looks local enough and most probably harmless. I'd still prefer the original author of this test ( @yubingex007-a11y ) to follow up with a PR adding comments explaining what this one (and other added together with it) are supposed to verify. |
||
|
||
#include <sycl/sycl.hpp> | ||
#if (SYCL_EXT_ONEAPI_MATRIX == 2) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this default case be changed to an assertion then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a little wary putting an assert this deep in the backend when the frontend might allow types to typecheck correctly.
I did find a way to get an i128 type in the test, though not for __int128--I used
_BitInt(128)
instead.