Skip to content

[SYCL] Re-land "Represent JointMatrixINTEL type as extension type" #9841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 83 additions & 59 deletions clang/lib/CodeGen/CodeGenTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,65 +51,6 @@ void CodeGenTypes::addRecordTypeName(const RecordDecl *RD,
StringRef suffix) {
SmallString<256> TypeName;
llvm::raw_svector_ostream OS(TypeName);
// If RD is spirv_JointMatrixINTEL type, mangle differently.
if (CGM.getTriple().isSPIRV() || CGM.getTriple().isSPIR()) {
if (RD->getQualifiedNameAsString() == "__spv::__spirv_JointMatrixINTEL") {
if (auto TemplateDecl = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
ArrayRef<TemplateArgument> TemplateArgs =
TemplateDecl->getTemplateArgs().asArray();
OS << "spirv.JointMatrixINTEL.";
for (auto &TemplateArg : TemplateArgs) {
OS << "_";
if (TemplateArg.getKind() == TemplateArgument::Type) {
llvm::Type *TTy = ConvertType(TemplateArg.getAsType());
if (TTy->isIntegerTy()) {
switch (TTy->getIntegerBitWidth()) {
case 8:
OS << "char";
break;
case 16:
OS << "short";
break;
case 32:
OS << "int";
break;
case 64:
OS << "long";
break;
default:
OS << "i" << TTy->getIntegerBitWidth();
break;
}
} else if (TTy->isHalfTy()) {
OS << "half";
} else if (TTy->isFloatTy()) {
OS << "float";
} else if (TTy->isDoubleTy()) {
OS << "double";
} else if (TTy->isBFloatTy()) {
OS << "bfloat16";
} else if (TTy->isStructTy()) {
StringRef LlvmTyName = TTy->getStructName();
// Emit half/bfloat16/tf32 for sycl[::*]::{half,bfloat16,tf32}
if (LlvmTyName.startswith("class.sycl::") ||
LlvmTyName.startswith("class.__sycl_internal::"))
LlvmTyName = LlvmTyName.rsplit("::").second;
if (LlvmTyName != "half" && LlvmTyName != "bfloat16" &&
LlvmTyName != "tf32")
llvm_unreachable("Wrong matrix base type!");
OS << LlvmTyName;
} else {
llvm_unreachable("Wrong matrix base type!");
}
} else if (TemplateArg.getKind() == TemplateArgument::Integral) {
OS << TemplateArg.getAsIntegral();
}
}
Ty->setName(OS.str());
return;
}
}
}
OS << RD->getKindName() << '.';

// FIXME: We probably want to make more tweaks to the printing policy. For
Expand Down Expand Up @@ -460,6 +401,77 @@ llvm::Type *CodeGenTypes::ConvertFunctionTypeInternal(QualType QFT) {
return ResultType;
}

template <bool NeedTypeInterpret = false>
llvm::Type *getJointMatrixINTELExtType(llvm::Type *CompTy,
ArrayRef<TemplateArgument> TemplateArgs,
const unsigned Val = 0) {
// TODO: we should actually have exactly 5 template parameters: 1 for
// type and 4 for type parameters. But in previous version of the SPIR-V
// spec we have Layout matrix type parameter, that was later removed.
// Once we update to the newest version of the spec - this should be updated.
assert((TemplateArgs.size() == 5 || TemplateArgs.size() == 6) &&
"Wrong JointMatrixINTEL template parameters number");
// This is required to represent optional 'Component Type Interpretation'
// parameter
std::vector<unsigned> Params;
for (size_t I = 1; I != TemplateArgs.size(); ++I) {
assert(TemplateArgs[I].getKind() == TemplateArgument::Integral &&
"Wrong JointMatrixINTEL template parameter");
Params.push_back(TemplateArgs[I].getAsIntegral().getExtValue());
}
// Don't add type interpretation for legacy matrices.
// Legacy matrices has 5 template parameters, while new representation
// has 6.
if (NeedTypeInterpret && TemplateArgs.size() != 5)
Params.push_back(Val);

return llvm::TargetExtType::get(CompTy->getContext(),
"spirv.JointMatrixINTEL", {CompTy}, Params);
}

/// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type
/// which is represented as a pointer to a structure to LLVM extension type
/// with the parameters that follow SPIR-V JointMatrixINTEL type.
/// The expected representation is:
/// target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%,
/// %use%, (optional) %element_type_interpretation%)
llvm::Type *CodeGenTypes::ConvertSYCLJointMatrixINTELType(RecordDecl *RD) {
auto *TemplateDecl = cast<ClassTemplateSpecializationDecl>(RD);
ArrayRef<TemplateArgument> TemplateArgs =
TemplateDecl->getTemplateArgs().asArray();
assert(TemplateArgs[0].getKind() == TemplateArgument::Type &&
"1st JointMatrixINTEL template parameter must be type");
llvm::Type *CompTy = ConvertType(TemplateArgs[0].getAsType());

// Per JointMatrixINTEL spec the type can have an optional
// 'Component Type Interpretation' parameter. We should emit it in case
// if on SYCL level joint matrix accepts 'bfloat16' or 'tf32' objects as
// matrix's components. Yet 'bfloat16' should be represented as 'int16' and
// 'tf32' as 'float' types.
if (CompTy->isStructTy()) {
StringRef LlvmTyName = CompTy->getStructName();
// Emit half/int16/float for sycl[::*]::{half,bfloat16,tf32}
if (LlvmTyName.startswith("class.sycl::") ||
LlvmTyName.startswith("class.__sycl_internal::"))
LlvmTyName = LlvmTyName.rsplit("::").second;
if (LlvmTyName == "half") {
CompTy = llvm::Type::getHalfTy(getLLVMContext());
return getJointMatrixINTELExtType(CompTy, TemplateArgs);
} else if (LlvmTyName == "tf32") {
CompTy = llvm::Type::getFloatTy(getLLVMContext());
// 'tf32' interpretation is mapped to '0'
return getJointMatrixINTELExtType<true>(CompTy, TemplateArgs, 0);
} else if (LlvmTyName == "bfloat16") {
CompTy = llvm::Type::getInt16Ty(getLLVMContext());
// 'bfloat16' interpretation is mapped to '1'
return getJointMatrixINTELExtType<true>(CompTy, TemplateArgs, 1);
} else {
llvm_unreachable("Wrong matrix base type!");
}
}
return getJointMatrixINTELExtType(CompTy, TemplateArgs);
}

/// ConvertType - Convert the specified type to its LLVM form.
llvm::Type *CodeGenTypes::ConvertType(QualType T) {
T = Context.getCanonicalType(T);
Expand Down Expand Up @@ -754,6 +766,18 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
llvm::Type *PointeeType = ConvertTypeForMem(ETy);
if (PointeeType->isVoidTy())
PointeeType = llvm::Type::getInt8Ty(getLLVMContext());
if (CGM.getTriple().isSPIRV() || CGM.getTriple().isSPIR()) {
const Type *ClangETy = ETy.getTypePtrOrNull();
if (ClangETy && ClangETy->isStructureOrClassType()) {
RecordDecl *RD = ClangETy->getAsCXXRecordDecl();
if (RD && RD->getQualifiedNameAsString() ==
"__spv::__spirv_JointMatrixINTEL") {
ResultType = ConvertSYCLJointMatrixINTELType(RD);
break;
}
}
}

unsigned AS = getTargetAddressSpace(ETy);
ResultType = llvm::PointerType::get(PointeeType, AS);
break;
Expand Down
8 changes: 8 additions & 0 deletions clang/lib/CodeGen/CodeGenTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,14 @@ class CodeGenTypes {
/// memory representation is usually i8 or i32, depending on the target.
llvm::Type *ConvertTypeForMem(QualType T, bool ForBitField = false);

/// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type
/// which is represented as a pointer to a structure to LLVM extension type
/// with the parameters that follow SPIR-V JointMatrixINTEL type.
/// The expected representation is:
/// target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%,
/// %use%, (optional) %element_type_interpretation%)
llvm::Type *ConvertSYCLJointMatrixINTELType(RecordDecl *RD);

/// GetFunctionType - Get the LLVM function type for \arg Info.
llvm::FunctionType *GetFunctionType(const CGFunctionInfo &Info);

Expand Down
34 changes: 17 additions & 17 deletions clang/test/CodeGenSYCL/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@
#include <stdint.h>

namespace __spv {
template <typename T, size_t R, size_t C, uint32_t U, uint32_t S>
template <typename T, size_t R, size_t C, uint32_t L, uint32_t S, uint32_t U>
struct __spirv_JointMatrixINTEL;
}

// CHECK: @_Z2f1{{.*}}(%spirv.JointMatrixINTEL._float_5_10_0_1
void f1(__spv::__spirv_JointMatrixINTEL<float, 5, 10, 0, 1> *matrix) {}
// CHECK: @_Z2f1{{.*}}(target("spirv.JointMatrixINTEL", float, 5, 10, 0, 1, 0)
void f1(__spv::__spirv_JointMatrixINTEL<float, 5, 10, 0, 1, 0> *matrix) {}

// CHECK: @_Z2f2{{.*}}(%spirv.JointMatrixINTEL._long_10_2_0_0
void f2(__spv::__spirv_JointMatrixINTEL<uint64_t, 10, 2, 0, 0> *matrix) {}
// CHECK: @_Z2f2{{.*}}(target("spirv.JointMatrixINTEL", i64, 10, 2, 0, 0, 0)
void f2(__spv::__spirv_JointMatrixINTEL<uint64_t, 10, 2, 0, 0, 0> *matrix) {}

// CHECK: @_Z2f3{{.*}}(%spirv.JointMatrixINTEL._char_10_2_0_0
void f3(__spv::__spirv_JointMatrixINTEL<char, 10, 2, 0, 0> *matrix) {}
// CHECK: @_Z2f3{{.*}}(target("spirv.JointMatrixINTEL", i8, 10, 2, 0, 0, 0)
void f3(__spv::__spirv_JointMatrixINTEL<char, 10, 2, 0, 0, 0> *matrix) {}

namespace sycl {
class half {};
Expand All @@ -25,17 +25,17 @@ namespace sycl {
}
typedef sycl::half my_half;

// CHECK: @_Z2f4{{.*}}(%spirv.JointMatrixINTEL._half_10_2_0_0
void f4(__spv::__spirv_JointMatrixINTEL<my_half, 10, 2, 0, 0> *matrix) {}
// CHECK: @_Z2f4{{.*}}(target("spirv.JointMatrixINTEL", half, 10, 2, 0, 0, 0)
void f4(__spv::__spirv_JointMatrixINTEL<my_half, 10, 2, 0, 0, 0> *matrix) {}

// CHECK: @_Z2f5{{.*}}(%spirv.JointMatrixINTEL._bfloat16_10_2_0_0
void f5(__spv::__spirv_JointMatrixINTEL<sycl::bfloat16, 10, 2, 0, 0> *matrix) {}
// CHECK: @_Z2f5{{.*}}(target("spirv.JointMatrixINTEL", i16, 10, 2, 0, 0, 0, 1)
void f5(__spv::__spirv_JointMatrixINTEL<sycl::bfloat16, 10, 2, 0, 0, 0> *matrix) {}

// CHECK: @_Z2f6{{.*}}(%spirv.JointMatrixINTEL._i128_10_2_0_0
void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0> *matrix) {}
// CHECK: @_Z2f6{{.*}}(target("spirv.JointMatrixINTEL", i128, 10, 2, 0, 0, 0)
void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0, 0> *matrix) {}

// CHECK: @_Z2f7{{.*}}(%spirv.JointMatrixINTEL._tf32_10_2_0_0
void f7(__spv::__spirv_JointMatrixINTEL<sycl::tf32, 10, 2, 0, 0> *matrix) {}
// CHECK: @_Z2f7{{.*}}(target("spirv.JointMatrixINTEL", float, 10, 2, 0, 0, 0, 0)
void f7(__spv::__spirv_JointMatrixINTEL<sycl::tf32, 10, 2, 0, 0, 0> *matrix) {}

// CHECK: @_Z2f8{{.*}}(%spirv.JointMatrixINTEL._double_5_10_0_1
void f8(__spv::__spirv_JointMatrixINTEL<double, 5, 10, 0, 1> *matrix) {}
// CHECK: @_Z2f8{{.*}}(target("spirv.JointMatrixINTEL", double, 5, 10, 0, 1, 0)
void f8(__spv::__spirv_JointMatrixINTEL<double, 5, 10, 0, 1, 0> *matrix) {}
18 changes: 18 additions & 0 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,24 @@ struct joint_matrix {
get_wi_data() {
return wi_data<T, NumRows, NumCols, Layout, Group>(*this);
}

#ifdef __SYCL_DEVICE_ONLY__
#if defined(__SPIR__)
// Generate a non-trivial assignment operator and copy c'tor that prevents
// memcpy from being generated.
// TODO: to remove, when either IGC can handle alloca JointMatrix or
// combination of InstCombine + SROA + mem2reg can remove it
joint_matrix(const joint_matrix &other) {
spvm = other.spvm;
return *this;
}

joint_matrix &operator=(const joint_matrix &rhs) {
spvm = rhs.spvm;
return *this;
}
#endif // defined(__SPIR__)
#endif
};

template <typename Group, typename T, size_t NumRows, size_t NumCols,
Expand Down
17 changes: 17 additions & 0 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,23 @@ struct joint_matrix {
PI_ERROR_INVALID_DEVICE);
#endif
}
#ifdef __SYCL_DEVICE_ONLY__
#if defined(__SPIR__)
// Generate a non-trivial assignment operator and copy c'tor that prevents
// memcpy from being generated.
// TODO: to remove, when either IGC can handle alloca JointMatrix or
// combination of InstCombine + SROA + mem2reg can remove it
joint_matrix(const joint_matrix &other) {
spvm = other.spvm;
return *this;
}

joint_matrix &operator=(const joint_matrix &rhs) {
spvm = rhs.spvm;
return *this;
}
#endif // defined(__SPIR__)
#endif
};

