Skip to content

Commit 13f4b40

Browse files
tlrmchlsmthgarg-amit
authored andcommitted
[Kernel] Change interface to Mamba selective_state_update for continuous batching (vllm-project#8039)
Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
1 parent 28b9c8a commit 13f4b40

File tree

2 files changed

+174
-3
lines changed

2 files changed

+174
-3
lines changed

tests/kernels/test_mamba_ssm.py

+146
Original file line numberDiff line numberDiff line change
@@ -323,3 +323,149 @@ def test_selective_state_update(dim, dstate, has_z, itype):
323323

324324
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
325325
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
326+
327+
328+
@pytest.mark.parametrize("itype",
329+
[torch.float32, torch.float16, torch.bfloat16])
330+
@pytest.mark.parametrize("has_z", [False, True])
331+
@pytest.mark.parametrize("dstate", [16, 32, 64])
332+
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
333+
def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
334+
device = "cuda"
335+
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
336+
if itype == torch.bfloat16:
337+
rtol, atol = 7e-2, 7e-2
338+
if torch.version.hip:
339+
atol *= 2
340+
# set seed
341+
torch.random.manual_seed(0)
342+
batch_size = 16
343+
344+
total_entries = 10 * batch_size
345+
state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
346+
state_indices = torch.randperm(total_entries)[:batch_size].to(
347+
dtype=torch.int32, device=device)
348+
349+
x = torch.randn(batch_size, dim, device=device, dtype=itype)
350+
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
351+
dt_bias = torch.rand(dim, device=device) - 4.0
352+
A = -torch.rand(dim, dstate, device=device) - 1.0
353+
B = torch.randn(batch_size, dstate, device=device)
354+
C = torch.randn(batch_size, dstate, device=device)
355+
D = torch.randn(dim, device=device)
356+
z = torch.randn_like(x) if has_z else None
357+
state_ref = state[state_indices, :].detach().clone()
358+
out = selective_state_update(state,
359+
x,
360+
dt,
361+
A,
362+
B,
363+
C,
364+
D=D,
365+
z=z,
366+
dt_bias=dt_bias,
367+
dt_softplus=True,
368+
state_batch_indices=state_indices)
369+
out_ref = selective_state_update_ref(state_ref,
370+
x,
371+
dt,
372+
A,
373+
B,
374+
C,
375+
D=D,
376+
z=z,
377+
dt_bias=dt_bias,
378+
dt_softplus=True)
379+
380+
assert torch.allclose(state[state_indices, :],
381+
state_ref,
382+
rtol=rtol,
383+
atol=atol)
384+
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
385+
386+
387+
@pytest.mark.parametrize("itype",
388+
[torch.float32, torch.float16, torch.bfloat16])
389+
@pytest.mark.parametrize("has_z", [False, True])
390+
@pytest.mark.parametrize("tie_hdim", [False, True])
391+
@pytest.mark.parametrize("ngroups", [1, 2, 4])
392+
@pytest.mark.parametrize("dstate", [16, 32, 64])
393+
@pytest.mark.parametrize("dim", [2048, 4096])
394+
def test_selective_state_update_with_heads_with_batch_indices(
395+
dim, dstate, ngroups, has_z, tie_hdim, itype):
396+
device = "cuda"
397+
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2)
398+
if itype == torch.bfloat16:
399+
rtol, atol = 1e-1, 1e-1
400+
# set seed
401+
torch.random.manual_seed(0)
402+
batch_size = 16
403+
headdim = 64
404+
nheads = dim // headdim
405+
406+
total_entries = 10 * batch_size
407+
state = torch.randn(total_entries,
408+
nheads,
409+
headdim,
410+
dstate,
411+
dtype=itype,
412+
device=device)
413+
state_indices = torch.randperm(total_entries)[:batch_size].to(
414+
dtype=torch.int32, device=device)
415+
416+
x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
417+
if not tie_hdim:
418+
dt = torch.randn(batch_size,
419+
nheads,
420+
headdim,
421+
device=device,
422+
dtype=itype)
423+
dt_bias = torch.rand(nheads, headdim, device=device) - 4.0
424+
A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0
425+
D = torch.randn(nheads, headdim, device=device)
426+
else:
427+
dt = repeat(torch.randn(batch_size, nheads, device=device,
428+
dtype=itype),
429+
"b h -> b h p",
430+
p=headdim)
431+
dt_bias = repeat(torch.rand(nheads, device=device) - 4.0,
432+
"h -> h p",
433+
p=headdim)
434+
A = repeat(-torch.rand(nheads, device=device) - 1.0,
435+
"h -> h p n",
436+
p=headdim,
437+
n=dstate)
438+
D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim)
439+
B = torch.randn(batch_size, ngroups, dstate, device=device)
440+
C = torch.randn(batch_size, ngroups, dstate, device=device)
441+
z = torch.randn_like(x) if has_z else None
442+
state_ref = state[state_indices, :].detach().clone()
443+
out = selective_state_update(state,
444+
x,
445+
dt,
446+
A,
447+
B,
448+
C,
449+
D=D,
450+
z=z,
451+
dt_bias=dt_bias,
452+
dt_softplus=True,
453+
state_batch_indices=state_indices)
454+
out_ref = selective_state_update_ref(state_ref,
455+
x,
456+
dt,
457+
A,
458+
B,
459+
C,
460+
D=D,
461+
z=z,
462+
dt_bias=dt_bias,
463+
dt_softplus=True)
464+
465+
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
466+
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
467+
assert torch.allclose(state[state_indices, :],
468+
state_ref,
469+
rtol=rtol,
470+
atol=atol)
471+
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)

