-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[feat] cuda graph support and refactor non-functional api #5434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
LRY89757
merged 15 commits into
hpcaitech:feature/colossal-infer
from
LRY89757:colossal-infer-cuda-graph
Mar 25, 2024
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
cefaeb5
[feat] cuda graph support and refactor non-functional api
LRY89757 b2c0d9f
[fix] multi graphs capture error
LRY89757 9dec66f
[fix] multi graphs capture error
LRY89757 633e95b
[doc] add doc
LRY89757 1821a6d
[fix] pytest and fix dyn grid bug
LRY89757 ae24b4f
diverse tests
LRY89757 d02e257
Merge branch 'feature/colossal-infer' into colossal-infer-cuda-graph
LRY89757 6e30248
[fix] tmp for test
LRY89757 aabc9fb
[feat] add use_cuda_kernel option
LRY89757 4eafe0c
[fix] unused option
LRY89757 606603b
Merge branch 'feature/colossal-infer' of https://github.com/hpcaitech…
LRY89757 5b017d6
[fix]
LRY89757 9fe61b4
[fix]
LRY89757 ff4998c
[fix] remove unused comment
LRY89757 68e9396
[fix] merge conflicts
LRY89757 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| from typing import Dict, List | ||
|
|
||
| import torch | ||
| from torch import nn | ||
|
|
||
| from colossalai.inference.config import InputMetaData | ||
| from colossalai.logging import get_dist_logger | ||
|
|
||
|
|
||
| class CUDAGraphRunner: | ||
| def __init__(self, model: nn.Module): | ||
| self.model = model | ||
| self.graph = None | ||
| self.input_buffers: Dict[str, torch.Tensor] = {} | ||
| self.output_buffers: Dict[str, torch.Tensor] = {} | ||
| self.logger = get_dist_logger(__name__) | ||
|
|
||
| def capture( | ||
| self, | ||
| input_tokens_ids: torch.Tensor, | ||
| output_tensor: torch.Tensor, | ||
| inputmetadata: InputMetaData, | ||
| k_caches: List[torch.Tensor] = None, | ||
| v_caches: List[torch.Tensor] = None, | ||
| memory_pool=None, | ||
| ) -> None: | ||
| assert self.graph is None | ||
|
|
||
| # run kernel once to cache the kernel, avoid stream capture error | ||
| hidden_states_origin_model = self.model( | ||
| input_tokens_ids, | ||
| output_tensor, | ||
| inputmetadata, | ||
| k_caches, | ||
| v_caches, | ||
| ) | ||
| torch.cuda.synchronize() | ||
|
|
||
| # Capture the graph. | ||
| # self.logger.info(f"begin capture model...") | ||
| self.graph = torch.cuda.CUDAGraph() | ||
| with torch.cuda.graph(self.graph, pool=memory_pool): | ||
| hidden_states_cuda_graph = self.model( | ||
| input_tokens_ids, | ||
| output_tensor, | ||
| inputmetadata, | ||
| k_caches, | ||
| v_caches, | ||
| ) | ||
| torch.cuda.synchronize() | ||
|
|
||
| # Save the input and output buffers, because replay always uses the same virtual memory space | ||
| self.input_buffers = { | ||
| "input_tokens_ids": input_tokens_ids, | ||
| "output_tensor": output_tensor, | ||
| "block_tables": inputmetadata.block_tables, | ||
| "sequence_lengths": inputmetadata.sequence_lengths, | ||
| # "fd_inter_tensor_mid_output": inputmetadata.fd_inter_tensor._mid_output, | ||
| # "fd_inter_tensor_mid_output_lse": inputmetadata.fd_inter_tensor._mid_output_lse, | ||
| "k_caches": k_caches, | ||
| "v_caches": v_caches, | ||
| } | ||
| self.output_buffers = {"logits": hidden_states_cuda_graph} | ||
| return | ||
|
|
||
| def forward( | ||
| self, | ||
| input_tokens_ids: torch.Tensor, | ||
| output_tensor: torch.Tensor, | ||
| inputmetadata: InputMetaData, | ||
| k_caches: List[torch.Tensor] = None, | ||
| v_caches: List[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| # Copy the input tensors to the input buffers. | ||
| self.input_buffers["input_tokens_ids"].copy_(input_tokens_ids, non_blocking=True) | ||
| self.input_buffers["output_tensor"].copy_(output_tensor, non_blocking=True) | ||
|
|
||
| # for flexible block_table | ||
| self.input_buffers["block_tables"].fill_(-1) | ||
| M, N = inputmetadata.block_tables.shape | ||
| self.input_buffers["block_tables"][:M, :N].copy_(inputmetadata.block_tables, non_blocking=True) | ||
|
|
||
| self.input_buffers["sequence_lengths"].copy_(inputmetadata.sequence_lengths, non_blocking=True) | ||
|
|
||
| # we only have a global fd_inter_tensor so we don't need to copy them | ||
| # self.input_buffers["fd_inter_tensor_mid_output"].copy_(inputmetadata.fd_inter_tensor.mid_output, non_blocking=True) | ||
| # self.input_buffers["fd_inter_tensor_mid_output_lse"].copy_(inputmetadata.fd_inter_tensor.mid_output_lse, non_blocking=True) | ||
|
|
||
| # KV caches are fixed tensors, so we don't need to copy them. | ||
| # self.input_buffers["k_caches"].copy_(k_caches, non_blocking=True) | ||
| # self.input_buffers["v_caches"].copy_(v_caches, non_blocking=True) | ||
|
|
||
| # Run the graph. | ||
| self.graph.replay() | ||
|
|
||
| # Return the output tensor. | ||
| return self.output_buffers["logits"] | ||
|
|
||
| def __call__(self, *args, **kwargs): | ||
| return self.forward(*args, **kwargs) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.