@@ -154,7 +154,7 @@ def __init__(
154154 # cache in_wsl result
155155 self .mamba_cache = None
156156 self .mamba_cache4gc = None
157- self .request_id2index = {}
157+ self .request_id2index : Dict [ str , Dict [ int , int ]] = {}
158158 self .in_wsl = in_wsl ()
159159 self .kv_cache_dtype = kv_cache_dtype
160160
@@ -441,7 +441,7 @@ def _prepare_prompt(
441441 requests_info = [
442442 RequestInfo (
443443 request_id = req .request_id ,
444- n = req .sampling_params . n
444+ seqs_id = list ( req .seq_data . keys ())
445445 )
446446 for req in seq_group_metadata_list
447447 ]
@@ -579,10 +579,9 @@ def _prepare_decode(
579579 requests_info = [
580580 RequestInfo (
581581 request_id = req .request_id ,
582- n = req .sampling_params . n
582+ seqs_id = list ( req .seq_data . keys ())
583583 )
584584 for req in seq_group_metadata_list ]
585-
586585 )
587586 return PrepareDecodeMetadata (
588587 input_tokens = input_tokens ,
@@ -790,13 +789,7 @@ def prepare_input_tensors(
790789 "slot_mapping" : slot_mapping ,
791790 "num_prefills" : num_prefills ,
792791 "batch_type" : batch_type ,
793- "requests_info" : [
794- RequestInfo (
795- request_id = req .request_id ,
796- n = req .sampling_params .n
797- )
798- for req in seq_group_metadata_list
799- ]
792+ "requests_info" : input_metadata .requests_info
800793 }
801794 if prefill_attn_metadata is not None :
802795 metadata_dict .update (prefill_attn_metadata .asdict_zerocopy ())
@@ -922,22 +915,29 @@ def execute_model(
922915 if not sampling_metadata .perform_sampling :
923916 return None
924917
925- if self .mamba_cache is None :
926- self .prepare_contiguous_mamba_cache (self .model_config .dtype )
927-
928- conv_state , ssm_state , indecies = self ._prepare_request_mamba_cache (input_metadata , input_tokens .shape [0 ])
929-
930- hidden_states = model_executable (
931- input_ids = input_tokens ,
932- positions = input_positions ,
933- kv_caches = kv_caches ,
934- input_metadata = input_metadata ,
935- conv_state = conv_state ,
936- ssm_state = ssm_state
937- )
938- for i ,offset in enumerate (indecies ):
939- self .mamba_cache [0 ][:,offset ] = conv_state [:,i ]
940- self .mamba_cache [1 ][:,offset ] = ssm_state [:,i ]
918+ is_mamba = self .model_config .hf_config .model_type == "jamba"
919+ indices = []
920+ conv_state = None
921+ model_inputs = {
922+ "input_ids" :input_tokens ,
923+ "positions" :input_positions ,
924+ "kv_caches" :kv_caches ,
925+ "input_metadata" :input_metadata ,
926+ }
927+ if is_mamba :
928+ if self .mamba_cache is None :
929+ self .prepare_contiguous_mamba_cache (self .model_config .dtype )
930+ conv_state , ssm_state , indices = self ._prepare_request_mamba_cache (input_metadata , input_tokens .shape [0 ])
931+ model_inputs = {
932+ ** model_inputs ,
933+ "conv_state" :conv_state ,
934+ "ssm_state" :ssm_state ,
935+ }
936+ hidden_states = model_executable (** model_inputs )
937+ if is_mamba :
938+ for i , offset in enumerate (indices ):
939+ self .mamba_cache [0 ][:, offset ] = conv_state [:, i ]
940+ self .mamba_cache [1 ][:, offset ] = ssm_state [:, i ]
941941
942942 # Sample the next token.
943943 output = self .model .sample (
@@ -946,6 +946,13 @@ def execute_model(
946946 )
947947 return output
948948
949+ def _get_first_free_mamba_cache_index (self ):
950+ max_possible_bs = self .mamba_cache [0 ].shape [1 ]
951+ occupied = [id for seq_ids in self .request_id2index .values () for id in seq_ids .values ()]
952+ first_free_index = [i not in occupied for i in range (max_possible_bs )].index (True )
953+ return first_free_index
954+
955+
949956 def _prepare_request_mamba_cache (
950957 self ,
951958 input_metadata : InputMetadata ,
@@ -955,13 +962,26 @@ def _prepare_request_mamba_cache(
955962 max_possible_bs = self .mamba_cache [0 ].shape [1 ]
956963 for request_info in input_metadata .requests_info :
957964 if request_info .request_id not in self .request_id2index :
958- first_free_index = [i not in self .request_id2index .values () for i in range (max_possible_bs )].index (True )
959- self .request_id2index [request_info .request_id ] = first_free_index
960- indices .append (self .request_id2index [request_info .request_id ])
965+ self .request_id2index [request_info .request_id ] = {}
966+ for seq_id in request_info .seqs_id :
967+ first_free_index = self ._get_first_free_mamba_cache_index ()
968+ self .request_id2index [request_info .request_id ][seq_id ] = first_free_index
969+ indices .append (first_free_index )
970+ else :
971+ for seq_id in request_info .seqs_id :
972+ if seq_id not in self .request_id2index [request_info .request_id ]:
973+ first_free_index = self ._get_first_free_mamba_cache_index ()
974+ ## case of decoding n>1
975+ if len (self .request_id2index [request_info .request_id ].keys ()) > 0 :
976+ self .mamba_cache [0 ][:,first_free_index ].copy_ (self .mamba_cache [0 ][:,list (self .request_id2index [request_info .request_id ].values ())[0 ]])
977+ self .mamba_cache [1 ][:,first_free_index ].copy_ (self .mamba_cache [1 ][:,list (self .request_id2index [request_info .request_id ].values ())[0 ]])
978+ self .request_id2index [request_info .request_id ][seq_id ] = first_free_index
979+ indices .append (self .request_id2index [request_info .request_id ][seq_id ])
961980 ## Pad the batch incase of running batch that was not captured via CG
962981 padded_indices = indices
963982 for _ in range (batch_size - len (indices )):
964- padded_indices += [[i not in set (self .request_id2index .values ()).union (padded_indices ) for i in range (max_possible_bs )].index (True )]
983+ occupied = [id for seq_ids in self .request_id2index .values () for id in seq_ids .values ()]
984+ padded_indices += [[i not in set (occupied ).union (padded_indices ) for i in range (max_possible_bs )].index (True )]
965985
966986 conv_state = self .mamba_cache [0 ][:,padded_indices ]
967987 ssm_state = self .mamba_cache [1 ][:,padded_indices ]
@@ -1140,23 +1160,26 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
11401160 kv_cache_dtype = self .kv_cache_dtype ,
11411161 )
11421162
1163+ is_mamba = self .model_config .hf_config .model_type == "jamba"
11431164 if self .lora_config :
11441165 lora_mapping = LoRAMapping (
11451166 [0 ] * batch_size ,
11461167 [0 ] * batch_size ,
11471168 )
11481169 self .set_active_loras (set (), lora_mapping )
11491170
1150- graph_runner = CUDAGraphRunner (self .model )
1151- graph_runner .capture (
1152- input_tokens [:batch_size ],
1153- input_positions [:batch_size ],
1154- kv_caches ,
1155- attn_metadata ,
1156- memory_pool = self .graph_memory_pool ,
1157- conv_state = self .mamba_cache4gc [0 ][:, :batch_size ],
1158- ssm_state = self .mamba_cache4gc [1 ][:, :batch_size ]
1159- )
1171+ graph_runner = CUDAGraphRunner (self .model ,is_mamba )
1172+ capture_inputs = {
1173+ "input_ids" : input_tokens [:batch_size ],
1174+ "positions" :input_positions [:batch_size ],
1175+ "kv_caches" : kv_caches ,
1176+ "attn_metadata" : attn_metadata ,
1177+ "memory_pool" :self .graph_memory_pool ,
1178+ }
1179+ if is_mamba :
1180+ capture_inputs ["conv_state" ]= self .mamba_cache4gc [0 ][:, :batch_size ]
1181+ capture_inputs ["ssm_state" ]= self .mamba_cache4gc [1 ][:, :batch_size ]
1182+ graph_runner .capture (** capture_inputs )
11601183 self .graph_memory_pool = graph_runner .graph .pool ()
11611184 self .graph_runners [batch_size ] = graph_runner
11621185
@@ -1182,11 +1205,12 @@ def vocab_size(self) -> int:
11821205
11831206class CUDAGraphRunner :
11841207
1185- def __init__ (self , model : nn .Module ):
1208+ def __init__ (self , model : nn .Module , is_mamba : bool ):
11861209 self .model = model
11871210 self .graph = None
11881211 self .input_buffers : Dict [str , torch .Tensor ] = {}
11891212 self .output_buffers : Dict [str , torch .Tensor ] = {}
1213+ self .is_mamba = is_mamba
11901214
11911215 def capture (
11921216 self ,
@@ -1197,40 +1221,38 @@ def capture(
11971221 conv_state : torch .Tensor ,
11981222 ssm_state : torch .Tensor ,
11991223 memory_pool ,
1224+ conv_state : Optional [torch .Tensor ] = None ,
1225+ ssm_state : Optional [torch .Tensor ] = None ,
12001226 ** kwargs ,
12011227 ) -> None :
12021228 assert self .graph is None
12031229 # Run the model once without capturing the graph.
12041230 # This is to make sure that the captured graph does not include the
12051231 # kernel launches for initial benchmarking (e.g., Triton autotune).
1206- with _maybe_pynccl ():
1207- self .model (
1208- input_ids ,
1209- positions ,
1210- kv_caches ,
1211- attn_metadata ,
1212- conv_state ,
1213- ssm_state
1214- ** kwargs ,
1215- )
1232+ model_inputs = {
1233+ "input_ids" :input_ids ,
1234+ "positions" :positions ,
1235+ "kv_caches" :kv_caches ,
1236+ "attn_metadata" :attn_metadata ,
1237+ }
1238+ if self .is_mamba :
1239+ model_inputs = {
1240+ ** model_inputs ,
1241+ "conv_state" :conv_state ,
1242+ "ssm_state" :ssm_state ,
1243+ }
1244+
1245+ with _maybe_cupy_nccl ():
1246+ self .model (** model_inputs )
12161247 torch .cuda .synchronize ()
12171248
12181249 # Capture the graph.
12191250 # NOTE(woosuk): Python 3.8 does not support multi-line with statements.
12201251 # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
12211252 self .graph = torch .cuda .CUDAGraph ()
12221253 with torch .cuda .graph (self .graph , pool = memory_pool ): # noqa: SIM117
1223- with _maybe_pynccl ():
1224- hidden_states = self .model (
1225- input_ids ,
1226- positions ,
1227- kv_caches ,
1228- attn_metadata ,
1229- input_metadata ,
1230- conv_state ,
1231- ssm_state
1232- ** kwargs ,
1233- )
1254+ with _maybe_cupy_nccl ():
1255+ hidden_states = self .model (** model_inputs )
12341256 torch .cuda .synchronize ()
12351257
12361258 # Save the input and output buffers.
@@ -1244,6 +1266,13 @@ def capture(
12441266 "conv_state" : conv_state ,
12451267 "ssm_state" : ssm_state
12461268 }
1269+ if self .is_mamba :
1270+ self .input_buffers = {
1271+ ** self .input_buffers ,
1272+ "conv_state" : conv_state ,
1273+ "ssm_state" : ssm_state ,
1274+ }
1275+
12471276 self .output_buffers = {"hidden_states" : hidden_states }
12481277 return
12491278
@@ -1253,8 +1282,8 @@ def forward(
12531282 positions : torch .Tensor ,
12541283 kv_caches : List [torch .Tensor ],
12551284 attn_metadata : AttentionMetadata ,
1256- conv_state :torch .Tensor ,
1257- ssm_state :torch .Tensor
1285+ conv_state :Optional [ torch .Tensor ] = None ,
1286+ ssm_state :Optional [ torch .Tensor ] = None
12581287 ** kwargs ,
12591288 ) -> torch .Tensor :
12601289 # KV caches are fixed tensors, so we don't need to copy them.
@@ -1269,16 +1298,19 @@ def forward(
12691298 attn_metadata .decode_metadata .context_lens , non_blocking = True )
12701299 self .input_buffers ["block_tables" ].copy_ (
12711300 attn_metadata .decode_metadata .block_tables , non_blocking = True )
1272- self .input_buffers ["conv_state" ].copy_ (conv_state ,
1273- non_blocking = True )
1274- self .input_buffers ["ssm_state" ].copy_ (ssm_state ,
1275- non_blocking = True )
1301+ if self .is_mamba :
1302+ self .input_buffers ["conv_state" ].copy_ (conv_state ,
1303+ non_blocking = True )
1304+ self .input_buffers ["ssm_state" ].copy_ (ssm_state ,
1305+ non_blocking = True )
1306+
12761307 # Run the graph.
12771308 self .graph .replay ()
12781309
12791310 # in-place edit of the mamba cache states as in the KV cache
1280- ssm_state .copy_ (self .input_buffers ["ssm_state" ])
1281- conv_state .copy_ (self .input_buffers ["conv_state" ])
1311+ if self .is_mamba :
1312+ ssm_state .copy_ (self .input_buffers ["ssm_state" ])
1313+ conv_state .copy_ (self .input_buffers ["conv_state" ])
12821314
12831315 # Return the output tensor.
12841316 return self .output_buffers ["hidden_states" ]
0 commit comments