Skip to content

Commit

Permalink
Disable custom all reduce by default (vllm-project#2808)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Feb 8, 2024
1 parent 79f09ae commit d366b67
Showing 1 changed file with 18 additions and 8 deletions.
26 changes: 18 additions & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,16 +388,26 @@ def _verify_args(self) -> None:
if self.pipeline_parallel_size > 1:
raise NotImplementedError(
"Pipeline parallelism is not supported yet.")
if is_hip():
if not self.disable_custom_all_reduce and self.world_size > 1:
if is_hip():
self.disable_custom_all_reduce = True
logger.info(
"Disabled the custom all-reduce kernel because it is not "
"supported on AMD GPUs.")
elif self.pipeline_parallel_size > 1:
self.disable_custom_all_reduce = True
logger.info(
"Disabled the custom all-reduce kernel because it is not "
"supported with pipeline parallelism.")

# FIXME(woosuk): Fix the stability issues and re-enable the custom
# all-reduce kernel.
if not self.disable_custom_all_reduce and self.world_size > 1:
self.disable_custom_all_reduce = True
logger.info(
"Disabled the custom all-reduce kernel because it is not "
"supported on AMD GPUs.")
elif self.pipeline_parallel_size > 1:
self.disable_custom_all_reduce = True
logger.info(
"Disabled the custom all-reduce kernel because it is not "
"supported with pipeline parallelism.")
"Custom all-reduce kernels are temporarily disabled due to "
"stability issues. We will re-enable them once the issues are "
"resolved.")


class SchedulerConfig:
Expand Down

0 comments on commit d366b67

Please sign in to comment.