Skip to content

Commit 0439302

Browse files
committed
chore: revert and exclude CUDA files
1 parent f57e23a commit 0439302

File tree

5 files changed

+493
-581
lines changed

5 files changed

+493
-581
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ repos:
3434
- id: clang-format
3535
exclude: |
3636
(?ix)(
37+
^.+\.(cu|cuh)$|
3738
^.+\.json$
3839
)
3940
- repo: https://github.com/astral-sh/ruff-pre-commit

examples/minference/ops/vertical_slash_index.cu

Lines changed: 129 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -3,142 +3,157 @@
33

44
#include <assert.h>
55

6-
#include <pybind11/numpy.h>
76
#include <pybind11/pybind11.h>
87
#include <pybind11/stl.h>
8+
#include <pybind11/numpy.h>
99
#include <torch/extension.h>
1010

1111
#include <cuda.h>
1212

13-
__device__ void save_blocks(int *block_offset, int range_start, int range_end,
14-
int block_size, int &block_count) {
15-
for (int idx = range_start; idx < range_end; idx += block_size) {
16-
block_offset[block_count++] = idx;
17-
}
13+
__device__ void save_blocks(int* block_offset, int range_start, int range_end, int block_size, int& block_count) {
14+
for (int idx = range_start; idx < range_end; idx += block_size) {
15+
block_offset[block_count++] = idx;
16+
}
1817
}
1918

2019
__global__ void convert_vertical_slash_indexes_kernel(
21-
const int *seqlens, // [BATCH, ]
22-
const int *vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
23-
const int *slash_indexes, // [BATCH, N_HEADS, NNZ_S]
24-
int *block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
25-
int *block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
26-
int *column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
27-
int *column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
28-
int N_HEADS, int N_ROWS, int BLOCK_SIZE_M, int BLOCK_SIZE_N, int NNZ_V,
29-
int NNZ_S) {
30-
const int batch_idx = blockIdx.y;
31-
const int head_idx = blockIdx.x;
32-
const int group_idx = blockIdx.z;
20+
const int* seqlens, // [BATCH, ]
21+
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
22+
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
23+
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
24+
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
25+
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
26+
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
27+
int N_HEADS,
28+
int N_ROWS,
29+
int BLOCK_SIZE_M,
30+
int BLOCK_SIZE_N,
31+
int NNZ_V,
32+
int NNZ_S
33+
) {
34+
const int batch_idx = blockIdx.y;
35+
const int head_idx = blockIdx.x;
36+
const int group_idx = blockIdx.z;
3337

34-
int seqlen = seqlens[batch_idx];
35-
int block_idx_m = group_idx * blockDim.x + threadIdx.x;
36-
int start_m = block_idx_m * BLOCK_SIZE_M;
37-
if (start_m >= seqlen) {
38-
return;
39-
}
40-
int end_m = start_m + BLOCK_SIZE_M;
41-
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
42-
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
43-
int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
44-
block_count += row_offset;
45-
block_offset += row_offset * NNZ_S;
46-
column_count += row_offset;
47-
column_index += row_offset * NNZ_V;
38+
int seqlen = seqlens[batch_idx];
39+
int block_idx_m = group_idx * blockDim.x + threadIdx.x;
40+
int start_m = block_idx_m * BLOCK_SIZE_M;
41+
if (start_m >= seqlen) {
42+
return;
43+
}
44+
int end_m = start_m + BLOCK_SIZE_M;
45+
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
46+
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
47+
int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
48+
block_count += row_offset;
49+
block_offset += row_offset * NNZ_S;
50+
column_count += row_offset;
51+
column_index += row_offset * NNZ_V;
4852

49-
int tmp_col_cnt = 0, tmp_blk_cnt = 0;
50-
int s = 0, v = 0;
51-
int v_idx = vertical_indexes[v++];
52-
int s_idx = slash_indexes[s++];
53-
while (s_idx >= end_m) {
54-
s_idx = slash_indexes[s++];
55-
}
56-
s_idx = max(end_m - s_idx, BLOCK_SIZE_M);
57-
int range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
58-
while (1) {
59-
if (v_idx < range_end) {
60-
if (v_idx < range_start) {
61-
column_index[tmp_col_cnt++] = v_idx;
62-
}
63-
if (v < NNZ_V) {
64-
v_idx = vertical_indexes[v++];
65-
} else {
66-
v_idx = end_m + BLOCK_SIZE_M;
67-
}
68-
} else {
69-
if (s < NNZ_S) {
70-
s_idx = max(end_m - slash_indexes[s++], BLOCK_SIZE_M);
71-
} else {
72-
save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N,
73-
tmp_blk_cnt);
74-
break;
75-
}
76-
if (s_idx > range_end + BLOCK_SIZE_M) {
77-
save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N,
78-
tmp_blk_cnt);
79-
range_start = s_idx - BLOCK_SIZE_M;
80-
range_end = s_idx;
81-
} else if (s_idx > range_end) {
82-
range_end += BLOCK_SIZE_M;
83-
}
53+
int tmp_col_cnt = 0, tmp_blk_cnt = 0;
54+
int s = 0, v = 0;
55+
int v_idx = vertical_indexes[v++];
56+
int s_idx = slash_indexes[s++];
57+
while (s_idx >= end_m) {
58+
s_idx = slash_indexes[s++];
59+
}
60+
s_idx = max(end_m - s_idx, BLOCK_SIZE_M);
61+
int range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
62+
while (1) {
63+
if (v_idx < range_end) {
64+
if (v_idx < range_start) {
65+
column_index[tmp_col_cnt++] = v_idx;
66+
}
67+
if (v < NNZ_V) {
68+
v_idx = vertical_indexes[v++];
69+
} else {
70+
v_idx = end_m + BLOCK_SIZE_M;
71+
}
72+
} else {
73+
if (s < NNZ_S) {
74+
s_idx = max(end_m - slash_indexes[s++], BLOCK_SIZE_M);
75+
} else {
76+
save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt);
77+
break;
78+
}
79+
if (s_idx > range_end + BLOCK_SIZE_M) {
80+
save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt);
81+
range_start = s_idx - BLOCK_SIZE_M;
82+
range_end = s_idx;
83+
} else if (s_idx > range_end) {
84+
range_end += BLOCK_SIZE_M;
85+
}
86+
}
8487
}
85-
}
8688

