Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split the graphs to run with flash_attention on 1x #75

Merged
3 commits merged into from
Mar 4, 2024

Conversation

kalyanjk
Copy link

With flash attention enabled for larger batch sizes, recipe arc hbm memory size exceeds QueueComputeScal arc hbm memory. Hence split the graph on 1x.

@kalyanjk kalyanjk requested a review from a user February 26, 2024 17:07
@@ -676,6 +677,9 @@ def forward(
next_decoder_cache = () if not use_new_cache else None

for layer_idx, decoder_layer in enumerate(self.layers):
if torch.distributed.is_initialized() == False:
htcore.mark_step()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kalyanjk what's the impact for input/output not introduced oom? should we add an argument in text-generation from cmd line?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kalyanjk ,why only mark_step() for 1x?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For 8x mark_step will be introduced through a collective call.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kalyanjk what's the impact for input/output not introduced oom? should we add an argument in text-generation from cmd line?

The issue is not with oom. The real issue is recipe size being too large and also compilation time is too high.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update as below
if lazy_mode and (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1):

@puneeshkhanna
Copy link

@kalyanjk - we can abandon this PR. I have handled the change in #65.
This also helps 8x inference.
I m checking 1x perf results too.
Further need to check finetuning script once.

@puneeshkhanna
Copy link

Wait we should not put mark step after the start of loop. Will create more graphs and perf is lower.

@kalyanjk
Copy link
Author

Wait we should not put mark step after the start of loop. Will create more graphs and perf is lower.
@puneeshkhanna
On G3 we were seeing good perf with mark_step inside the for loop. With mark_step outside the for loop we are not able to run on single card. This issue is also present in G2

Copy link

@msinnha1 msinnha1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verified the change and it is required for faster recipe compilation

@@ -23,6 +23,7 @@
_gaudi_prepare_4d_causal_attention_mask,
)

import habana_frameworks.torch.core as htcore
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you rebase to latest then this htcore import is not required, as it is part of PR#65

Copy link

@msinnha1 msinnha1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@ghost ghost merged commit eec5b3f into HabanaAI:habana-main Mar 4, 2024
astachowiczhabana pushed a commit that referenced this pull request Apr 5, 2024
* Split the graphs to run with flash_attention on 1x

* Added lazy_mode check and removed additional htcore import

---------

Co-authored-by: Kalyan <kkumar@habana.ai>
astachowiczhabana pushed a commit that referenced this pull request Apr 5, 2024
* Split the graphs to run with flash_attention on 1x

* Added lazy_mode check and removed additional htcore import

---------

Co-authored-by: Kalyan <kkumar@habana.ai>
astachowiczhabana pushed a commit that referenced this pull request Apr 19, 2024
* Split the graphs to run with flash_attention on 1x

* Added lazy_mode check and removed additional htcore import

---------

Co-authored-by: Kalyan <kkumar@habana.ai>
astachowiczhabana pushed a commit that referenced this pull request Apr 22, 2024
* Split the graphs to run with flash_attention on 1x

* Added lazy_mode check and removed additional htcore import

---------

Co-authored-by: Kalyan <kkumar@habana.ai>
astachowiczhabana pushed a commit that referenced this pull request Apr 24, 2024
* Split the graphs to run with flash_attention on 1x

* Added lazy_mode check and removed additional htcore import

---------

Co-authored-by: Kalyan <kkumar@habana.ai>
astachowiczhabana pushed a commit that referenced this pull request Apr 24, 2024
* Split the graphs to run with flash_attention on 1x

* Added lazy_mode check and removed additional htcore import

---------

Co-authored-by: Kalyan <kkumar@habana.ai>
@kalyanjk
Copy link
Author

This PR solves the actual issue #126

@astachowiczhabana
Copy link

astachowiczhabana commented Jun 12, 2024

huggingface#875

@kalyanjk kalyanjk deleted the decoder_mark_step branch July 5, 2024 11:47
This pull request was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants