Skip to content

Commit

Permalink
[Bugfix] Fix ray workers profiling with nsight (vllm-project#4095)
Browse files Browse the repository at this point in the history
  • Loading branch information
rickyyx authored and joerunde committed Apr 18, 2024
1 parent a53c384 commit 121f0aa
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions vllm/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,21 @@ def _init_executor(self) -> None:
if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag()

def _configure_ray_workers_use_nsight(self,
ray_remote_kwargs) -> Dict[str, Any]:
# If nsight profiling is enabled, we need to set the profiling
# configuration for the ray workers as runtime env.
runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
runtime_env.update({
"nsight": {
"t": "cuda,cudnn,cublas",
"o": "'worker_process_%p'",
"cuda-graph-trace": "node",
}
})

return ray_remote_kwargs

def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
if self.parallel_config.tensor_parallel_size == 1:
Expand All @@ -63,6 +78,10 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
# The remaining workers are the actual ray actors.
self.workers: List[RayWorkerVllm] = []

if self.parallel_config.ray_workers_use_nsight:
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
ray_remote_kwargs)

# Create the workers.
driver_ip = get_ip()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
Expand Down

0 comments on commit 121f0aa

Please sign in to comment.