Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor dead code - Removing all flash_xxx.py files. #2166

Merged
merged 22 commits into from
Jul 5, 2024
Merged

Conversation

Narsil
Copy link
Collaborator

@Narsil Narsil commented Jul 2, 2024

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

OlivierDehaene
OlivierDehaene previously approved these changes Jul 3, 2024
@OlivierDehaene OlivierDehaene self-requested a review July 3, 2024 09:08
server/text_generation_server/models/causal_lm.py Outdated Show resolved Hide resolved
server/text_generation_server/models/causal_lm.py Outdated Show resolved Hide resolved
server/text_generation_server/models/mpt.py Show resolved Hide resolved
Comment on lines 671 to 675
# Not used anymore
# def decode(self, decoder_ids: List[int]) -> str:
# return self.tokenizer.decode(
# decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
# )
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Not used anymore
# def decode(self, decoder_ids: List[int]) -> str:
# return self.tokenizer.decode(
# decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
# )

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't thtat surprising that it's not used anymore ?

Shouldn't we have those flags be used somewhere for those models ? Do we have test that cover the raison d'être of this code ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this code was to allow models to specify the value of skip_special_tokens. For example for santacoder, you needed the fill in the middle special tokens to correctly display the outputs.

super().__init__()
config.transpose = config.architectures[0].startswith("GPT2")
self.transformer = FlashSantacoderModel(config, weights)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will also need to be self.model instead, otherwise the iteration in adapter_target_to_layer from flash_causal_lm.py does not work.

Copy link
Collaborator Author

@Narsil Narsil Jul 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don' think loras for Santacoder even exist, right ?

Even if that's the case, ideally we want to push the logic about layer loads into the model itself (which makes more sense than keeping around random layer names in flash_causal_lm.py.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, indeed seems like a better place than flash_causal_lm.py.

weights._set_gptq_params(model_id, revision)

prefix = ""
model = model_class(prefix, config, weights)
Copy link
Member

@danieldk danieldk Jul 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently breaks with Gemma because the FlashGemmaForCausalLM takes an extra causal argument.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, fixed it by making it a default (since it seems too niche to be worth an extra flag).

@Narsil Narsil force-pushed the refactor_dead_code branch from f89c2b4 to 425f348 Compare July 4, 2024 14:38
@Narsil Narsil merged commit fb2f74e into main Jul 5, 2024
8 of 9 checks passed
@Narsil Narsil deleted the refactor_dead_code branch July 5, 2024 08:29
ErikKaum pushed a commit that referenced this pull request Jul 26, 2024
* Refactor dead code.

* First working step.

* Remove a lot of duplicated code.

* More dead code.

* More cleanup.

* Fix Santacoder test.

* Fixing the simple tests.

* Fixing sharding.

* Fixes for VLM.

* Fixing santacoder (num_kv_heads hardcoded).

* Removing more dead code.

* Fixing `config.n_head`.

* Stopping earlier because of `<end_of_utterance>` in idefics2.

* Addresses comments.

* Removing the dead code.

* Fuse back mistral into FlashCausalLM.

* Finish removal.

* Fixing docs + causal_lm `batch_class`.

* Fixing docs + causal.lm.

* Add default to Gemma Causality.

* Default value for gemma/gemma2.

* Wrong default.
yuanwu2017 pushed a commit to yuanwu2017/tgi-gaudi that referenced this pull request Sep 26, 2024
…2166)

* Refactor dead code.

* First working step.

* Remove a lot of duplicated code.

* More dead code.

* More cleanup.

* Fix Santacoder test.

* Fixing the simple tests.

* Fixing sharding.

* Fixes for VLM.

* Fixing santacoder (num_kv_heads hardcoded).

* Removing more dead code.

* Fixing `config.n_head`.

* Stopping earlier because of `<end_of_utterance>` in idefics2.

* Addresses comments.

* Removing the dead code.

* Fuse back mistral into FlashCausalLM.

* Finish removal.

* Fixing docs + causal_lm `batch_class`.

* Fixing docs + causal.lm.

* Add default to Gemma Causality.

* Default value for gemma/gemma2.

* Wrong default.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants