@@ -2227,6 +2227,54 @@ def prepare_input_tensors(
2227
2227
lora_ids = lora_ids ), \
2228
2228
sampling_metadata
2229
2229
2230
+ @torch .inference_mode ()
2231
+ def prepare_model_input_align_worker (
2232
+ self ,
2233
+ seq_group_metadata_list : List [SequenceGroupMetadata ],
2234
+ virtual_engine : int = 0 ,
2235
+ finished_requests_ids : Optional [List [str ]] = None ,
2236
+ align_worker : bool = False ,
2237
+ ) -> ModelInputForHPUWithSamplingMetadata :
2238
+ """Prepare the model input based on a given sequence group, including
2239
+ metadata for the sampling step.
2240
+ The API assumes seq_group_metadata_list is sorted by prefill -> decode.
2241
+ The result tensors and data structure also batches input in prefill
2242
+ -> decode order. For example,
2243
+ - input_tokens[:num_prefill_tokens] contains prefill tokens.
2244
+ - input_tokens[num_prefill_tokens:] contains decode tokens.
2245
+ If cuda graph is required, this API automatically pads inputs.
2246
+ """
2247
+ with self .profiler .record_event ('internal' , 'prepare_input_tensors' ):
2248
+ assert seq_group_metadata_list is not None
2249
+ if self .profiler .enabled :
2250
+ self .profiler_counter_helper .capture_seq_group_metadata_stats (
2251
+ seq_group_metadata_list = seq_group_metadata_list )
2252
+ model_input , sampling_metadata = self .prepare_input_tensors (
2253
+ seq_group_metadata_list , finished_requests_ids , align_worker )
2254
+ assert model_input .attn_metadata is not None
2255
+ is_prompt = model_input .attn_metadata .is_prompt
2256
+
2257
+ return ModelInputForHPUWithSamplingMetadata (
2258
+ input_tokens = model_input .input_tokens ,
2259
+ input_positions = model_input .input_positions ,
2260
+ seq_lens = model_input .seq_lens ,
2261
+ query_lens = model_input .query_lens ,
2262
+ lora_mapping = model_input .lora_mapping ,
2263
+ lora_requests = model_input .lora_requests ,
2264
+ attn_metadata = model_input .attn_metadata ,
2265
+ multi_modal_kwargs = model_input .multi_modal_kwargs ,
2266
+ real_batch_size = model_input .real_batch_size ,
2267
+ batch_size_padded = model_input .batch_size_padded ,
2268
+ virtual_engine = virtual_engine ,
2269
+ lora_ids = model_input .lora_ids ,
2270
+ async_callback = model_input .async_callback ,
2271
+ is_first_multi_step = model_input .is_first_multi_step ,
2272
+ is_last_step = model_input .is_last_step ,
2273
+ previous_hidden_states = model_input .previous_hidden_states ,
2274
+ sampling_metadata = sampling_metadata ,
2275
+ is_prompt = is_prompt ,
2276
+ )
2277
+
2230
2278
def create_lora_mask (self , input_tokens : torch .Tensor , lora_ids : List [int ],
2231
2279
is_prompt : bool ):
2232
2280
'''
@@ -3160,38 +3208,6 @@ def prepare_model_input(
3160
3208
finished_requests_ids ,
3161
3209
False )
3162
3210
3163
- @torch .inference_mode ()
3164
- def prepare_model_input_align_worker (
3165
- self ,
3166
- seq_group_metadata_list : List [SequenceGroupMetadata ],
3167
- virtual_engine : int = 0 ,
3168
- finished_requests_ids : Optional [List [str ]] = None ,
3169
- align_worker : bool = False ,
3170
- ) -> ModelInputForHPUWithSamplingMetadata :
3171
- """Prepare the model input based on a given sequence group, including
3172
- metadata for the sampling step.
3173
- The API assumes seq_group_metadata_list is sorted by prefill -> decode.
3174
- The result tensors and data structure also batches input in prefill
3175
- -> decode order. For example,
3176
- - input_tokens[:num_prefill_tokens] contains prefill tokens.
3177
- - input_tokens[num_prefill_tokens:] contains decode tokens.
3178
- If cuda graph is required, this API automatically pads inputs.
3179
- """
3180
- with self .profiler .record_event ('internal' , 'prepare_input_tensors' ):
3181
- assert seq_group_metadata_list is not None
3182
- if self .profiler .enabled :
3183
- self .profiler_counter_helper .capture_seq_group_metadata_stats (
3184
- seq_group_metadata_list = seq_group_metadata_list )
3185
- model_input , sampling_metadata = self .prepare_input_tensors (
3186
- seq_group_metadata_list , finished_requests_ids , align_worker )
3187
- assert model_input .attn_metadata is not None
3188
- is_prompt = model_input .attn_metadata .is_prompt
3189
-
3190
- return dataclasses .replace (model_input ,
3191
- sampling_metadata = sampling_metadata ,
3192
- is_prompt = is_prompt ,
3193
- virtual_engine = virtual_engine )
3194
-
3195
3211
def finish_measurements (self ):
3196
3212
from neural_compressor .torch .quantization import finalize_calibration
3197
3213
finalize_calibration (self .model .model )
0 commit comments