Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemma: enabled HPU Graphs and Flash Attention #1173

Merged
merged 7 commits into from
Sep 24, 2024

Conversation

dsmertin
Copy link
Contributor

What does this PR do?

This PR fixes HPU Graphs usage and Flash Attention for Gemma model.
Changes are based on Starcoder 2 and Qwen 2 implementations.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@yafshar
Copy link
Contributor

yafshar commented Sep 6, 2024

@dsmertin can you re-base this PR with the main?

Copy link
Contributor

@imangohari1 imangohari1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @dsmertin
Thank you for this PR.
Could you please address the followings and push the changes?

  • Pls. rebase/sync on top of OH main. Currently 72 commits behind.
  • Pls. make sure to run make style
  • Pls. share the results of these gemma before/after these changes.
  • We need to test/update CI tests with these changes.

q_len = query_layer.size(-2)
q_tiles = (q_len // q_block_size) if (q_len % q_block_size == 0) else math.ceil(q_len / q_block_size)
q_padding = q_tiles * q_block_size - q_len
query_layer = F.pad(query_layer, (0, 0, 0, q_padding), "constant", 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make style here complains the F is not defined.
Pls. take a look.

q_padding = q_tiles * q_block_size - q_len
query_layer = F.pad(query_layer, (0, 0, 0, q_padding), "constant", 0)
if attention_mask is not None:
attention_mask = F.pad(attention_mask, (0, 0, 0, q_padding), "constant", -10000.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here about make style

Copy link
Contributor

@imangohari1 imangohari1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a bit more testing here and added few more comments.
this PR also has merge conflict with main.
Please make sure to address them during the rebase, and make style.
Thanks.

- add new arg flash_attention_recompute
"""
if "padding_mask" in kwargs:
warnings.warn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make style here complains the warnings are not defined.
shouldn't this be a logger.warning_once

Copy link
Contributor

@imangohari1 imangohari1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed some issues and updated the code with make style and added --use_flash_attention to gemma CI test.
Please apply the attach fix with git am < 0001* and push.
Thanks.
0001-fix-gemma-make-style.-minor-fixes.patch

@yafshar
Copy link
Contributor

yafshar commented Sep 16, 2024

@dsmertin can you address the comments?

@dsmertin
Copy link
Contributor Author

@yafshar @imangohari1
This PR was created a month and a half ago for different version of the software stack. There was a difference and improvement in performance for this model, but I'm not sure if it's still the case. My colleague told me that there is not much difference for 17.1 version and I need to double check it now. If there'll be no improvement, I'll close it then. Otherwise, I'll go through the comments.

@imangohari1
Copy link
Contributor

@yafshar @imangohari1 This PR was created a month and a half ago for different version of the software stack. There was a difference and improvement in performance for this model, but I'm not sure if it's still the case. My colleague told me that there is not much difference for 17.1 version and I need to double check it now. If there'll be no improvement, I'll close it then. Otherwise, I'll go through the comments.

@dsmertin
Thanks.
I have tested this PR with the latest release and we would like this PR to be included for the next release.

Please work through the changes I have suggested here and push them so we can test this further on the RC.

thanks.

I fixed some issues and updated the code with make style and added --use_flash_attention to gemma CI test. Please apply the attach fix with git am < 0001* and push. Thanks. 0001-fix-gemma-make-style.-minor-fixes.patch

@dsmertin dsmertin force-pushed the ds/gemma-optimization branch from 686bf2d to 1f71df0 Compare September 18, 2024 15:09
@dsmertin
Copy link
Contributor Author

I've updated the branch with rebase and your patch @imangohari1 .
I'm confused a little, what else should be done?

Copy link
Contributor

@imangohari1 imangohari1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!.
@regisss please take a look. Thanks

@libinta libinta added the run-test Run CI for PRs from external contributors label Sep 18, 2024
Copy link

The code quality check failed, please run make style.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@regisss
Copy link
Collaborator

regisss commented Sep 23, 2024

Ci failed with:

E     File "/root/workspace/optimum/habana/transformers/models/gemma/modeling_gemma.py", line 552
E       use_cache = use_cache if use_cache is not None else self.config.use_cache
E   IndentationError: unexpected indent

@regisss regisss merged commit 00dd5bf into huggingface:main Sep 24, 2024
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
run-test Run CI for PRs from external contributors
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants