Skip to content

Commit 5373362

Browse files
authored
[SYCL][MATRIX][CUDA] Add support for bf16, (u)int8, and half. (#5009)
Implementation of Nvidia MMA's using bf16, mixed precision int ((u)int8/int32), and mixed precision float (half/float). Signed-off-by: jack.kirk <jack.kirk@codeplay.com>
1 parent 58508ba commit 5373362

File tree

9 files changed

+1357
-144
lines changed

9 files changed

+1357
-144
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-aot-amx.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===-------------- matrix-amx.hpp - SYCL matrix --------------*- C++ -*---===//
1+
//===------------ matrix-aot-amx.hpp - SYCL matrix ------------*- C++ -*---===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.

sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//==------------------ matrix.hpp - SYCL matrix ----------------*- C++ -*---==//
1+
//==---------------- matrix-jit.hpp - SYCL matrix --------------*- C++ -*---==//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.

sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp

Lines changed: 387 additions & 137 deletions
Large diffs are not rendered by default.
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
// REQUIRES: cuda
2+
3+
// RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s
4+
5+
#include <CL/sycl.hpp>
6+
7+
using namespace sycl;
8+
using namespace sycl::ext::oneapi::experimental::matrix;
9+
10+
constexpr int stride = 16;
11+
12+
int main() {
13+
14+
buffer<uint16_t, 1> bufA(nullptr, range<1>(1));
15+
buffer<uint16_t, 1> bufB(nullptr, range<1>(1));
16+
buffer<float, 1> bufC(nullptr, range<1>(1));
17+
buffer<float, 1> bufD(nullptr, range<1>(1));
18+
19+
queue q;
20+
21+
q.submit([&](handler &cgh) {
22+
auto accC = bufC.get_access<access::mode::read_write>(cgh);
23+
auto accA = bufA.get_access<access::mode::read_write>(cgh);
24+
auto accB = bufB.get_access<access::mode::read_write>(cgh);
25+
auto accD = bufD.get_access<access::mode::read_write>(cgh);
26+
27+
cgh.parallel_for<class row_row_m16n16k16>(
28+
nd_range<2>({1, 32}, {1, 32}),
29+
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
30+
sycl::sub_group sg = item.get_sub_group();
31+
32+
joint_matrix<float, matrix_use::accumulator, 16, 16,
33+
matrix_layout::row_major>
34+
sub_c;
35+
36+
joint_matrix<uint16_t, matrix_use::a, 16, 16,
37+
matrix_layout::row_major>
38+
sub_a;
39+
40+
joint_matrix<uint16_t, matrix_use::b, 16, 16,
41+
matrix_layout::row_major>
42+
sub_b;
43+
44+
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_, i32 16) #{{.*}}
45+
joint_matrix_load(sg, sub_c, accC.get_pointer(), stride);
46+
// CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.bf16.p0i32(i32* %call.ascast.i.i49.i, i32 16) #{{.*}}
47+
joint_matrix_load(sg, sub_a, accA.get_pointer(), stride);
48+
// CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.row.stride.bf16.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}}
49+
joint_matrix_load(sg, sub_b, accB.get_pointer(), stride);
50+
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.bf16(i32 %11, i32 %12, i32 %13, i32 %14, i32 %17, i32 %18, i32 %19, i32 %20, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8) #{{.*}}
51+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
52+
// CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %_arg_14, float %22, float %23, float %24, float %25, float %26, float %27, float %28, float %29, i32 16) #{{.*}}
53+
joint_matrix_store(sg, sub_c, accD.get_pointer(), stride);
54+
});
55+
56+
cgh.parallel_for<class col_col_m16n16k16>(
57+
nd_range<2>({1, 32}, {1, 32}),
58+
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
59+
sycl::sub_group sg = item.get_sub_group();
60+
61+
joint_matrix<float, matrix_use::accumulator, 16, 16,
62+
matrix_layout::col_major>
63+
sub_c;
64+
65+
joint_matrix<uint16_t, matrix_use::a, 16, 16,
66+
matrix_layout::col_major>
67+
sub_a;
68+
69+
joint_matrix<uint16_t, matrix_use::b, 16, 16,
70+
matrix_layout::col_major>
71+
sub_b;
72+
73+
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_, i32 16) #{{.*}}
74+
joint_matrix_load(sg, sub_c, accC.get_pointer(), stride);
75+
// CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.bf16.p0i32(i32* %call.ascast.i.i49.i, i32 16) #{{.*}}
76+
joint_matrix_load(sg, sub_a, accA.get_pointer(), stride);
77+
// CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.col.stride.bf16.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}}
78+
joint_matrix_load(sg, sub_b, accB.get_pointer(), stride);
79+
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.bf16(i32 %11, i32 %12, i32 %13, i32 %14, i32 %17, i32 %18, i32 %19, i32 %20, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8) #{{.*}}
80+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
81+
// CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %_arg_14, float %22, float %23, float %24, float %25, float %26, float %27, float %28, float %29, i32 16) #{{.*}}
82+
joint_matrix_store(sg, sub_c, accD.get_pointer(), stride);
83+
});
84+
85+
cgh.parallel_for<class row_row_m32n8k16>(
86+
nd_range<2>({1, 32}, {1, 32}),
87+
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
88+
sycl::sub_group sg = item.get_sub_group();
89+
90+
joint_matrix<float, matrix_use::accumulator, 32, 8,
91+
matrix_layout::row_major>
92+
sub_c;
93+
94+
joint_matrix<uint16_t, matrix_use::a, 32, 16,
95+
matrix_layout::row_major>
96+
sub_a;
97+
98+
joint_matrix<uint16_t, matrix_use::b, 16, 8, matrix_layout::row_major>
99+
sub_b;
100+
101+
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_, i32 16) #{{.*}}
102+
joint_matrix_load(sg, sub_c, accC.get_pointer(), stride);
103+
// CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.bf16.p0i32(i32* %call.ascast.i.i50.i, i32 16) #{{.*}}
104+
joint_matrix_load(sg, sub_a, accA.get_pointer(), stride);
105+
// CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.b.row.stride.bf16.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}}
106+
joint_matrix_load(sg, sub_b, accB.get_pointer(), stride);
107+
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.bf16(i32 %11, i32 %12, i32 %13, i32 %14, i32 %15, i32 %16, i32 %17, i32 %18, i32 %21, i32 %22, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8) #{{.*}}
108+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
109+
// CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %_arg_14, float %24, float %25, float %26, float %27, float %28, float %29, float %30, float %31, i32 16) #{{.*}}
110+
joint_matrix_store(sg, sub_c, accD.get_pointer(), stride);
111+
});
112+
113+
cgh.parallel_for<class col_col_m32n8k16>(
114+
nd_range<2>({1, 32}, {1, 32}),
115+
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
116+
sycl::sub_group sg = item.get_sub_group();
117+
118+
joint_matrix<float, matrix_use::accumulator, 32, 8,
119+
matrix_layout::col_major>
120+
sub_c;
121+
122+
joint_matrix<uint16_t, matrix_use::a, 32, 16,
123+
matrix_layout::col_major>
124+
sub_a;
125+
126+
joint_matrix<uint16_t, matrix_use::b, 16, 8, matrix_layout::col_major>
127+
sub_b;
128+
129+
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_, i32 16) #{{.*}}
130+
joint_matrix_load(sg, sub_c, accC.get_pointer(), stride);
131+
// CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.bf16.p0i32(i32* %call.ascast.i.i50.i, i32 16) #{{.*}}
132+
joint_matrix_load(sg, sub_a, accA.get_pointer(), stride);
133+
// CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.b.col.stride.bf16.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}}
134+
joint_matrix_load(sg, sub_b, accB.get_pointer(), stride);
135+
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.bf16(i32 %11, i32 %12, i32 %13, i32 %14, i32 %15, i32 %16, i32 %17, i32 %18, i32 %21, i32 %22, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8) #{{.*}}
136+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
137+
// CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %_arg_14, float %24, float %25, float %26, float %27, float %28, float %29, float %30, float %31, i32 16) #{{.*}}
138+
joint_matrix_store(sg, sub_c, accD.get_pointer(), stride);
139+
});
140+
141+
cgh.parallel_for<class row_row_m8n32k16>(
142+
nd_range<2>({1, 32}, {1, 32}),
143+
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
144+
sycl::sub_group sg = item.get_sub_group();
145+
146+
joint_matrix<float, matrix_use::accumulator, 8, 32,
147+
matrix_layout::row_major>
148+
sub_c;
149+
150+
joint_matrix<uint16_t, matrix_use::a, 8, 16, matrix_layout::row_major>
151+
sub_a;
152+
153+
joint_matrix<uint16_t, matrix_use::b, 16, 32,
154+
matrix_layout::row_major>
155+
sub_b;
156+
157+
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_, i32 16) #{{.*}}
158+
joint_matrix_load(sg, sub_c, accC.get_pointer(), stride);
159+
// CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.bf16.p0i32(i32* %call.ascast.i.i50.i, i32 16) #{{.*}}
160+
joint_matrix_load(sg, sub_a, accA.get_pointer(), stride);
161+
// CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.row.stride.bf16.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}}
162+
joint_matrix_load(sg, sub_b, accB.get_pointer(), stride);
163+
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.bf16(i32 %11, i32 %12, i32 %15, i32 %16, i32 %17, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8) #{{.*}}
164+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
165+
// CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %_arg_14, float %24, float %25, float %26, float %27, float %28, float %29, float %30, float %31, i32 16) #{{.*}}
166+
joint_matrix_store(sg, sub_c, accD.get_pointer(), stride);
167+
});
168+
169+
cgh.parallel_for<class col_col_m8n32k16>(
170+
nd_range<2>({1, 32}, {1, 32}),
171+
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
172+
sycl::sub_group sg = item.get_sub_group();
173+
174+
joint_matrix<float, matrix_use::accumulator, 8, 32,
175+
matrix_layout::col_major>
176+
sub_c;
177+
178+
joint_matrix<uint16_t, matrix_use::a, 8, 16, matrix_layout::col_major>
179+
sub_a;
180+
181+
joint_matrix<uint16_t, matrix_use::b, 16, 32,
182+
matrix_layout::col_major>
183+
sub_b;
184+
185+
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_, i32 16) #{{.*}}
186+
joint_matrix_load(sg, sub_c, accC.get_pointer(), stride);
187+
// CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.bf16.p0i32(i32* %call.ascast.i.i50.i, i32 16) #{{.*}}
188+
joint_matrix_load(sg, sub_a, accA.get_pointer(), stride);
189+
// CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.col.stride.bf16.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}}
190+
joint_matrix_load(sg, sub_b, accB.get_pointer(), stride);
191+
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.bf16(i32 %11, i32 %12, i32 %15, i32 %16, i32 %17, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8) #{{.*}}
192+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
193+
// CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %_arg_14, float %24, float %25, float %26, float %27, float %28, float %29, float %30, float %31, i32 16) #{{.*}}
194+
joint_matrix_store(sg, sub_c, accD.get_pointer(), stride);
195+
});
196+
});
197+
198+
return 0;
199+
};

sycl/test/check_device_code/matrix/matrix-nvptx-double-test.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// REQUIRES: gpu, cuda
1+
// REQUIRES: cuda
22

33
// RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s
44

@@ -36,8 +36,8 @@ int main() {
3636
auto accD = bufD.get_access<access::mode::read_write>(cgh);
3737

3838
cgh.parallel_for<class row_row>(
39-
nd_range<2>({1, 32}, {1, 32}), [=
40-
](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
39+
nd_range<2>({1, 32}, {1, 32}),
40+
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
4141
sycl::sub_group sg = item.get_sub_group();
4242

4343
joint_matrix<double, matrix_use::accumulator, M, N,
@@ -70,8 +70,8 @@ int main() {
7070
auto accD = bufD.get_access<access::mode::read_write>(cgh);
7171

7272
cgh.parallel_for<class col_col>(
73-
nd_range<2>({1, 32}, {1, 32}), [=
74-
](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
73+
nd_range<2>({1, 32}, {1, 32}),
74+
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
7575
sycl::sub_group sg = item.get_sub_group();
7676

7777
joint_matrix<double, matrix_use::accumulator, M, N,

0 commit comments

Comments
 (0)