1
+ import json
2
+
3
+ from typing import Dict , Any
4
+ from mlserver import MLModel , types
5
+ from mlserver .codecs import StringCodec
6
+
7
+
8
+ class Orion14BChatInt4 (MLModel ):
9
+ MODEL_NAME = "OrionStarAI/Orion-14B-Chat-Int4"
10
+
11
+ async def load (self ) -> bool :
12
+ import torch
13
+ from transformers import AutoModelForCausalLM , AutoTokenizer
14
+ from transformers .generation .utils import GenerationConfig
15
+ self .tokenizer = AutoTokenizer .from_pretrained (self .MODEL_NAME , use_fast = False , trust_remote_code = True )
16
+ self .model = AutoModelForCausalLM .from_pretrained (self .MODEL_NAME , device_map = "auto" , torch_dtype = torch .float16 , trust_remote_code = True )
17
+ self .model .generation_config = GenerationConfig .from_pretrained (self .MODEL_NAME )
18
+ return await super ().load ()
19
+
20
+ async def predict (self , payload : types .InferenceRequest ) -> types .InferenceResponse :
21
+ messages = self ._extract_json (payload )["messages" ]
22
+ response = {
23
+ "assistant" : self .model .chat (self .tokenizer , messages , streaming = False )
24
+ }
25
+ response_bytes = json .dumps (response , ensure_ascii = False ).encode ("UTF-8" )
26
+ return types .InferenceResponse (
27
+ id = payload .id ,
28
+ model_name = self .name ,
29
+ model_version = self .version ,
30
+ outputs = [
31
+ types .ResponseOutput (
32
+ name = "generated_text" ,
33
+ shape = [len (response_bytes )],
34
+ datatype = "BYTES" ,
35
+ data = [response_bytes ],
36
+ parameters = types .Parameters (content_type = "str" ),
37
+ )
38
+ ],
39
+ )
40
+
41
+ def _parse_request (self , payload : types .InferenceRequest ) -> Dict [str , Any ]:
42
+ inputs = {}
43
+ for inp in payload .inputs :
44
+ inputs [inp .name ] = json .loads (
45
+ "" .join (self .decode (inp , default_codec = StringCodec ))
46
+ )
47
+ return inputs
0 commit comments