-
Notifications
You must be signed in to change notification settings - Fork 614
/
reference_hf.py
114 lines (89 loc) · 3.15 KB
/
reference_hf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
Usage:
python3 reference_hf.py --model TinyLlama/TinyLlama-1.1B-Chat-v0.4
Reference output:
========== Prompt 0 ==========
prefill logits (final) tensor([-8.3125, -7.1172, 3.3398, ..., -4.9531, -4.1328, -3.4141],
device='cuda:0')
<s> The capital of France is Paris.
The capital of the United States is Washington, D.C.
========== Prompt 1 ==========
prefill logits (final) tensor([-8.9062, -9.0156, 4.1484, ..., -4.9922, -4.4961, -4.0742],
device='cuda:0')
<s> The capital of the United Kindom is London.
The capital of the United Kingdom is London.
The capital of
========== Prompt 2 ==========
prefill logits (final) tensor([-9.6328, -9.0547, 4.0234, ..., -5.3047, -4.7148, -4.4609],
device='cuda:0')
<s> Today is a sunny day and I like to go for a walk in the park.
I'm going to the
"""
import argparse
import torch
from transformers import AutoModelForCausalLM
from sglang.srt.hf_transformers_utils import get_tokenizer
@torch.no_grad()
def normal_text(args):
t = get_tokenizer(args.model_path, trust_remote_code=True)
m = AutoModelForCausalLM.from_pretrained(
args.model_path,
torch_dtype=args.dtype,
low_cpu_mem_usage=True,
device_map="auto",
trust_remote_code=True,
)
prompts = [
"The capital of France is",
"The capital of the United Kindom is",
"Today is a sunny day and I like",
]
max_new_tokens = args.max_new_tokens
torch.cuda.set_device(0)
for i, p in enumerate(prompts):
if isinstance(p, str):
input_ids = t.encode(p, return_tensors="pt").to("cuda:0")
else:
input_ids = torch.tensor([p], device="cuda:0")
output_ids = m.generate(
input_ids, do_sample=False, max_new_tokens=max_new_tokens
)
output_str = t.decode(output_ids[0])
prefill_logits = m.forward(input_ids).logits[0][-1]
print(f"\n========== Prompt {i} ==========")
print("prefill logits (final)", prefill_logits)
print(output_str)
@torch.no_grad()
def synthetic_tokens(args):
m = AutoModelForCausalLM.from_pretrained(
args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
)
m.cuda()
print(m)
input_len = 256
output_len = 8
prompts = [list(range(5, 5 + input_len))]
for p in prompts:
input_ids = p
for i in range(output_len + 1):
prefill_logits = m.forward(torch.tensor([input_ids], device="cuda")).logits[
0
][-1]
if i == 0:
print("prefill logits", prefill_logits)
else:
print("decode", i - 1, prefill_logits)
input_ids.append(torch.argmax(prefill_logits).item())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-path",
type=str,
default="TinyLlama/TinyLlama-1.1B-Chat-v0.4",
# default="meta-llama/Llama-2-7b-chat-hf",
)
parser.add_argument("--max-new-tokens", type=int, default=16)
parser.add_argument("--dtype", type=str, default="float16")
args = parser.parse_args()
normal_text(args)
# synthetic_tokens(args)