#ifdef __SYCL_DEVICE_ONLY__
Expand Down
9 changes: 6 additions & 3 deletions sycl/test/check_device_code/matrix/matrix_load_store_as.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
// RUN: %clangxx -fsycl-device-only -S -emit-llvm -o - %s | FileCheck %s

// Check that SROA and mem2reg won't leave alloca of matrix type in IR
// CHECK-NOT: alloca target("spirv.JointMatrixINTEL"

// check that correct address spaces are used to load from and store to
#define SYCL_EXT_ONEAPI_MATRIX_VERSION 4
#include <sycl/sycl.hpp>
Expand Down Expand Up @@ -39,16 +42,16 @@ int main(void) {
it.barrier(access::fence_space::local_space);

// A should load from local address space
// CHECK: %{{.*}} = tail call spir_func noundef %spirv.JointMatrixINTEL._short_8_16_0_3_0 addrspace(4)* @_Z[[#]]__spirv_JointMatrixLoadINTEL{{.*}}(i16 addrspace(3)* noundef %{{.*}}, i64 noundef 16, i32 noundef 0, i32 noundef 3, i32 noundef 0) #{{.*}}
// CHECK: %{{.*}} = tail call spir_func noundef target("spirv.JointMatrixINTEL", i16, 8, 16, 0, 3, 0) @_Z[[#]]__spirv_JointMatrixLoadINTEL{{.*}}(i16 addrspace(3)* noundef %{{.*}}, i64 noundef 16, i32 noundef 0, i32 noundef 3, i32 noundef 0) #{{.*}}
joint_matrix_load(
sg, tA,
tileA.template get_multi_ptr<sycl::access::decorated::yes>(), 16);
// B should load from global address space
// CHECK: %{{.*}} = tail call spir_func noundef %spirv.JointMatrixINTEL._short_16_16_2_3_1 addrspace(4)* @_Z[[#]]__spirv_JointMatrixLoadINTEL{{.*}}(i16 addrspace(1)* noundef %{{.*}}, i64 noundef 32, i32 noundef 2, i32 noundef 3, i32 noundef 0) #{{.*}}
// CHECK: %{{.*}} = tail call spir_func noundef target("spirv.JointMatrixINTEL", i16, 16, 16, 2, 3, 1) @_Z[[#]]__spirv_JointMatrixLoadINTEL{{.*}}(i16 addrspace(1)* noundef %{{.*}}, i64 noundef 32, i32 noundef 2, i32 noundef 3, i32 noundef 0) #{{.*}}
joint_matrix_load(sg, tB, pB, 32);
tC = joint_matrix_mad(sg, tA, tB, tC);
// C should store to global address space
// CHECK: tail call spir_func void @_Z[[#]]__spirv_JointMatrixStoreINTEL{{.*}}(float addrspace(1)* noundef %{{.*}}, %spirv.JointMatrixINTEL._float_8_16_3_3_2 addrspace(4)* noundef %{{.*}}, i64 noundef 16, i32 noundef 0, i32 noundef 3, i32 noundef 0) #{{.*}}
// CHECK: tail call spir_func void @_Z[[#]]__spirv_JointMatrixStoreINTEL{{.*}}(float addrspace(1)* noundef %{{.*}}, target("spirv.JointMatrixINTEL", float, 8, 16, 3, 3, 2) noundef %{{.*}}, i64 noundef 16, i32 noundef 0, i32 noundef 3, i32 noundef 0) #{{.*}}
joint_matrix_store(sg, tC, pC, 16, layout::row_major);
});
});
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
// RUN: %clangxx -fsycl-device-only -S -emit-llvm -o - %s | FileCheck %s

