|
7 | 7 |
|
8 | 8 | from vllm import LLM, SamplingParams
|
9 | 9 |
|
10 |
| -parser = argparse.ArgumentParser() |
11 |
| - |
12 |
| -parser.add_argument( |
13 |
| - "--dataset", |
14 |
| - type=str, |
15 |
| - default="./examples/data/gsm8k.jsonl", |
16 |
| - help="downloaded from the eagle repo " \ |
17 |
| - "https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/" |
18 |
| -) |
19 |
| -parser.add_argument("--max_num_seqs", type=int, default=8) |
20 |
| -parser.add_argument("--num_prompts", type=int, default=80) |
21 |
| -parser.add_argument("--num_spec_tokens", type=int, default=2) |
22 |
| -parser.add_argument("--tp", type=int, default=1) |
23 |
| -parser.add_argument("--draft_tp", type=int, default=1) |
24 |
| -parser.add_argument("--enforce_eager", action='store_true') |
25 |
| -parser.add_argument("--enable_chunked_prefill", action='store_true') |
26 |
| -parser.add_argument("--max_num_batched_tokens", type=int, default=2048) |
27 |
| -parser.add_argument("--temp", type=float, default=0) |
28 |
| - |
29 |
| -args = parser.parse_args() |
30 |
| - |
31 |
| -print(args) |
32 |
| - |
33 |
| -model_dir = "meta-llama/Meta-Llama-3-8B-Instruct" |
34 |
| -eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm" |
35 |
| - |
36 |
| -max_model_len = 2048 |
37 |
| - |
38 |
| -tokenizer = AutoTokenizer.from_pretrained(model_dir) |
39 |
| - |
40 |
| -if os.path.exists(args.dataset): |
41 |
| - prompts = [] |
42 |
| - num_prompts = args.num_prompts |
43 |
| - with open(args.dataset) as f: |
44 |
| - for line in f: |
45 |
| - data = json.loads(line) |
46 |
| - prompts.append(data["turns"][0]) |
47 |
| -else: |
48 |
| - prompts = ["The future of AI is", "The president of the United States is"] |
49 |
| - |
50 |
| -prompts = prompts[:args.num_prompts] |
51 |
| -num_prompts = len(prompts) |
52 |
| - |
53 |
| -prompt_ids = [ |
54 |
| - tokenizer.apply_chat_template([{ |
55 |
| - "role": "user", |
56 |
| - "content": prompt |
57 |
| - }], |
58 |
| - add_generation_prompt=True) |
59 |
| - for prompt in prompts |
60 |
| -] |
61 |
| - |
62 |
| -llm = LLM( |
63 |
| - model=model_dir, |
64 |
| - trust_remote_code=True, |
65 |
| - tensor_parallel_size=args.tp, |
66 |
| - enable_chunked_prefill=args.enable_chunked_prefill, |
67 |
| - max_num_batched_tokens=args.max_num_batched_tokens, |
68 |
| - enforce_eager=args.enforce_eager, |
69 |
| - max_model_len=max_model_len, |
70 |
| - max_num_seqs=args.max_num_seqs, |
71 |
| - gpu_memory_utilization=0.8, |
72 |
| - speculative_config={ |
73 |
| - "model": eagle_dir, |
74 |
| - "num_speculative_tokens": args.num_spec_tokens, |
75 |
| - "draft_tensor_parallel_size": args.draft_tp, |
76 |
| - "max_model_len": max_model_len, |
77 |
| - }, |
78 |
| - disable_log_stats=False, |
79 |
| -) |
80 |
| - |
81 |
| -sampling_params = SamplingParams(temperature=args.temp, max_tokens=256) |
82 |
| - |
83 |
| -outputs = llm.generate(prompt_token_ids=prompt_ids, |
84 |
| - sampling_params=sampling_params) |
85 |
| - |
86 |
| -# calculate the average number of accepted tokens per forward pass, +1 is |
87 |
| -# to account for the token from the target model that's always going to be |
88 |
| -# accepted |
89 |
| -acceptance_counts = [0] * (args.num_spec_tokens + 1) |
90 |
| -for output in outputs: |
91 |
| - for step, count in enumerate(output.metrics.spec_token_acceptance_counts): |
92 |
| - acceptance_counts[step] += count |
93 |
| - |
94 |
| -print(f"mean acceptance length: \ |
95 |
| - {sum(acceptance_counts) / acceptance_counts[0]:.2f}") |
| 10 | + |
| 11 | +def load_prompts(dataset_path, num_prompts): |
| 12 | + if os.path.exists(dataset_path): |
| 13 | + prompts = [] |
| 14 | + try: |
| 15 | + with open(dataset_path) as f: |
| 16 | + for line in f: |
| 17 | + data = json.loads(line) |
| 18 | + prompts.append(data["turns"][0]) |
| 19 | + except Exception as e: |
| 20 | + print(f"Error reading dataset: {e}") |
| 21 | + return [] |
| 22 | + else: |
| 23 | + prompts = [ |
| 24 | + "The future of AI is", "The president of the United States is" |
| 25 | + ] |
| 26 | + |
| 27 | + return prompts[:num_prompts] |
| 28 | + |
| 29 | + |
| 30 | +def main(): |
| 31 | + parser = argparse.ArgumentParser() |
| 32 | + parser.add_argument( |
| 33 | + "--dataset", |
| 34 | + type=str, |
| 35 | + default="./examples/data/gsm8k.jsonl", |
| 36 | + help="downloaded from the eagle repo " \ |
| 37 | + "https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/" |
| 38 | + ) |
| 39 | + parser.add_argument("--max_num_seqs", type=int, default=8) |
| 40 | + parser.add_argument("--num_prompts", type=int, default=80) |
| 41 | + parser.add_argument("--num_spec_tokens", type=int, default=2) |
| 42 | + parser.add_argument("--tp", type=int, default=1) |
| 43 | + parser.add_argument("--draft_tp", type=int, default=1) |
| 44 | + parser.add_argument("--enforce_eager", action='store_true') |
| 45 | + parser.add_argument("--enable_chunked_prefill", action='store_true') |
| 46 | + parser.add_argument("--max_num_batched_tokens", type=int, default=2048) |
| 47 | + parser.add_argument("--temp", type=float, default=0) |
| 48 | + args = parser.parse_args() |
| 49 | + |
| 50 | + model_dir = "meta-llama/Meta-Llama-3-8B-Instruct" |
| 51 | + eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm" |
| 52 | + |
| 53 | + max_model_len = 2048 |
| 54 | + |
| 55 | + tokenizer = AutoTokenizer.from_pretrained(model_dir) |
| 56 | + |
| 57 | + prompts = load_prompts(args.dataset, args.num_prompts) |
| 58 | + |
| 59 | + prompt_ids = [ |
| 60 | + tokenizer.apply_chat_template([{ |
| 61 | + "role": "user", |
| 62 | + "content": prompt |
| 63 | + }], |
| 64 | + add_generation_prompt=True) |
| 65 | + for prompt in prompts |
| 66 | + ] |
| 67 | + |
| 68 | + llm = LLM( |
| 69 | + model=model_dir, |
| 70 | + trust_remote_code=True, |
| 71 | + tensor_parallel_size=args.tp, |
| 72 | + enable_chunked_prefill=args.enable_chunked_prefill, |
| 73 | + max_num_batched_tokens=args.max_num_batched_tokens, |
| 74 | + enforce_eager=args.enforce_eager, |
| 75 | + max_model_len=max_model_len, |
| 76 | + max_num_seqs=args.max_num_seqs, |
| 77 | + gpu_memory_utilization=0.8, |
| 78 | + speculative_config={ |
| 79 | + "model": eagle_dir, |
| 80 | + "num_speculative_tokens": args.num_spec_tokens, |
| 81 | + "draft_tensor_parallel_size": args.draft_tp, |
| 82 | + "max_model_len": max_model_len, |
| 83 | + }, |
| 84 | + disable_log_stats=False, |
| 85 | + ) |
| 86 | + |
| 87 | + sampling_params = SamplingParams(temperature=args.temp, max_tokens=256) |
| 88 | + |
| 89 | + outputs = llm.generate(prompt_token_ids=prompt_ids, |
| 90 | + sampling_params=sampling_params) |
| 91 | + |
| 92 | + # calculate the average number of accepted tokens per forward pass, +1 is |
| 93 | + # to account for the token from the target model that's always going to be |
| 94 | + # accepted |
| 95 | + acceptance_counts = [0] * (args.num_spec_tokens + 1) |
| 96 | + for output in outputs: |
| 97 | + for step, count in enumerate( |
| 98 | + output.metrics.spec_token_acceptance_counts): |
| 99 | + acceptance_counts[step] += count |
| 100 | + |
| 101 | + print("-" * 50) |
| 102 | + print(f"mean acceptance length: \ |
| 103 | + {sum(acceptance_counts) / acceptance_counts[0]:.2f}") |
| 104 | + print("-" * 50) |
| 105 | + |
| 106 | + |
| 107 | +if __name__ == "__main__": |
| 108 | + main() |
0 commit comments