Skip to content

Commit 9011f4a

Browse files
[Lint]: [pre-commit.ci] auto fixes [...]
1 parent 592ddd9 commit 9011f4a

File tree

4 files changed

+581
-492
lines changed

4 files changed

+581
-492
lines changed

examples/minference/ops/vertical_slash_index.cu

Lines changed: 114 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -3,157 +3,142 @@
33

44
#include <assert.h>
55

6+
#include <pybind11/numpy.h>
67
#include <pybind11/pybind11.h>
78
#include <pybind11/stl.h>
8-
#include <pybind11/numpy.h>
99
#include <torch/extension.h>
1010

1111
#include <cuda.h>
1212

13-
__device__ void save_blocks(int* block_offset, int range_start, int range_end, int block_size, int& block_count) {
14-
for (int idx = range_start; idx < range_end; idx += block_size) {
15-
block_offset[block_count++] = idx;
16-
}
13+
__device__ void save_blocks(int *block_offset, int range_start, int range_end,
14+
int block_size, int &block_count) {
15+
for (int idx = range_start; idx < range_end; idx += block_size) {
16+
block_offset[block_count++] = idx;
17+
}
1718
}
1819

1920
__global__ void convert_vertical_slash_indexes_kernel(
20-
const int* seqlens, // [BATCH, ]
21-
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
22-
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
23-
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
24-
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
25-
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
26-
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
27-
int N_HEADS,
28-
int N_ROWS,
29-
int BLOCK_SIZE_M,
30-
int BLOCK_SIZE_N,
31-
int NNZ_V,
32-
int NNZ_S
33-
) {
34-
const int batch_idx = blockIdx.y;
35-
const int head_idx = blockIdx.x;
36-
const int group_idx = blockIdx.z;
21+
const int *seqlens, // [BATCH, ]
22+
const int *vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
23+
const int *slash_indexes, // [BATCH, N_HEADS, NNZ_S]
24+
int *block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
25+
int *block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
26+
int *column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
27+
int *column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
28+
int N_HEADS, int N_ROWS, int BLOCK_SIZE_M, int BLOCK_SIZE_N, int NNZ_V,
29+
int NNZ_S) {
30+
const int batch_idx = blockIdx.y;
31+
const int head_idx = blockIdx.x;
32+
const int group_idx = blockIdx.z;
3733

38-
int seqlen = seqlens[batch_idx];
39-
int block_idx_m = group_idx * blockDim.x + threadIdx.x;
40-
int start_m = block_idx_m * BLOCK_SIZE_M;
41-
if (start_m >= seqlen) {
42-
return;
43-
}
44-
int end_m = start_m + BLOCK_SIZE_M;
45-
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
46-
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
47-
int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
48-
block_count += row_offset;
49-
block_offset += row_offset * NNZ_S;
50-
column_count += row_offset;
51-
column_index += row_offset * NNZ_V;
34+
int seqlen = seqlens[batch_idx];
35+
int block_idx_m = group_idx * blockDim.x + threadIdx.x;
36+
int start_m = block_idx_m * BLOCK_SIZE_M;
37+
if (start_m >= seqlen) {
38+
return;
39+
}
40+
int end_m = start_m + BLOCK_SIZE_M;
41+
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
42+
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
43+
int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
44+
block_count += row_offset;
45+
block_offset += row_offset * NNZ_S;
46+
column_count += row_offset;
47+
column_index += row_offset * NNZ_V;
5248

53-
int tmp_col_cnt = 0, tmp_blk_cnt = 0;
54-
int s = 0, v = 0;
55-
int v_idx = vertical_indexes[v++];
56-
int s_idx = slash_indexes[s++];
57-
while (s_idx >= end_m) {
58-
s_idx = slash_indexes[s++];
59-
}
60-
s_idx = max(end_m - s_idx, BLOCK_SIZE_M);
61-
int range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
62-
while (1) {
63-
if (v_idx < range_end) {
64-
if (v_idx < range_start) {
65-
column_index[tmp_col_cnt++] = v_idx;
66-
}
67-
if (v < NNZ_V) {
68-
v_idx = vertical_indexes[v++];
69-
} else {
70-
v_idx = end_m + BLOCK_SIZE_M;
71-
}
72-
} else {
73-
if (s < NNZ_S) {
74-
s_idx = max(end_m - slash_indexes[s++], BLOCK_SIZE_M);
75-
} else {
76-
save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt);
77-
break;
78-
}
79-
if (s_idx > range_end + BLOCK_SIZE_M) {
80-
save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt);
81-
range_start = s_idx - BLOCK_SIZE_M;
82-
range_end = s_idx;
83-
} else if (s_idx > range_end) {
84-
range_end += BLOCK_SIZE_M;
85-
}
86-
}
49+
int tmp_col_cnt = 0, tmp_blk_cnt = 0;
50+
int s = 0, v = 0;
51+
int v_idx = vertical_indexes[v++];
52+
int s_idx = slash_indexes[s++];
53+
while (s_idx >= end_m) {
54+
s_idx = slash_indexes[s++];
55+
}
56+
s_idx = max(end_m - s_idx, BLOCK_SIZE_M);
57+
int range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
58+
while (1) {
59+
if (v_idx < range_end) {
60+
if (v_idx < range_start) {
61+
column_index[tmp_col_cnt++] = v_idx;
62+
}
63+
if (v < NNZ_V) {
64+
v_idx = vertical_indexes[v++];
65+
} else {
66+
v_idx = end_m + BLOCK_SIZE_M;
67+
}
68+
} else {
69+
if (s < NNZ_S) {
70+
s_idx = max(end_m - slash_indexes[s++], BLOCK_SIZE_M);
71+
} else {
72+
save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N,
73+
tmp_blk_cnt);
74+
break;
75+
}
76+
if (s_idx > range_end + BLOCK_SIZE_M) {
77+
save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N,
78+
tmp_blk_cnt);
79+
range_start = s_idx - BLOCK_SIZE_M;
80+
range_end = s_idx;
81+
} else if (s_idx > range_end) {
82+
range_end += BLOCK_SIZE_M;
83+
}
8784
}
85+
}
8886