vllm/model_executor/layers/mamba/ops/mamba_ssm.py

+28-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) 2024, Tri Dao, Albert Gu.
2+
# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py
23

34
import torch
45
import triton
@@ -27,6 +28,10 @@ def softplus(dt):
2728
{"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
2829
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
2930
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
31+
@triton.heuristics({
32+
"HAS_STATE_BATCH_INDICES":
33+
lambda args: args["state_batch_indices_ptr"] is not None
34+
})
3035
@triton.heuristics(
3136
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
3237
@triton.jit
@@ -42,6 +47,7 @@ def _selective_scan_update_kernel(
4247
D_ptr,
4348
z_ptr,
4449
out_ptr,
50+
state_batch_indices_ptr,
4551
# Matrix dimensions
4652
batch,
4753
nheads,
@@ -85,12 +91,24 @@ def _selective_scan_update_kernel(
8591
HAS_DT_BIAS: tl.constexpr,
8692
HAS_D: tl.constexpr,
8793
HAS_Z: tl.constexpr,
94+
HAS_STATE_BATCH_INDICES: tl.constexpr,
8895
BLOCK_SIZE_DSTATE: tl.constexpr,
8996
):
9097
pid_m = tl.program_id(axis=0)
9198
pid_b = tl.program_id(axis=1)
9299
pid_h = tl.program_id(axis=2)
93-
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
100+
101+
# If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate
102+
# is taken from the state_batch_indices_ptr Otherwise, the state coordinate
103+
# is the same as the batch id.
104+
if HAS_STATE_BATCH_INDICES:
105+
state_batch_indices_ptr += pid_b
106+
state_batch_idx = tl.load(state_batch_indices_ptr)
107+
state_ptr += (state_batch_idx * stride_state_batch +
108+
pid_h * stride_state_head)
109+
else:
110+
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
111+
94112
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
95113
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
96114
if HAS_DT_BIAS:
@@ -177,7 +195,8 @@ def selective_state_update(state,
177195
D=None,
178196
z=None,
179197
dt_bias=None,
180-
dt_softplus=False):
198+
dt_softplus=False,
199+
state_batch_indices=None):
181200
"""
182201
Argument:
183202
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
@@ -211,7 +230,10 @@ def selective_state_update(state,
211230
z = z.unsqueeze(1)
212231
if dt_bias is not None and dt_bias.dim() == 1:
213232
dt_bias = dt_bias.unsqueeze(0)
214-
batch, nheads, dim, dstate = state.shape
233+
234+
_, nheads, dim, dstate = state.shape
235+
batch = x.shape[0]
236+
215237
assert x.shape == (batch, nheads, dim)
216238
assert dt.shape == x.shape
217239
assert A.shape == (nheads, dim, dstate)
@@ -225,6 +247,8 @@ def selective_state_update(state,
225247
assert z.shape == x.shape
226248
if dt_bias is not None:
227249
assert dt_bias.shape == (nheads, dim)
250+
if state_batch_indices is not None:
251+
assert state_batch_indices.shape == (batch, )
228252
out = torch.empty_like(x)
229253
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
230254
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
@@ -249,6 +273,7 @@ def selective_state_update(state,
249273
D,
250274
z,
251275
out,
276+
state_batch_indices,
252277
batch,
253278
nheads,
254279
dim,

0 commit comments

Comments
 (0)