Skip to content

Commit 3c081ea

Browse files
author
reidliu41
committed
[Misc] refactor example eagle
Signed-off-by: reidliu41 <reid201711@gmail.com>
1 parent 97ae6d7 commit 3c081ea

File tree

1 file changed

+117
-86
lines changed

1 file changed

+117
-86
lines changed

examples/offline_inference/eagle.py

+117-86
Original file line numberDiff line numberDiff line change
@@ -7,89 +7,120 @@
77

88
from vllm import LLM, SamplingParams
99

10-
parser = argparse.ArgumentParser()
11-
12-
parser.add_argument(
13-
"--dataset",
14-
type=str,
15-
default="./examples/data/gsm8k.jsonl",
16-
help="downloaded from the eagle repo " \
17-
"https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/"
18-
)
19-
parser.add_argument("--max_num_seqs", type=int, default=8)
20-
parser.add_argument("--num_prompts", type=int, default=80)
21-
parser.add_argument("--num_spec_tokens", type=int, default=2)
22-
parser.add_argument("--tp", type=int, default=1)
23-
parser.add_argument("--draft_tp", type=int, default=1)
24-
parser.add_argument("--enforce_eager", action='store_true')
25-
parser.add_argument("--enable_chunked_prefill", action='store_true')
26-
parser.add_argument("--max_num_batched_tokens", type=int, default=2048)
27-
parser.add_argument("--temp", type=float, default=0)
28-
29-
args = parser.parse_args()
30-
31-
print(args)
32-
33-
model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
34-
eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm"
35-
36-
max_model_len = 2048
37-
38-
tokenizer = AutoTokenizer.from_pretrained(model_dir)
39-
40-
if os.path.exists(args.dataset):
41-
prompts = []
42-
num_prompts = args.num_prompts
43-
with open(args.dataset) as f:
44-
for line in f:
45-
data = json.loads(line)
46-
prompts.append(data["turns"][0])
47-
else:
48-
prompts = ["The future of AI is", "The president of the United States is"]
49-
50-
prompts = prompts[:args.num_prompts]
51-
num_prompts = len(prompts)
52-
53-
prompt_ids = [
54-
tokenizer.apply_chat_template([{
55-
"role": "user",
56-
"content": prompt
57-
}],
58-
add_generation_prompt=True)
59-
for prompt in prompts
60-
]
61-
62-
llm = LLM(
63-
model=model_dir,
64-
trust_remote_code=True,
65-
tensor_parallel_size=args.tp,
66-
enable_chunked_prefill=args.enable_chunked_prefill,
67-
max_num_batched_tokens=args.max_num_batched_tokens,
68-
enforce_eager=args.enforce_eager,
69-
max_model_len=max_model_len,
70-
max_num_seqs=args.max_num_seqs,
71-
gpu_memory_utilization=0.8,
72-
speculative_config={
73-
"model": eagle_dir,
74-
"num_speculative_tokens": args.num_spec_tokens,
75-
"draft_tensor_parallel_size": args.draft_tp,
76-
"max_model_len": max_model_len,
77-
},
78-
disable_log_stats=False,
79-
)
80-
81-
sampling_params = SamplingParams(temperature=args.temp, max_tokens=256)
82-
83-
outputs = llm.generate(prompt_token_ids=prompt_ids,
84-
sampling_params=sampling_params)
85-
86-
# calculate the average number of accepted tokens per forward pass, +1 is
87-
# to account for the token from the target model that's always going to be
88-
# accepted
89-
acceptance_counts = [0] * (args.num_spec_tokens + 1)
90-
for output in outputs:
91-
for step, count in enumerate(output.metrics.spec_token_acceptance_counts):
92-
acceptance_counts[step] += count
93-
94-
print(f"mean acceptance length: \
95-
{sum(acceptance_counts) / acceptance_counts[0]:.2f}")
10+
11+
def print_msg(msg):
12+
print(f"[debug] {msg}\n")
13+
14+
15+
def load_prompts(dataset_path, num_prompts):
16+
if os.path.exists(dataset_path):
17+
prompts = []
18+
try:
19+
with open(dataset_path) as f:
20+
for line in f:
21+
data = json.loads(line)
22+
prompts.append(data["turns"][0])
23+
except Exception as e:
24+
print_msg(f"Error reading dataset: {e}")
25+
return []
26+
else:
27+
prompts = [
28+
"The future of AI is", "The president of the United States is"
29+
]
30+
print_msg(
31+
f"Dataset not found at {dataset_path}, using prompts:\n{prompts}.")
32+
33+
return prompts[:num_prompts]
34+
35+
36+
def main():
37+
parser = argparse.ArgumentParser()
38+
parser.add_argument(
39+
"--dataset",
40+
type=str,
41+
default="./examples/data/gsm8k.jsonl",
42+
help="downloaded from the eagle repo " \
43+
"https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/"
44+
)
45+
parser.add_argument("--max_num_seqs", type=int, default=8)
46+
parser.add_argument("--num_prompts", type=int, default=80)
47+
parser.add_argument("--num_spec_tokens", type=int, default=2)
48+
parser.add_argument("--tp", type=int, default=1)
49+
parser.add_argument("--draft_tp", type=int, default=1)
50+
parser.add_argument("--enforce_eager", action='store_true')
51+
parser.add_argument("--enable_chunked_prefill", action='store_true')
52+
parser.add_argument("--max_num_batched_tokens", type=int, default=2048)
53+
parser.add_argument("--temp", type=float, default=0)
54+
args = parser.parse_args()
55+
56+
print_msg(f"Starting inference with the following parameters:\n{args}")
57+
58+
model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
59+
eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm"
60+
61+
max_model_len = 2048
62+
63+
# Initialize tokenizer
64+
print_msg(f"Loading tokenizer for model {model_dir}")
65+
tokenizer = AutoTokenizer.from_pretrained(model_dir)
66+
67+
# Load prompts
68+
prompts = load_prompts(args.dataset, args.num_prompts)
69+
print_msg(f"Loaded and tokenized {len(prompts)} prompts")
70+
71+
# Tokenize prompts
72+
prompt_ids = [
73+
tokenizer.apply_chat_template([{
74+
"role": "user",
75+
"content": prompt
76+
}],
77+
add_generation_prompt=True)
78+
for prompt in prompts
79+
]
80+
81+
# Initialize LLM
82+
print_msg(
83+
f"Initializing model {model_dir} with tensor parallel size {args.tp}")
84+
llm = LLM(
85+
model=model_dir,
86+
trust_remote_code=True,
87+
tensor_parallel_size=args.tp,
88+
enable_chunked_prefill=args.enable_chunked_prefill,
89+
max_num_batched_tokens=args.max_num_batched_tokens,
90+
enforce_eager=args.enforce_eager,
91+
max_model_len=max_model_len,
92+
max_num_seqs=args.max_num_seqs,
93+
gpu_memory_utilization=0.8,
94+
speculative_config={
95+
"model": eagle_dir,
96+
"num_speculative_tokens": args.num_spec_tokens,
97+
"draft_tensor_parallel_size": args.draft_tp,
98+
"max_model_len": max_model_len,
99+
},
100+
disable_log_stats=False,
101+
)
102+
103+
sampling_params = SamplingParams(temperature=args.temp, max_tokens=256)
104+
105+
# Start inference
106+
print_msg("Starting inference...")
107+
outputs = llm.generate(prompt_token_ids=prompt_ids,
108+
sampling_params=sampling_params)
109+
110+
# calculate the average number of accepted tokens per forward pass, +1 is
111+
# to account for the token from the target model that's always going to be
112+
# accepted
113+
acceptance_counts = [0] * (args.num_spec_tokens + 1)
114+
for output in outputs:
115+
for step, count in enumerate(
116+
output.metrics.spec_token_acceptance_counts):
117+
acceptance_counts[step] += count
118+
119+
print("-" * 50)
120+
print(f"mean acceptance length: \
121+
{sum(acceptance_counts) / acceptance_counts[0]:.2f}")
122+
print("-" * 50)
123+
124+
125+
if __name__ == "__main__":
126+
main()

0 commit comments

Comments
 (0)