Skip to content

Commit 39c27b7

Browse files
mzusmanMor Zusmantomeras91
committed
N support (vllm-project#8)
* Return support for other models apart from jamba * Support n>1 * A little cleanup * Rename * Apply whitespace suggestions from code review * Add max batch size to the main func * Fixed attention kv cache bug * log where requests id are deleted from the dict to debug mode * Fix typo * Align with v0.3.3 vllm code * Remove comments * Take out model config from CUDAGraph object * Fix * Fix typo * Make the kv cache selection cleaner * Another typo * Took the num layers calc outside * Remove the -1 * Set as num layer / period --------- Co-authored-by: Mor Zusman <morz@ai21.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com>
1 parent 00bce1f commit 39c27b7

File tree

5 files changed

+130
-87
lines changed

5 files changed

+130
-87
lines changed
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from collections import defaultdict
21
from dataclasses import dataclass, field
3-
from typing import Dict, Optional, Tuple
2+
from typing import List
3+
44
import torch
55

6+
67
@dataclass
78
class MambaCacheParams:
89
is_prompt: bool = False
@@ -13,6 +14,6 @@ class MambaCacheParams:
1314
@dataclass
1415
class RequestInfo:
1516
request_id: str = ''
16-
n: int = 1
17+
seqs_id: List[int] = field(default_factory=list)
1718

1819

vllm/model_executor/models/jamba.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -504,10 +504,12 @@ def forward(
504504
residual = None
505505
for i in range(len(self.layers)):
506506
layer = self.layers[i]
507-
507+
kv_cache = None
508+
if isinstance(layer, JambaAttentionDecoderLayer):
509+
kv_cache = kv_caches[(i - self.config.attn_layer_offset) // self.config.attn_layer_period]
508510
hidden_states, residual = layer(positions=positions,
509511
hidden_states=hidden_states,
510-
kv_cache=kv_caches[i],
512+
kv_cache=kv_cache,
511513
input_metadata=input_metadata,
512514
residual=residual,
513515
conv_state=conv_state,

vllm/worker/cache_engine.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(
3030
self.parallel_config = parallel_config
3131

3232
self.head_size = model_config.get_head_size()
33-
self.num_layers = model_config.get_num_layers(parallel_config)
33+
self.num_layers = CacheEngine.get_num_attention_layers(model_config, parallel_config)
3434
self.num_heads = model_config.get_num_kv_heads(parallel_config)
3535

3636
self.block_size = cache_config.block_size
@@ -80,6 +80,18 @@ def swap_out(self, src_to_dst: Dict[int, int]) -> None:
8080
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
8181
self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
8282

83+
@staticmethod
84+
def get_num_attention_layers(
85+
model_config:ModelConfig,
86+
parallel_config:ParallelConfig
87+
):
88+
num_layers = model_config.get_num_layers(parallel_config)
89+
is_mamba = model_config.hf_config.model_type == "jamba"
90+
if is_mamba:
91+
attention_period = model_config.hf_config.attn_layer_period
92+
num_layers = num_layers // attention_period
93+
return num_layers
94+
8395
@staticmethod
8496
def get_cache_block_size(
8597
cache_config: CacheConfig,
@@ -88,13 +100,7 @@ def get_cache_block_size(
88100
) -> int:
89101
head_size = model_config.get_head_size()
90102
num_heads = model_config.get_num_kv_heads(parallel_config)
91-
num_layers = model_config.get_num_layers(parallel_config)
92-
is_mamba = model_config.hf_config.model_type == "jamba"
93-
94-
if is_mamba:
95-
attention_period = model_config.hf_config.attn_layer_period
96-
num_layers = max(num_layers // attention_period, 1)
97-
103+
num_layers = CacheEngine.get_num_attention_layers(model_config,parallel_config)
98104
key_cache_block = cache_config.block_size * num_heads * head_size
99105
value_cache_block = key_cache_block
100106
total = num_layers * (key_cache_block + value_cache_block)

vllm/worker/model_runner.py

Lines changed: 103 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def __init__(
154154
# cache in_wsl result
155155
self.mamba_cache = None
156156
self.mamba_cache4gc = None
157-
self.request_id2index = {}
157+
self.request_id2index: Dict[str, Dict[int, int]] = {}
158158
self.in_wsl = in_wsl()
159159
self.kv_cache_dtype = kv_cache_dtype
160160

@@ -441,7 +441,7 @@ def _prepare_prompt(
441441
requests_info=[
442442
RequestInfo(
443443
request_id=req.request_id,
444-
n=req.sampling_params.n
444+
seqs_id=list(req.seq_data.keys())
445445
)
446446
for req in seq_group_metadata_list
447447
]
@@ -579,10 +579,9 @@ def _prepare_decode(
579579
requests_info=[
580580
RequestInfo(
581581
request_id=req.request_id,
582-
n=req.sampling_params.n
582+
seqs_id=list(req.seq_data.keys())
583583
)
584584
for req in seq_group_metadata_list]
585-
586585
)
587586
return PrepareDecodeMetadata(
588587
input_tokens=input_tokens,
@@ -790,13 +789,7 @@ def prepare_input_tensors(
790789
"slot_mapping": slot_mapping,
791790
"num_prefills": num_prefills,
792791
"batch_type": batch_type,
793-
"requests_info": [
794-
RequestInfo(
795-
request_id=req.request_id,
796-
n=req.sampling_params.n
797-
)
798-
for req in seq_group_metadata_list
799-
]
792+
"requests_info": input_metadata.requests_info
800793
}
801794
if prefill_attn_metadata is not None:
802795
metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
@@ -922,22 +915,29 @@ def execute_model(
922915
if not sampling_metadata.perform_sampling:
923916
return None
924917

925-
if self.mamba_cache is None:
926-
self.prepare_contiguous_mamba_cache(self.model_config.dtype)
927-
928-
conv_state, ssm_state, indecies = self._prepare_request_mamba_cache(input_metadata, input_tokens.shape[0])
929-
930-
hidden_states = model_executable(
931-
input_ids=input_tokens,
932-
positions=input_positions,
933-
kv_caches=kv_caches,
934-
input_metadata=input_metadata,
935-
conv_state=conv_state,
936-
ssm_state=ssm_state
937-
)
938-
for i,offset in enumerate(indecies):
939-
self.mamba_cache[0][:,offset] = conv_state[:,i]
940-
self.mamba_cache[1][:,offset] = ssm_state[:,i]
918+
is_mamba = self.model_config.hf_config.model_type == "jamba"
919+
indices = []
920+
conv_state = None
921+
model_inputs = {
922+
"input_ids":input_tokens,
923+
"positions":input_positions,
924+
"kv_caches":kv_caches,
925+
"input_metadata":input_metadata,
926+
}
927+
if is_mamba:
928+
if self.mamba_cache is None:
929+
self.prepare_contiguous_mamba_cache(self.model_config.dtype)
930+
conv_state, ssm_state, indices = self._prepare_request_mamba_cache(input_metadata, input_tokens.shape[0])
931+
model_inputs = {
932+
**model_inputs,
933+
"conv_state":conv_state,
934+
"ssm_state":ssm_state,
935+
}
936+
hidden_states = model_executable(**model_inputs)
937+
if is_mamba:
938+
for i, offset in enumerate(indices):
939+
self.mamba_cache[0][:, offset] = conv_state[:, i]
940+
self.mamba_cache[1][:, offset] = ssm_state[:, i]
941941

942942
# Sample the next token.
943943
output = self.model.sample(
@@ -946,6 +946,13 @@ def execute_model(
946946
)
947947
return output
948948

949+
def _get_first_free_mamba_cache_index(self):
950+
max_possible_bs = self.mamba_cache[0].shape[1]
951+
occupied = [id for seq_ids in self.request_id2index.values() for id in seq_ids.values()]
952+
first_free_index = [i not in occupied for i in range(max_possible_bs)].index(True)
953+
return first_free_index
954+
955+
949956
def _prepare_request_mamba_cache(
950957
self,
951958
input_metadata: InputMetadata,
@@ -955,13 +962,26 @@ def _prepare_request_mamba_cache(
955962
max_possible_bs = self.mamba_cache[0].shape[1]
956963
for request_info in input_metadata.requests_info:
957964
if request_info.request_id not in self.request_id2index:
958-
first_free_index = [i not in self.request_id2index.values() for i in range(max_possible_bs)].index(True)
959-
self.request_id2index[request_info.request_id] = first_free_index
960-
indices.append(self.request_id2index[request_info.request_id])
965+
self.request_id2index[request_info.request_id] = {}
966+
for seq_id in request_info.seqs_id:
967+
first_free_index = self._get_first_free_mamba_cache_index()
968+
self.request_id2index[request_info.request_id][seq_id] = first_free_index
969+
indices.append(first_free_index)
970+
else:
971+
for seq_id in request_info.seqs_id:
972+
if seq_id not in self.request_id2index[request_info.request_id]:
973+
first_free_index = self._get_first_free_mamba_cache_index()
974+
## case of decoding n>1
975+
if len(self.request_id2index[request_info.request_id].keys()) > 0:
976+
self.mamba_cache[0][:,first_free_index].copy_(self.mamba_cache[0][:,list(self.request_id2index[request_info.request_id].values())[0]])
977+
self.mamba_cache[1][:,first_free_index].copy_(self.mamba_cache[1][:,list(self.request_id2index[request_info.request_id].values())[0]])
978+
self.request_id2index[request_info.request_id][seq_id] = first_free_index
979+
indices.append(self.request_id2index[request_info.request_id][seq_id])
961980
## Pad the batch incase of running batch that was not captured via CG
962981
padded_indices = indices
963982
for _ in range(batch_size - len(indices)):
964-
padded_indices += [[i not in set(self.request_id2index.values()).union(padded_indices) for i in range(max_possible_bs)].index(True)]
983+
occupied = [id for seq_ids in self.request_id2index.values() for id in seq_ids.values()]
984+
padded_indices += [[i not in set(occupied).union(padded_indices) for i in range(max_possible_bs)].index(True)]
965985

966986
conv_state = self.mamba_cache[0][:,padded_indices]
967987
ssm_state = self.mamba_cache[1][:,padded_indices]
@@ -1140,23 +1160,26 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
11401160
kv_cache_dtype=self.kv_cache_dtype,
11411161
)
11421162

1163+
is_mamba = self.model_config.hf_config.model_type == "jamba"
11431164
if self.lora_config:
11441165
lora_mapping = LoRAMapping(
11451166
[0] * batch_size,
11461167
[0] * batch_size,
11471168
)
11481169
self.set_active_loras(set(), lora_mapping)
11491170

1150-
graph_runner = CUDAGraphRunner(self.model)
1151-
graph_runner.capture(
1152-
input_tokens[:batch_size],
1153-
input_positions[:batch_size],
1154-
kv_caches,
1155-
attn_metadata,
1156-
memory_pool=self.graph_memory_pool,
1157-
conv_state=self.mamba_cache4gc[0][:, :batch_size],
1158-
ssm_state=self.mamba_cache4gc[1][:, :batch_size]
1159-
)
1171+
graph_runner = CUDAGraphRunner(self.model,is_mamba)
1172+
capture_inputs = {
1173+
"input_ids" : input_tokens[:batch_size],
1174+
"positions" :input_positions[:batch_size],
1175+
"kv_caches": kv_caches,
1176+
"attn_metadata": attn_metadata,
1177+
"memory_pool":self.graph_memory_pool,
1178+
}
1179+
if is_mamba:
1180+
capture_inputs["conv_state"]=self.mamba_cache4gc[0][:, :batch_size]
1181+
capture_inputs["ssm_state"]=self.mamba_cache4gc[1][:, :batch_size]
1182+
graph_runner.capture(**capture_inputs)
11601183
self.graph_memory_pool = graph_runner.graph.pool()
11611184
self.graph_runners[batch_size] = graph_runner
11621185

@@ -1182,11 +1205,12 @@ def vocab_size(self) -> int:
11821205

11831206
class CUDAGraphRunner:
11841207

1185-
def __init__(self, model: nn.Module):
1208+
def __init__(self, model: nn.Module, is_mamba: bool):
11861209
self.model = model
11871210
self.graph = None
11881211
self.input_buffers: Dict[str, torch.Tensor] = {}
11891212
self.output_buffers: Dict[str, torch.Tensor] = {}
1213+
self.is_mamba = is_mamba
11901214

11911215
def capture(
11921216
self,
@@ -1197,40 +1221,38 @@ def capture(
11971221
conv_state: torch.Tensor,
11981222
ssm_state: torch.Tensor,
11991223
memory_pool,
1224+
conv_state: Optional[torch.Tensor] = None,
1225+
ssm_state: Optional[torch.Tensor] = None,
12001226
**kwargs,
12011227
) -> None:
12021228
assert self.graph is None
12031229
# Run the model once without capturing the graph.
12041230
# This is to make sure that the captured graph does not include the
12051231
# kernel launches for initial benchmarking (e.g., Triton autotune).
1206-
with _maybe_pynccl():
1207-
self.model(
1208-
input_ids,
1209-
positions,
1210-
kv_caches,
1211-
attn_metadata,
1212-
conv_state,
1213-
ssm_state
1214-
**kwargs,
1215-
)
1232+
model_inputs = {
1233+
"input_ids":input_ids,
1234+
"positions":positions,
1235+
"kv_caches":kv_caches,
1236+
"attn_metadata":attn_metadata,
1237+
}
1238+
if self.is_mamba:
1239+
model_inputs = {
1240+
**model_inputs,
1241+
"conv_state":conv_state,
1242+
"ssm_state":ssm_state,
1243+
}
1244+
1245+
with _maybe_cupy_nccl():
1246+
self.model(**model_inputs)
12161247
torch.cuda.synchronize()
12171248

12181249
# Capture the graph.
12191250
# NOTE(woosuk): Python 3.8 does not support multi-line with statements.
12201251
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
12211252
self.graph = torch.cuda.CUDAGraph()
12221253
with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117
1223-
with _maybe_pynccl():
1224-
hidden_states = self.model(
1225-
input_ids,
1226-
positions,
1227-
kv_caches,
1228-
attn_metadata,
1229-
input_metadata,
1230-
conv_state,
1231-
ssm_state
1232-
**kwargs,
1233-
)
1254+
with _maybe_cupy_nccl():
1255+
hidden_states = self.model(**model_inputs)
12341256
torch.cuda.synchronize()
12351257

12361258
# Save the input and output buffers.
@@ -1244,6 +1266,13 @@ def capture(
12441266
"conv_state": conv_state,
12451267
"ssm_state": ssm_state
12461268
}
1269+
if self.is_mamba:
1270+
self.input_buffers = {
1271+
**self.input_buffers,
1272+
"conv_state": conv_state,
1273+
"ssm_state": ssm_state,
1274+
}
1275+
12471276
self.output_buffers = {"hidden_states": hidden_states}
12481277
return
12491278

@@ -1253,8 +1282,8 @@ def forward(
12531282
positions: torch.Tensor,
12541283
kv_caches: List[torch.Tensor],
12551284
attn_metadata: AttentionMetadata,
1256-
conv_state:torch.Tensor,
1257-
ssm_state:torch.Tensor
1285+
conv_state:Optional[torch.Tensor] = None,
1286+
ssm_state:Optional[torch.Tensor] = None
12581287
**kwargs,
12591288
) -> torch.Tensor:
12601289
# KV caches are fixed tensors, so we don't need to copy them.
@@ -1269,16 +1298,19 @@ def forward(
12691298
attn_metadata.decode_metadata.context_lens, non_blocking=True)
12701299
self.input_buffers["block_tables"].copy_(
12711300
attn_metadata.decode_metadata.block_tables, non_blocking=True)
1272-
self.input_buffers["conv_state"].copy_(conv_state,
1273-
non_blocking=True)
1274-
self.input_buffers["ssm_state"].copy_(ssm_state,
1275-
non_blocking=True)
1301+
if self.is_mamba:
1302+
self.input_buffers["conv_state"].copy_(conv_state,
1303+
non_blocking=True)
1304+
self.input_buffers["ssm_state"].copy_(ssm_state,
1305+
non_blocking=True)
1306+
12761307
# Run the graph.
12771308
self.graph.replay()
12781309

12791310
# in-place edit of the mamba cache states as in the KV cache
1280-
ssm_state.copy_(self.input_buffers["ssm_state"])
1281-
conv_state.copy_(self.input_buffers["conv_state"])
1311+
if self.is_mamba:
1312+
ssm_state.copy_(self.input_buffers["ssm_state"])
1313+
conv_state.copy_(self.input_buffers["conv_state"])
12821314

12831315
# Return the output tensor.
12841316
return self.output_buffers["hidden_states"]

vllm/worker/worker.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,9 @@ def _init_cache_engine(self):
184184
self.parallel_config)
185185
self.gpu_cache = self.cache_engine.gpu_cache
186186
self.model_runner.set_block_size(self.cache_engine.block_size)
187-
self.model_runner.prepare_contiguous_mamba_cache(self.cache_engine.dtype)
187+
is_mamba = self.model_config.hf_config.model_type == "jamba"
188+
if is_mamba:
189+
self.model_runner.prepare_contiguous_mamba_cache(self.cache_engine.dtype)
188190

189191
def _warm_up_model(self) -> None:
190192
if not self.model_config.enforce_eager:
@@ -212,8 +214,8 @@ def cache_swap(
212214
def release_mamba_cache(self, finished_seq_groups_req_ids: List[str]):
213215
for req_id in finished_seq_groups_req_ids:
214216
if req_id in self.model_runner.request_id2index:
215-
index = self.model_runner.request_id2index.pop(req_id)
216-
logger.info(f"deleted { req_id } from mamba_cache with index = {index}")
217+
indices = self.model_runner.request_id2index.pop(req_id)
218+
logger.debug(f"Deleted { req_id } from mamba_cache with indices = {indices}")
217219

218220

219221
@torch.inference_mode()

0 commit comments

Comments
 (0)