forked from OpenBMB/ChatDev
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_backend.py
136 lines (109 loc) · 4.8 KB
/
model_backend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from abc import ABC, abstractmethod
from typing import Any, Dict
import openai
import tiktoken
from camel.typing import ModelType
from chatdev.statistics import prompt_cost
from chatdev.utils import log_and_print_online
class ModelBackend(ABC):
r"""Base class for different model backends.
May be OpenAI API, a local LLM, a stub for unit tests, etc."""
@abstractmethod
def run(self, *args, **kwargs) -> Dict[str, Any]:
r"""Runs the query to the backend model.
Raises:
RuntimeError: if the return value from OpenAI API
is not a dict that is expected.
Returns:
Dict[str, Any]: All backends must return a dict in OpenAI format.
"""
pass
class OpenAIModel(ModelBackend):
r"""OpenAI API in a unified ModelBackend interface."""
def __init__(self, model_type: ModelType, model_config_dict: Dict) -> None:
super().__init__()
self.model_type = model_type
self.model_config_dict = model_config_dict
def run(self, *args, **kwargs) -> Dict[str, Any]:
string = "\n".join([message["content"] for message in kwargs["messages"]])
encoding = tiktoken.encoding_for_model(self.model_type.value)
num_prompt_tokens = len(encoding.encode(string))
gap_between_send_receive = 15 * len(kwargs["messages"])
num_prompt_tokens += gap_between_send_receive
num_max_token_map = {
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-16k": 16384,
"gpt-3.5-turbo-0613": 4096,
"gpt-3.5-turbo-16k-0613": 16384,
"gpt-4": 8192,
"gpt-4-0613": 8192,
"gpt-4-32k": 32768,
}
num_max_token = num_max_token_map[self.model_type.value]
num_max_completion_tokens = num_max_token - num_prompt_tokens
self.model_config_dict['max_tokens'] = num_max_completion_tokens
try:
response = openai.ChatCompletion.create(*args, **kwargs, model=self.model_type.value, **self.model_config_dict)
except AttributeError:
response = openai.chat.completions.create(*args, **kwargs, model=self.model_type.value, **self.model_config_dict)
cost = prompt_cost(
self.model_type.value,
num_prompt_tokens=response["usage"]["prompt_tokens"],
num_completion_tokens=response["usage"]["completion_tokens"]
)
log_and_print_online(
"**[OpenAI_Usage_Info Receive]**\nprompt_tokens: {}\ncompletion_tokens: {}\ntotal_tokens: {}\ncost: ${:.6f}\n".format(
response["usage"]["prompt_tokens"], response["usage"]["completion_tokens"],
response["usage"]["total_tokens"], cost))
if not isinstance(response, Dict):
raise RuntimeError("Unexpected return from OpenAI API")
return response
class StubModel(ModelBackend):
r"""A dummy model used for unit tests."""
def __init__(self, *args, **kwargs) -> None:
super().__init__()
def run(self, *args, **kwargs) -> Dict[str, Any]:
ARBITRARY_STRING = "Lorem Ipsum"
return dict(
id="stub_model_id",
usage=dict(),
choices=[
dict(finish_reason="stop",
message=dict(content=ARBITRARY_STRING, role="assistant"))
],
)
class ModelFactory:
r"""Factory of backend models.
Raises:
ValueError: in case the provided model type is unknown.
"""
@staticmethod
def create(model_type: ModelType, model_config_dict: Dict) -> ModelBackend:
default_model_type = ModelType.GPT_3_5_TURBO
if model_type in {
ModelType.GPT_3_5_TURBO, ModelType.GPT_4, ModelType.GPT_4_32k,
None
}:
model_class = OpenAIModel
elif model_type == ModelType.STUB:
model_class = StubModel
else:
raise ValueError("Unknown model")
if model_type is None:
model_type = default_model_type
# log_and_print_online("Model Type: {}".format(model_type))
inst = model_class(model_type, model_config_dict)
return inst