Skip to content

Unit Tests for On Device Sampling #463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
195 changes: 168 additions & 27 deletions QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from collections import deque
from dataclasses import dataclass
from time import perf_counter
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import transformers
Expand Down Expand Up @@ -322,6 +322,9 @@ def cloud_ai_100_exec_kv(
automation=False,
prompt_to_lora_id_mapping: Optional[List[int]] = None,
is_tlm: bool = False,
include_sampler: bool = False,
return_pdfs: bool = False,
sampling_params: Optional[Dict[str, Any]] = None,
):
"""
This method generates output until ``eos`` or ``generation_len`` by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards.
Expand All @@ -342,6 +345,15 @@ def cloud_ai_100_exec_kv(
:Write_io_dir (str): Path to write the input and output files. ``Defaults to None``.
:automation (bool): If true, it prints input, output, and performance stats. ``Defaults to False``.
:prompt_to_lora_id_mapping (List[int]): Mapping to associate prompts with their respective LoRA adapter.
:include_sampler (bool): Enable/Disable sampling of next tokens.
:return_pdfs (bool): Return probability distributions along with sampled
next tokens. For Speculative Decoding Target Language Model,
`return_pdfs`=True always. Otherwise, `return_pdfs`=True for Speculative
Decoding Draft Language Model and `return_pdfs`=False for regular model.
sampling_params (Dict[str, Any]): A dictionary of sampling parameters supported by the QAIC backend.
The dictionary should contain the following keys:
`repetition_penalties`, `presence_penalties`, `temperatures`, `top_ks`, `top_ps`,
`min_ps`, and `random_numbers`. Each value should be a numpy array of shape (batch_size, 1).

Returns:
:CloudAI100ExecInfo: Object holding execution output and performance details.
Expand Down Expand Up @@ -372,6 +384,9 @@ def cloud_ai_100_exec_kv(
write_io_dir=write_io_dir,
full_batch_size=full_batch_size,
is_tlm=is_tlm,
include_sampler=include_sampler,
return_pdfs=return_pdfs,
sampling_params=sampling_params,
)
if full_batch_size is None:
exec_info = [
Expand Down Expand Up @@ -411,14 +426,59 @@ def __init__(
enable_debug_logs: bool = False,
write_io_dir: Optional[str] = None,
is_tlm: Optional[int] = None,
include_sampler: bool = False,
return_pdfs: bool = False,
sampling_params: Optional[Dict[str, Any]] = None,
) -> None:
self._ctx_len = ctx_len
self._write_io_dir = write_io_dir
self.is_tlm = is_tlm
self.include_sampler = include_sampler
self.return_pdfs = return_pdfs
self.sampling_params = sampling_params

# Load QPC
self._session = QAICInferenceSession(qpc_path, device_id, enable_debug_logs=enable_debug_logs)

# Validate sampler inputs for On-Device Sampling
sampler_inputs = [
"last_accepted_output_tokens",
"repetition_penalties",
"presence_penalties",
"temperatures",
"top_ks",
"top_ps",
"min_ps",
"random_numbers",
]
count = 0
for session_input_name in self._session.input_names:
if session_input_name in sampler_inputs:
count += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can there be a case where user provides the same session_input_names multiple times. In that case how we will catch it in this code.
count variable will keep on incrementing and may satisfy the condition

Copy link
Contributor Author

@quic-sanising quic-sanising Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._session.input_names comes from the exported ONNX file. If there are duplicate names, say abc, the ONNX will convert them to something like abc_0, abc_1, so on. So, we would never get the same name multiple times.

However, if accuracy is the only priority here and performance is not, I could use set() but it would add a slight overhead of O(n).

if count == len(sampler_inputs):
self.include_sampler = True
break
if count == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can avoid this if.. else block.
at line 455 by default set self.include_sampler = False.
Then at line 458 before break set it to True.
At line 462 just check for error condition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case the user provides include_sampler as input, self.include_sampler is not set to False. That is why, we need the check in line 460.

We can only avoid the else block in line 468.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make the change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

self.include_sampler = False
elif count < len(sampler_inputs):
raise ValueError(
"The provided QPC does not have the required number of inputs to run sampling "
f"on the QAIC device (only {count}/{len(sampler_inputs)} inputs provided). Partial "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should do count % sampler_inputs here. If we divide count by len(sampler_inputs) then it would return 0.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only a print statement. We are not actually dividing here. So, if count = 5 and len(sampler_inputs) = 10, it would print (only 5/10 inputs provided).

"sampling support is not available. Please check the QPC and try again."
)

if include_sampler and not self.include_sampler:
logger.warning_once(
"User entered `include_sampler`=True. But the provided QPC is not compiled "
"to run sampling on the QAIC device. Falling back to the PyTorch backend."
)
elif (include_sampler is None or not include_sampler) and self.include_sampler:
raise ValueError(
"The provided QPC is compiled to run sampling on the QAIC device. "
"But the user did not enter `include_sampler`=True. Please make sure the input "
"is specified correctly."
)

# Fetch the variables from the QPC
self._vocab_size = self._fetch_vocab_size() # Fetch Vocab size
self.batch_size, self._prefill_seq_len = self._fetch_batch_size_prefill_seq_len()
Expand Down Expand Up @@ -523,10 +583,17 @@ def _fetch_vocab_size(
Returns:
vocab_size: The vocabulary size fetched from the session's allowed shapes.
"""
key = (
"probs"
if self.include_sampler and self.return_pdfs
else "next_tokens"
if self.include_sampler
else "logits"
)
if self._session.allowed_shapes:
return [x[self._session.binding_index_map["logits"]] for x in self._session.allowed_shapes][0][1][2]
return [x[self._session.binding_index_map[key]] for x in self._session.allowed_shapes][0][1][2]

return self._session.bindings[self._session.binding_index_map["logits"]].dims[2]
return self._session.bindings[self._session.binding_index_map[key]].dims[2]

def _fetch_generation_len(self, generation_len, max_gen_len):
"""
Expand Down Expand Up @@ -574,6 +641,21 @@ def prepare_decode_inputs(self):
decode_inputs["position_ids"] = self.decode_pos_ids
if self.batch_index is not None:
decode_inputs["batch_index"] = self.batch_index
if self.include_sampler:
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"]
for op in [
"repetition_penalties",
"presence_penalties",
"temperatures",
"top_ks",
"top_ps",
"min_ps",
"random_numbers",
]:
if self.batch_index is not None:
decode_inputs[op] = self.sampling_params[op][self.batch_index.flatten()]
else:
decode_inputs[op] = self.sampling_params[op]

if self._prompt_to_lora_id_mapping_decode:
if self.full_batch_size:
Expand All @@ -589,21 +671,24 @@ def prepare_decode_inputs(self):

def _fetch_next_token_id(self, outputs):
"""
Fetches the next token ID from the model's output logits.
The method identifies the token with the highest probability using argmax along the last dimension.
Fetches the next token ID from the model's output.

Args:
outputs (dict): A dictionary containing the model's output logits. The key "logits" should map to a numpy array of shape (batch_size, sequence_length, vocab_size) or (batch_size, vocab_size).
outputs (dict): A dictionary containing the model's output.

Returns:
numpy.ndarray: An array of the next token IDs for each sequence in the batch.
"""
logits = outputs["logits"]
if len(logits.shape) == 2:
logits = np.expand_dims(logits, 1)

# Get output token
next_token_id = logits.argmax(2)
return next_token_id
if self.include_sampler:
if self.return_pdfs:
return outputs["probs"].argmax(2)
else:
return outputs["next_tokens"].reshape(outputs["next_tokens"].shape[0], outputs["next_tokens"].shape[1])
else:
logits = outputs["logits"]
if len(logits.shape) == 2:
logits = np.expand_dims(logits, 1)
return logits.argmax(2)

def initialize_decode_inputs(self, num_prompts, execution_batch_size, max_gen_length):
"""
Expand Down Expand Up @@ -673,6 +758,23 @@ def run_prefill_for_all_inputs(self, prompt_queue, generation_len):

_ = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id)

