-
Notifications
You must be signed in to change notification settings - Fork 2
/
inference.py
56 lines (41 loc) · 1.87 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from llama.tokenizer import Tokenizer
from llama.model import ModelArgs, Llama
import torch
def inference():
torch.manual_seed(1)
tokenizer_path = "/home1/ichuncha/llama/llama2-7b/tokenizer.model"
model_path = "/home1/ichuncha/llama/llama2-7b/consolidated.00.pth"
lora_weights_path = "/home1/ichuncha/llama/weights/lora_weights_32.pth"
tokenizer = Tokenizer(tokenizer_path)
checkpoint = torch.load(model_path, map_location="cpu")
lora_state_dict = torch.load(lora_weights_path, map_location="cpu")
model_args = ModelArgs()
torch.set_default_tensor_type(torch.cuda.HalfTensor) # load model in fp16
model = Llama(model_args)
model.load_state_dict(checkpoint, strict=False)
# model.load_state_dict(lora_state_dict, strict=False)
model.to("cuda")
prompts = [
# For these prompts, the expected answer is the natural continuation of the prompt
"Describe the structure of an atom.",
"Make the following sentence more concise: I have a really bad cold and it is making me feeling really miserable.",
"I believe the meaning of life is",
"Simply put, the theory of relativity states that ",
"""A brief message congratulating the team on the launch:
Hi everyone,
I just """,
# Few shot prompt (providing a few examples before asking model to complete more);
"""Translate English to French:
sea otter => loutre de mer
peppermint => menthe poivrée
plush girafe => girafe peluche
cheese =>""",
]
model.eval()
results = model.generate(tokenizer, prompts, max_gen_len=64, temperature=0.6, top_p=0.9)
for prompt, result in zip(prompts, results):
print(prompt)
print(f"> {result['generation']}")
print("\n==================================\n")
if __name__ == "__main__":
inference()