-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Split the graphs to run with flash_attention on 1x #75
Conversation
@@ -676,6 +677,9 @@ def forward( | |||
next_decoder_cache = () if not use_new_cache else None | |||
|
|||
for layer_idx, decoder_layer in enumerate(self.layers): | |||
if torch.distributed.is_initialized() == False: | |||
htcore.mark_step() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kalyanjk what's the impact for input/output not introduced oom? should we add an argument in text-generation from cmd line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kalyanjk ,why only mark_step() for 1x?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For 8x mark_step will be introduced through a collective call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kalyanjk what's the impact for input/output not introduced oom? should we add an argument in text-generation from cmd line?
The issue is not with oom. The real issue is recipe size being too large and also compilation time is too high.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update as below
if lazy_mode and (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1):
Wait we should not put mark step after the start of loop. Will create more graphs and perf is lower. |
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Verified the change and it is required for faster recipe compilation
@@ -23,6 +23,7 @@ | |||
_gaudi_prepare_4d_causal_attention_mask, | |||
) | |||
|
|||
import habana_frameworks.torch.core as htcore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you rebase to latest then this htcore import is not required, as it is part of PR#65
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
* Split the graphs to run with flash_attention on 1x * Added lazy_mode check and removed additional htcore import --------- Co-authored-by: Kalyan <kkumar@habana.ai>
* Split the graphs to run with flash_attention on 1x * Added lazy_mode check and removed additional htcore import --------- Co-authored-by: Kalyan <kkumar@habana.ai>
* Split the graphs to run with flash_attention on 1x * Added lazy_mode check and removed additional htcore import --------- Co-authored-by: Kalyan <kkumar@habana.ai>
* Split the graphs to run with flash_attention on 1x * Added lazy_mode check and removed additional htcore import --------- Co-authored-by: Kalyan <kkumar@habana.ai>
* Split the graphs to run with flash_attention on 1x * Added lazy_mode check and removed additional htcore import --------- Co-authored-by: Kalyan <kkumar@habana.ai>
* Split the graphs to run with flash_attention on 1x * Added lazy_mode check and removed additional htcore import --------- Co-authored-by: Kalyan <kkumar@habana.ai>
This PR solves the actual issue #126 |
With flash attention enabled for larger batch sizes, recipe arc hbm memory size exceeds QueueComputeScal arc hbm memory. Hence split the graph on 1x.