Skip to content

Commit 2989249

Browse files
committed
chore: add calc_tokens method on session
1 parent 9cef559 commit 2989249

File tree

5 files changed

+31
-24
lines changed

5 files changed

+31
-24
lines changed

bot/chatgpt/chat_gpt_bot.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def reply(self, query, context=None):
5858
# # reply in stream
5959
# return self.reply_text_stream(query, new_query, session_id)
6060

61-
reply_content = self.reply_text(session, session_id, api_key, 0)
61+
reply_content = self.reply_text(session, api_key)
6262
logger.debug("[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content["content"], reply_content["completion_tokens"]))
6363
if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0:
6464
reply = Reply(ReplyType.ERROR, reply_content['content'])
@@ -94,7 +94,7 @@ def compose_args(self):
9494
"timeout": conf().get('request_timeout', None), #重试超时时间,在这个时间内,将会自动重试
9595
}
9696

97-
def reply_text(self, session:ChatGPTSession, session_id, api_key, retry_count=0) -> dict:
97+
def reply_text(self, session:ChatGPTSession, api_key=None, retry_count=0) -> dict:
9898
'''
9999
call openai's ChatCompletion to get the answer
100100
:param session: a conversation session
@@ -133,11 +133,11 @@ def reply_text(self, session:ChatGPTSession, session_id, api_key, retry_count=0)
133133
else:
134134
logger.warn("[CHATGPT] Exception: {}".format(e))
135135
need_retry = False
136-
self.sessions.clear_session(session_id)
136+
self.sessions.clear_session(session.session_id)
137137

138138
if need_retry:
139139
logger.warn("[CHATGPT] 第{}次重试".format(retry_count+1))
140-
return self.reply_text(session, session_id, api_key, retry_count+1)
140+
return self.reply_text(session, api_key, retry_count+1)
141141
else:
142142
return result
143143

bot/chatgpt/chat_gpt_session.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, session_id, system_prompt=None, model= "gpt-3.5-turbo"):
1717
def discard_exceeding(self, max_tokens, cur_tokens= None):
1818
precise = True
1919
try:
20-
cur_tokens = num_tokens_from_messages(self.messages, self.model)
20+
cur_tokens = self.calc_tokens()
2121
except Exception as e:
2222
precise = False
2323
if cur_tokens is None:
@@ -29,7 +29,7 @@ def discard_exceeding(self, max_tokens, cur_tokens= None):
2929
elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
3030
self.messages.pop(1)
3131
if precise:
32-
cur_tokens = num_tokens_from_messages(self.messages, self.model)
32+
cur_tokens = self.calc_tokens()
3333
else:
3434
cur_tokens = cur_tokens - max_tokens
3535
break
@@ -40,11 +40,14 @@ def discard_exceeding(self, max_tokens, cur_tokens= None):
4040
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
4141
break
4242
if precise:
43-
cur_tokens = num_tokens_from_messages(self.messages, self.model)
43+
cur_tokens = self.calc_tokens()
4444
else:
4545
cur_tokens = cur_tokens - max_tokens
4646
return cur_tokens
4747

48+
def calc_tokens(self):
49+
return num_tokens_from_messages(self.messages, self.model)
50+
4851

4952
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
5053
def num_tokens_from_messages(messages, model):

bot/openai/open_ai_bot.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,9 @@ def reply(self, query, context=None):
4242
reply = Reply(ReplyType.INFO, '所有人记忆已清除')
4343
else:
4444
session = self.sessions.session_query(query, session_id)
45-
new_query = str(session)
46-
logger.debug("[OPEN_AI] session query={}".format(new_query))
47-
48-
total_tokens, completion_tokens, reply_content = self.reply_text(new_query, session_id, 0)
49-
logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(new_query, session_id, reply_content, completion_tokens))
45+
result = self.reply_text(session)
46+
total_tokens, completion_tokens, reply_content = result['total_tokens'], result['completion_tokens'], result['content']
47+
logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens))
5048

5149
if total_tokens == 0 :
5250
reply = Reply(ReplyType.ERROR, reply_content)
@@ -63,11 +61,11 @@ def reply(self, query, context=None):
6361
reply = Reply(ReplyType.ERROR, retstring)
6462
return reply
6563

66-
def reply_text(self, query, session_id, retry_count=0):
64+
def reply_text(self, session:OpenAISession, retry_count=0):
6765
try:
6866
response = openai.Completion.create(
6967
model= conf().get("model") or "text-davinci-003", # 对话模型的名称
70-
prompt=query,
68+
prompt=str(session),
7169
temperature=0.9, # 值在[0,1]之间,越大表示回复越具有不确定性
7270
max_tokens=1200, # 回复最大的字符数
7371
top_p=1,
@@ -79,31 +77,33 @@ def reply_text(self, query, session_id, retry_count=0):
7977
total_tokens = response["usage"]["total_tokens"]
8078
completion_tokens = response["usage"]["completion_tokens"]
8179
logger.info("[OPEN_AI] reply={}".format(res_content))
82-
return total_tokens, completion_tokens, res_content
80+
return {"total_tokens": total_tokens,
81+
"completion_tokens": completion_tokens,
82+
"content": res_content}
8383
except Exception as e:
8484
need_retry = retry_count < 2
85-
result = [0,0,"我现在有点累了,等会再来吧"]
85+
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
8686
if isinstance(e, openai.error.RateLimitError):
8787
logger.warn("[OPEN_AI] RateLimitError: {}".format(e))
88-
result[2] = "提问太快啦,请休息一下再问我吧"
88+
result['content'] = "提问太快啦,请休息一下再问我吧"
8989
if need_retry:
9090
time.sleep(5)
9191
elif isinstance(e, openai.error.Timeout):
9292
logger.warn("[OPEN_AI] Timeout: {}".format(e))
93-
result[2] = "我没有收到你的消息"
93+
result['content'] = "我没有收到你的消息"
9494
if need_retry:
9595
time.sleep(5)
9696
elif isinstance(e, openai.error.APIConnectionError):
9797
logger.warn("[OPEN_AI] APIConnectionError: {}".format(e))
9898
need_retry = False
99-
result[2] = "我连接不到你的网络"
99+
result['content'] = "我连接不到你的网络"
100100
else:
101101
logger.warn("[OPEN_AI] Exception: {}".format(e))
102102
need_retry = False
103-
self.sessions.clear_session(session_id)
103+
self.sessions.clear_session(session.session_id)
104104

105105
if need_retry:
106106
logger.warn("[OPEN_AI] 第{}次重试".format(retry_count+1))
107-
return self.reply_text(query, session_id, retry_count+1)
107+
return self.reply_text(session, retry_count+1)
108108
else:
109109
return result

bot/openai/open_ai_session.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __str__(self):
2929
def discard_exceeding(self, max_tokens, cur_tokens= None):
3030
precise = True
3131
try:
32-
cur_tokens = num_tokens_from_string(str(self), self.model)
32+
cur_tokens = self.calc_tokens()
3333
except Exception as e:
3434
precise = False
3535
if cur_tokens is None:
@@ -41,7 +41,7 @@ def discard_exceeding(self, max_tokens, cur_tokens= None):
4141
elif len(self.messages) == 1 and self.messages[0]["role"] == "assistant":
4242
self.messages.pop(0)
4343
if precise:
44-
cur_tokens = num_tokens_from_string(str(self), self.model)
44+
cur_tokens = self.calc_tokens()
4545
else:
4646
cur_tokens = len(str(self))
4747
break
@@ -52,11 +52,13 @@ def discard_exceeding(self, max_tokens, cur_tokens= None):
5252
logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages)))
5353
break
5454
if precise:
55-
cur_tokens = num_tokens_from_string(str(self), self.model)
55+
cur_tokens = self.calc_tokens()
5656
else:
5757
cur_tokens = len(str(self))
5858
return cur_tokens
5959

60+
def calc_tokens(self):
61+
return num_tokens_from_string(str(self), self.model)
6062

6163
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
6264
def num_tokens_from_string(string: str, model: str) -> int:

bot/session_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ def add_reply(self, reply):
3131
def discard_exceeding(self, max_tokens=None, cur_tokens=None):
3232
raise NotImplementedError
3333

34+
def calc_tokens(self):
35+
raise NotImplementedError
3436

3537

3638
class SessionManager(object):

0 commit comments

Comments
 (0)