Skip to content

Commit 2eb4859

Browse files
ochouguleplatero97
authored andcommitted
SwiftKV backup PR (quic#367)
* SwiftKV support added for CB as well as non-cb
1 parent 161975e commit 2eb4859

File tree

2 files changed

+123
-40
lines changed

2 files changed

+123
-40
lines changed

QEfficient/transformers/modeling_utils.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010

1111
import torch
1212
import torch.nn as nn
13-
import transformers.models.auto.modeling_auto as mapping
14-
from transformers import AutoModelForCausalLM
1513
from transformers.models.codegen.modeling_codegen import (
1614
CodeGenAttention,
1715
CodeGenBlock,
@@ -279,20 +277,6 @@
279277
WhisperForConditionalGeneration: QEffWhisperForConditionalGeneration,
280278
}
281279

282-
# Map of model type to config class, Modelling class and transformer model architecture class
283-
MODEL_TYPE_TO_CONFIG_CLS_AND_ARCH_CLS = {
284-
"llama_swiftkv": [QEffLlamaSwiftKVConfig, QEffLlamaSwiftKVForCausalLM, AutoModelForCausalLM],
285-
}
286-
287-
288-
MODEL_CLASS_MAPPING = {
289-
**{architecture: "QEFFAutoModelForCausalLM" for architecture in mapping.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values()},
290-
**{
291-
architecture: "QEFFAutoModelForImageTextToText"
292-
for architecture in mapping.MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values()
293-
},
294-
}
295-
296280

297281
def _prepare_cross_attention_mask(
298282
cross_attention_mask: torch.Tensor,

tests/transformers/models/test_causal_lm_models.py

Lines changed: 123 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,6 @@
4747
"ibm-granite/granite-3.1-1b-a400m-base",
4848
]
4949

50-
test_models_qnn = [
51-
"mistralai/Mixtral-8x7B-Instruct-v0.1",
52-
"meta-llama/Llama-3.2-1B",
53-
"unsloth/gemma-2b",
54-
"ibm-granite/granite-guardian-3.1-2b",
55-
]
56-
5750
spd_test_models = [
5851
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
5952
"Qwen/Qwen2-0.5B",
@@ -122,7 +115,6 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
122115
)
123116

124117
pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf)
125-
126118
is_tlm = False if num_speculative_tokens is None else True
127119
qaic_config = None
128120
if is_tlm:
@@ -156,22 +148,16 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
156148
qnn_config=qnn_config,
157149
)
158150
exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR)
159-
cloud_ai_100_tokens = exec_info.generated_ids[0][
160-
:, :gen_len
161-
] # Because we always run for single input and single batch size
162-
if prefill_only:
163-
assert (ort_tokens[0][0] == cloud_ai_100_tokens[0][0]).all(), (
164-
"prefill run output tokens don't match for ONNXRT output and Cloud AI 100 output."
165-
)
166-
else:
167-
assert (ort_tokens == cloud_ai_100_tokens).all(), (
168-
"Tokens don't match for ONNXRT output and Cloud AI 100 output."
169-
)
170-
assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json"))
171-
if prefill_only is not None:
172-
return
151+
cloud_ai_100_tokens = exec_info.generated_ids[0] # Because we always run for single input and single batch size
152+
gen_len = ort_tokens.shape[-1]
153+
assert (ort_tokens == cloud_ai_100_tokens[:, :gen_len]).all(), (
154+
"Tokens don't match for ONNXRT output and Cloud AI 100 output."
155+
)
156+
assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json"))
157+
173158
# testing for CB models
174159
model_hf, _ = load_causal_lm_model(model_config)
160+
config = model_hf.config
175161
full_batch_size = 4
176162
fbs_prompts = Constants.INPUT_STR * 4
177163
api_runner = ApiRunner(
@@ -187,7 +173,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
187173
pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch_CB(model_hf)
188174
pytorch_hf_tokens = np.vstack(pytorch_hf_tokens)
189175

190-
qeff_model = QEFFAutoModelForCausalLM(model_hf, continuous_batching=True, is_tlm=is_tlm)
176+
qeff_model = QEFFAutoModelForCausalLM(model_hf, continuous_batching=True, is_tlm=False)
191177
onnx_model_path = qeff_model.export()
192178

193179
if not get_available_device_id():
@@ -198,7 +184,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
198184
prefill_seq_len=prompt_len,
199185
ctx_len=ctx_len,
200186
num_cores=14,
201-
mxfp6=False,
187+
mxfp6_matmul=False,
202188
aic_enable_depth_first=False,
203189
full_batch_size=full_batch_size,
204190
num_speculative_tokens=num_speculative_tokens,
@@ -216,6 +202,103 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
216202
assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json"))
217203

218204

205+
def check_non_hf_kv_vs_ort_vs_ai100(
206+
model_name: str,
207+
prompt_len: int = Constants.PROMPT_LEN,
208+
ctx_len: int = Constants.CTX_LEN,
209+
n_layer: int = 1,
210+
num_speculative_tokens: Optional[int] = None,
211+
):
212+
"""
213+
Validate the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching.
214+
``Mandatory`` Args:
215+
:model_name (str): Hugging Face Model Card name, Example: ``gpt2``
216+
:prompt_len (int): Prompt length for the model to compile.
217+
:ctx_len (int): Maximum context length to compile the model.
218+
:n_layers (int): Number of layers for the Model.
219+
"""
220+
replace_transformers_quantizers()
221+
model_config = {"model_name": model_name}
222+
model_config["n_layer"] = n_layer
223+
224+
model_hf, _ = load_causal_lm_model(model_config)
225+
226+
tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_name)
227+
config = model_hf.config
228+
batch_size = len(Constants.INPUT_STR)
229+
api_runner = ApiRunner(
230+
batch_size,
231+
tokenizer,
232+
config,
233+
Constants.INPUT_STR,
234+
Constants.PROMPT_LEN,
235+
Constants.CTX_LEN,
236+
)
237+
238+
is_tlm = False if num_speculative_tokens is None else True
239+
240+
qeff_model = QEFFAutoModelForCausalLM(model_hf, is_tlm=is_tlm)
241+
pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model)
242+
243+
onnx_model_path = qeff_model.export()
244+
ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path, is_tlm=is_tlm)
245+
246+
assert (pytorch_kv_tokens == ort_tokens).all(), "Tokens don't match for ONNXRT output and PyTorch output."
247+
248+
if not get_available_device_id():
249+
pytest.skip("No available devices to run model on Cloud AI 100")
250+
251+
qpc_path = qeff_model.compile(
252+
prefill_seq_len=prompt_len,
253+
ctx_len=ctx_len,
254+
num_cores=14,
255+
mxfp6=False,
256+
aic_enable_depth_first=False,
257+
num_speculative_tokens=num_speculative_tokens,
258+
)
259+
260+
exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR)
261+
cloud_ai_100_tokens = exec_info.generated_ids[0] # Because we always run for single input and single batch size
262+
gen_len = ort_tokens.shape[-1]
263+
264+
assert (ort_tokens == cloud_ai_100_tokens[:, :gen_len]).all(), (
265+
"Tokens don't match for ONNXRT output and Cloud AI 100 output."
266+
)
267+
assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json"))
268+
269+
# testing for CB models
270+
model_hf, _ = load_causal_lm_model(model_config)
271+
config = model_hf.config
272+
full_batch_size = 4
273+
fbs_prompts = Constants.INPUT_STR * 4
274+
275+
qeff_model = QEFFAutoModelForCausalLM(model_hf, continuous_batching=True, is_tlm=is_tlm)
276+
onnx_model_path = qeff_model.export()
277+
278+
if not get_available_device_id():
279+
pytest.skip("No available devices to run model on Cloud AI 100")
280+
281+
qpc_path = qeff_model.compile(
282+
prefill_seq_len=prompt_len,
283+
ctx_len=ctx_len,
284+
num_cores=14,
285+
mxfp6=False,
286+
aic_enable_depth_first=False,
287+
full_batch_size=full_batch_size,
288+
num_speculative_tokens=num_speculative_tokens,
289+
)
290+
291+
exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts)
292+
293+
assert all(
294+
[
295+
all(pt_token[:24] == cloud_token[:24])
296+
for pt_token, cloud_token in zip(ort_tokens, exec_info_fbs.generated_ids)
297+
]
298+
), "Tokens don't match for HF PyTorch model output and Cloud AI 100 output."
299+
assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json"))
300+
301+
219302
# FIXME: there should be a CB test here
220303
@pytest.mark.parametrize("model_name", ["gpt2"], ids=lambda x: x)
221304
def test_causal_lm_export_with_deprecated_api(model_name):
@@ -262,6 +345,22 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name):
262345
check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer)
263346

264347

348+
@pytest.mark.on_qaic
349+
@pytest.mark.parametrize("model_name", swiftkv_test_models)
350+
def test_non_hf_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name):
351+
"""
352+
Test function to validate the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching.
353+
``Mandatory`` Args:
354+
:model_name (str): Hugging Face Model Card name, Example: ``gpt2``
355+
"""
356+
if model_name == "Snowflake/Llama-3.1-SwiftKV-8B-Instruct":
357+
n_layer = 32
358+
else:
359+
n_layer = 2
360+
361+
check_non_hf_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer)
362+
363+
265364
@pytest.mark.on_qaic
266365
@pytest.mark.qnn
267366
@pytest.mark.parametrize("model_name", test_models_qnn)

0 commit comments

Comments
 (0)