Skip to content

Co-Locating vLLM Instances with Training Processes Via External Launcher #3105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ class GRPOConfig(TrainingArguments):
support this feature.
vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
vllm_external_launcher (`bool`, *optional*, defaults to `False`):
Whether to use an external launcher for distributed vLLM execution. If set to `True`, vLLM will be
initialized in **all processes**, each assigned to its respective device. This allows multi-GPU
or multi-node execution with vLLM's external launcher, enabling improved large-scale inference.

> Parameters that control the training

Expand Down Expand Up @@ -276,6 +280,14 @@ class GRPOConfig(TrainingArguments):
default=None,
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
)
vllm_external_launcher: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to use an external launcher for distributed vLLM execution. If set to `True`, vLLM will be "
"initialized in all processes, each assigned to its respective device. This enables optimized "
"multi-GPU or multi-node inference."
},
)

# Parameters that control the training
learning_rate: float = field(
Expand Down
165 changes: 84 additions & 81 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
selective_log_softmax,
)


if is_peft_available():
from peft import PeftConfig, get_peft_model

Expand Down Expand Up @@ -445,61 +444,63 @@ def data_collator(features): # No data collation is needed in GRPO
"vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
"`pip install vllm` to use it."
)

if self.accelerator.is_main_process or self.args.vllm_external_launcher:

