Skip to content

Commit b6c502a

Browse files
reidliu41reidliu41
and
reidliu41
authored
[Misc] refactor example eagle (#16100)
Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com>
1 parent 9ca710e commit b6c502a

File tree

1 file changed

+99
-86
lines changed

1 file changed

+99
-86
lines changed

examples/offline_inference/eagle.py

+99-86
Original file line numberDiff line numberDiff line change
@@ -7,89 +7,102 @@
77

88
from vllm import LLM, SamplingParams
99

10-
parser = argparse.ArgumentParser()
11-
12-
parser.add_argument(
13-
"--dataset",
14-
type=str,
15-
default="./examples/data/gsm8k.jsonl",
16-
help="downloaded from the eagle repo " \
17-
"https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/"
18-
)
19-
parser.add_argument("--max_num_seqs", type=int, default=8)
20-
parser.add_argument("--num_prompts", type=int, default=80)
21-
parser.add_argument("--num_spec_tokens", type=int, default=2)
22-
parser.add_argument("--tp", type=int, default=1)
23-
parser.add_argument("--draft_tp", type=int, default=1)
24-
parser.add_argument("--enforce_eager", action='store_true')
25-
parser.add_argument("--enable_chunked_prefill", action='store_true')
26-
parser.add_argument("--max_num_batched_tokens", type=int, default=2048)
27-
parser.add_argument("--temp", type=float, default=0)
28-
29-
args = parser.parse_args()
30-
31-
print(args)
32-
33-
model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
34-
eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm"
35-
36-
max_model_len = 2048
37-
38-
tokenizer = AutoTokenizer.from_pretrained(model_dir)
39-
40-
if os.path.exists(args.dataset):
41-
prompts = []
42-
num_prompts = args.num_prompts
43-
with open(args.dataset) as f:
44-
for line in f:
45-
data = json.loads(line)
46-
prompts.append(data["turns"][0])
47-
else:
48-
prompts = ["The future of AI is", "The president of the United States is"]
49-
50-
prompts = prompts[:args.num_prompts]
51-
num_prompts = len(prompts)
52-
53-
prompt_ids = [
54-
tokenizer.apply_chat_template([{
55-
"role": "user",
56-
"content": prompt
57-
}],
58-
add_generation_prompt=True)
59-
for prompt in prompts
60-
]
61-
62-
llm = LLM(
63-
model=model_dir,
64-
trust_remote_code=True,
65-
tensor_parallel_size=args.tp,
66-
enable_chunked_prefill=args.enable_chunked_prefill,
67-
max_num_batched_tokens=args.max_num_batched_tokens,
68-
enforce_eager=args.enforce_eager,
69-
max_model_len=max_model_len,
70-
max_num_seqs=args.max_num_seqs,
71-
gpu_memory_utilization=0.8,
72-
speculative_config={
73-
"model": eagle_dir,
74-
"num_speculative_tokens": args.num_spec_tokens,
75-
"draft_tensor_parallel_size": args.draft_tp,
76-
"max_model_len": max_model_len,
77-
},
78-
disable_log_stats=False,
79-
)
80-
81-
sampling_params = SamplingParams(temperature=args.temp, max_tokens=256)
82-
83-
outputs = llm.generate(prompt_token_ids=prompt_ids,
84-
sampling_params=sampling_params)
85-
86-
# calculate the average number of accepted tokens per forward pass, +1 is
87-
# to account for the token from the target model that's always going to be
88-
# accepted
89-
acceptance_counts = [0] * (args.num_spec_tokens + 1)
90-
for output in outputs:
91-
for step, count in enumerate(output.metrics.spec_token_acceptance_counts):
92-
acceptance_counts[step] += count
93-
94-
print(f"mean acceptance length: \
95-
{sum(acceptance_counts) / acceptance_counts[0]:.2f}")
10+
11+
def load_prompts(dataset_path, num_prompts):
12+
if os.path.exists(dataset_path):
13+
prompts = []
14+
try:
15+
with open(dataset_path) as f:
16+
for line in f:
17+
data = json.loads(line)
18+
prompts.append(data["turns"][0])
19+
except Exception as e:
20+
print(f"Error reading dataset: {e}")
21+
return []
22+
else:
23+
prompts = [
24+
"The future of AI is", "The president of the United States is"
25+
]
26+
27+
return prompts[:num_prompts]
28+
29+
30+
def main():
31+
parser = argparse.ArgumentParser()
32+
parser.add_argument(
33+
"--dataset",
34+
type=str,
35+
default="./examples/data/gsm8k.jsonl",
36+
help="downloaded from the eagle repo " \
37+
"https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/"
38+
)
39+
parser.add_argument("--max_num_seqs", type=int, default=8)
40+
parser.add_argument("--num_prompts", type=int, default=80)
41+
parser.add_argument("--num_spec_tokens", type=int, default=2)
42+
parser.add_argument("--tp", type=int, default=1)
43+
parser.add_argument("--draft_tp", type=int, default=1)
44+
parser.add_argument("--enforce_eager", action='store_true')
45+
parser.add_argument("--enable_chunked_prefill", action='store_true')
46+
parser.add_argument("--max_num_batched_tokens", type=int, default=2048)
47+
parser.add_argument("--temp", type=float, default=0)
48+
args = parser.parse_args()
49+
50+
model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
51+
eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm"
52+
53+
max_model_len = 2048
54+
55+
tokenizer = AutoTokenizer.from_pretrained(model_dir)
56+
57+
prompts = load_prompts(args.dataset, args.num_prompts)
58+
59+
prompt_ids = [
60+
tokenizer.apply_chat_template([{
61+
"role": "user",
62+
"content": prompt
63+
}],
64+
add_generation_prompt=True)
65+
for prompt in prompts
66+
]
67+
68+
llm = LLM(
69+
model=model_dir,
70+
trust_remote_code=True,
71+
tensor_parallel_size=args.tp,
72+
enable_chunked_prefill=args.enable_chunked_prefill,
73+
max_num_batched_tokens=args.max_num_batched_tokens,
74+
enforce_eager=args.enforce_eager,
75+
max_model_len=max_model_len,
76+
max_num_seqs=args.max_num_seqs,
77+
gpu_memory_utilization=0.8,
78+
speculative_config={
79+
"model": eagle_dir,
80+
"num_speculative_tokens": args.num_spec_tokens,
81+
"draft_tensor_parallel_size": args.draft_tp,
82+
"max_model_len": max_model_len,
83+
},
84+
disable_log_stats=False,
85+
)
86+
87+
sampling_params = SamplingParams(temperature=args.temp, max_tokens=256)
88+
89+
outputs = llm.generate(prompt_token_ids=prompt_ids,
90+
sampling_params=sampling_params)
91+
92+
# calculate the average number of accepted tokens per forward pass, +1 is
93+
# to account for the token from the target model that's always going to be
94+
# accepted
95+
acceptance_counts = [0] * (args.num_spec_tokens + 1)
96+
for output in outputs:
97+
for step, count in enumerate(
98+
output.metrics.spec_token_acceptance_counts):
99+
acceptance_counts[step] += count
100+
101+
print("-" * 50)
102+
print(f"mean acceptance length: \
103+
{sum(acceptance_counts) / acceptance_counts[0]:.2f}")
104+
print("-" * 50)
105+
106+
107+
if __name__ == "__main__":
108+
main()

0 commit comments

Comments
 (0)