Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[misc] add forward context for attention #9029

Merged
merged 11 commits into from
Oct 3, 2024
Prev Previous commit
Next Next commit
add draft model runner
  • Loading branch information
youkaichao committed Oct 2, 2024
commit 01e5a7da0025ceab48c0986d19886b155d9554d2
22 changes: 12 additions & 10 deletions vllm/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch

from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.sampler import SamplerOutput

try:
Expand Down Expand Up @@ -291,16 +292,17 @@ def execute_model(
if previous_hidden_states is not None else {}

# Run model
hidden_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**MultiModalInputs.as_kwargs(multi_modal_kwargs,
device=self.device),
**kwargs,
)
with set_forward_context(model_input.attn_metadata):
hidden_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**MultiModalInputs.as_kwargs(multi_modal_kwargs,
device=self.device),
**kwargs,
)

# Compute the logits.
logits = self.model.compute_logits(hidden_states,
Expand Down