Skip to content

Unit Tests for On Device Sampling #463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

quic-sanising
Copy link
Contributor

@quic-sanising quic-sanising commented Jun 18, 2025

This PR adds the following Unit Tests for On Device Sampling:

  1. test_sampler_transform: Test if SamplerTransform adds nodes at the output of a QEffForCausalLM model to enable the sampling of next tokens at the device (instead of the host) and returns the next tokens and/or probability distributions.
  2. test_greedy_sampling: Test greedy sampling with QPC compiled with and without On Device Sampling.
  3. test_random_sampling: Test random sampling with QPC compiled with and without On Device Sampling.

Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we also run the two session with the fixed prompt and make sure the outputs don't match with each other. But match with a golden output that we know matches with pytorch execution as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refer to the test test_random_sampling below.

@quic-rishinr
Copy link
Contributor

@quic-sanising can you add a small feature description under /docs/source/quick_start.md supported feature section? also provide the example script link in the description

sanising added 3 commits June 30, 2025 13:20
Signed-off-by: sanising <sanising@qti.qualcomm.com>
Signed-off-by: sanising <sanising@qti.qualcomm.com>
@quic-sanising
Copy link
Contributor Author

@quic-sanising can you add a small feature description under /docs/source/quick_start.md supported feature section? also provide the example script link in the description

Done

sanising added 3 commits June 30, 2025 18:40
Signed-off-by: sanising <sanising@qti.qualcomm.com>
Signed-off-by: sanising <sanising@qti.qualcomm.com>
Signed-off-by: sanising <sanising@qti.qualcomm.com>
Copy link
Contributor

@quic-amitraj quic-amitraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix lint error.

@@ -19,7 +19,8 @@ To achieve this, we have 2 levels of APIs, with different levels of abstraction.
| [Vision Language Model](QEFFAutoModelForImageTextToText) | Provides support for the AutoModelForImageTextToText class from the transformers library, enabling advanced vision-language tasks. Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/image_text_to_text_inference.py) for more **details**. |
| [Speech Sequence to Sequence Model](QEFFAutoModelForSpeechSeq2Seq) | Provides support for the QEFFAutoModelForSpeechSeq2Seq Facilitates speech-to-text sequence models. Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/speech_to_text/run_whisper_speech_to_text.py) for more **details**. |
| Support for FP8 Execution | Enables execution with FP8 precision, significantly improving performance and reducing memory usage for computational tasks. |
| Prefill caching | Enhances inference speed by caching key-value pairs for shared prefixes, reducing redundant computations and improving efficiency. |
| Prefix caching | Enhances inference speed by caching key-value pairs for shared prefixes, reducing redundant computations and improving efficiency. |
| On Device Sampling | Enables sampling operations to be executed directly on the QAIC device rather than the host CPU for QEffForCausalLM models. This enhancement significantly reduces host-device communication overhead and improves inference throughput and scalability. Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/on_device_sampling.py) for more **details**. |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Link seems broken, please fix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The link points to an example file that will be added by this PR. So, the link will be available when the PR is merged.

sanising added 2 commits July 2, 2025 18:43
Signed-off-by: sanising <sanising@qti.qualcomm.com>
Signed-off-by: sanising <sanising@qti.qualcomm.com>
elif count < len(sampler_inputs):
raise ValueError(
"The provided QPC does not have the required number of inputs to run sampling "
f"on the QAIC device (only {count}/{len(sampler_inputs)} inputs provided). Partial "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should do count % sampler_inputs here. If we divide count by len(sampler_inputs) then it would return 0.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only a print statement. We are not actually dividing here. So, if count = 5 and len(sampler_inputs) = 10, it would print (only 5/10 inputs provided).

count = 0
for session_input_name in self._session.input_names:
if session_input_name in sampler_inputs:
count += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can there be a case where user provides the same session_input_names multiple times. In that case how we will catch it in this code.
count variable will keep on incrementing and may satisfy the condition

Copy link
Contributor Author

@quic-sanising quic-sanising Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._session.input_names comes from the exported ONNX file. If there are duplicate names, say abc, the ONNX will convert them to something like abc_0, abc_1, so on. So, we would never get the same name multiple times.

However, if accuracy is the only priority here and performance is not, I could use set() but it would add a slight overhead of O(n).

count += 1
if count == len(sampler_inputs):
break
if count == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can avoid this if.. else block.
at line 455 by default set self.include_sampler = False.
Then at line 458 before break set it to True.
At line 462 just check for error condition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case the user provides include_sampler as input, self.include_sampler is not set to False. That is why, we need the check in line 460.

We can only avoid the else block in line 468.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make the change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

sanising added 3 commits July 3, 2025 13:44
Signed-off-by: sanising <sanising@qti.qualcomm.com>
Signed-off-by: sanising <sanising@qti.qualcomm.com>
Signed-off-by: sanising <sanising@qti.qualcomm.com>
@quic-sanising quic-sanising marked this pull request as ready for review July 3, 2025 19:08
Signed-off-by: sanising <sanising@qti.qualcomm.com>
@quic-sanising
Copy link
Contributor Author

quic-sanising commented Jul 3, 2025

Please fix lint error.

@quic-amitraj The lint failures were happening because the linter is installing ruff v0.12.2 whereas the .pre-commit-config.yaml file has an older version of v0.5.2.

To fix the errors, we need to either install ruff v0.5.2 in the linter or update the .pre-commit-config.yaml file to version v0.12.2.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants