-
Notifications
You must be signed in to change notification settings - Fork 45
adding Context Length Specialization (CCL) #466
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: vjanfaza <vjanfaza@apex-scl01-giga-linux.qualcomm.com>
Signed-off-by: vjanfaza <vjanfaza@apex-scl01-giga-linux.qualcomm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how much time does the tests take?
We can choose to only test one model per KV type i.e. for chunked, hybrid, sliding window etc.
chunked -> global + local -> llama4
hybrid -> sliding window + global -> gemma3
sliding window -> mistral
For the above categories, we need to handle CCL in a different way.
Probably for local or sliding window layers, the complete CCL won't apply when it goes beyond sliding window length
And full CCL applies for global layers.
This support needs to be added.
@@ -102,6 +102,7 @@ def main( | |||
full_batch_size: Optional[int] = None, | |||
prompt_len: int = 32, | |||
ctx_len: int = 128, | |||
comp_ctx_lengths: Optional[List[int]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add this in the docstring as well of the function.
@@ -1489,6 +1491,8 @@ def from_pretrained( | |||
|
|||
kv_offload = kwargs.pop("kv_offload", None) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think, comp_ctx_length
should be handled as a explicit parameter in from_pretrained
rather than handling inside the kwargs. After this there will be no need of pooping that var from kwargs and we can add proper docstring as well here.
self.comp_ctx_lengths = kwargs.pop("comp_ctx_lengths", None) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why, I don't think there is any need of this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can use kwargs.get instead of pop. We are planning to use this kwargs for creating the model hash
for i in range(1, len(self.comp_ctx_lengths)): | ||
decode_spec = self.build_decode_specialization( | ||
prefill_seq_len=prefill_seq_len, | ||
ctx_len=ctx_len, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no need of if else condition, please handle this for loop inside the build_decode_specialization
.
@@ -29,6 +30,16 @@ | |||
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask | |||
|
|||
|
|||
@dataclass | |||
class QEffBaseModelOutputWithPast(BaseModelOutputWithPast): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As these data class is common across all the modelling file, better to keep it in modeling_utils.py
.
Context-Length-Specialization technique optimizes the throughput of large language models (LLMs) on Qualcomm devices when handling very large context lengths. The current Ahead Of Time (AOT) compilation on Qualcomm devices doesn't predict the number of tokens needed, leading to significant throughput drops during the prefilling and the decoding phases. This happens because the system performs attention calculations based on large context length. To address this issue, we introduce Compute Context Length (CCL), an additional ONNX variable that allows for dynamic context-length specialization. By generating tokens using smaller, more manageable context lengths (CCL), we optimize memory reads and attention calculations, thereby improving throughput.