87-
block_count[0] = tmp_blk_cnt;
88-
column_count[0] = tmp_col_cnt;
89+
block_count[0] = tmp_blk_cnt;
90+
column_count[0] = tmp_col_cnt;
8991
}
9092

9193
void convert_vertical_slash_indexes_64x64(
92-
const int *seqlens, // [BATCH, ]
93-
const int *vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
94-
const int *slash_indexes, // [BATCH, N_HEADS, NNZ_S]
95-
int *block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
96-
int *block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
97-
int *column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
98-
int *column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
99-
int BATCH_SIZE, int N_HEADS, int N_ROWS, int NNZ_V, int NNZ_S) {
100-
const int BLOCK_SIZE_M = 64;
101-
const int BLOCK_SIZE_N = 64;
102-
const int N_THREADS = 64;
103-
const dim3 dimBlock(N_THREADS);
104-
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
105-
convert_vertical_slash_indexes_kernel<<<dimGrid, dimBlock>>>(
106-
seqlens, vertical_indexes, slash_indexes, block_count, block_offset,
107-
column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N,
108-
NNZ_V, NNZ_S);
94+
const int* seqlens, // [BATCH, ]
95+
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
96+
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
97+
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
98+
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
99+
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
100+
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
101+
int BATCH_SIZE,
102+
int N_HEADS,
103+
int N_ROWS,
104+
int NNZ_V,
105+
int NNZ_S
106+
) {
107+
const int BLOCK_SIZE_M = 64;
108+
const int BLOCK_SIZE_N = 64;
109+
const int N_THREADS = 64;
110+
const dim3 dimBlock(N_THREADS);
111+
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
112+
convert_vertical_slash_indexes_kernel<<<dimGrid, dimBlock>>>(
113+
seqlens, vertical_indexes, slash_indexes,
114+
block_count, block_offset, column_count, column_index,
115+
N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, NNZ_V, NNZ_S
116+
);
109117
}
110118