89-
block_count[0] = tmp_blk_cnt;
90-
column_count[0] = tmp_col_cnt;
87+
block_count[0] = tmp_blk_cnt;
88+
column_count[0] = tmp_col_cnt;
9189
}
9290

9391
void convert_vertical_slash_indexes_64x64(
94-
const int* seqlens, // [BATCH, ]
95-
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
96-
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
97-
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
98-
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
99-
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
100-
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
101-
int BATCH_SIZE,
102-
int N_HEADS,
103-
int N_ROWS,
104-
int NNZ_V,
105-
int NNZ_S
106-
) {
107-
const int BLOCK_SIZE_M = 64;
108-
const int BLOCK_SIZE_N = 64;
109-
const int N_THREADS = 64;
110-
const dim3 dimBlock(N_THREADS);
111-
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
112-
convert_vertical_slash_indexes_kernel<<<dimGrid, dimBlock>>>(
113-
seqlens, vertical_indexes, slash_indexes,
114-
block_count, block_offset, column_count, column_index,
115-
N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, NNZ_V, NNZ_S
116-
);
92+
const int *seqlens, // [BATCH, ]
93+
const int *vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
94+
const int *slash_indexes, // [BATCH, N_HEADS, NNZ_S]
95+
int *block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
96+
int *block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
97+
int *column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
98+
int *column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
99+
int BATCH_SIZE, int N_HEADS, int N_ROWS, int NNZ_V, int NNZ_S) {
100+
const int BLOCK_SIZE_M = 64;
101+
const int BLOCK_SIZE_N = 64;
102+
const int N_THREADS = 64;
103+
const dim3 dimBlock(N_THREADS);
104+
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
105+
convert_vertical_slash_indexes_kernel<<<dimGrid, dimBlock>>>(
106+
seqlens, vertical_indexes, slash_indexes, block_count, block_offset,
107+
column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N,
108+
NNZ_V, NNZ_S);
117109
}
118110

119111
std::vector<at::Tensor> convert_vertical_slash_indexes(
120-
torch::Tensor seqlens, // [BATCH, ]
121-
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
122-
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
123-
int context_size,
124-
int block_size_M,
125-
int block_size_N
126-
) {
127-
assert(block_size_M == 64);
128-
assert(block_size_N == 64);
112+
torch::Tensor seqlens, // [BATCH, ]
113+
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
114+
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
115+
int context_size, int block_size_M, int block_size_N) {
116+
assert(block_size_M == 64);
117+
assert(block_size_N == 64);
129118

130-
cudaSetDevice(seqlens.get_device());
119+
cudaSetDevice(seqlens.get_device());
131120

132-
int batch_size = slash_indexes.size(0);
133-
int num_heads = slash_indexes.size(1);
134-
int nnz_slash = slash_indexes.size(2);
135-
int nnz_vertical = vertical_indexes.size(2);
136-
int num_rows = (context_size + block_size_M - 1) / block_size_M;
121+
int batch_size = slash_indexes.size(0);
122+
int num_heads = slash_indexes.size(1);
123+
int nnz_slash = slash_indexes.size(2);
124+
int nnz_vertical = vertical_indexes.size(2);
125+
int num_rows = (context_size + block_size_M - 1) / block_size_M;
137126

138-
torch::Tensor block_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options());
139-
torch::Tensor block_offset = torch::zeros({batch_size, num_heads, num_rows, nnz_slash}, seqlens.options());
140-
torch::Tensor column_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options());
141-
torch::Tensor column_index = torch::zeros({batch_size, num_heads, num_rows, nnz_vertical}, seqlens.options());
127+
torch::Tensor block_count =
128+
torch::zeros({batch_size, num_heads, num_rows}, seqlens.options());
129+
torch::Tensor block_offset = torch::zeros(
130+
{batch_size, num_heads, num_rows, nnz_slash}, seqlens.options());
131+
torch::Tensor column_count =
132+
torch::zeros({batch_size, num_heads, num_rows}, seqlens.options());
133+
torch::Tensor column_index = torch::zeros(
134+
{batch_size, num_heads, num_rows, nnz_vertical}, seqlens.options());
142135

143-
convert_vertical_slash_indexes_64x64(
144-
seqlens.data_ptr<int>(),
145-
vertical_indexes.data_ptr<int>(),
146-
slash_indexes.data_ptr<int>(),
147-
block_count.data_ptr<int>(),
148-
block_offset.data_ptr<int>(),
149-
column_count.data_ptr<int>(),
150-
column_index.data_ptr<int>(),
151-
batch_size,
152-
num_heads,
153-
num_rows,
154-
nnz_vertical,
155-
nnz_slash
156-
);
136+
convert_vertical_slash_indexes_64x64(
137+
seqlens.data_ptr<int>(), vertical_indexes.data_ptr<int>(),
138+
slash_indexes.data_ptr<int>(), block_count.data_ptr<int>(),
139+
block_offset.data_ptr<int>(), column_count.data_ptr<int>(),
140+
column_index.data_ptr<int>(), batch_size, num_heads, num_rows,
141+
nnz_vertical, nnz_slash);
157142

158-
return { block_count, block_offset, column_count, column_index };
143+
return {block_count, block_offset, column_count, column_index};
159144
}

0 commit comments

Comments
 (0)