Skip to content

[Bugfix] Fix EAGLE vocab embedding construction for Llama 70B #19033

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/kernels/bench_fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import itertools

import torch
import triton
from weight_shapes import WEIGHT_SHAPES

from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
from vllm.triton_utils import triton


@triton.testing.perf_report(
Expand Down
64 changes: 40 additions & 24 deletions tests/v1/spec_decode/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig)
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.v1.spec_decode.eagle import EagleProposer

model_dir = "meta-llama/Llama-3.1-8B-Instruct"
Expand Down Expand Up @@ -112,21 +113,26 @@ def test_prepare_inputs():
assert torch.equal(token_indices, expected_token_indices)


@pytest.mark.parametrize(
"method,proposer_helper,draft_model_dir,target_attribute_path", [
("eagle", lambda k: _create_proposer("eagle", k), eagle_dir,
('lm_head', )),
("eagle3", lambda k: _create_proposer("eagle3", k), eagle3_dir,
('model', 'embed_tokens')),
])
@pytest.mark.parametrize("method,proposer_helper", [
("eagle", lambda k: _create_proposer("eagle", k)),
("eagle3", lambda k: _create_proposer("eagle3", k)),
])
@pytest.mark.parametrize("pp_size", [1, 2])
@pytest.mark.parametrize("use_distinct_embed_tokens", [True, False])
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
proposer_helper, draft_model_dir, target_attribute_path):

# Setup model mock
proposer_helper, pp_size, use_distinct_embed_tokens):
# Setup draft model mock
mock_model = mock.MagicMock()
if use_distinct_embed_tokens:
# Some models can have a different hidden size than the target model,
# so we test that their embed_tokens doesn't get overwritten
mock_model.model.embed_tokens.weight.shape = (131072, 2048)
else:
mock_model.model.embed_tokens.weight.shape = (131072, 4096)

mock_get_model.return_value = mock_model

# Setup mocks for attention layers
Expand All @@ -144,22 +150,24 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,

# Setup mock for pp group to return the appropriate value for world size
mock_pp_group = mock.MagicMock()
mock_pp_group.world_size = 2 if method == "eagle" else 1
mock_pp_group.world_size = pp_size
mock_get_pp_group.return_value = mock_pp_group

# Setup target model with the appropriate attributes
target_model = mock.MagicMock()
# Setup the target model mock with a custom class so that
# isinstance() checks match the expected type.
class _TargetModelStub(LlamaForCausalLM):
model: mock.MagicMock
lm_head: mock.MagicMock

# Create the necessary attributes on the target model
current_obj = target_model
for i, attr in enumerate(target_attribute_path):
if i == len(target_attribute_path) - 1:
# Set the last attribute in the path to a MagicMock
setattr(current_obj, attr, mock.MagicMock())
else:
# Create intermediate objects if needed
setattr(current_obj, attr, mock.MagicMock())
current_obj = getattr(current_obj, attr)
target_model = mock.create_autospec(_TargetModelStub, instance=True)
target_model.model = mock.MagicMock()
target_model.model.embed_tokens.weight.shape = (131072, 4096)

from vllm.model_executor.models import SupportsMultiModal
assert not isinstance(target_model, SupportsMultiModal)

if method == "eagle":
target_model.lm_head = mock.MagicMock()

# Create proposer using the helper function
proposer = proposer_helper(k=8)
Expand All @@ -170,10 +178,18 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
# Verify common interactions
mock_get_model.assert_called_once()

# Verify the specific attribute sharing based on the method
# Verify that EAGLE models gain the lm head from the target model
if method == "eagle":
assert proposer.model.lm_head == target_model.lm_head

# Verify that the embed tokens are set correctly
# If pp_size is > 1, the embed tokens should be distinct
if pp_size > 1 or use_distinct_embed_tokens:
assert proposer.model.model.embed_tokens != \
target_model.model.embed_tokens
else:
# When pp_size is 1 and the draft and target models have
# embed_tokens of the same shape, they should be shared.
assert proposer.model.model.embed_tokens == \
target_model.model.embed_tokens

