-
Notifications
You must be signed in to change notification settings - Fork 30.7k
[core] Refactor the Cache logic to make it simpler and more general #39797
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
5082551
to
59a3296
Compare
f09bd1a
to
71124e6
Compare
Slow tests are the same on this PR and on |
@Cyrilvallez 🔥🔥 PR nano requests:
1 + 2 = if users open issues about manual compilation, we can link to the docstrings :D |
Hey @gante! Done with 1. 👌 from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_id = "google/gemma-2-2b"
torch_device = 0
EXPECTED_COMPLETIONS = [
" the people, the food, the culture, the history, the music, the art, the architecture",
", green, yellow, orange, purple, pink, brown, black, white, gray, silver",
]
input_text = [
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
"A list of colors: red, blue", # This will almost all be padding tokens
]
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation="sdpa", torch_dtype=torch.bfloat16, device_map=torch_device
)
# Compile the full forward! So call to `lazy_initialization` will be compiled as well
model.compile(fullgraph=True)
# Make sure prefill is larger than sliding window
input_size = inputs.input_ids.shape[-1]
assert input_size > model.config.sliding_window
out = model.generate(**inputs, max_new_tokens=20, cache_implementation="hybrid")[:, input_size:]
output_text = tokenizer.batch_decode(out)
assert output_text == EXPECTED_COMPLETIONS and you'll see it works! Only compilation with cuda graph will fail, e.g. |
@Cyrilvallez fair! (I would still make a note of it though, for power users that don't rely on |
7617e35
to
f4f7361
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nothing much to say: very very clean! 🤗
- compiling the full call may not be the best: wrt accelerate and hooks in general, would rather just compile the call!
def offload(self): | ||
"""Offload this layer's data to CPU device.""" | ||
if self.keys is not None: | ||
self.keys = self.keys.to("cpu", non_blocking=True) | ||
self.values = self.values.to("cpu", non_blocking=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think for cuda we needed / have better perfs with a different strea
offload_only_non_sliding (`bool`, *optional*, defaults to `True`): | ||
If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice comment
with self.prefetch_stream if _is_torch_greater_or_equal_than_2_7 else torch.cuda.stream(self.prefetch_stream): | ||
self.layers[layer_idx].prefetch() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah ok there's the stream
run-slow: bamba, dia, falcon_h1, gptj, granitemoehybrid, jamba, kyutai_speech_to_text, lfm2, musicgen, musicgen_melody, rag, zamba, zamba2 |
This comment contains run-slow, running the specified jobs: models: ['models/bamba', 'models/dia', 'models/falcon_h1', 'models/gptj', 'models/granitemoehybrid', 'models/jamba', 'models/kyutai_speech_to_text', 'models/lfm2', 'models/musicgen', 'models/musicgen_melody', 'models/rag', 'models/zamba', 'models/zamba2'] |
197f330
to
236bf9d
Compare
[For maintainers] Suggested jobs to run (before merge) run-slow: bamba, dia, falcon_h1, gptj, granitemoehybrid, jamba, kyutai_speech_to_text, lfm2, musicgen, musicgen_melody, rag, zamba, zamba2 |
Slow tests are similar as main, merging! |
Resolves current CI errors with prefix tuning. Due to some recent changes in transformers (surfaced by huggingface/transformers#39797), checking hasattr(cache, max_cache_len) results in an error: >>> cache = DynamicCache() >>> hasattr(cache, "max_cache_len") Traceback (most recent call last): File "/home/name/work/forks/transformers/foo.py", line 9, in <module> hasattr(cache, "max_cache_len") ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/name/work/forks/transformers/src/transformers/cache_utils.py", line 916, in max_cache_len return max(values) ^^^^^^^^^^^ ValueError: max() iterable argument is empty This has been reported and will be fixed in transformers. On the PEFT side, it is safeest check the cache type and avoid accessing this attribute in the first place, which is what this PR does. Morever, that PR also changed the argument order to initialize HybridCache (will probably also be reverted in transformers), which is also taken into account in this PR by only using keyword arguments.
Resolves current CI errors with prefix tuning. Due to some recent changes in transformers (surfaced by huggingface/transformers#39797), checking hasattr(cache, max_cache_len) results in an error. This PR fixes it. Morever, that PR also changed the argument order to initialize HybridCache (will probably also be reverted in transformers), which is also taken into account in this PR by only using keyword arguments. Finally, HybridCache will be deprecated and later removed, so move the import inside a version guard.
Resolves current CI errors with prefix tuning. Due to some recent changes in transformers (surfaced by huggingface/transformers#39797), checking hasattr(cache, max_cache_len) results in an error. This PR fixes it. Morever, that PR also changed the argument order to initialize HybridCache (will probably also be reverted in transformers), which is also taken into account in this PR by only using keyword arguments. Finally, HybridCache will be deprecated and later removed, so move the import inside a version guard.
What does this PR do?
Big simplifications everywhere, but most notably:
generate
(all properties are derived at firstupdate
time) -> simpler and more efficient (no device copies)early_initialization
provides a way to init everything beforeupdate
is called -> this is needed forexport
as we can't trace correctly if initialization is lazychunk_attention_size
correctly again (it was lost before which would break Llama4)cache_position
, which would simplify the library a lot, and will come in a follow-up PR