-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor dead code - Removing all flash_xxx.py
files.
#2166
Conversation
# Not used anymore | ||
# def decode(self, decoder_ids: List[int]) -> str: | ||
# return self.tokenizer.decode( | ||
# decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False | ||
# ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Not used anymore | |
# def decode(self, decoder_ids: List[int]) -> str: | |
# return self.tokenizer.decode( | |
# decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
# ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't thtat surprising that it's not used anymore ?
Shouldn't we have those flags be used somewhere for those models ? Do we have test that cover the raison d'être of this code ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this code was to allow models to specify the value of skip_special_tokens
. For example for santacoder, you needed the fill in the middle special tokens to correctly display the outputs.
super().__init__() | ||
config.transpose = config.architectures[0].startswith("GPT2") | ||
self.transformer = FlashSantacoderModel(config, weights) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this will also need to be self.model
instead, otherwise the iteration in adapter_target_to_layer
from flash_causal_lm.py
does not work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don' think loras for Santacoder even exist, right ?
Even if that's the case, ideally we want to push the logic about layer loads into the model itself (which makes more sense than keeping around random layer names in flash_causal_lm.py
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, indeed seems like a better place than flash_causal_lm.py
.
weights._set_gptq_params(model_id, revision) | ||
|
||
prefix = "" | ||
model = model_class(prefix, config, weights) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently breaks with Gemma because the FlashGemmaForCausalLM
takes an extra causal
argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, fixed it by making it a default (since it seems too niche to be worth an extra flag).
f89c2b4
to
425f348
Compare
* Refactor dead code. * First working step. * Remove a lot of duplicated code. * More dead code. * More cleanup. * Fix Santacoder test. * Fixing the simple tests. * Fixing sharding. * Fixes for VLM. * Fixing santacoder (num_kv_heads hardcoded). * Removing more dead code. * Fixing `config.n_head`. * Stopping earlier because of `<end_of_utterance>` in idefics2. * Addresses comments. * Removing the dead code. * Fuse back mistral into FlashCausalLM. * Finish removal. * Fixing docs + causal_lm `batch_class`. * Fixing docs + causal.lm. * Add default to Gemma Causality. * Default value for gemma/gemma2. * Wrong default.
…2166) * Refactor dead code. * First working step. * Remove a lot of duplicated code. * More dead code. * More cleanup. * Fix Santacoder test. * Fixing the simple tests. * Fixing sharding. * Fixes for VLM. * Fixing santacoder (num_kv_heads hardcoded). * Removing more dead code. * Fixing `config.n_head`. * Stopping earlier because of `<end_of_utterance>` in idefics2. * Addresses comments. * Removing the dead code. * Fuse back mistral into FlashCausalLM. * Finish removal. * Fixing docs + causal_lm `batch_class`. * Fixing docs + causal.lm. * Add default to Gemma Causality. * Default value for gemma/gemma2. * Wrong default.
What does this PR do?
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.