Expand Down
14 changes: 6 additions & 8 deletions vllm/model_executor/models/llama_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,11 @@ def __init__(
speculative_config.draft_model_config.hf_config
self.vocab_size = self.config.vocab_size

# if PP disabled then draft will share embed with target
if get_pp_group().world_size > 1:
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)

self.layers = nn.ModuleList([
LlamaDecoderLayer(
Expand Down Expand Up @@ -163,4 +161,4 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
if "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight
return loader.load_weights(model_weights.items())
loader.load_weights(model_weights.items())
24 changes: 14 additions & 10 deletions vllm/model_executor/models/llama_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear
Expand Down Expand Up @@ -94,13 +93,11 @@ def __init__(
speculative_config.draft_model_config.hf_config
self.vocab_size = self.config.vocab_size

# if PP disabled then draft will share embed with target
if get_pp_group().world_size > 1:
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)

self.layers = nn.ModuleList([
LlamaDecoderLayer(
Expand Down Expand Up @@ -239,6 +236,7 @@ def combine_hidden_states(
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
model_weights = {}
includes_draft_id_mapping = False
includes_embed_tokens = False
for name, loaded_weight in weights:
if "t2d" in name:
continue
Expand All @@ -247,12 +245,18 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
includes_draft_id_mapping = True
elif "lm_head" not in name:
name = "model." + name
if "embed_tokens" in name:
includes_embed_tokens = True
model_weights[name] = loaded_weight

skip_substrs = []
if not includes_draft_id_mapping:
skip_substrs.append("draft_id_to_target_id")
if not includes_embed_tokens:
skip_substrs.append("embed_tokens")
loader = AutoWeightsLoader(
self,
skip_prefixes=None,
skip_substrs=["draft_id_to_target_id"] \
if not includes_draft_id_mapping else None,
skip_substrs=skip_substrs,
)
loader.load_weights(model_weights.items())
1 change: 1 addition & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None
) -> float:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
return torch.cuda.max_memory_allocated(device)

Expand Down
1 change: 1 addition & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,7 @@ def __init__(self, device: Optional[torch.types.Device] = None):
def current_memory_usage(self) -> float:
# Return the memory usage in bytes.
from vllm.platforms import current_platform
gc.collect()
Copy link
Contributor

@ekagra-ranjan ekagra-ranjan Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon can share if gc.collect() and torch.cuda.empty_cache() are fine here. Maybe there is some reason why they were not already added before.

I believe this was added because we delete some torch tensor after allocation. Just in case for some reason we think its better to avoid these new gc commands, an alternative approach to avoid it would be to first load draft model weights from checkpoint and determine if the draft vocab is needed and then pass this info to draft model object instantiation which can skip allocating draft vocab and achieve the same objective.

Copy link
Collaborator Author

@benchislett benchislett Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the current approach makes sense, as enforcing GC and clearing the torch cache seem like natural choices to improve the accuracy of the memory profiler.

If we foresee any issues with calling GC/cleanup in this way, then I'm on board for doing it the other way

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agreed with @ekagra-ranjan, though I didn’t see a clear problem. Let’s keep this in mind and revisit if any issue arises.

return current_platform.get_current_memory_usage(self.device)

def __enter__(self):
Expand Down
11 changes: 7 additions & 4 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,16 +329,19 @@ def load_model(self, target_model: nn.Module) -> None:
self.attn_layer_names = list(draft_attn_layer_names)

# share embed_tokens with the target model if needed
if get_pp_group().world_size == 1:
if get_pp_group().world_size == 1 \
and self.model.model.embed_tokens.weight.shape \
== target_model.model.embed_tokens.weight.shape:
logger.info(
"The EAGLE head shares the same vocab embedding" \
"Assuming the EAGLE head shares the same vocab embedding" \
" with the target model."
)
del self.model.model.embed_tokens
self.model.model.embed_tokens = target_model.model.embed_tokens
else:
logger.info(
"Since PP > 1, the EAGLE head loaded its own vocab embedding" \
" weights instead of sharing them with the target model."
"The EAGLE head's vocab embedding will be loaded separately" \
" from the target model."
)

# share lm_head with the target model if needed
Expand Down