if self.accelerator.is_main_process:
vllm_device = self.args.vllm_device
device_type = PartialState().default_device.type
device_module = getattr(torch, device_type)
if vllm_device == "auto":
if device_module.device_count() == 1:
vllm_device = f"{device_type}:0" # particular case when training with onyl 1 device: share it
else:
vllm_device = f"{device_type}:{self.accelerator.num_processes}" # take the next GPU idx
# Check that the requested device is available
if (
vllm_device.split(":")[0] == f"{device_type}"
and int(vllm_device.split(":")[1]) >= device_module.device_count()
):
raise ValueError(
f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
"without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
"value lower than the number of GPUs available on your machine—typically, reducing it by one "
f"is sufficient. In your case: `--num_processes {device_module.device_count() - 1}`."
)
# Check that the requested device is not also used for training
if vllm_device in {f"{device_type}:{idx}" for idx in range(self.accelerator.num_processes)}:
warnings.warn(
f"The requested device {vllm_device} is also being used for training. For higher throughput "
"and to avoid out-of-memory errors, it is recommended to use a dedicated device for vLLM. "
"If this is intentional, you may ignore this warning but should adjust "
"`vllm_gpu_memory_utilization` accordingly."
# Determine the device
if self.args.vllm_external_launcher:
# External launcher mode: Assign vLLM to the current process's device
vllm_device = f"{device_type}:{self.accelerator.process_index}"
else:
vllm_device = self.args.vllm_device
if vllm_device == "auto":
if device_module.device_count() == 1:
vllm_device = f"{device_type}:0" # particular case when training with onyl 1 device: share it
else:
vllm_device = f"{device_type}:{self.accelerator.num_processes}" # take the next GPU idx

# Check that the requested device is available
if (
vllm_device.split(":")[0] == f"{device_type}"
and int(vllm_device.split(":")[1]) >= device_module.device_count()
):
raise ValueError(
f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
"without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
"value lower than the number of GPUs available on your machine—typically, reducing it by one "
f"is sufficient. In your case: `--num_processes {device_module.device_count() - 1}`."
)
# vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) place the vLLM
# model on the desired device (world_size_patch) and (2) avoid a test that is not designed for our
# setting (profiling_patch).
world_size_patch = patch("torch.distributed.get_world_size", return_value=1)
profiling_patch = patch(
"vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling", return_value=None
)

# For Ascend NPU (torch-npu), collective communication requires the establishment of a communication
# group, and different processes must hold the same group number. However, multiple process groups will
# be created internally within vLLM. This will cause the group id of the communication group on rank 0
# to be different from that of other ranks, causing backward to hang on because the communication
# domain cannot be established. So we need to patch it to make sure the group id of different ranks in
# the training phase are the same.
@contextlib.contextmanager
def new_group_context():
new_group = torch.distributed.new_group
try:
torch.distributed.new_group = functools.partial(new_group, use_local_synchronization=True)
torch.npu.mem_get_info = functools.partial(torch.npu.mem_get_info, device=vllm_device)
yield
finally:
torch.distributed.new_group = new_group

new_group_patch = new_group_context() if device_type == "npu" else contextlib.nullcontext()
with world_size_patch, profiling_patch, new_group_patch:
# Check that the requested device is not also used for training
if vllm_device in {f"{device_type}:{idx}" for idx in range(self.accelerator.num_processes)}:
warnings.warn(
f"The requested device {vllm_device} is also being used for training. For higher throughput "
"and to avoid out-of-memory errors, it is recommended to use a dedicated device for vLLM. "
"If this is intentional, you may ignore this warning but should adjust "
"`vllm_gpu_memory_utilization` accordingly."
)

# Prepare a list of context managers according to device type and vllm w/o external launcher
cmanagers = []
if not self.args.vllm_external_launcher:
cmanagers.extend([
functools.partial(patch, "torch.distributed.get_world_size", return_value=1),
functools.partial(patch, "vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling", return_value=None),
])

if device_type == "npu":
cmanagers.extend([
functools.partial(patch, "torch.distributed.new_group", functools.partial(torch.distributed.new_group, use_local_synchronization=True)),
functools.partial(patch, "torch.npu.mem_get_info", functools.partial(torch.npu.mem_get_info, device=vllm_device)),
])

# Initialize LLM under ExitStack (apply required patches accordingly)
with contextlib.ExitStack() as stack:
for cm in cmanagers:
stack.enter_context(cm())

# Initialize the LLM (set distributed_executor_backend = external_launcher if requested)
self.llm = LLM(
model=model.name_or_path,
device=vllm_device,
Expand All @@ -510,6 +511,7 @@ def new_group_context():
# This is particularly useful here because we generate completions from the same prompts.
enable_prefix_caching=self.args.vllm_enable_prefix_caching,
max_model_len=self.args.vllm_max_model_len,
distributed_executor_backend="external_launcher" if self.args.vllm_external_launcher else None,
)

# Guided decoding, if enabled
Expand All @@ -518,11 +520,10 @@ def new_group_context():
else:
guided_decoding = None

# Sampling parameters
self.sampling_params = SamplingParams(
max_tokens=self.max_completion_length,
guided_decoding=guided_decoding,
n=args.num_generations,
n=1 if self.args.vllm_external_launcher else args.num_generations, # vLLM on each GPU generates only 1 in external_launcher mode
temperature=args.temperature,
top_p=args.top_p,
top_k=-1 if args.top_k is None else args.top_k,
Expand All @@ -531,11 +532,11 @@ def new_group_context():
)

self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation

# When using vLLM, the main process is responsible for loading the model weights. This can cause process
# desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
# synchronize all processes after vLLM has been fully initialized.
self.accelerator.wait_for_everyone()
if not self.args.vllm_external_launcher:
# When using vLLM, the main process is responsible for loading the model weights. This can cause process
# desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
# synchronize all processes after vLLM has been fully initialized.
self.accelerator.wait_for_everyone()
else:
self.generation_config = GenerationConfig(
max_new_tokens=self.max_completion_length,
Expand Down Expand Up @@ -691,7 +692,7 @@ def _move_model_to_vllm(self):
}
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
if self.accelerator.is_main_process or self.args.vllm_external_launcher:
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights(state_dict.items())
# Unmerge the adapter to restore the model to its original state.
Expand Down Expand Up @@ -737,31 +738,33 @@ def _generate_and_score_completions(
self._move_model_to_vllm()
self._last_loaded_step = self.state.global_step

# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
# prompt individually.
ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process unless external launcher enabled
all_prompts_text = gather_object(prompts_text) if not self.args.vllm_external_launcher else None

if self.args.vllm_external_launcher or self.accelerator.is_main_process:
if self.args.vllm_external_launcher:
prompts_to_use = prompts_text # Each GPU handles its own batch
else:
prompts_to_use = all_prompts_text[::self.num_generations] # Unique prompts for generation

# Generate completions
with profiling_context(self, "vLLM.generate"):
all_outputs = self.llm.generate(
ordered_set_of_prompts, sampling_params=self.sampling_params, use_tqdm=False
prompts_to_use, sampling_params=self.sampling_params, use_tqdm=False
)
completion_ids = []
for outputs in all_outputs:
for output in outputs.outputs:
completion_ids.append(output.token_ids)
completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]
else:
completion_ids = [None] * len(all_prompts_text)
# Broadcast the completions from the main process to all processes, ensuring each process receives its
# corresponding slice.
completion_ids = broadcast_object_list(completion_ids, from_process=0)
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
completion_ids = completion_ids[process_slice]

if not self.args.vllm_external_launcher:
# Broadcast the completions from the main process to all processes, ensuring each process receives its
# corresponding slice.
completion_ids = broadcast_object_list(completion_ids, from_process=0)
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
completion_ids = completion_ids[process_slice]

# Pad the completions, and concatenate them with the prompts
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
Expand Down Expand Up @@ -898,7 +901,7 @@ def _generate_and_score_completions(
completions_to_log = gather_object(completions_text)
rewards_to_log = rewards.tolist()

if self.accelerator.is_main_process:
if self.accelerator.is_main_process: #ToDo: What to report in external_launcher mode (report all vllms' stats)
if is_rich_available():
print_prompt_completions_sample(
prompts_to_log,
Expand Down