Skip to content

[Core][Bugfix] New way for full cudagraph, add support for FA2 and FlashInfer #20059

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
17 changes: 13 additions & 4 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,10 +563,6 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:

self._called = True

if not self.compilation_config.use_cudagraph or \
not self.compilation_config.cudagraph_copy_inputs:
return self.split_gm

# if we need to copy input buffers for cudagraph
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode()
Expand All @@ -585,6 +581,19 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
any(is_symbolic(d) for d in x.size())
]

if self.compilation_config.full_cuda_graph:
assert self.compilation_config.use_cudagraph, \
"full_cuda_graph mode requires use_cudagraph to be True"
fullgraph_wrapper = resolve_obj_by_qualname(
current_platform.get_fullgraph_wrapper_cls())
self.split_gm = fullgraph_wrapper(self.split_gm, self.vllm_config,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see why this has to be platform-specific. If it doesn't, let's create it directly?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, I follow the convention in the class PiecewiseCompileInterpreter, where the piecewise_backend is resolved as platform-specific. It seems CUDAPiecewiseBackend support is limited on the cuda and rocm platforms.

self.graph_pool,
self.sym_tensor_indices)

if not self.compilation_config.use_cudagraph or \
not self.compilation_config.cudagraph_copy_inputs:
return self.split_gm

# compiler managed cudagraph input buffers
# we assume the first run with symbolic shapes
# has the maximum size among all the tensors
Expand Down
43 changes: 43 additions & 0 deletions vllm/compilation/base_piecewise_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,46 @@ def __call__(self, *args) -> Any:
or a replayed static graph.
"""
raise NotImplementedError


class AbstractFullgraphWrapper(Protocol):
"""
FullgraphWrapper interface that allows platforms to wrap the piecewise graph
to be viewed or captured as a full graph.
"""

def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
graph_pool: Any, sym_shape_indices: list[int], **kwargs):
"""
Initializes the FullgraphWrapper class with compilation and
execution-related configurations.

Args:
graph (fx.GraphModule): The graph represented in fx.
vllm_config (VllmConfig): Global configuration for vLLM.
graph_pool (Any):
Graph memory pool handle, e.g.,
`torch.cuda.graph_pool_handle()`.
sym_shape_indices (list[int]):
Indices of symbolic shape.

Keyword Args:
kwargs: Additional keyword arguments reserved for future
extensions or custom platforms.

"""
raise NotImplementedError

def __call__(self, *args) -> Any:
"""
Executes the wrapped graph for given input args.

Args:
*args: Variable length input arguments to be passed into the
graph. The symbolic shape is expected to be in position
`sym_shape_indices[0]`.

Returns:
Any: Output of the executed wrapped graph.
"""
raise NotImplementedError
Loading