Skip to content

Commit b7fc464

Browse files
committed
add orion 14b chat model
1 parent 28ff665 commit b7fc464

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"name": "Orion-14B-Chat-Int4",
3+
"implementation": "model.Orion14BChatInt4"
4+
}

models/Orion-14B-Chat-Int4/model.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import json
2+
3+
from typing import Dict, Any
4+
from mlserver import MLModel, types
5+
from mlserver.codecs import StringCodec
6+
7+
8+
class Orion14BChatInt4(MLModel):
9+
MODEL_NAME = "OrionStarAI/Orion-14B-Chat-Int4"
10+
11+
async def load(self) -> bool:
12+
import torch
13+
from transformers import AutoModelForCausalLM, AutoTokenizer
14+
from transformers.generation.utils import GenerationConfig
15+
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME, use_fast=False, trust_remote_code=True)
16+
self.model = AutoModelForCausalLM.from_pretrained(self.MODEL_NAME, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)
17+
self.model.generation_config = GenerationConfig.from_pretrained(self.MODEL_NAME)
18+
return await super().load()
19+
20+
async def predict(self, payload: types.InferenceRequest) -> types.InferenceResponse:
21+
messages = self._extract_json(payload)["messages"]
22+
response = {
23+
"assistant": self.model.chat(self.tokenizer, messages, streaming=False)
24+
}
25+
response_bytes = json.dumps(response, ensure_ascii=False).encode("UTF-8")
26+
return types.InferenceResponse(
27+
id=payload.id,
28+
model_name=self.name,
29+
model_version=self.version,
30+
outputs=[
31+
types.ResponseOutput(
32+
name="generated_text",
33+
shape=[len(response_bytes)],
34+
datatype="BYTES",
35+
data=[response_bytes],
36+
parameters=types.Parameters(content_type="str"),
37+
)
38+
],
39+
)
40+
41+
def _parse_request(self, payload: types.InferenceRequest) -> Dict[str, Any]:
42+
inputs = {}
43+
for inp in payload.inputs:
44+
inputs[inp.name] = json.loads(
45+
"".join(self.decode(inp, default_codec=StringCodec))
46+
)
47+
return inputs

0 commit comments

Comments
 (0)