def _set_output_buffers(self, batch_size: int = 1, sequence_length: int = 1):
"""
Sets the sizes of the output buffers.

Args:
batch_size (int): The batch size.
"""
if self.include_sampler:
if self.return_pdfs:
probs_out_placeholder = np.zeros((batch_size, sequence_length, self._vocab_size), dtype=np.float32)
self._session.set_buffers({"probs": probs_out_placeholder})
next_tokens_out_placeholder = np.zeros((batch_size, sequence_length, 1), dtype=np.int64)
self._session.set_buffers({"next_tokens": next_tokens_out_placeholder})
else:
logits_out_placeholder = np.zeros((batch_size, sequence_length, self._vocab_size), dtype=np.float32)
self._session.set_buffers({"logits": logits_out_placeholder})

def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_id=None):
"""
Runs prefill for a given prompt and generation length.
Expand Down Expand Up @@ -702,9 +804,8 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
max_gen_len = self._ctx_len - position_ids.max()
generation_len = self._fetch_generation_len(generation_len, max_gen_len)

# Set the prefill logic buffer
logits_out_placeholder = np.zeros((prefill_logit_bs, 1, self._vocab_size), dtype=np.float32)
self._session.set_buffers({"logits": logits_out_placeholder})
# Set the prefill output buffers
self._set_output_buffers(batch_size=prefill_logit_bs, sequence_length=1)

inputs = self.tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len)
inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)
Expand All @@ -714,6 +815,21 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
inputs["batch_index"] = decode_batch_id
if self.is_tlm:
inputs["num_logits_to_keep"] = np.zeros((1, 1))
if self.include_sampler:
inputs["last_accepted_output_tokens"] = inputs["input_ids"]
for op in [
"repetition_penalties",
"presence_penalties",
"temperatures",
"top_ks",
"top_ps",
"min_ps",
"random_numbers",
]:
if decode_batch_id is not None:
inputs[op] = self.sampling_params[op][decode_batch_id.flatten()]
else:
inputs[op] = self.sampling_params[op]

if self._prompt_to_lora_id_mapping_prefill:
if self.full_batch_size:
Expand All @@ -732,6 +848,8 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
chunk_inputs["position_ids"] = inputs["position_ids"][
:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len
]
if self.include_sampler:
chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"]
outputs = self._session.run(chunk_inputs)
if self._write_io_dir is not None:
write_io_files(inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False)
Expand All @@ -753,11 +871,12 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):

"""

