1717from QEfficient .base .common import AUTO_MODEL_MAP_TO_MODEL_TYPE_MAP , QEFF_MODEL_TYPE , QEFFCommonLoader
1818from QEfficient .base .modeling_qeff import QEFFBaseModel
1919from QEfficient .exporter .export_utils import export_onnx , fix_onnx_fp16 , generate_input_files , run_model_on_ort
20+ from QEfficient .lora .auto import QEffAutoLoraModelForCausalLM
2021from QEfficient .transformers .modeling_utils import get_lists_of_cb_qeff_models
2122from QEfficient .transformers .models .modeling_auto import QEFFAutoModelForCausalLM
2223from QEfficient .utils import load_hf_tokenizer
@@ -149,6 +150,7 @@ def convert_to_cloud_kvstyle(
149150 tokenizer : Union [PreTrainedTokenizer , PreTrainedTokenizerFast ],
150151 onnx_dir_path : str ,
151152 seq_len : int ,
153+ max_num_adapters : int ,
152154) -> str :
153155 """
154156 API to convert model with kv retention and export to ONNX.
@@ -181,7 +183,7 @@ def convert_to_cloud_kvstyle(
181183
182184 # Decide path for saving exported ONNX files.
183185 model_name = export_kvstyle_transformed_model_to_onnx (
184- model_name , qeff_model .model , tokenizer , onnx_dir_path , seq_len
186+ model_name , qeff_model .model , tokenizer , onnx_dir_path , seq_len , max_num_adapters
185187 ) # type: ignore
186188
187189 # return the model path for automation.
@@ -195,6 +197,7 @@ def export_kvstyle_transformed_model_to_onnx(
195197 onnx_dir_path : str ,
196198 seq_len : int ,
197199 full_batch_size : Optional [int ] = None ,
200+ max_num_adapters : Optional [int ] = None ,
198201) -> str :
199202 # Disabling requires_grad on all parameters
200203 for _ , p in enumerate (transformed_model .parameters ()):
@@ -213,6 +216,7 @@ def export_kvstyle_transformed_model_to_onnx(
213216 prompt_len = Constants .PROMPT_LEN ,
214217 ctx_len = seq_len ,
215218 full_batch_size = full_batch_size ,
219+ max_num_adapters = max_num_adapters ,
216220 )
217221
218222 inputs = input_handler .prepare_pytorch_inputs ()
@@ -318,6 +322,7 @@ def export_for_cloud(
318322 onnx_dir_path : str ,
319323 seq_length : int = Constants .SEQ_LEN ,
320324 full_batch_size : Optional [int ] = None ,
325+ max_num_adapters : Optional [int ] = None ,
321326) -> str :
322327 # Check if model architecture is supported for continuous batching.
323328 if full_batch_size and qeff_model .model .config .architectures [0 ] not in get_lists_of_cb_qeff_models .architectures :
@@ -326,14 +331,18 @@ def export_for_cloud(
326331 )
327332
328333 # FIXME: move all this to class instead of here, and just call qeff_model.export here.
329- if AUTO_MODEL_MAP_TO_MODEL_TYPE_MAP .get (qeff_model .__class__ , None ) == QEFF_MODEL_TYPE .CAUSALLM : # type: ignore
334+ if (
335+ AUTO_MODEL_MAP_TO_MODEL_TYPE_MAP .get (qeff_model .__class__ , None ) == QEFF_MODEL_TYPE .CAUSALLM
336+ or qeff_model .__class__ == QEffAutoLoraModelForCausalLM
337+ ): # type: ignore
330338 return export_lm_model_for_cloud (
331339 model_name = model_name ,
332340 qeff_model = qeff_model , # type: ignore
333341 tokenizer = tokenizer ,
334342 onnx_dir_path = onnx_dir_path ,
335343 seq_length = seq_length ,
336344 full_batch_size = full_batch_size ,
345+ max_num_adapters = max_num_adapters ,
337346 )
338347 else :
339348 raise NotImplementedError (
@@ -348,6 +357,7 @@ def export_lm_model_for_cloud(
348357 onnx_dir_path : str ,
349358 seq_length : int ,
350359 full_batch_size : Optional [int ] = None ,
360+ max_num_adapters : Optional [int ] = None ,
351361) -> str :
352362 if os .path .exists (onnx_dir_path ):
353363 logger .warning (f"Overriding { onnx_dir_path } " )
@@ -361,6 +371,7 @@ def export_lm_model_for_cloud(
361371 onnx_dir_path = onnx_dir_path ,
362372 seq_len = seq_length ,
363373 full_batch_size = full_batch_size ,
374+ max_num_adapters = max_num_adapters ,
364375 ) # type: ignore
365376
366377 else :
@@ -386,6 +397,7 @@ def qualcomm_efficient_converter(
386397 kv : bool = True ,
387398 form_factor : str = "cloud" ,
388399 full_batch_size : Optional [int ] = None ,
400+ max_num_adapters : Optional [int ] = None ,
389401) -> Tuple [str , str ]:
390402 """
391403 This method is an alias for ``QEfficient.export``.
@@ -466,6 +478,7 @@ def qualcomm_efficient_converter(
466478 onnx_dir_path = onnx_dir_path ,
467479 seq_length = seq_length ,
468480 full_batch_size = full_batch_size ,
481+ max_num_adapters = max_num_adapters ,
469482 )
470483 return onnx_dir_path , generated_onnx_model_path
471484 else :
0 commit comments