-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapi.py
More file actions
executable file
·55 lines (43 loc) · 1.63 KB
/
Copy pathapi.py
File metadata and controls
executable file
·55 lines (43 loc) · 1.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from typing import List
from pydantic import BaseModel
import uvicorn
from fastapi import FastAPI
from transformers import GPTNeoXForCausalLM, AutoTokenizer
import argparse
parser = argparse.ArgumentParser(description='FastAPI for testing the chip_1.4B_instruct_alpha model.')
parser.add_argument('--host', type=str, default="127.0.0.1", help='Host IP address')
parser.add_argument('--port', type=int, default=8000, help='Host port')
class Data(BaseModel):
input_prompt: str
# Creating an instance of the app
app = FastAPI(title="Fast Chippy", description="FastAPI for testing the chip_1.4B_instruct_alpha model.")
model = GPTNeoXForCausalLM.from_pretrained(
"Rallio67/chip_1.4B_instruct_alpha",
device_map="auto",
#load_in_8bit=True
)
tokenizer = AutoTokenizer.from_pretrained(
"Rallio67/chip_1.4B_instruct_alpha"
)
@app.get('/health')
async def service_health():
"""Return service health"""
return {
"ok"
}
@app.post('/predict')
async def model_predict(data: Data):
inputs = tokenizer("User: " + data.input_prompt, return_tensors="pt").to("cuda")
tokens = model.generate(**inputs,
top_p=0.95,
temperature=0.5,
top_k=4,
repetition_penalty=1.03,
max_length=100,
early_stopping=True
)
output = tokenizer.decode(tokens[0])
return output.replace("<|endoftext|>", "")
if __name__ == '__main__':
args = parser.parse_args()
uvicorn.run(app, host=args.host,port=args.port)