Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

[Bugfix] Update profile example to new add request interface + fix profiler not picking up kernels within cudagraphs #332

Merged
merged 6 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions examples/offline_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
print(f" {key} = {value}")

# Create sampling params
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=8)
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=8,
ignore_eos=True)

# Create LLM
llm = LLM(
Expand Down Expand Up @@ -74,14 +77,14 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
sys.exit(-1)

for i in range(batch_size):
prompt_token_ids = torch.randint(
llm.llm_engine.model_config.get_vocab_size(),
size=(prompt_len, )).tolist()

llm.llm_engine.add_request(
request_id=f"seq{i}",
prompt=None,
prompt_token_ids=torch.randint(
128, # 128 to skip over special tokens
llm.llm_engine.model_config.get_vocab_size() // 2,
size=(prompt_len, )).tolist(),
sampling_params=sampling_params)
inputs={'prompt_token_ids': prompt_token_ids},
params=sampling_params)

with nm_profile() as prefill_prof:
llm.llm_engine.step() # First step is prefill
Expand Down
3 changes: 2 additions & 1 deletion vllm/profiler/nm_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ def _get_kineto_gpu_event(self, node: _ModuleTreeNode):
correlated_kineto_events = self._kineto_event_correlation_map.get(
node.event.correlation_id, [])
iterator = (x for x in correlated_kineto_events
if x.device_type() == DeviceType.CUDA)
if x.device_type() == DeviceType.CUDA
and x.name() == node.event.name)
return next(iterator, None)

def _cumulative_cuda_time(self, node: _ModuleTreeNode):
Expand Down
11 changes: 7 additions & 4 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# This file has been modified by Neural Magic

import gc
import time
import warnings
Expand Down Expand Up @@ -969,9 +971,13 @@ def vocab_size(self) -> int:
return self.model_config.get_vocab_size()


class CUDAGraphRunner:
# NOTE: this is nn.Module so the profiler can properly capture/group
# kernels calls made within the graph
class CUDAGraphRunner(nn.Module):
LucasWilkinson marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, model: nn.Module):
super().__init__()

self.model = model
self.input_buffers: Dict[str, torch.Tensor] = {}
self.output_buffers: Dict[str, torch.Tensor] = {}
Expand Down Expand Up @@ -1067,9 +1073,6 @@ def forward(
# Return the output tensor.
return self.output_buffers["hidden_states"]

def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)


def _get_graph_batch_size(batch_size: int) -> int:
"""Returns the padded batch size given actual batch size.
Expand Down
Loading