-
Notifications
You must be signed in to change notification settings - Fork 45
Unit Tests for On Device Sampling #463
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
8417d8f
27d8dd5
067f9b5
931860f
79b6c95
75eac30
eb6e2eb
48b35e3
c83a631
f698a24
7b34a07
5fa7269
0ee201a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ | |
from collections import deque | ||
from dataclasses import dataclass | ||
from time import perf_counter | ||
from typing import Dict, List, Optional, Tuple, Union | ||
from typing import Any, Dict, List, Optional, Tuple, Union | ||
|
||
import numpy as np | ||
import transformers | ||
|
@@ -322,6 +322,9 @@ def cloud_ai_100_exec_kv( | |
automation=False, | ||
prompt_to_lora_id_mapping: Optional[List[int]] = None, | ||
is_tlm: bool = False, | ||
include_sampler: bool = False, | ||
return_pdfs: bool = False, | ||
sampling_params: Optional[Dict[str, Any]] = None, | ||
): | ||
""" | ||
This method generates output until ``eos`` or ``generation_len`` by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards. | ||
|
@@ -342,6 +345,15 @@ def cloud_ai_100_exec_kv( | |
:Write_io_dir (str): Path to write the input and output files. ``Defaults to None``. | ||
:automation (bool): If true, it prints input, output, and performance stats. ``Defaults to False``. | ||
:prompt_to_lora_id_mapping (List[int]): Mapping to associate prompts with their respective LoRA adapter. | ||
:include_sampler (bool): Enable/Disable sampling of next tokens. | ||
:return_pdfs (bool): Return probability distributions along with sampled | ||
next tokens. For Speculative Decoding Target Language Model, | ||
`return_pdfs`=True always. Otherwise, `return_pdfs`=True for Speculative | ||
Decoding Draft Language Model and `return_pdfs`=False for regular model. | ||
sampling_params (Dict[str, Any]): A dictionary of sampling parameters supported by the QAIC backend. | ||
The dictionary should contain the following keys: | ||
`repetition_penalties`, `presence_penalties`, `temperatures`, `top_ks`, `top_ps`, | ||
`min_ps`, and `random_numbers`. Each value should be a numpy array of shape (batch_size, 1). | ||
|
||
Returns: | ||
:CloudAI100ExecInfo: Object holding execution output and performance details. | ||
|
@@ -372,6 +384,9 @@ def cloud_ai_100_exec_kv( | |
write_io_dir=write_io_dir, | ||
full_batch_size=full_batch_size, | ||
is_tlm=is_tlm, | ||
include_sampler=include_sampler, | ||
return_pdfs=return_pdfs, | ||
sampling_params=sampling_params, | ||
) | ||
if full_batch_size is None: | ||
exec_info = [ | ||
|
@@ -411,14 +426,59 @@ def __init__( | |
enable_debug_logs: bool = False, | ||
write_io_dir: Optional[str] = None, | ||
is_tlm: Optional[int] = None, | ||
include_sampler: bool = False, | ||
return_pdfs: bool = False, | ||
sampling_params: Optional[Dict[str, Any]] = None, | ||
) -> None: | ||
self._ctx_len = ctx_len | ||
self._write_io_dir = write_io_dir | ||
self.is_tlm = is_tlm | ||
self.include_sampler = include_sampler | ||
self.return_pdfs = return_pdfs | ||
self.sampling_params = sampling_params | ||
|
||
# Load QPC | ||
self._session = QAICInferenceSession(qpc_path, device_id, enable_debug_logs=enable_debug_logs) | ||
|
||
# Validate sampler inputs for On-Device Sampling | ||
sampler_inputs = [ | ||
"last_accepted_output_tokens", | ||
"repetition_penalties", | ||
"presence_penalties", | ||
"temperatures", | ||
"top_ks", | ||
"top_ps", | ||
"min_ps", | ||
"random_numbers", | ||
] | ||
count = 0 | ||
for session_input_name in self._session.input_names: | ||
if session_input_name in sampler_inputs: | ||
count += 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can there be a case where user provides the same session_input_names multiple times. In that case how we will catch it in this code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
However, if accuracy is the only priority here and performance is not, I could use |
||
if count == len(sampler_inputs): | ||
self.include_sampler = True | ||
break | ||
if count == 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can avoid this if.. else block. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In case the user provides We can only avoid the else block in line 468. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll make the change. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
self.include_sampler = False | ||
elif count < len(sampler_inputs): | ||
raise ValueError( | ||
"The provided QPC does not have the required number of inputs to run sampling " | ||
f"on the QAIC device (only {count}/{len(sampler_inputs)} inputs provided). Partial " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should do count % sampler_inputs here. If we divide count by len(sampler_inputs) then it would return 0. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is only a print statement. We are not actually dividing here. So, if |
||
"sampling support is not available. Please check the QPC and try again." | ||
) | ||
|
||
if include_sampler and not self.include_sampler: | ||
logger.warning_once( | ||
"User entered `include_sampler`=True. But the provided QPC is not compiled " | ||
"to run sampling on the QAIC device. Falling back to the PyTorch backend." | ||
) | ||
elif (include_sampler is None or not include_sampler) and self.include_sampler: | ||
raise ValueError( | ||
"The provided QPC is compiled to run sampling on the QAIC device. " | ||
"But the user did not enter `include_sampler`=True. Please make sure the input " | ||
"is specified correctly." | ||
) | ||
|
||
# Fetch the variables from the QPC | ||
self._vocab_size = self._fetch_vocab_size() # Fetch Vocab size | ||
self.batch_size, self._prefill_seq_len = self._fetch_batch_size_prefill_seq_len() | ||
|
@@ -523,10 +583,17 @@ def _fetch_vocab_size( | |
Returns: | ||
vocab_size: The vocabulary size fetched from the session's allowed shapes. | ||
""" | ||
key = ( | ||
"probs" | ||
if self.include_sampler and self.return_pdfs | ||
else "next_tokens" | ||
if self.include_sampler | ||
else "logits" | ||
) | ||
if self._session.allowed_shapes: | ||
return [x[self._session.binding_index_map["logits"]] for x in self._session.allowed_shapes][0][1][2] | ||
return [x[self._session.binding_index_map[key]] for x in self._session.allowed_shapes][0][1][2] | ||
|
||
return self._session.bindings[self._session.binding_index_map["logits"]].dims[2] | ||
return self._session.bindings[self._session.binding_index_map[key]].dims[2] | ||
|
||
def _fetch_generation_len(self, generation_len, max_gen_len): | ||
""" | ||
|
@@ -574,6 +641,21 @@ def prepare_decode_inputs(self): | |
decode_inputs["position_ids"] = self.decode_pos_ids | ||
if self.batch_index is not None: | ||
decode_inputs["batch_index"] = self.batch_index | ||
if self.include_sampler: | ||
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"] | ||
for op in [ | ||
"repetition_penalties", | ||
"presence_penalties", | ||
"temperatures", | ||
"top_ks", | ||
"top_ps", | ||
"min_ps", | ||
"random_numbers", | ||
]: | ||
if self.batch_index is not None: | ||
decode_inputs[op] = self.sampling_params[op][self.batch_index.flatten()] | ||
else: | ||
decode_inputs[op] = self.sampling_params[op] | ||
|
||
if self._prompt_to_lora_id_mapping_decode: | ||
if self.full_batch_size: | ||
|
@@ -589,21 +671,24 @@ def prepare_decode_inputs(self): | |
|
||
def _fetch_next_token_id(self, outputs): | ||
""" | ||
Fetches the next token ID from the model's output logits. | ||
The method identifies the token with the highest probability using argmax along the last dimension. | ||
Fetches the next token ID from the model's output. | ||
|
||
Args: | ||
outputs (dict): A dictionary containing the model's output logits. The key "logits" should map to a numpy array of shape (batch_size, sequence_length, vocab_size) or (batch_size, vocab_size). | ||
outputs (dict): A dictionary containing the model's output. | ||
|
||
Returns: | ||
numpy.ndarray: An array of the next token IDs for each sequence in the batch. | ||
""" | ||
logits = outputs["logits"] | ||
if len(logits.shape) == 2: | ||
logits = np.expand_dims(logits, 1) | ||
|
||
# Get output token | ||
next_token_id = logits.argmax(2) | ||
return next_token_id | ||
if self.include_sampler: | ||
if self.return_pdfs: | ||
return outputs["probs"].argmax(2) | ||
else: | ||
return outputs["next_tokens"].reshape(outputs["next_tokens"].shape[0], outputs["next_tokens"].shape[1]) | ||
else: | ||
logits = outputs["logits"] | ||
if len(logits.shape) == 2: | ||
logits = np.expand_dims(logits, 1) | ||
return logits.argmax(2) | ||
|
||
def initialize_decode_inputs(self, num_prompts, execution_batch_size, max_gen_length): | ||
""" | ||
|
@@ -673,6 +758,23 @@ def run_prefill_for_all_inputs(self, prompt_queue, generation_len): | |
|
||
_ = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id) | ||
|
||
def _set_output_buffers(self, batch_size: int = 1, sequence_length: int = 1): | ||
""" | ||
Sets the sizes of the output buffers. | ||
|
||
Args: | ||
batch_size (int): The batch size. | ||
""" | ||
if self.include_sampler: | ||
if self.return_pdfs: | ||
probs_out_placeholder = np.zeros((batch_size, sequence_length, self._vocab_size), dtype=np.float32) | ||
self._session.set_buffers({"probs": probs_out_placeholder}) | ||
next_tokens_out_placeholder = np.zeros((batch_size, sequence_length, 1), dtype=np.int64) | ||
self._session.set_buffers({"next_tokens": next_tokens_out_placeholder}) | ||
else: | ||
logits_out_placeholder = np.zeros((batch_size, sequence_length, self._vocab_size), dtype=np.float32) | ||
self._session.set_buffers({"logits": logits_out_placeholder}) | ||
|
||
def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_id=None): | ||
""" | ||
Runs prefill for a given prompt and generation length. | ||
|
@@ -702,9 +804,8 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i | |
max_gen_len = self._ctx_len - position_ids.max() | ||
generation_len = self._fetch_generation_len(generation_len, max_gen_len) | ||
|
||
# Set the prefill logic buffer | ||
logits_out_placeholder = np.zeros((prefill_logit_bs, 1, self._vocab_size), dtype=np.float32) | ||
self._session.set_buffers({"logits": logits_out_placeholder}) | ||
# Set the prefill output buffers | ||
self._set_output_buffers(batch_size=prefill_logit_bs, sequence_length=1) | ||
|
||
inputs = self.tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) | ||
inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) | ||
|
@@ -714,6 +815,21 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i | |
inputs["batch_index"] = decode_batch_id | ||
if self.is_tlm: | ||
inputs["num_logits_to_keep"] = np.zeros((1, 1)) | ||
if self.include_sampler: | ||
inputs["last_accepted_output_tokens"] = inputs["input_ids"] | ||
for op in [ | ||
"repetition_penalties", | ||
"presence_penalties", | ||
"temperatures", | ||
"top_ks", | ||
"top_ps", | ||
"min_ps", | ||
"random_numbers", | ||
]: | ||
if decode_batch_id is not None: | ||
inputs[op] = self.sampling_params[op][decode_batch_id.flatten()] | ||
else: | ||
inputs[op] = self.sampling_params[op] | ||
|
||
if self._prompt_to_lora_id_mapping_prefill: | ||
if self.full_batch_size: | ||
|
@@ -732,6 +848,8 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i | |
chunk_inputs["position_ids"] = inputs["position_ids"][ | ||
:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len | ||
] | ||
if self.include_sampler: | ||
chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"] | ||
outputs = self._session.run(chunk_inputs) | ||
if self._write_io_dir is not None: | ||
write_io_files(inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False) | ||
|
@@ -753,11 +871,12 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): | |
|
||
""" | ||
|
||
# Set logits placeholder for decode | ||
logits_out_placeholder = np.zeros( | ||
(self.full_batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32 | ||
# Set output placeholders for decode | ||
self._set_output_buffers( | ||
batch_size=self.full_batch_size, | ||
sequence_length=self._decode_seq_len, | ||
) | ||
self._session.set_buffers({"logits": logits_out_placeholder}) | ||
|
||
# Generate flag for tracking progress for each batch ID | ||
current_decode_ongoing = np.full((self.full_batch_size, 1), True) | ||
|
||
|
@@ -775,10 +894,7 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): | |
outputs = self._session.run(decode_inputs) | ||
|
||
# Prepare inputs for next iteration | ||
logits = outputs["logits"] | ||
if len(logits.shape) == 2: | ||
logits = np.expand_dims(logits, 1) | ||
next_token_id = logits.argmax(2) | ||
next_token_id = self._fetch_next_token_id(outputs) | ||
|
||
for decode_batch_id in range(self.full_batch_size): | ||
if ( | ||
|
@@ -800,7 +916,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): | |
self.generated_ids[batch_id_map[decode_batch_id], 0] = new_token_id.squeeze(1) | ||
generated_id_current_index[decode_batch_id] = 1 | ||
|
||
self._session.set_buffers({"logits": logits_out_placeholder}) | ||
self._set_output_buffers( | ||
batch_size=self.full_batch_size, | ||
sequence_length=self._decode_seq_len, | ||
) | ||
decode_pause_time += perf_counter() - start | ||
|
||
if self._prompt_to_lora_id_mapping_decode: | ||
|
@@ -817,6 +936,8 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): | |
self.generated_ids[batch_id_map[decode_batch_id], generated_id_current_index[decode_batch_id]] = ( | ||
next_token_id[decode_batch_id, -1] | ||
) | ||
if self.include_sampler: | ||
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"] | ||
|
||
generated_id_current_index[decode_batch_id] += 1 | ||
|
||
|
@@ -840,6 +961,11 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform | |
(self.batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32 | ||
) | ||
self._session.set_buffers({"logits": logits_out_placeholder}) | ||
else: | ||
self._set_output_buffers( | ||
batch_size=self.batch_size, | ||
sequence_length=self._decode_seq_len, | ||
) | ||
finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id | ||
num_token = 0 | ||
for num_token in range(1, generation_len): | ||
|
@@ -852,10 +978,12 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform | |
self._write_io_dir = None | ||
|
||
# Prepare inputs for next iteration | ||
decode_inputs["input_ids"] = outputs["logits"].argmax(2) | ||
decode_inputs["input_ids"] = self._fetch_next_token_id(outputs) | ||
decode_inputs["position_ids"][:, -1] += 1 | ||
self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1] | ||
finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id | ||
if self.include_sampler: | ||
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"] | ||
|
||
if finished_sequences.all(): | ||
break | ||
|
@@ -905,9 +1033,22 @@ def __init__( | |
enable_debug_logs: bool = False, | ||
write_io_dir: Optional[str] = None, | ||
is_tlm: bool = False, | ||
include_sampler: bool = False, | ||
return_pdfs: bool = False, | ||
sampling_params: Optional[Dict[str, Any]] = None, | ||
) -> None: | ||
self._qaic_model = QEffTextGenerationBase( | ||
tokenizer, qpc_path, full_batch_size, ctx_len, device_id, enable_debug_logs, write_io_dir, is_tlm | ||
tokenizer=tokenizer, | ||
qpc_path=qpc_path, | ||
full_batch_size=full_batch_size, | ||
ctx_len=ctx_len, | ||
device_id=device_id, | ||
enable_debug_logs=enable_debug_logs, | ||
write_io_dir=write_io_dir, | ||
is_tlm=is_tlm, | ||
include_sampler=include_sampler, | ||
return_pdfs=return_pdfs, | ||
sampling_params=sampling_params, | ||
) | ||
self._full_batch_size = self._qaic_model.full_batch_size | ||
self._tokenizer = self._qaic_model.tokenizer | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,8 @@ To achieve this, we have 2 levels of APIs, with different levels of abstraction. | |
| [Vision Language Model](QEFFAutoModelForImageTextToText) | Provides support for the AutoModelForImageTextToText class from the transformers library, enabling advanced vision-language tasks. Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/image_text_to_text_inference.py) for more **details**. | | ||
| [Speech Sequence to Sequence Model](QEFFAutoModelForSpeechSeq2Seq) | Provides support for the QEFFAutoModelForSpeechSeq2Seq Facilitates speech-to-text sequence models. Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/speech_to_text/run_whisper_speech_to_text.py) for more **details**. | | ||
| Support for FP8 Execution | Enables execution with FP8 precision, significantly improving performance and reducing memory usage for computational tasks. | | ||
| Prefill caching | Enhances inference speed by caching key-value pairs for shared prefixes, reducing redundant computations and improving efficiency. | | ||
| Prefix caching | Enhances inference speed by caching key-value pairs for shared prefixes, reducing redundant computations and improving efficiency. | | ||
| On Device Sampling | Enables sampling operations to be executed directly on the QAIC device rather than the host CPU for QEffForCausalLM models. This enhancement significantly reduces host-device communication overhead and improves inference throughput and scalability. Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/on_device_sampling.py) for more **details**. | | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Link seems broken, please fix. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The link points to an example file that will be added by this PR. So, the link will be available when the PR is merged. |
||
|Prompt-Lookup Decoding | Speeds up text generation by using overlapping parts of the input prompt and the generated text, making the process faster without losing quality. Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/pld_spd_inference.py) for more **details**.| | ||
| [PEFT LoRA support](QEffAutoPeftModelForCausalLM) | Enables parameter-efficient fine-tuning using low-rank adaptation techniques, reducing the computational and memory requirements for fine-tuning large models. Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/peft_models.py) for more **details**. | | ||
| [QNN support](#qnn-compilation) | Enables compilation using QNN SDK, making Qeff adaptable for various backends in the future. | | ||
|
Uh oh!
There was an error while loading. Please reload this page.