Skip to content
Merged
10 changes: 5 additions & 5 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,11 @@ def _compile(
qpc_path = compile_dir / "qpc"
if not onnx_path.is_file():
raise FileNotFoundError(f"ONNX file not found at: {onnx_path}")

command = constants.COMPILER + [f"-m={onnx_path}"]
if mdp_ts_json_path := compiler_options.pop("mdp_ts_json_path", None):
mdp_ts_num_devices = None
command.append(f"-mdp-load-partition-config={mdp_ts_json_path}")

for key, value in compiler_options.items():
option = "-" + key.replace("_", "-")
if isinstance(value, bool):
Expand All @@ -262,9 +265,6 @@ def _compile(
if custom_io is not None:
compile_hash.update(to_hashable(custom_io))

if mdp_ts_num_devices > 1:
compile_hash.update(to_hashable({"mdp_ts_num_devices": mdp_ts_num_devices}))

if num_speculative_tokens:
compile_hash.update(to_hashable({"num_speculative_tokens": num_speculative_tokens}))

Expand Down Expand Up @@ -300,7 +300,7 @@ def _compile(
command.append(f"-custom-IO-list-file={custom_io_yaml}")

# Write mdp_config.json file
if mdp_ts_num_devices > 1:
if not mdp_ts_json_path and mdp_ts_num_devices > 1:
num_cores = compiler_options.get("aic_num_cores", 16)
mdp_ts_json = compile_dir / f"mdp_ts_{mdp_ts_num_devices}.json"
with open(mdp_ts_json, "w") as fp:
Expand Down
219 changes: 135 additions & 84 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,12 @@ def onnx_path(self):

@property
def qpc_path(self):
return [self.vision_model.qpc_path, self.lang_model.qpc_path]
if self.vision_model.qpc_path and self.lang_model.qpc_path:
return [self.vision_model.qpc_path, self.lang_model.qpc_path]
elif self.vision_model.qpc_path:
return self.vision_model.qpc_path
else:
return self.lang_model.qpc_path

def export(
self,
Expand Down Expand Up @@ -600,6 +605,8 @@ def compile(
num_speculative_tokens: Optional[int] = None,
enable_qnn: bool = False,
qnn_config: Optional[str] = None,
skip_vision: Optional[bool] = False,
skip_lang: Optional[bool] = False,
**compiler_options,
) -> str:
if (
Expand All @@ -615,6 +622,9 @@ def compile(
f"enable_qnn={enable_qnn}, qnn_config={qnn_config}"
)

if skip_lang and skip_vision:
raise ValueError("Expected at least one of 'skip_lang' or 'skip_vision' to be False")

output_names = self.model.get_output_names(kv_offload=True)

specializations, compiler_options = self.model.get_specializations(
Expand Down Expand Up @@ -642,41 +652,43 @@ def compile(
):
self.export()

self.vision_model._compile(
compile_dir,
compile_only=True,
specializations=specializations["vision"],
convert_to_fp16=True,
mxfp6_matmul=mxfp6_matmul,
mdp_ts_num_devices=num_devices,
aic_num_cores=num_cores,
custom_io=custom_io_vision,
**compiler_options,
)
if not skip_vision:
self.vision_model._compile(
compile_dir,
compile_only=True,
specializations=specializations["vision"],
convert_to_fp16=True,
mxfp6_matmul=mxfp6_matmul,
mdp_ts_num_devices=num_devices,
aic_num_cores=num_cores,
custom_io=custom_io_vision,
**compiler_options,
)

custom_io_lang = {}
# Inputs
for output_name in output_names["lang"]:
if output_name.endswith("_RetainedState"):
custom_io_lang[output_name[: -len("_RetainedState")]] = kv_cache_dtype
if not skip_lang:
custom_io_lang = {}
# Inputs
for output_name in output_names["lang"]:
if output_name.endswith("_RetainedState"):
custom_io_lang[output_name[: -len("_RetainedState")]] = kv_cache_dtype

# outputs
for output_name in output_names["lang"]:
if output_name.endswith("_RetainedState"):
custom_io_lang[output_name] = kv_cache_dtype
# outputs
for output_name in output_names["lang"]:
if output_name.endswith("_RetainedState"):
custom_io_lang[output_name] = kv_cache_dtype

self.lang_model._compile(
compile_dir,
compile_only=True,
retained_state=True,
specializations=specializations["lang"],
convert_to_fp16=True,
mxfp6_matmul=mxfp6_matmul,
mdp_ts_num_devices=num_devices,
aic_num_cores=num_cores,
custom_io=custom_io_lang,
**compiler_options,
)
self.lang_model._compile(
compile_dir,
compile_only=True,
retained_state=True,
specializations=specializations["lang"],
convert_to_fp16=True,
mxfp6_matmul=mxfp6_matmul,
mdp_ts_num_devices=num_devices,
aic_num_cores=num_cores,
custom_io=custom_io_lang,
**compiler_options,
)
return self.qpc_path

def generate(
Expand Down Expand Up @@ -711,6 +723,9 @@ def kv_offload_generate(
device_ids: List[int] = None,
generation_len: int = None,
):
if not self.vision_model.qpc_path or not self.lang_model.qpc_path:
raise TypeError("Please run compile API for vision and language model first!")

lang_session = QAICInferenceSession(self.lang_model.qpc_path, device_ids, activate=False)

vision_session = QAICInferenceSession(self.vision_model.qpc_path, device_ids)
Expand Down Expand Up @@ -1461,6 +1476,51 @@ def export(self, export_dir: Optional[str] = None) -> str:
export_dir=export_dir,
)

def build_prefill_specialization(
self,
prefill_seq_len: int = 32,
ctx_len: int = 128,
batch_size: int = 1,
kv_cache_batch_size: Optional[int] = None,
full_batch_size: Optional[int] = None,
):
spec = {
"batch_size": 1 if self.continuous_batching else batch_size,
"seq_len": prefill_seq_len,
"ctx_len": ctx_len,
"num_logits_to_keep": 1 if self.is_tlm else None,
}
if self.continuous_batching:
spec["full_batch_size"] = kv_cache_batch_size
else:
spec["batch_size"] = kv_cache_batch_size
if full_batch_size:
spec["full_batch_exec_size"] = full_batch_size
return {k: v for k, v in spec.items() if v is not None}

def build_decode_specialization(
self,
prefill_seq_len: int = 32,
ctx_len: int = 128,
batch_size: int = 1,
kv_cache_batch_size: Optional[int] = None,
full_batch_size: Optional[int] = None,
num_speculative_tokens: Optional[int] = None,
):
if prefill_seq_len == 1 and not self.continuous_batching:
return None # Avoid duplication with prefill
spec = {
"batch_size": full_batch_size if self.continuous_batching else batch_size,
"seq_len": (num_speculative_tokens + 1) if self.is_tlm else 1,
"ctx_len": ctx_len,
"num_logits_to_keep": (num_speculative_tokens + 1) if self.is_tlm else None,
}
if self.continuous_batching:
spec["full_batch_size"] = kv_cache_batch_size
else:
spec["batch_size"] = kv_cache_batch_size
return {k: v for k, v in spec.items() if v is not None}

def compile(
self,
onnx_path: Optional[str] = None,
Expand All @@ -1478,6 +1538,7 @@ def compile(
num_speculative_tokens: Optional[int] = None,
enable_qnn: bool = False,
qnn_config: Optional[str] = None,
prefill_only: Optional[bool] = None,
**compiler_options,
) -> str:
"""
Expand All @@ -1501,74 +1562,63 @@ def compile(
:aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``.
:enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.``
:qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.``
:prefill_only (bool): if ``True`` compile for prefill only and if ``False`` compile for decode only. Defaults to None, which compiles for both ``prefill and ``decode``.
:compiler_options (dict, optional): Any other options that the `qaic-exec` takes. ``Defaults to None``.

Returns:
:str: Path of the compiled ``qpc`` package.
"""
# --- Validation ---
if prefill_only is not None and not isinstance(prefill_only, bool):
raise TypeError("`prefill_only` must be a boolean.")

if self.is_tlm:
# assert num_speculative_tokens cfg is acceptable if defined
if num_speculative_tokens is None:
raise TypeError("missing required argument `num_speculative_tokens` as `is_tlm` is True.")
if not isinstance(num_speculative_tokens, int) and num_speculative_tokens < 2:
ValueError(
f"`num_speculative_tokens` arg should be an integer greater than 1, got {num_speculative_tokens}"
)
num_logits_to_keep = num_speculative_tokens + 1
if prefill_seq_len < num_logits_to_keep:
raise TypeError("`num_speculative_tokens` is required when `is_tlm=True`.")
if not isinstance(num_speculative_tokens, int) or num_speculative_tokens < 2:
raise ValueError("`num_speculative_tokens` must be an integer >= 2.")
if prefill_seq_len < (num_speculative_tokens + 1):
raise ValueError(
f"sequence length ({prefill_seq_len}) must be at least `num_speculative_tokens+1` ({num_logits_to_keep})"
f"`prefill_seq_len` must be at least `num_speculative_tokens + 1` "
f"({num_speculative_tokens + 1}), got {prefill_seq_len}."
)

if self.continuous_batching and full_batch_size is None:
raise TypeError("missing required argument: 'full_batch_size'")
raise TypeError("`full_batch_size` is required when `continuous_batching=True`.")

if kv_cache_batch_size and not full_batch_size:
raise ValueError(
"Prefix caching is enabled only for continuous batching as of now. Please pass `full_batch_size` argument and make sure you pass `continuous_batching=True` in the `from_pretrained` call"
"KV caching requires continuous batching. Please set `full_batch_size` and "
"enable `continuous_batching=True` in `from_pretrained`."
)

kv_cache_batch_size = (
kv_cache_batch_size if kv_cache_batch_size else (full_batch_size if full_batch_size else batch_size)
)
# Define prefill specialization
prefill_specialization = {
# Prefill is always run with single BS for continuous batching.
"batch_size": 1 if self.continuous_batching else batch_size,
"seq_len": prefill_seq_len,
"ctx_len": ctx_len,
# TODO: should be renamed to kv_cache_batch_size in specialization too
}
prefill_specialization.update({"num_logits_to_keep": 1}) if self.is_tlm else ...
if self.continuous_batching:
prefill_specialization.update({"full_batch_size": kv_cache_batch_size})
else:
prefill_specialization.update({"batch_size": kv_cache_batch_size})
prefill_specialization.update({"full_batch_exec_size": full_batch_size}) if full_batch_size else ...
specializations = [
prefill_specialization,
]
# Infer kv_cache_batch_size if not provided
kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size

# Skip decode specialization if we are not in continuous batching and prefill_seq_len=1 as this repeats prefill specialization
if prefill_seq_len != 1 or self.continuous_batching:
decode_specialization = {
"batch_size": full_batch_size if self.continuous_batching else batch_size,
"seq_len": num_speculative_tokens + 1 if self.is_tlm else 1,
"ctx_len": ctx_len,
}
if self.continuous_batching:
decode_specialization.update({"full_batch_size": kv_cache_batch_size})
else:
decode_specialization.update({"batch_size": kv_cache_batch_size})
decode_specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) if self.is_tlm else ...
specializations.append(decode_specialization)
# --- Specializations ---
specializations = []

if prefill_only is None or prefill_only or prefill_seq_len == 1:
specializations.append(
self.build_prefill_specialization(
prefill_seq_len, ctx_len, batch_size, kv_cache_batch_size, full_batch_size
)
)
if prefill_only is None or not prefill_only:
decode_spec = self.build_decode_specialization(
prefill_seq_len, ctx_len, batch_size, kv_cache_batch_size, full_batch_size, num_speculative_tokens
)
if decode_spec:
specializations.append(decode_spec)

# --- Compilation ---
if enable_qnn:
if compiler_options:
logger.warning("Extra arguments to QNN compilation are supported via qnn_config.json only")
logger.warning("Extra arguments to QNN compilation are ignored. Use `qnn_config.json`.")

qpc_path = self._qnn_compile(
onnx_path,
compile_dir,
onnx_path=onnx_path,
compile_dir=compile_dir,
specializations=specializations,
prefill_seq_len=prefill_seq_len,
ctx_len=ctx_len,
Expand All @@ -1582,17 +1632,17 @@ def compile(
kv_cache_batch_size=kv_cache_batch_size,
)
else:
# Custom IO
custom_io = {}
kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
custom_io = {}

for suffix in ["", "_RetainedState"]:
for i in range(self.num_layers):
for kv in ["key", "value"]:
custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype

qpc_path = self._compile(
onnx_path,
compile_dir,
onnx_path=onnx_path,
compile_dir=compile_dir,
compile_only=True,
retained_state=True,
specializations=specializations,
Expand All @@ -1604,6 +1654,7 @@ def compile(
aic_num_cores=num_cores,
**compiler_options,
)

return qpc_path

# FIXME: Update this method to match with transformers AutoModelForCausalLM.generate
Expand Down
3 changes: 2 additions & 1 deletion QEfficient/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,8 @@ def wrapper(self, *args, **kwargs):
**{
k: v
for k, v in kwargs.items()
if k not in ["specializations", "mdp_ts_num_devices", "num_speculative_tokens", "custom_io"]
if k
not in ["specializations", "mdp_ts_num_devices", "num_speculative_tokens", "custom_io", "onnx_path"]
},
)
return result
Expand Down
Loading
Loading