Skip to content

[V1][Spec Decode] EAGLE-3 Support #16937

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions examples/offline_inference/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def main():

args = parse_args()

model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm"
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"

max_model_len = 2048

Expand Down Expand Up @@ -81,7 +81,7 @@ def main():
max_num_seqs=args.max_num_seqs,
gpu_memory_utilization=0.8,
speculative_config={
"method": "eagle",
"method": "eagle3" if "eagle3" in eagle_dir.lower() else "eagle",
"model": eagle_dir,
"num_speculative_tokens": args.num_spec_tokens,
"draft_tensor_parallel_size": args.draft_tp,
Expand All @@ -95,6 +95,9 @@ def main():
outputs = llm.generate(prompt_token_ids=prompt_ids,
sampling_params=sampling_params)

if not hasattr(outputs, "metrics") or outputs.metrics is None:
return

# calculate the average number of accepted tokens per forward pass, +1 is
# to account for the token from the target model that's always going to be
# accepted
Expand All @@ -109,6 +112,11 @@ def main():
{sum(acceptance_counts) / acceptance_counts[0]:.2f}")
print("-" * 50)

# print acceptance at each token position
for i in range(len(acceptance_counts)):
print(f"acceptance at token {i}:"
f"{acceptance_counts[i] / (acceptance_counts[0]):.2f}")


if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,10 @@ def check_available_online(
trust_remote_code=True,
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501
"Eagle3LlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", # noqa: E501
trust_remote_code=True,
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
tokenizer="meta-llama/Llama-3.1-8B-Instruct"),
}

_TRANSFORMERS_MODELS = {
Expand Down
24 changes: 16 additions & 8 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,15 @@ def sampling_config():

@pytest.fixture
def model_name():
return "meta-llama/Meta-Llama-3-8B-Instruct"
return "meta-llama/Llama-3.1-8B-Instruct"


@pytest.fixture
def eagle_model_name():
return "yuhuili/EAGLE-LLaMA3-Instruct-8B"
return "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"


def eagle3_model_name():
return "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"


def test_ngram_correctness(
Expand Down Expand Up @@ -102,12 +105,13 @@ def test_ngram_correctness(
del spec_llm


@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
def test_eagle_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
eagle_model_name: str,
use_eagle3: bool,
):
'''
Compare the outputs of a original LLM and a speculative LLM
Expand All @@ -116,18 +120,22 @@ def test_eagle_correctness(
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

ref_llm = LLM(model=model_name, max_model_len=1024)
ref_llm = LLM(model=model_name, max_model_len=2048)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm

spec_model_name = eagle3_model_name(
) if use_eagle3 else eagle_model_name()
spec_llm = LLM(
model=model_name,
trust_remote_code=True,
speculative_config={
"method": "eagle",
"model": eagle_model_name,
"method": "eagle3" if use_eagle3 else "eagle",
"model": spec_model_name,
"num_speculative_tokens": 3,
"max_model_len": 2048,
},
max_model_len=1024,
max_model_len=2048,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
Expand Down
13 changes: 10 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2338,9 +2338,10 @@ def __post_init__(self):
)

# Automatically detect the method
if self.method == 'eagle':
if self.method in ('eagle', 'eagle3'):
pass
elif "eagle-" in self.draft_model_config.model.lower():
elif "eagle-" in self.draft_model_config.model.lower() or \
"eagle3-" in self.draft_model_config.model.lower():
self.method = "eagle"
elif self.draft_model_config.hf_config.model_type == "medusa":
self.method = "medusa"
Expand All @@ -2351,7 +2352,7 @@ def __post_init__(self):
self.method = "draft_model"

# Replace hf_config for EAGLE draft_model
if self.method == "eagle":
if self.method in ("eagle", "eagle3"):
if self.enable_chunked_prefill and not envs.VLLM_USE_V1:
raise ValueError(
"Chunked prefill and EAGLE are not compatible "
Expand Down Expand Up @@ -2548,6 +2549,12 @@ def _verify_args(self) -> None:
"speculative decoding is > 1, but got "
f"{self.disable_by_batch_size=}")

if self.method == "eagle3" and self.target_model_config and \
"llama" not in self.target_model_config.hf_text_config.model_type:
raise ValueError(
"Eagle3 is only supported for Llama models. "
f"Got {self.target_model_config.hf_text_config.model_type=}")

@property
def num_lookahead_slots(self) -> int:
"""The number of additional slots the scheduler should allocate per
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1456,7 +1456,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
if speculative_method:
if speculative_method in ("ngram", "[ngram]"):
is_ngram_enabled = True
elif speculative_method == "eagle":
elif speculative_method in ("eagle", "eagle3"):
is_eagle_enabled = True
else:
speculative_model = self.speculative_config.get("model")
Expand Down
18 changes: 17 additions & 1 deletion vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,8 @@ def __init__(self,
else:
self.norm = PPMissingLayer()

self.aux_hidden_state_layers: tuple[int] = tuple()

self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
Expand All @@ -355,7 +357,11 @@ def forward(
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]

for layer in self.layers[self.start_layer:self.end_layer]:
aux_hidden_states = []
for idx, layer in enumerate(
self.layers[self.start_layer:self.end_layer]):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(positions, hidden_states, residual)

if not get_pp_group().is_last_rank:
Expand All @@ -365,6 +371,9 @@ def forward(
})

hidden_states, _ = self.norm(hidden_states, residual)

if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states

def load_weights(self, weights: Iterable[Tuple[str,
Expand Down Expand Up @@ -517,6 +526,13 @@ def __init__(self,
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None:
self.model.aux_hidden_state_layers = layers

def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)

def _init_model(self,
vllm_config: VllmConfig,
prefix: str = "",
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/llama_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def forward(
hidden_states,
residual,
)
return hidden_states + residual
hidden_states = hidden_states + residual
return hidden_states, hidden_states

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
Expand Down
Loading