Skip to content

[V1] Only print cudagraph tqdm on rank 0 with is_global_first_rank #19516

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,6 +1315,37 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
return [x == 1 for x in aggregated_data.tolist()]


def is_global_first_rank() -> bool:
"""
Check if the current process is the first rank globally across all
parallelism strategies (PP, TP, DP, EP, etc.).

Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0`
or `get_pp_group().is_first_rank`, this function checks the global rank
across all parallelism dimensions.

Returns:
bool: True if this is the global first rank (rank 0), False otherwise.
Returns True if distributed is not initialized (single process).
"""
try:
# If world group is available, use it for the most accurate check
global _WORLD
if _WORLD is not None:
return _WORLD.is_first_rank

# If torch distributed is not initialized, assume single process
if not torch.distributed.is_initialized():
return True

# Fallback to torch's global rank
return torch.distributed.get_rank() == 0

except Exception:
# If anything goes wrong, assume this is the first rank
return True
Comment on lines +1344 to +1346
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of a bare except Exception: can mask various underlying issues from torch.distributed or the _WORLD state. This might lead to unexpected behavior if, for example, multiple processes erroneously assume they are the global first rank due to an unrelated exception during rank determination.

To improve robustness and diagnosability, consider the following:

  1. Capture the exception instance: Change except Exception: to except Exception as e:. This allows access to the specific error.
  2. Log the exception: Logging the error e would be beneficial for debugging. This helps in understanding why rank determination might have failed. (Assuming a logger is available or can be set up in this module, similar to other vllm.distributed files).
  3. Docstring update: The current docstring explains that the function returns True if distributed is not initialized. It would be helpful to also document the behavior of returning True in case of any other exception during rank checking, if this is the intended safe default for all scenarios.

If returning True on any error is a deliberate and safe design choice for all potential uses of this function, explicitly stating this rationale in a comment within the except block or the docstring would improve clarity.

Suggested change
except Exception:
# If anything goes wrong, assume this is the first rank
return True
except Exception as e: # Capture the specific exception instance.
# Consider logging 'e' for debugging to understand potential failures.
# For example, if a logger is configured:
# logger.warning("is_global_first_rank() encountered an error, defaulting to True: %s", e, exc_info=True)
# If anything goes wrong, assume this is the first rank
return True

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logging a debug message at least would be nice here



def _node_count(pg: Union[ProcessGroup, StatelessProcessGroup]) -> int:
"""
Returns the total number of nodes in the process group.
Expand Down
11 changes: 7 additions & 4 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import (
get_pp_group, get_tp_group, graph_capture,
get_pp_group, get_tp_group, graph_capture, is_global_first_rank,
prepare_communication_buffer_for_model)
from vllm.forward_context import (DPMetadata, get_forward_context,
set_forward_context)
Expand Down Expand Up @@ -2207,9 +2207,12 @@ def capture_model(self) -> None:
# can reuse the memory pool allocated for the large shapes.
with graph_capture(device=self.device):
full_cg = self.full_cuda_graph
for num_tokens in tqdm(reversed(self.cudagraph_batch_sizes),
desc="Capturing CUDA graphs",
total=len(self.cudagraph_batch_sizes)):
# Only rank 0 should print progress bar during capture
compilation_cases = reversed(self.cudagraph_batch_sizes)
if is_global_first_rank():
compilation_cases = tqdm(list(compilation_cases),
desc="Capturing CUDA graph shapes")
for num_tokens in compilation_cases:
for _ in range(
self.compilation_config.cudagraph_num_of_warmups):
self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg)
Expand Down