Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL] Represent JointMatrixINTEL type as extension type #8343

Merged
merged 8 commits into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 84 additions & 59 deletions clang/lib/CodeGen/CodeGenTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,65 +51,6 @@ void CodeGenTypes::addRecordTypeName(const RecordDecl *RD,
StringRef suffix) {
SmallString<256> TypeName;
llvm::raw_svector_ostream OS(TypeName);
// If RD is spirv_JointMatrixINTEL type, mangle differently.
if (CGM.getTriple().isSPIRV() || CGM.getTriple().isSPIR()) {
if (RD->getQualifiedNameAsString() == "__spv::__spirv_JointMatrixINTEL") {
if (auto TemplateDecl = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
ArrayRef<TemplateArgument> TemplateArgs =
TemplateDecl->getTemplateArgs().asArray();
OS << "spirv.JointMatrixINTEL.";
for (auto &TemplateArg : TemplateArgs) {
OS << "_";
if (TemplateArg.getKind() == TemplateArgument::Type) {
llvm::Type *TTy = ConvertType(TemplateArg.getAsType());
if (TTy->isIntegerTy()) {
switch (TTy->getIntegerBitWidth()) {
case 8:
OS << "char";
break;
case 16:
OS << "short";
break;
case 32:
OS << "int";
break;
case 64:
OS << "long";
break;
default:
OS << "i" << TTy->getIntegerBitWidth();
break;
}
} else if (TTy->isHalfTy()) {
OS << "half";
} else if (TTy->isFloatTy()) {
OS << "float";
} else if (TTy->isDoubleTy()) {
OS << "double";
} else if (TTy->isBFloatTy()) {
OS << "bfloat16";
} else if (TTy->isStructTy()) {
StringRef LlvmTyName = TTy->getStructName();
// Emit half/bfloat16/tf32 for sycl[::*]::{half,bfloat16,tf32}
if (LlvmTyName.startswith("class.sycl::") ||
LlvmTyName.startswith("class.__sycl_internal::"))
LlvmTyName = LlvmTyName.rsplit("::").second;
if (LlvmTyName != "half" && LlvmTyName != "bfloat16" &&
LlvmTyName != "tf32")
llvm_unreachable("Wrong matrix base type!");
OS << LlvmTyName;
} else {
llvm_unreachable("Wrong matrix base type!");
}
} else if (TemplateArg.getKind() == TemplateArgument::Integral) {
OS << TemplateArg.getAsIntegral();
}
}
Ty->setName(OS.str());
return;
}
}
}
OS << RD->getKindName() << '.';

// FIXME: We probably want to make more tweaks to the printing policy. For
Expand Down Expand Up @@ -460,6 +401,78 @@ llvm::Type *CodeGenTypes::ConvertFunctionTypeInternal(QualType QFT) {
return ResultType;
}

template <bool NeedTypeInterpret = false>
llvm::Type *getJointMatrixINTELExtType(llvm::Type *CompTy,
ArrayRef<TemplateArgument> TemplateArgs,
const unsigned Val = 0) {
// TODO: we should actually have exactly 5 template parameters: 1 for
// type and 4 for type parameters. But in previous version of the SPIR-V
// spec we have Layout matrix type parameter, that was later removed.
// Once we update to the newest version of the spec - this should be updated.
assert((TemplateArgs.size() == 5 || TemplateArgs.size() == 6) &&
"Wrong JointMatrixINTEL template parameters number");
// This is required to represent optional Optional
MrSidims marked this conversation as resolved.
Show resolved Hide resolved
// 'Component Type Interpretation' parameter
using ParamsType =
typename std::conditional<NeedTypeInterpret, SmallVector<unsigned, 6>,
SmallVector<unsigned, 5>>::type;
ParamsType Params;
if constexpr (NeedTypeInterpret)
Params = {0, 0, 0, 0, 0, Val};
else
Params = {0, 0, 0, 0, 0};
for (size_t I = 1; I != TemplateArgs.size(); ++I) {
assert(TemplateArgs[I].getKind() == TemplateArgument::Integral &&
"Wrong JointMatrixINTEL template parameter");
Params[I - 1] = TemplateArgs[I].getAsIntegral().getExtValue();
}
return llvm::TargetExtType::get(CompTy->getContext(),
"spirv.JointMatrixINTEL", {CompTy}, Params);
}

/// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type
/// which is represented as a pointer to a structure to LLVM extension type
/// with the parameters that follow SPIR-V JointMatrixINTEL type.
/// The expected representation is:
/// target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%,
/// %use%, (optional) %element_type_interpretation%)
llvm::Type *CodeGenTypes::ConvertSYCLJointMatrixINTELType(RecordDecl *RD) {
auto *TemplateDecl = cast<ClassTemplateSpecializationDecl>(RD);
ArrayRef<TemplateArgument> TemplateArgs =
TemplateDecl->getTemplateArgs().asArray();
assert(TemplateArgs[0].getKind() == TemplateArgument::Type &&
"1st JointMatrixINTEL template parameter must be type");
llvm::Type *CompTy = ConvertType(TemplateArgs[0].getAsType());

// Per JointMatrixINTEL spec the type can have an Optional
MrSidims marked this conversation as resolved.
Show resolved Hide resolved
// 'Component Type Interpretation' parameter. We should emit it in case
// if on SYCL level joint matrix accepts 'bfloat16' or 'tf32' objects as
// matrix's components. Yet bfloat16 should be represented as 'int16' and
MrSidims marked this conversation as resolved.
Show resolved Hide resolved
// 'tf32' as 'float' types.
if (CompTy->isStructTy()) {
StringRef LlvmTyName = CompTy->getStructName();
// Emit half/int16/float for sycl[::*]::{half,bfloat16,tf32}
if (LlvmTyName.startswith("class.sycl::") ||
LlvmTyName.startswith("class.__sycl_internal::"))
LlvmTyName = LlvmTyName.rsplit("::").second;
if (LlvmTyName == "half") {
CompTy = llvm::Type::getHalfTy(getLLVMContext());
return getJointMatrixINTELExtType(CompTy, TemplateArgs);
} else if (LlvmTyName == "tf32") {
CompTy = llvm::Type::getFloatTy(getLLVMContext());
// 'tf32' interpretation is mapped to '0'
return getJointMatrixINTELExtType<true>(CompTy, TemplateArgs, 0);
} else if (LlvmTyName == "bfloat16") {
CompTy = llvm::Type::getInt16Ty(getLLVMContext());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LLVM has a bfloat type, so surely it would make more sense to use that instead of i16? Or is there something about the JointMatrixINTEL spec that I'm missing?

Copy link
Contributor Author

@MrSidims MrSidims Feb 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is that in SPIR-V we don't have bfloat type (yet?). Instead we have conversion instructions and using short value as bfloat16 storage. Same for SYCL - here bfloat16 class is a wrapper around int16 storage. So for the consistency int16 is used here as well, though there is no other reason not to reuse LLVM's bf16 (just because we still will replace it with int16 in SPIR-V). So I lean towards appying this suggestion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, lets keep i16 to be aligned with SPIR-V spec. Once we have proper SPIR-V bf16 type we can start generating in by DPC++ compiler either here or elsewhere.

// 'bfloat16' interpretation is mapped to '1'
return getJointMatrixINTELExtType<true>(CompTy, TemplateArgs, 1);
} else {
llvm_unreachable("Wrong matrix base type!");
}
}
return getJointMatrixINTELExtType(CompTy, TemplateArgs);
}

/// ConvertType - Convert the specified type to its LLVM form.
llvm::Type *CodeGenTypes::ConvertType(QualType T) {
T = Context.getCanonicalType(T);
Expand Down Expand Up @@ -745,6 +758,18 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
llvm::Type *PointeeType = ConvertTypeForMem(ETy);
if (PointeeType->isVoidTy())
PointeeType = llvm::Type::getInt8Ty(getLLVMContext());
if (CGM.getTriple().isSPIRV() || CGM.getTriple().isSPIR()) {
const Type *ClangETy = ETy.getTypePtrOrNull();
if (ClangETy && ClangETy->isStructureOrClassType()) {
RecordDecl *RD = ClangETy->getAsCXXRecordDecl();
if (RD &&
RD->getQualifiedNameAsString() == "__spv::__spirv_JointMatrixINTEL") {
ResultType = ConvertSYCLJointMatrixINTELType(RD);
break;
}
}
}

unsigned AS = getTargetAddressSpace(ETy);
ResultType = llvm::PointerType::get(PointeeType, AS);
break;
Expand Down
8 changes: 8 additions & 0 deletions clang/lib/CodeGen/CodeGenTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,14 @@ class CodeGenTypes {
/// memory representation is usually i8 or i32, depending on the target.
llvm::Type *ConvertTypeForMem(QualType T, bool ForBitField = false);

/// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type
/// which is represented as a pointer to a structure to LLVM extension type
/// with the parameters that follow SPIR-V JointMatrixINTEL type.
/// The expected representation is:
/// target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%,
/// %use%, (optional) %element_type_interpretation%)
llvm::Type *ConvertSYCLJointMatrixINTELType(RecordDecl *RD);

/// GetFunctionType - Get the LLVM function type for \arg Info.
llvm::FunctionType *GetFunctionType(const CGFunctionInfo &Info);

Expand Down
34 changes: 17 additions & 17 deletions clang/test/CodeGenSYCL/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@
#include <stdint.h>

namespace __spv {
template <typename T, size_t R, size_t C, uint32_t U, uint32_t S>
template <typename T, size_t R, size_t C, uint32_t L, uint32_t S, uint32_t U>
struct __spirv_JointMatrixINTEL;
}

// CHECK: @_Z2f1{{.*}}(%spirv.JointMatrixINTEL._float_5_10_0_1
void f1(__spv::__spirv_JointMatrixINTEL<float, 5, 10, 0, 1> *matrix) {}
// CHECK: @_Z2f1{{.*}}(target("spirv.JointMatrixINTEL", float, 5, 10, 0, 1, 0)
void f1(__spv::__spirv_JointMatrixINTEL<float, 5, 10, 0, 1, 0> *matrix) {}

// CHECK: @_Z2f2{{.*}}(%spirv.JointMatrixINTEL._long_10_2_0_0
void f2(__spv::__spirv_JointMatrixINTEL<uint64_t, 10, 2, 0, 0> *matrix) {}
// CHECK: @_Z2f2{{.*}}(target("spirv.JointMatrixINTEL", i64, 10, 2, 0, 0, 0)
void f2(__spv::__spirv_JointMatrixINTEL<uint64_t, 10, 2, 0, 0, 0> *matrix) {}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aelovikov-intel here is the test for unsigned. Would you mind if I won't duplicate it in sycl headers?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I'm not sure what's the purpose of the SYCL RT tests then, but yes, no need for unsigned there.

Copy link
Contributor Author

@MrSidims MrSidims Apr 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, matrix API is changing overtime, so having several compilation tests using real headers is good to have. One the extension is stable we will remove them (note, the E2E tests were in different repo and require specific hardware)


// CHECK: @_Z2f3{{.*}}(%spirv.JointMatrixINTEL._char_10_2_0_0
void f3(__spv::__spirv_JointMatrixINTEL<char, 10, 2, 0, 0> *matrix) {}
// CHECK: @_Z2f3{{.*}}(target("spirv.JointMatrixINTEL", i8, 10, 2, 0, 0, 0)
void f3(__spv::__spirv_JointMatrixINTEL<char, 10, 2, 0, 0, 0> *matrix) {}

namespace sycl {
class half {};
Expand All @@ -25,17 +25,17 @@ namespace sycl {
}
typedef sycl::half my_half;

// CHECK: @_Z2f4{{.*}}(%spirv.JointMatrixINTEL._half_10_2_0_0
void f4(__spv::__spirv_JointMatrixINTEL<my_half, 10, 2, 0, 0> *matrix) {}
// CHECK: @_Z2f4{{.*}}(target("spirv.JointMatrixINTEL", half, 10, 2, 0, 0, 0)
void f4(__spv::__spirv_JointMatrixINTEL<my_half, 10, 2, 0, 0, 0> *matrix) {}

// CHECK: @_Z2f5{{.*}}(%spirv.JointMatrixINTEL._bfloat16_10_2_0_0
void f5(__spv::__spirv_JointMatrixINTEL<sycl::bfloat16, 10, 2, 0, 0> *matrix) {}
// CHECK: @_Z2f5{{.*}}(target("spirv.JointMatrixINTEL", i16, 10, 2, 0, 0, 0, 1)
void f5(__spv::__spirv_JointMatrixINTEL<sycl::bfloat16, 10, 2, 0, 0, 0> *matrix) {}

// CHECK: @_Z2f6{{.*}}(%spirv.JointMatrixINTEL._i128_10_2_0_0
void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0> *matrix) {}
// CHECK: @_Z2f6{{.*}}(target("spirv.JointMatrixINTEL", i128, 10, 2, 0, 0, 0)
void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0, 0> *matrix) {}

// CHECK: @_Z2f7{{.*}}(%spirv.JointMatrixINTEL._tf32_10_2_0_0
void f7(__spv::__spirv_JointMatrixINTEL<sycl::tf32, 10, 2, 0, 0> *matrix) {}
// CHECK: @_Z2f7{{.*}}(target("spirv.JointMatrixINTEL", float, 10, 2, 0, 0, 0, 0)
void f7(__spv::__spirv_JointMatrixINTEL<sycl::tf32, 10, 2, 0, 0, 0> *matrix) {}

// CHECK: @_Z2f8{{.*}}(%spirv.JointMatrixINTEL._double_5_10_0_1
void f8(__spv::__spirv_JointMatrixINTEL<double, 5, 10, 0, 1> *matrix) {}
// CHECK: @_Z2f8{{.*}}(target("spirv.JointMatrixINTEL", double, 5, 10, 0, 1, 0)
void f8(__spv::__spirv_JointMatrixINTEL<double, 5, 10, 0, 1, 0> *matrix) {}
6 changes: 3 additions & 3 deletions sycl/test/matrix/legacy/matrix-int8-test.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// RUN: %clangxx -fsycl -fsycl-device-only -O2 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 -S -emit-llvm -o - %s | FileCheck %s

// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_0_3 = type opaque
// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_0_3 = type opaque
// CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_3_3 = type opaque
// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 12, 48, 0, 3, 0)
aelovikov-intel marked this conversation as resolved.
Show resolved Hide resolved
// CHECK-DAG: target("spirv.JointMatrixINTEL", i32, 12, 12, 0, 3, 0)
// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 48, 12, 3, 3, 0)

#include <iostream>
#include <sycl/sycl.hpp>
Expand Down
6 changes: 3 additions & 3 deletions sycl/test/matrix/matrix-int8-test.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// RUN: %clangxx -fsycl -fsycl-device-only -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -O2 -S -emit-llvm -o - %s | FileCheck %s

// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_0_3_0 = type opaque
// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_3_3_2 = type opaque
// CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_2_3_1 = type opaque
// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 12, 48, 0, 3, 0)
aelovikov-intel marked this conversation as resolved.
Show resolved Hide resolved
// CHECK-DAG: target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2)
// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 48, 12, 2, 3, 1)

#include <iostream>
#include <sycl/sycl.hpp>
Expand Down