Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix OOM when inference with llama-3.1-70b #1302

Merged
merged 4 commits into from
Sep 26, 2024

Conversation

harborn
Copy link
Contributor

@harborn harborn commented Aug 30, 2024

What does this PR do?

background

when I running inference with command:

INPUT=32768
OUTPUT=32768
BATCH_SIZE=12

python gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \
    --model_name_or_path Meta-Llama-3.1-70B-Instruct/ \
    --max_input_tokens ${INPUT} \
    --max_new_tokens ${OUTPUT} \
    --bf16 \
    --use_hpu_graphs \
    --use_kv_cache \
    --batch_size ${BATCH_SIZE} \
    --attn_softmax_bf16 \
    --limit_hpu_graphs \
    --trim_logits \
    --flash_attention_causal_mask \
    --flash_attention_recompute \
    --warmup 1 \
    --n_iteration 1 \
    --bucket_internal \
    --bucket_size=512 \
    --use_flash_attention

it will OOM, while not OOM if BATCH_SIZE=11

after I debugged by using memory analysis tool, I found that the first time of creating causal attention mask tensor need too much device memory, that lead to device memory exhaustion.

details of creating causal mask tensor

Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, key_value_length) shape and by adding a large negative bias to not-attended positions.

If attention_mask is causal, a causal mask will be added.

For the first time of creating this tensor, the shape is very big (for my case, it is [12, 1, 32768, 32768]).
During the creation of this tensor, it need a mask tensor. The mask tensor's dtype can be torch.bool, but actual it is torch.int, which caused four times the memory usage. (for shape [12, 1, 32768, 32768], it need 48G device memory, it will cause peak memory usage.)

Fixes

This PR's change is aim to explicitly make the computation of causal attention mask tensor use less device memory by using the torch.bool type mask tensor.

For code changes, just overwrite the base class's to_4d function.

Others

But why BATCH_SIZE=11 did not cause device memory exhaustion?
I think its a bug of LAZY GRAPH.
In lazy graph, it should optimize the computation of the big tensor with using less device memory.
So the best solution of fixing this bug is doing more optimization in LAZY GRAPH.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@harborn harborn requested a review from regisss as a code owner August 30, 2024 07:06
@harborn harborn changed the title fix oom when infernece with llama-3.1-70b fix OOM when inference with llama-3.1-70b Aug 30, 2024
@yafshar
Copy link
Contributor

yafshar commented Sep 5, 2024

@harborn

  • would you please add more info for this PR and the issue it is addressing in the README?
  • please run GAUDI2_CI=1 RUN_SLOW=true python -m pytest tests/transformers/tests/models/ -s -v before and after changes and make sure there is no new one is introduced.

@harborn
Copy link
Contributor Author

harborn commented Sep 12, 2024

@harborn

  • would you please add more info for this PR and the issue it is addressing in the README?
  • please run GAUDI2_CI=1 RUN_SLOW=true python -m pytest tests/transformers/tests/models/ -s -v before and after changes and make sure there is no new one is introduced.

I have updated the necessary information of this updates in PR description.

@yafshar
Copy link
Contributor

yafshar commented Sep 12, 2024

@harborn

  • would you please add more info for this PR and the issue it is addressing in the README?
  • please run GAUDI2_CI=1 RUN_SLOW=true python -m pytest tests/transformers/tests/models/ -s -v before and after changes and make sure there is no new one is introduced.

I have updated the necessary information of this updates in PR description.

Thanks, very nice explanation!

Copy link
Contributor

@yafshar yafshar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@regisss would you please check this PR.

@harborn
Copy link
Contributor Author

harborn commented Sep 14, 2024

can you merge this PR? @yafshar @regisss

@yafshar
Copy link
Contributor

yafshar commented Sep 18, 2024

@libinta can you label this PR

@yafshar
Copy link
Contributor

yafshar commented Sep 19, 2024

@libinta can you label this PR

This PR is ready, can we label this with run_test

@libinta libinta added run-test Run CI for PRs from external contributors and removed review wip labels Sep 24, 2024
@harborn
Copy link
Contributor Author

harborn commented Sep 25, 2024

any task to be finished to merge this PR? @yafshar

Copy link

The code quality check failed, please run make style.

@yafshar
Copy link
Contributor

yafshar commented Sep 25, 2024

@harborn, can you run make style and fix any related issues? Also, rebase the code.

@harborn harborn force-pushed the fix-oom-infer-llama branch from df58e06 to 0cdca2e Compare September 26, 2024 02:47
@harborn
Copy link
Contributor Author

harborn commented Sep 26, 2024

@harborn, can you run make style and fix any related issues? Also, rebase the code.

done

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@regisss regisss merged commit 4baaf3d into huggingface:main Sep 26, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
run-test Run CI for PRs from external contributors
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants