Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix nccl regression on PyTorch 2.3 upgrade #2099

Merged
merged 5 commits into from
Jul 8, 2024
Merged

Conversation

fxmarty
Copy link
Contributor

@fxmarty fxmarty commented Jun 20, 2024

As per title, fixes NVIDIA/nccl#1251 in TGI's cuda image, regression introduced in #1730 & #1833

We hit this issue e.g. with llama 3 70B model with TP=4 or TP=8 on H100 & default cuda graphs, one can e.g. repro the hanging with text-generation-benchmark --tokenizer-name meta-llama/Meta-Llama-3-70B-Instruct --sequence-length 128 --decode-length 10 --warmups 2 --runs 100 -b 1, where shards hang in

Thread 1302975 (active): "MainThread"
    sched_yield (libc.so.6)
    ncclLaunchKernelBefore_NoUncapturedCuda (enqueue.cc:968)
    doLaunches (group.cc:161)
    groupLaunch (group.cc:339)
    ncclGroupEndInternal (group.cc:418)
    ncclGroupEndInternal (group.cc:368)
    ncclEnqueueCheck (enqueue.cc:1981)
    ncclAllReduce (collectives.cc:49)
    c10d::ProcessGroupNCCL::collective<c10d::ProcessGroupNCCL::allreduce_impl(at::Tensor&, c10d::AllreduceOptions const&)::{lambda(at::Tensor&, at::Tensor&, ncclComm*, c10::cuda::CUDAStream&)#1}, c10d::ProcessGroupNCCL::collective<c10d::ProcessGroupNCCL::allreduce_impl(at::Tensor&, c10d::AllreduceOptions const&)::{lambda(at::Tensor&, at::Tensor&, ncclComm*, c10::cuda::CUDAStream&)#1}>(at::Tensor&, at::Tensor&, c10d::ProcessGroupNCCL::allreduce_impl(at::Tensor&, c10d::AllreduceOptions const&)::{lambda(at::Tensor&, at::Tensor&, ncclComm*, c10::cuda::CUDAStream&)#1}, c10d::OpType, char const*, bool)::{lambda(c10::cuda::CUDAStream&, c10::intrusive_ptr<c10d::ProcessGroupNCCL::WorkNCCL, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroupNCCL::WorkNCCL> >&)#1}, c10d::ProcessGroupNCCL::collective<c10d::ProcessGroupNCCL::allreduce_impl(at::Tensor&, c10d::AllreduceOptions const&)::{lambda(at::Tensor&, at::Tensor&, ncclComm*, c10::cuda::CUDAStream&)#1}>(at::Tensor&, at::Tensor&, c10d::ProcessGroupNCCL::allreduce_impl(at::Tensor&, c10d::AllreduceOptions const&)::{lambda(at::Tensor&, at::Tensor&, ncclComm*, c10::cuda::CUDAStream&)#1}, c10d::OpType, char const*, bool)::{lambda(c10::cuda::CUDAStream&, c10::intrusive_ptr<c10d::ProcessGroupNCCL::WorkNCCL, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroupNCCL::WorkNCCL> >&)#2}> (libtorch_cuda.so)
    c10d::ProcessGroupNCCL::allreduce_impl (libtorch_cuda.so)
    c10d::ProcessGroupNCCL::allreduce (libtorch_cuda.so)
    c10d::ops::(anonymous namespace)::allreduce_CUDA (libtorch_cpu.so)

PyTorch 2.3 has a hard requirement on nccl 2.20.5 so I am not completely sure this fix is fine. We could also choose to downgrade.

interesting read as well https://pytorch.slack.com/archives/C3PDTEV8E/p1713223950622429?thread_ts=1712807088.459829&cid=C3PDTEV8E

Will wait for the build to run to check TGI's benchmark again & any potential regression.

@fxmarty fxmarty requested review from OlivierDehaene and Hugoch June 20, 2024 18:05
@@ -232,7 +234,8 @@ COPY server/Makefile server/Makefile
RUN cd server && \
make gen-server && \
pip install -r requirements_cuda.txt && \
pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir
pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir && \
pip install nvidia-nccl-cu12==2.22.3
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would have liked to use pyproject.toml for that, but poetry disapproves of conflict handling python-poetry/poetry#697 (comment)

@Narsil
Copy link
Collaborator

Narsil commented Jun 24, 2024

Thanks a lot for the find, the fix and the details.

I'm more on the fence of waiting for torch to fix it (2.3.1 hasn't fixed it yet) since afaik this does NOT affect production.
If it did, 100% on your solution (seems better than downgrading for the time being since torch 2.3 still received some nice ugprades).

@fxmarty
Copy link
Contributor Author

fxmarty commented Jun 25, 2024

As you'd like. I am using this fix to benchmark.

@Hugoch
Copy link
Member

Hugoch commented Jul 1, 2024

Nice fix @fxmarty !
I confirm that upgrading NCCL as proposed fixes the systematic hang on 8xH100 P5 instances. TGI freezes without crashing. Pytorch 2.4 should be released this month, let's check if NCCL gets updated, otherwise it would be nice to merge that patch.

Copy link
Member

@OlivierDehaene OlivierDehaene left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this affect real deployments, let's merge this.

pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir && \
pip install nvidia-nccl-cu12==2.22.3

ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to preload?

Copy link
Contributor Author

@fxmarty fxmarty Jul 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise, the shared object is not used. The current base docker image of TGI is nvidia/cuda:12.1.0-base-ubuntu22.04, where there is no libnccl.so anywhere and it is not loaded by pytorch either, although we have /opt/conda/lib/libcudart.so.12.1.105 etc. COPY --from=pytorch-install /opt/conda /opt/conda does not seem to copy any libnccl.so. Weird.

@Hugoch Hugoch mentioned this pull request Jul 8, 2024
4 tasks
@OlivierDehaene OlivierDehaene merged commit 4c50b6d into main Jul 8, 2024
8 of 9 checks passed
@OlivierDehaene OlivierDehaene deleted the fix-nccl-regression branch July 8, 2024 15:52
@HoKim98 HoKim98 mentioned this pull request Jul 11, 2024
4 tasks
ErikKaum pushed a commit that referenced this pull request Jul 26, 2024
* fix nccl issue

* add note in dockerfile

* use v2.22.3 that also fixes @samsamoa's repro

* poetry actually can't handle the conflict between torch and nccl

* set LD_PRELOAD
yuanwu2017 pushed a commit to yuanwu2017/tgi-gaudi that referenced this pull request Sep 26, 2024
* fix nccl issue

* add note in dockerfile

* use v2.22.3 that also fixes @samsamoa's repro

* poetry actually can't handle the conflict between torch and nccl

* set LD_PRELOAD
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Leak in FIFO queue
4 participants