Skip to content

Commit 0ceff5e

Browse files
mzusmanLeiWang1999
authored andcommitted
[Kernel][Model] Varlen prefill + Prefill chunking support for mamba kernels and Jamba model (vllm-project#8533)
Signed-off-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent 37257f3 commit 0ceff5e

File tree

13 files changed

+1176
-894
lines changed

13 files changed

+1176
-894
lines changed

csrc/mamba/causal_conv1d/causal_conv1d.cu

+211-316
Large diffs are not rendered by default.

csrc/mamba/causal_conv1d/causal_conv1d.h

+10
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ struct ConvParamsBase {
2424
index_t out_c_stride;
2525
index_t out_l_stride;
2626

27+
int conv_state_len;
2728
index_t conv_state_batch_stride;
2829
index_t conv_state_c_stride;
2930
index_t conv_state_l_stride;
@@ -35,6 +36,10 @@ struct ConvParamsBase {
3536
void *__restrict__ out_ptr;
3637

3738
void *__restrict__ conv_state_ptr;
39+
void *__restrict__ query_start_loc_ptr;
40+
void *__restrict__ has_initial_state_ptr;
41+
void *__restrict__ cache_indices_ptr;
42+
int32_t *__restrict__ cache_seqlens;
3843

3944
// For the continuous batching case. Makes it so that the mamba state for
4045
// the current batch doesn't need to be a contiguous tensor.
@@ -52,6 +57,11 @@ struct ConvParamsBase {
5257
index_t final_states_batch_stride;
5358
index_t final_states_l_stride;
5459
index_t final_states_c_stride;
60+
61+
void * conv_states_ptr;
62+
index_t conv_states_batch_stride;
63+
index_t conv_states_l_stride;
64+
index_t conv_states_c_stride;
5565
};
5666

5767

csrc/mamba/mamba_ssm/selective_scan.h

+9-20
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,14 @@ struct SSMParamsBase {
5454
void *__restrict__ delta_ptr;
5555
void *__restrict__ delta_bias_ptr;
5656
void *__restrict__ out_ptr;
57-
void *__restrict__ x_ptr;
57+
void *__restrict__ ssm_states_ptr;
5858
void *__restrict__ z_ptr;
5959
void *__restrict__ out_z_ptr;
60-
void *__restrict__ index_ptr;
60+
61+
void *__restrict__ query_start_loc_ptr;
62+
void *__restrict__ cache_indices_ptr;
63+
void *__restrict__ has_initial_state_ptr;
64+
6165
};
6266

6367

@@ -201,7 +205,7 @@ inline __device__ void load_input(typename Ktraits::input_t *u,
201205
typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
202206
typename Ktraits::BlockLoadT::TempStorage &smem_load,
203207
int seqlen) {
204-
if constexpr (Ktraits::kIsEvenLen) {
208+
if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) {
205209
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load);
206210
using vec_t = typename Ktraits::vec_t;
207211
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(
@@ -217,21 +221,6 @@ inline __device__ void load_input(typename Ktraits::input_t *u,
217221
}
218222
}
219223

220-
template<typename Ktraits>
221-
inline __device__ void load_index(int *u,
222-
int (&u_vals)[Ktraits::kNItems],
223-
typename Ktraits::BlockLoadIndexT::TempStorage &smem_load_index,
224-
int seqlen) {
225-
if constexpr (Ktraits::kIsEvenLen) {
226-
auto& smem_load_index_vec = reinterpret_cast<typename Ktraits::BlockLoadIndexVecT::TempStorage&>(smem_load_index);
227-
Ktraits::BlockLoadIndexVecT(smem_load_index_vec).Load(
228-
reinterpret_cast<uint4*>(u),
229-
reinterpret_cast<uint4(&)[Ktraits::kNLoadsIndex]>(u_vals)
230-
);
231-
} else {
232-
Ktraits::BlockLoadIndexT(smem_load_index).Load(u, u_vals, seqlen, 0);
233-
}
234-
}
235224

236225
template<typename Ktraits>
237226
inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
@@ -240,7 +229,7 @@ inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
240229
int seqlen) {
241230
constexpr int kNItems = Ktraits::kNItems;
242231
typename Ktraits::input_t B_vals_load[kNItems];
243-
if constexpr (Ktraits::kIsEvenLen) {
232+
if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) {
244233
auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
245234
using vec_t = typename Ktraits::vec_t;
246235
typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
@@ -263,7 +252,7 @@ inline __device__ void store_output(typename Ktraits::input_t *out,
263252
typename Ktraits::input_t write_vals[Ktraits::kNItems];
264253
#pragma unroll
265254
for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
266-
if constexpr (Ktraits::kIsEvenLen) {
255+
if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) {
267256
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store);
268257
using vec_t = typename Ktraits::vec_t;
269258
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(

csrc/mamba/mamba_ssm/selective_scan_fwd.cu

+179-118
Large diffs are not rendered by default.

csrc/ops.h

+18-13
Original file line numberDiff line numberDiff line change
@@ -215,25 +215,30 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
215215
torch::Tensor experts_ids,
216216
torch::Tensor num_tokens_post_pad);
217217

218-
std::vector<torch::Tensor> selective_scan_fwd(
219-
const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
220-
const torch::Tensor& B, const torch::Tensor& C,
221-
const c10::optional<torch::Tensor>& D_,
222-
const c10::optional<torch::Tensor>& z_,
223-
const c10::optional<torch::Tensor>& delta_bias_, bool delta_softplus,
224-
const c10::optional<torch::Tensor>& index_,
225-
const c10::optional<torch::Tensor>& x);
218+
void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
219+
const torch::Tensor& A, const torch::Tensor& B,
220+
const torch::Tensor& C,
221+
const c10::optional<torch::Tensor>& D_,
222+
const c10::optional<torch::Tensor>& z_,
223+
const c10::optional<torch::Tensor>& delta_bias_,
224+
bool delta_softplus,
225+
const c10::optional<torch::Tensor>& query_start_loc,
226+
const c10::optional<torch::Tensor>& cache_indices,
227+
const c10::optional<torch::Tensor>& has_initial_state,
228+
const torch::Tensor& ssm_states);
226229

227230
at::Tensor causal_conv1d_update(
228231
const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight,
229-
const c10::optional<at::Tensor>& bias, bool silu_activation,
230-
const c10::optional<at::Tensor>& conv_state_indices);
232+
const c10::optional<at::Tensor>& bias_, bool silu_activation,
233+
const c10::optional<at::Tensor>& cache_seqlens_,
234+
const c10::optional<at::Tensor>& conv_state_indices_);
231235

232236
at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
233237
const c10::optional<at::Tensor>& bias_,
234-
const c10::optional<at::Tensor>& seq_idx_,
235-
const c10::optional<at::Tensor>& initial_states_,
236-
const c10::optional<at::Tensor>& final_states_out_,
238+
const c10::optional<at::Tensor>& conv_states,
239+
const c10::optional<at::Tensor>& query_start_loc,
240+
const c10::optional<at::Tensor>& cache_indices,
241+
const c10::optional<at::Tensor>& has_initial_state,
237242
bool silu_activation);
238243

239244
#ifndef USE_ROCM

csrc/torch_bindings.cpp

+11-6
Original file line numberDiff line numberDiff line change
@@ -273,26 +273,31 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
273273
ops.def(
274274
"selective_scan_fwd(Tensor! u, Tensor! delta,"
275275
"Tensor! A, Tensor! B, Tensor! C,"
276-
"Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
276+
"Tensor? D_, Tensor!? z_, Tensor? delta_bias_,"
277277
"bool delta_softplus,"
278-
"Tensor? index_, Tensor!? x) -> Tensor[]");
278+
"Tensor? query_start_loc,"
279+
"Tensor? cache_indices,"
280+
"Tensor? has_initial_state,"
281+
"Tensor! ssm_states) -> ()");
279282
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
280283

281284
ops.def(
282285
"causal_conv1d_update(Tensor! x,"
283286
"Tensor! conv_state,"
284287
"Tensor! weight,"
285-
"Tensor? bias,"
288+
"Tensor? bias_,"
286289
"bool silu_activation,"
290+
"Tensor? cache_seqlens_,"
287291
"Tensor? conv_state_indices) -> Tensor");
288292
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
289293

290294
ops.def(
291295
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
292296
"Tensor? bias_,"
293-
"Tensor? seq_idx_,"
294-
"Tensor? initial_states_,"
295-
"Tensor!? final_states_out_,"
297+
"Tensor!? conv_states,"
298+
"Tensor? query_start_loc,"
299+
"Tensor? cache_indices,"
300+
"Tensor? has_initial_state,"
296301
"bool silu_activation) -> Tensor");
297302
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
298303
#endif

0 commit comments

Comments
 (0)