111119
std::vector<at::Tensor> convert_vertical_slash_indexes(
112-
torch::Tensor seqlens, // [BATCH, ]
113-
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
114-
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
115-
int context_size, int block_size_M, int block_size_N) {
116-
assert(block_size_M == 64);
117-
assert(block_size_N == 64);
120+
torch::Tensor seqlens, // [BATCH, ]
121+
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
122+
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
123+
int context_size,
124+
int block_size_M,
125+
int block_size_N
126+
) {
127+
assert(block_size_M == 64);
128+
assert(block_size_N == 64);
118129

119-
cudaSetDevice(seqlens.get_device());
130+
cudaSetDevice(seqlens.get_device());
120131

121-
int batch_size = slash_indexes.size(0);
122-
int num_heads = slash_indexes.size(1);
123-
int nnz_slash = slash_indexes.size(2);
124-
int nnz_vertical = vertical_indexes.size(2);
125-
int num_rows = (context_size + block_size_M - 1) / block_size_M;
132+
int batch_size = slash_indexes.size(0);
133+
int num_heads = slash_indexes.size(1);
134+
int nnz_slash = slash_indexes.size(2);
135+
int nnz_vertical = vertical_indexes.size(2);
136+
int num_rows = (context_size + block_size_M - 1) / block_size_M;
126137

127-
torch::Tensor block_count =
128-
torch::zeros({batch_size, num_heads, num_rows}, seqlens.options());
129-
torch::Tensor block_offset = torch::zeros(
130-
{batch_size, num_heads, num_rows, nnz_slash}, seqlens.options());
131-
torch::Tensor column_count =
132-
torch::zeros({batch_size, num_heads, num_rows}, seqlens.options());
133-
torch::Tensor column_index = torch::zeros(
134-
{batch_size, num_heads, num_rows, nnz_vertical}, seqlens.options());
138+
torch::Tensor block_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options());
139+
torch::Tensor block_offset = torch::zeros({batch_size, num_heads, num_rows, nnz_slash}, seqlens.options());
140+
torch::Tensor column_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options());
141+
torch::Tensor column_index = torch::zeros({batch_size, num_heads, num_rows, nnz_vertical}, seqlens.options());
135142

136-
convert_vertical_slash_indexes_64x64(
137-
seqlens.data_ptr<int>(), vertical_indexes.data_ptr<int>(),
138-
slash_indexes.data_ptr<int>(), block_count.data_ptr<int>(),
139-
block_offset.data_ptr<int>(), column_count.data_ptr<int>(),
140-
column_index.data_ptr<int>(), batch_size, num_heads, num_rows,
141-
nnz_vertical, nnz_slash);
143+
convert_vertical_slash_indexes_64x64(
144+
seqlens.data_ptr<int>(),
145+
vertical_indexes.data_ptr<int>(),
146+
slash_indexes.data_ptr<int>(),
147+
block_count.data_ptr<int>(),
148+
block_offset.data_ptr<int>(),
149+
column_count.data_ptr<int>(),
150+
column_index.data_ptr<int>(),
151+
batch_size,
152+
num_heads,
153+
num_rows,
154+
nnz_vertical,
155+
nnz_slash
156+
);
142157

143-
return {block_count, block_offset, column_count, column_index};
158+
return { block_count, block_offset, column_count, column_index };
144159
}

0 commit comments

Comments
 (0)