-
Notifications
You must be signed in to change notification settings - Fork 234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gemma: enabled HPU Graphs and Flash Attention #1173
Conversation
@dsmertin can you re-base this PR with the main? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @dsmertin
Thank you for this PR.
Could you please address the followings and push the changes?
- Pls. rebase/sync on top of OH main. Currently 72 commits behind.
- Pls. make sure to run
make style
- Pls. share the results of these gemma before/after these changes.
- We need to test/update CI tests with these changes.
q_len = query_layer.size(-2) | ||
q_tiles = (q_len // q_block_size) if (q_len % q_block_size == 0) else math.ceil(q_len / q_block_size) | ||
q_padding = q_tiles * q_block_size - q_len | ||
query_layer = F.pad(query_layer, (0, 0, 0, q_padding), "constant", 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make style here complains the F is not defined.
Pls. take a look.
q_padding = q_tiles * q_block_size - q_len | ||
query_layer = F.pad(query_layer, (0, 0, 0, q_padding), "constant", 0) | ||
if attention_mask is not None: | ||
attention_mask = F.pad(attention_mask, (0, 0, 0, q_padding), "constant", -10000.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here about make style
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did a bit more testing here and added few more comments.
this PR also has merge conflict with main.
Please make sure to address them during the rebase, and make style
.
Thanks.
- add new arg flash_attention_recompute | ||
""" | ||
if "padding_mask" in kwargs: | ||
warnings.warn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make style here complains the warnings
are not defined.
shouldn't this be a logger.warning_once
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fixed some issues and updated the code with make style
and added --use_flash_attention
to gemma CI test.
Please apply the attach fix with git am < 0001*
and push.
Thanks.
0001-fix-gemma-make-style.-minor-fixes.patch
@dsmertin can you address the comments? |
@yafshar @imangohari1 |
@dsmertin Please work through the changes I have suggested here and push them so we can test this further on the RC. thanks.
|
686bf2d
to
1f71df0
Compare
I've updated the branch with rebase and your patch @imangohari1 . |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!.
@regisss please take a look. Thanks
The code quality check failed, please run |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Ci failed with:
|
What does this PR do?
This PR fixes HPU Graphs usage and Flash Attention for Gemma model.
Changes are based on Starcoder 2 and Qwen 2 implementations.
Before submitting