Skip to content

Commit c39ae01

Browse files
committed
Added CB support for Mistral3
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
1 parent 1220cf9 commit c39ae01

File tree

4 files changed

+82
-40
lines changed

4 files changed

+82
-40
lines changed

QEfficient/generation/embedding_handler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,9 @@ def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) -
168168
else:
169169
image = Image.open(image_url)
170170

171+
if "mistral3" in self._qeff_model.model.config.model_type:
172+
image = image.resize((1540, 1540))
173+
171174
# Prepare conversation format
172175
conversation = [
173176
{

QEfficient/transformers/models/internvl/modeling_internvl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def get_specializations(
151151
"batch_size": full_batch_size if continuous_batching else batch_size,
152152
"seq_len": "1",
153153
"ctx_len": ctx_len,
154-
"comp_ctx_lengths": comp_ctx_lengths_prefill[i],
154+
"comp_ctx_lengths": comp_ctx_lengths_decode[i],
155155
"num_patches": num_patches,
156156
"img_size": img_size,
157157
"vision_size": vision_size,

QEfficient/transformers/models/mistral3/modeling_mistral3.py

Lines changed: 77 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def forward(
176176
image_idx,
177177
past_key_values,
178178
comp_ctx_lengths: Optional[List[int]] = None,
179+
batch_index: Optional[torch.LongTensor] = None,
179180
):
180181
inputs_embeds = self.model.get_input_embeddings()(input_ids)
181182
vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
@@ -190,6 +191,7 @@ def forward(
190191
position_ids=position_ids,
191192
past_key_values=past_key_values,
192193
comp_ctx_lengths=comp_ctx_lengths,
194+
batch_index=batch_index,
193195
)
194196

195197
# Cast to int32 to avoid ONNXRT issue
@@ -250,7 +252,7 @@ def forward(
250252

251253
return logits, pixel_values, image_idx, outputs.past_key_values
252254

253-
def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, **kwargs):
255+
def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False, **kwargs):
254256
inputs_shapes = {}
255257
inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
256258
height = self.config.vision_config.image_size
@@ -290,10 +292,14 @@ def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offl
290292
.repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1)
291293
)
292294
lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64)
295+
296+
bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
297+
fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS
298+
293299
# Add data for KV
294300
kv_cache_shape = get_padding_shape_from_config(
295-
config=self.language_model.config,
296-
batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
301+
config=self.model.config.text_config,
302+
batch_size=fbs if continuous_batching else bs,
297303
seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
298304
)
299305

@@ -304,6 +310,8 @@ def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offl
304310

305311
if comp_ctx_lengths is not None:
306312
lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long)
313+
if continuous_batching:
314+
lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1)
307315

308316
inputs = {}
309317
if kv_offload:
@@ -324,6 +332,9 @@ def get_specializations(
324332
comp_ctx_lengths_prefill: Optional[List[int]] = None,
325333
comp_ctx_lengths_decode: Optional[List[int]] = None,
326334
kv_offload: bool = False,
335+
continuous_batching: bool = False,
336+
kv_cache_batch_size: Optional[int] = None,
337+
full_batch_size: Optional[int] = None,
327338
**compiler_options,
328339
):
329340
if img_size is None and hasattr(self.config.vision_config, "image_size"):
@@ -352,46 +363,65 @@ def get_specializations(
352363
lang = []
353364

354365
for i in range(0, len(comp_ctx_lengths_prefill)):
355-
lang.append(
356-
{
357-
"batch_size": batch_size,
358-
"seq_len": prefill_seq_len,
359-
"ctx_len": ctx_len,
360-
"comp_ctx_lengths": comp_ctx_lengths_prefill[i],
361-
"image_size": img_size,
362-
"vision_size": vision_size,
363-
}
364-
)
365-
366-
# Remaining elements use comp_ctx_lengths[1:] in a loop
367-
for i in range(0, len(comp_ctx_lengths_decode)):
368-
lang.append(
369-
{
370-
"batch_size": batch_size,
371-
"seq_len": "1",
372-
"ctx_len": ctx_len,
373-
"comp_ctx_lengths": comp_ctx_lengths_decode[i],
374-
"image_size": img_size,
375-
"vision_size": vision_size,
376-
}
377-
)
378-
else:
379-
lang = [
380-
{
381-
"batch_size": batch_size,
366+
lang_prefill = {
367+
"batch_size": 1 if continuous_batching else batch_size,
382368
"seq_len": prefill_seq_len,
383369
"ctx_len": ctx_len,
370+
"comp_ctx_lengths": comp_ctx_lengths_prefill[i],
384371
"image_size": img_size,
385372
"vision_size": vision_size,
386-
},
387-
{
388-
"batch_size": batch_size,
373+
}
374+
if continuous_batching:
375+
lang_prefill["full_batch_size"] = kv_cache_batch_size
376+
else:
377+
lang_prefill["batch_size"] = kv_cache_batch_size
378+
if full_batch_size:
379+
lang_prefill["full_batch_exec_size"] = full_batch_size
380+
lang.append(lang_prefill)
381+
382+
# Remaining elements use comp_ctx_lengths[1:] in a loop
383+
for i in range(0, len(comp_ctx_lengths_decode)):
384+
lang_decode = {
385+
"batch_size": full_batch_size if continuous_batching else batch_size,
389386
"seq_len": "1",
390387
"ctx_len": ctx_len,
388+
"comp_ctx_lengths": comp_ctx_lengths_decode[i],
391389
"image_size": img_size,
392390
"vision_size": vision_size,
393-
},
394-
]
391+
}
392+
393+
if continuous_batching:
394+
lang_decode["full_batch_size"] = kv_cache_batch_size
395+
else:
396+
lang_decode["batch_size"] = kv_cache_batch_size
397+
lang.append(lang_decode)
398+
else:
399+
lang_prefill = {
400+
"batch_size": 1 if continuous_batching else batch_size,
401+
"seq_len": prefill_seq_len,
402+
"ctx_len": ctx_len,
403+
"image_size": img_size,
404+
"vision_size": vision_size,
405+
}
406+
if continuous_batching:
407+
lang_prefill["full_batch_size"] = kv_cache_batch_size
408+
else:
409+
lang_prefill["batch_size"] = kv_cache_batch_size
410+
if full_batch_size:
411+
lang_prefill["full_batch_exec_size"] = full_batch_size
412+
413+
lang_decode = {
414+
"batch_size": full_batch_size if continuous_batching else batch_size,
415+
"seq_len": "1",
416+
"ctx_len": ctx_len,
417+
"image_size": img_size,
418+
"vision_size": vision_size,
419+
}
420+
421+
if continuous_batching:
422+
lang_decode["full_batch_size"] = kv_cache_batch_size
423+
else:
424+
lang_decode["batch_size"] = kv_cache_batch_size
395425

396426
specializations = {}
397427

@@ -404,7 +434,7 @@ def get_specializations(
404434
lang[1].pop("vision_size")
405435
return lang, compiler_options
406436

407-
def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False):
437+
def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False):
408438
# Define dynamic axes
409439
num_layers = self.config.text_config.num_hidden_layers
410440

@@ -417,9 +447,18 @@ def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv
417447
"vision_embeds": {0: "vision_size"},
418448
}
419449

450+
if continuous_batching:
451+
lang_dynamic_axes["batch_index"] = {0: "batch_size"}
452+
420453
for i in range(num_layers):
421-
lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"}
422-
lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"}
454+
lang_dynamic_axes[f"past_key.{i}"] = {
455+
0: "full_batch_size" if continuous_batching else "batch_size",
456+
2: "ctx_len",
457+
}
458+
lang_dynamic_axes[f"past_value.{i}"] = {
459+
0: "full_batch_size" if continuous_batching else "batch_size",
460+
2: "ctx_len",
461+
}
423462

424463
if comp_ctx_lengths is not None:
425464
lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"}

examples/internvl_CB_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
num_cores=16,
4646
num_devices=4,
4747
batch_size=1,
48-
full_batch_size=1,
48+
full_batch_size=4,
4949
mxfp6_matmul=True,
5050
mxint8_kv_cache=True,
5151
aic_enable_depth_first=True,

0 commit comments

Comments
 (0)