-
Notifications
You must be signed in to change notification settings - Fork 234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix OOM when inference with llama-3.1-70b #1302
Conversation
|
I have updated the necessary information of this updates in PR description. |
Thanks, very nice explanation! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
@regisss would you please check this PR.
@libinta can you label this PR |
This PR is ready, can we label this with run_test |
any task to be finished to merge this PR? @yafshar |
The code quality check failed, please run |
@harborn, can you run |
df58e06
to
0cdca2e
Compare
done |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
What does this PR do?
background
when I running inference with command:
it will OOM, while not OOM if
BATCH_SIZE=11
after I debugged by using memory analysis tool, I found that the first time of creating causal attention mask tensor need too much device memory, that lead to device memory exhaustion.
details of creating causal mask tensor
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, key_value_length) shape and by adding a large negative bias to not-attended positions.
If attention_mask is causal, a causal mask will be added.
For the first time of creating this tensor, the shape is very big (for my case, it is [12, 1, 32768, 32768]).
During the creation of this tensor, it need a mask tensor. The mask tensor's dtype can be
torch.bool
, but actual it istorch.int
, which caused four times the memory usage. (for shape [12, 1, 32768, 32768], it need 48G device memory, it will cause peak memory usage.)Fixes
This PR's change is aim to explicitly make the computation of causal attention mask tensor use less device memory by using the
torch.bool
type mask tensor.For code changes, just overwrite the base class's
to_4d
function.Others
But why
BATCH_SIZE=11
did not cause device memory exhaustion?I think its a bug of LAZY GRAPH.
In lazy graph, it should optimize the computation of the big tensor with using less device memory.
So the best solution of fixing this bug is doing more optimization in LAZY GRAPH.
Before submitting