-
Notifications
You must be signed in to change notification settings - Fork 792
[SYCL][CUDA] Implementation of matrix ext using new "unified" interface #7077
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fdc4c42
68d3150
4949464
8c09910
e55e5f0
a881055
5b84434
75774f2
5c03b3f
ccdb544
331760a
46e87a1
766fd8c
32dafa3
24d3aa1
b9a051f
3dbeadb
49147d3
ee1208e
446c0a0
8da0aa7
36004a0
5f02a0c
f64e861
4b88d94
1f9a8d3
a52eb7c
4b83846
310fe1c
65888e3
08e4974
5e8f8d7
51cbf73
ad6621b
2c4898a
426e7b3
57ddc6b
d4607a8
47b6714
0ca1223
cede44d
5abcca6
8e46d78
8bf5048
eb51b53
b9ca55c
bb6fc5e
68fcf9a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,221 @@ | ||
//===------- matrix-unified.hpp - SYCL matrix extension ----*- C++ -*------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
// ===--------------------------------------------------------------------=== // | ||
|
||
#pragma once | ||
#include <sycl/ext/oneapi/matrix/matrix-tensorcores.hpp> | ||
|
||
namespace sycl { | ||
__SYCL_INLINE_VER_NAMESPACE(_V1) { | ||
namespace ext { | ||
namespace oneapi { | ||
namespace experimental { | ||
namespace matrix { | ||
|
||
template <typename Group, typename T, use Use, size_t Rows, size_t Cols, | ||
layout Layout> | ||
struct joint_matrix { | ||
|
||
#if defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__) | ||
// TODO: Intel case here: we use the ext_oneapi_cuda case also for the host, | ||
// because the Intel SPIRV functions will not be host compilable. | ||
#else | ||
sycl::ext::oneapi::detail::joint_matrix_cuda<T, Use, Rows, Cols, Layout> | ||
cuda_impl; | ||
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__) | ||
|
||
joint_matrix() { | ||
#ifndef __SYCL_DEVICE_ONLY__ | ||
throw runtime_error("joint matrix is not supported on host device.", | ||
PI_ERROR_INVALID_DEVICE); | ||
#endif | ||
} | ||
}; | ||
|
||
template <typename Group, typename T, use Use, size_t Rows, size_t Cols, | ||
layout Layout> | ||
inline __SYCL_ALWAYS_INLINE wi_data<Group, T, Use, Rows, Cols, Layout> | ||
get_wi_data(Group sg, joint_matrix<Group, T, Use, Rows, Cols, Layout> &jm) { | ||
#if defined(__SYCL_DEVICE_ONLY__) | ||
#if defined(__NVPTX__) | ||
std::ignore = sg; | ||
return wi_data(jm); | ||
#else | ||
// TODO add Intel impl. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @AerialMantis @JackAKirk @dkhaldi |
||
#endif // defined(__NVPTX__) | ||
#endif // defined(__SYCL_DEVICE_ONLY__) | ||
} | ||
|
||
template <typename Group, typename T, size_t NumRows, size_t NumCols, use Use, | ||
layout Layout, typename T2> | ||
inline __SYCL_ALWAYS_INLINE void | ||
joint_matrix_fill(Group sg, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I noticed you don't have any test with joint_matrix_fill. Is that on purpose? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We only have device code tests in intel/llvm for functions that call nvptx builtins. |
||
joint_matrix<Group, T, Use, NumRows, NumCols, Layout> &res, | ||
const T2 &v) { | ||
#if defined(__SYCL_DEVICE_ONLY__) | ||
#if defined(__NVPTX__) | ||
std::ignore = sg; | ||
res.cuda_impl.wi_marray = v; | ||
#endif // defined(__NVPTX__) | ||
#else | ||
std::ignore = sg; | ||
std::ignore = res; | ||
std::ignore = v; | ||
throw runtime_error( | ||
"This version of the matrix extension is only currently supported on " | ||
"Nvidia devices", | ||
PI_ERROR_INVALID_DEVICE); | ||
#endif // defined(__SYCL_DEVICE_ONLY__) | ||
} | ||
|
||
template < | ||
typename Group, typename S, typename T, size_t NumRows, size_t NumCols, | ||
access::address_space Space, access::decorated IsDecorated, | ||
std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value, bool> = | ||
true> | ||
inline __SYCL_ALWAYS_INLINE void joint_matrix_load( | ||
Group sg, | ||
joint_matrix<Group, S, use::accumulator, NumRows, NumCols, | ||
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &res, | ||
multi_ptr<T, Space, IsDecorated> src, size_t stride, | ||
sycl::ext::oneapi::experimental::matrix::layout Layout) { | ||
#if defined(__SYCL_DEVICE_ONLY__) | ||
#if defined(__NVPTX__) | ||
std::ignore = sg; | ||
sycl::ext::oneapi::detail::load_accumulator_cuda(res.cuda_impl, src, stride, | ||
Layout); | ||
#endif // defined(__NVPTX__) | ||
#else | ||
std::ignore = sg; | ||
std::ignore = res; | ||
std::ignore = src; | ||
std::ignore = stride; | ||
throw runtime_error( | ||
"This version of the matrix extension is only currently supported on " | ||
"Nvidia devices", | ||
PI_ERROR_INVALID_DEVICE); | ||
#endif // defined(__SYCL_DEVICE_ONLY__) | ||
} | ||
|
||
template < | ||
typename Group, typename S, typename T, use Use, size_t NumRows, | ||
size_t NumCols, matrix::layout Layout, access::address_space Space, | ||
access::decorated IsDecorated, | ||
std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value || | ||
(std::is_same<S, precision::tf32>::value && | ||
std::is_same<std::remove_const_t<T>, float>::value), | ||
bool> = true> | ||
inline __SYCL_ALWAYS_INLINE void | ||
joint_matrix_load(Group sg, | ||
joint_matrix<Group, S, Use, NumRows, NumCols, Layout> &res, | ||
multi_ptr<T, Space, IsDecorated> src, size_t stride) { | ||
#if defined(__SYCL_DEVICE_ONLY__) | ||
#if defined(__NVPTX__) | ||
std::ignore = sg; | ||
sycl::ext::oneapi::detail::load_multiplicand_cuda<S, T, NumRows, NumCols, Use, | ||
Layout, Space>( | ||
res.cuda_impl, src, stride); | ||
#endif // defined(__NVPTX__) | ||
#else | ||
std::ignore = sg; | ||
std::ignore = res; | ||
std::ignore = src; | ||
std::ignore = stride; | ||
throw runtime_error( | ||
"This version of the matrix extension is only currently supported on " | ||
"Nvidia devices", | ||
PI_ERROR_INVALID_DEVICE); | ||
#endif // defined(__SYCL_DEVICE_ONLY__) | ||
} | ||
|
||
template <typename Group, typename T, size_t NumRows, size_t NumCols, | ||
access::address_space Space, access::decorated IsDecorated> | ||
inline __SYCL_ALWAYS_INLINE void joint_matrix_store( | ||
Group sg, | ||
joint_matrix<Group, T, use::accumulator, NumRows, NumCols, | ||
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src, | ||
multi_ptr<T, Space, IsDecorated> dst, size_t stride, | ||
sycl::ext::oneapi::experimental::matrix::layout Layout) { | ||
#if defined(__SYCL_DEVICE_ONLY__) | ||
#if defined(__NVPTX__) | ||
std::ignore = sg; | ||
sycl::ext::oneapi::detail::joint_matrix_store_cuda<T, NumRows, NumCols, | ||
Space>(src.cuda_impl, dst, | ||
stride, Layout); | ||
#endif // defined(__NVPTX__) | ||
#else | ||
std::ignore = sg; | ||
std::ignore = src; | ||
std::ignore = dst; | ||
std::ignore = stride; | ||
throw runtime_error( | ||
"This version of the matrix extension is only currently supported on " | ||
"Nvidia devices", | ||
PI_ERROR_INVALID_DEVICE); | ||
#endif // defined(__SYCL_DEVICE_ONLY__) | ||
} | ||
|
||
template <typename Group, typename Ta, typename Tb, typename Tc, std::size_t M, | ||
std::size_t K, std::size_t N, layout LayoutA, layout LayoutB> | ||
inline __SYCL_ALWAYS_INLINE | ||
joint_matrix<Group, Tc, use::accumulator, M, N, | ||
sycl::ext::oneapi::experimental::matrix::layout::dynamic> | ||
joint_matrix_mad( | ||
Group sg, joint_matrix<Group, Ta, use::a, M, K, LayoutA> &A, | ||
joint_matrix<Group, Tb, use::b, K, N, LayoutB> &B, | ||
joint_matrix<Group, Tc, use::accumulator, M, N, | ||
sycl::ext::oneapi::experimental::matrix::layout::dynamic> | ||
&C) { | ||
#if defined(__SYCL_DEVICE_ONLY__) | ||
#if defined(__NVPTX__) | ||
std::ignore = sg; | ||
if constexpr (std::is_same<Ta, Tb>::value) { | ||
joint_matrix<Group, Tc, use::accumulator, M, N, | ||
sycl::ext::oneapi::experimental::matrix::layout::dynamic> | ||
D; | ||
sycl::ext::oneapi::detail::joint_matrix_mad_cuda<Ta, Tc, M, K, N, LayoutA, | ||
LayoutB>( | ||
D.cuda_impl, A.cuda_impl, B.cuda_impl, C.cuda_impl); | ||
return D; | ||
} else { | ||
assert(false && "Ta != Tb : In the CUDA backend joint_matrix_mad " | ||
"requires that joint_matrix data types Ta and Tb match"); | ||
} | ||
#endif // defined(__NVPTX__) | ||
#else | ||
std::ignore = sg; | ||
std::ignore = A; | ||
std::ignore = B; | ||
std::ignore = C; | ||
throw runtime_error( | ||
"This version of the matrix extension is only currently supported on " | ||
"Nvidia devices", | ||
PI_ERROR_INVALID_DEVICE); | ||
#endif // defined(__SYCL_DEVICE_ONLY__) | ||
} | ||
|
||
// This function rounds the bottom 13 bits up or down, and then zeros out the | ||
// bottom bits | ||
inline __SYCL_ALWAYS_INLINE float round_to_tf32(float &a) { | ||
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) | ||
int32_t tmp_int = __nvvm_f2tf32_rna(a); | ||
return __nvvm_bitcast_i2f(tmp_int); | ||
#else | ||
uint32_t tmp_uint = reinterpret_cast<uint32_t &>(a); | ||
tmp_uint += 0x1000u; | ||
tmp_uint &= 0xFFFFE000u; | ||
float ret = reinterpret_cast<float &>(tmp_uint); | ||
return ret; | ||
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) | ||
} | ||
|
||
} // namespace matrix | ||
} // namespace experimental | ||
} // namespace oneapi | ||
} // namespace ext | ||
} // __SYCL_INLINE_VER_NAMESPACE(_V1) | ||
} // namespace sycl |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@AerialMantis @JackAKirk
sorry, i remember previously it is:
in intel side, we can't let host compilation use sycl::ext::oneapi::detail::joint_matrix_cuda. so i go back to the previous code and i can still get passed in cuda's testcases.