// Check that SROA and mem2reg won't leave alloca of matrix type in IR
// CHECK-NOT: alloca target("spirv.JointMatrixINTEL"

// check that correct address spaces are used to load from and store to
#define SYCL_EXT_ONEAPI_MATRIX_VERSION 1
#include <sycl/sycl.hpp>
Expand Down Expand Up @@ -36,17 +39,17 @@ int main(void) {
it.barrier(access::fence_space::local_space);

// A should load from local address space
// CHECK: %{{.*}} = tail call spir_func noundef %spirv.JointMatrixINTEL._short_8_16_0_3 addrspace(4)* @_Z[[#]]__spirv_JointMatrixLoadINTEL{{.*}}(i16 addrspace(3)* noundef %{{.*}}, i64 noundef 16, i32 noundef 0, i32 noundef 3, i32 noundef 0) #{{.*}}
// CHECK: %{{.*}} = tail call spir_func noundef target("spirv.JointMatrixINTEL", i16, 8, 16, 0, 3) @_Z[[#]]__spirv_JointMatrixLoadINTEL{{.*}}(i16 addrspace(3)* noundef %{{.*}}, i64 noundef 16, i32 noundef 0, i32 noundef 3, i32 noundef 0) #{{.*}}
joint_matrix_load(
sg, tA,
tileA.template get_multi_ptr<sycl::access::decorated::yes>(), 16,
matrix_layout::row_major);
// B should load from global address space
// CHECK: %{{.*}} = tail call spir_func noundef %spirv.JointMatrixINTEL._short_16_16_3_3 addrspace(4)* @_Z[[#]]__spirv_JointMatrixLoadINTEL{{.*}}(i16 addrspace(1)* noundef %{{.*}}, i64 noundef 32, i32 noundef [[#]], i32 noundef 3, i32 noundef 0) #{{.*}}
// CHECK: %{{.*}} = tail call spir_func noundef target("spirv.JointMatrixINTEL", i16, 16, 16, 3, 3) @_Z[[#]]__spirv_JointMatrixLoadINTEL{{.*}}(i16 addrspace(1)* noundef %{{.*}}, i64 noundef 32, i32 noundef [[#]], i32 noundef 3, i32 noundef 0) #{{.*}}
joint_matrix_load(sg, tB, pB, 32, matrix_layout::packed_b);
tC = joint_matrix_mad(sg, tA, tB, tC);
// C should store to global address space
// CHECK: tail call spir_func void @_Z[[#]]__spirv_JointMatrixStoreINTEL{{.*}}(float addrspace(1)* noundef %{{.*}}, %spirv.JointMatrixINTEL._float_8_16_0_3 addrspace(4)* noundef %{{.*}}, i64 noundef 16, i32 noundef 0, i32 noundef 3, i32 noundef 0) #{{.*}}
// CHECK: tail call spir_func void @_Z[[#]]__spirv_JointMatrixStoreINTEL{{.*}}(float addrspace(1)* noundef %{{.*}}, target("spirv.JointMatrixINTEL", float, 8, 16, 0, 3) noundef %{{.*}}, i64 noundef 16, i32 noundef 0, i32 noundef 3, i32 noundef 0) #{{.*}}
joint_matrix_store(sg, tC, pC, 16, matrix_layout::row_major);
});
});
Expand Down
6 changes: 3 additions & 3 deletions sycl/test/matrix/legacy/matrix-int8-test.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// RUN: %clangxx -fsycl -fsycl-device-only -O2 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 -S -emit-llvm -o - %s | FileCheck %s

// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_0_3 = type opaque
// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_0_3 = type opaque
// CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_3_3 = type opaque
// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 12, 48, 0, 3)
// CHECK-DAG: target("spirv.JointMatrixINTEL", i32, 12, 12, 0, 3)
// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 48, 12, 3, 3)

#include <iostream>
#include <sycl/sycl.hpp>
Expand Down
6 changes: 3 additions & 3 deletions sycl/test/matrix/matrix-int8-test.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// RUN: %clangxx -fsycl -fsycl-device-only -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -O2 -S -emit-llvm -o - %s | FileCheck %s

// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_0_3_0 = type opaque
// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_3_3_2 = type opaque
// CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_2_3_1 = type opaque
// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 12, 48, 0, 3, 0)
// CHECK-DAG: target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2)
// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 48, 12, 2, 3, 1)

#include <iostream>
#include <sycl/sycl.hpp>
Expand Down