# Set logits placeholder for decode
logits_out_placeholder = np.zeros(
(self.full_batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32
# Set output placeholders for decode
self._set_output_buffers(
batch_size=self.full_batch_size,
sequence_length=self._decode_seq_len,
)
self._session.set_buffers({"logits": logits_out_placeholder})

# Generate flag for tracking progress for each batch ID
current_decode_ongoing = np.full((self.full_batch_size, 1), True)

Expand All @@ -775,10 +894,7 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
outputs = self._session.run(decode_inputs)

# Prepare inputs for next iteration
logits = outputs["logits"]
if len(logits.shape) == 2:
logits = np.expand_dims(logits, 1)
next_token_id = logits.argmax(2)
next_token_id = self._fetch_next_token_id(outputs)

for decode_batch_id in range(self.full_batch_size):
if (
Expand All @@ -800,7 +916,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
self.generated_ids[batch_id_map[decode_batch_id], 0] = new_token_id.squeeze(1)
generated_id_current_index[decode_batch_id] = 1

self._session.set_buffers({"logits": logits_out_placeholder})
self._set_output_buffers(
batch_size=self.full_batch_size,
sequence_length=self._decode_seq_len,
)
decode_pause_time += perf_counter() - start

if self._prompt_to_lora_id_mapping_decode:
Expand All @@ -817,6 +936,8 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
self.generated_ids[batch_id_map[decode_batch_id], generated_id_current_index[decode_batch_id]] = (
next_token_id[decode_batch_id, -1]
)
if self.include_sampler:
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"]

generated_id_current_index[decode_batch_id] += 1

Expand All @@ -840,6 +961,11 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
(self.batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32
)
self._session.set_buffers({"logits": logits_out_placeholder})
else:
self._set_output_buffers(
batch_size=self.batch_size,
sequence_length=self._decode_seq_len,
)
finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id
num_token = 0
for num_token in range(1, generation_len):
Expand All @@ -852,10 +978,12 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
self._write_io_dir = None

# Prepare inputs for next iteration
decode_inputs["input_ids"] = outputs["logits"].argmax(2)
decode_inputs["input_ids"] = self._fetch_next_token_id(outputs)
decode_inputs["position_ids"][:, -1] += 1
self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1]
finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id
if self.include_sampler:
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"]

if finished_sequences.all():
break
Expand Down Expand Up @@ -905,9 +1033,22 @@ def __init__(
enable_debug_logs: bool = False,
write_io_dir: Optional[str] = None,
is_tlm: bool = False,
include_sampler: bool = False,
return_pdfs: bool = False,
sampling_params: Optional[Dict[str, Any]] = None,
) -> None:
self._qaic_model = QEffTextGenerationBase(
tokenizer, qpc_path, full_batch_size, ctx_len, device_id, enable_debug_logs, write_io_dir, is_tlm
tokenizer=tokenizer,
qpc_path=qpc_path,
full_batch_size=full_batch_size,
ctx_len=ctx_len,
device_id=device_id,
enable_debug_logs=enable_debug_logs,
write_io_dir=write_io_dir,
is_tlm=is_tlm,
include_sampler=include_sampler,
return_pdfs=return_pdfs,
sampling_params=sampling_params,
)
self._full_batch_size = self._qaic_model.full_batch_size
self._tokenizer = self._qaic_model.tokenizer
Expand Down
1 change: 1 addition & 0 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1893,6 +1893,7 @@ def generate(
device_id=device_id,
generation_len=generation_len,
is_tlm=self.is_tlm,
**kwargs,
)
else:
raise NotImplementedError("Only AI_100 runtime is supported right now via generate API")
Expand Down
3 changes: 3 additions & 0 deletions QEfficient/transformers/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ def sampler_forward(
batch_size, spec_length, vocab_size = logits.shape
logits = logits.reshape(-1, vocab_size) # Reshape tensor to 2D

if batch_index is None: # Regular model execution
batch_index = torch.arange(batch_size).view(-1, 1)

batch_index_reshaped = batch_index.view(-1)
# Prefill
past_repetition_penalty_buffer_prefill, past_presence_penalty_buffer_prefill = prefill_path(
Expand Down
3 changes: 2 additions & 1 deletion docs/source/quick_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ To achieve this, we have 2 levels of APIs, with different levels of abstraction.
| [Vision Language Model](QEFFAutoModelForImageTextToText) | Provides support for the AutoModelForImageTextToText class from the transformers library, enabling advanced vision-language tasks. Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/image_text_to_text_inference.py) for more **details**. |
| [Speech Sequence to Sequence Model](QEFFAutoModelForSpeechSeq2Seq) | Provides support for the QEFFAutoModelForSpeechSeq2Seq Facilitates speech-to-text sequence models. Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/speech_to_text/run_whisper_speech_to_text.py) for more **details**. |
| Support for FP8 Execution | Enables execution with FP8 precision, significantly improving performance and reducing memory usage for computational tasks. |
| Prefill caching | Enhances inference speed by caching key-value pairs for shared prefixes, reducing redundant computations and improving efficiency. |
| Prefix caching | Enhances inference speed by caching key-value pairs for shared prefixes, reducing redundant computations and improving efficiency. |
| On Device Sampling | Enables sampling operations to be executed directly on the QAIC device rather than the host CPU for QEffForCausalLM models. This enhancement significantly reduces host-device communication overhead and improves inference throughput and scalability. Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/on_device_sampling.py) for more **details**. |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Link seems broken, please fix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The link points to an example file that will be added by this PR. So, the link will be available when the PR is merged.

|Prompt-Lookup Decoding | Speeds up text generation by using overlapping parts of the input prompt and the generated text, making the process faster without losing quality. Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/pld_spd_inference.py) for more **details**.|
| [PEFT LoRA support](QEffAutoPeftModelForCausalLM) | Enables parameter-efficient fine-tuning using low-rank adaptation techniques, reducing the computational and memory requirements for fine-tuning large models. Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/peft_models.py) for more **details**. |
| [QNN support](#qnn-compilation) | Enables compilation using QNN SDK, making Qeff adaptable for various backends in the future. |
Expand Down
Loading
Loading