1
1
import asyncio , websockets , os , sys , json , ssl
2
2
from typing import Any
3
3
from langchain_openai import ChatOpenAI
4
- from langchain_community .vectorstores import Chroma
5
4
from langchain .memory import ConversationBufferWindowMemory , ConversationBufferMemory
6
- from langchain .chains import ConversationChain
5
+ from langchain .chains import ConversationChain , LLMChain
6
+ from langchain .prompts import PromptTemplate
7
7
from langchain_core .output_parsers import StrOutputParser , JsonOutputParser
8
8
from langchain_core .callbacks .streaming_stdout import StreamingStdOutCallbackHandler
9
9
from langchain_core .messages import get_buffer_string
10
10
from langchain_core .messages .ai import AIMessage
11
11
from langchain_core .messages .human import HumanMessage
12
- from langchain_text_splitters import RecursiveCharacterTextSplitter
12
+ from langchain . text_splitter import RecursiveCharacterTextSplitter
13
13
from dotenv import load_dotenv
14
14
15
15
load_dotenv ()
@@ -52,12 +52,6 @@ async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
52
52
# ai_prefix=self.ai_prefix,
53
53
# )
54
54
55
- CHAT_LLM = ChatOpenAI (temperature = 0 , model = "gpt-4" , streaming = True ,
56
- callbacks = [MyStreamingHandler ()]) # ChatOpenAI cannot have max_token=-1
57
- memory = ConversationBufferMemory (return_messages = False )
58
- chain = ConversationChain (llm = CHAT_LLM , memory = memory , verbose = True )
59
- chain .output_parser = StrOutputParser ()
60
-
61
55
###########################################
62
56
### Format of input
63
57
# {
@@ -73,6 +67,8 @@ async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
73
67
# }
74
68
############################################
75
69
70
+ CHAT_LLM = ChatOpenAI (temperature = 0 , model = "gpt-4" , streaming = True ,
71
+ callbacks = [MyStreamingHandler ()]) # ChatOpenAI cannot have max_token=-1
76
72
while True :
77
73
try :
78
74
async for message in websocket :
@@ -81,46 +77,49 @@ async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
81
77
params = event ["parameters" ]
82
78
if params ["llm" ] == "openai" :
83
79
CHAT_LLM .temperature = float (params ["temperature" ])
84
- CHAT_LLM .model = params ["model" ]
80
+ CHAT_LLM .model_name = params ["model" ]
85
81
elif params ["llm" ] == "qianfan" :
86
82
pass
87
83
88
84
# if params["client"] == "mobile":
89
85
# CHAT_LLM.streaming = False
90
86
91
- hlen = 0
92
- if "history" in event ["input" ]:
93
- # user server history if history key is not present in user request
94
- memory .clear () # do not use memory on serverside. Add chat history kept by client.
95
- for c in event ["input" ]["history" ]:
96
- hlen += len (c ["Q" ]) + len (c ["A" ])
97
- if hlen > MAX_TOKEN / 2 :
98
- break
99
- else :
100
- memory .chat_memory .add_messages ([HumanMessage (content = c ["Q" ]), AIMessage (content = c ["A" ])])
101
- chunks = []
102
-
103
87
if "secretary" in event ["input" ]:
104
88
# the request is from secretary APP. If it is too long, seperate it.
105
89
splitter = RecursiveCharacterTextSplitter (chunk_size = 3072 , chunk_overlap = 200 )
106
90
chunks_in = splitter .create_documents ([event ["input" ]["query" ]])
91
+
92
+ # prompt is sent from client, so that it can be customized.
93
+ prompt = PromptTemplate (input_variables = ["text" ],
94
+ prompt = event ["input" ]["prompt" ] + """
95
+ {text}
96
+ SUMMARY:
97
+ """ )
98
+ chain = LLMChain (llm = CHAT_LLM , verbose = True , prompt = prompt , output_parser = StrOutputParser ())
107
99
resp = ""
108
100
for ci in chunks_in :
109
- async for chunk in chain .astream (event ["input" ]["prompt" ] + ci .page_content ):
101
+ chunks = []
102
+ async for chunk in chain .astream ({"text" : ci .page_content }):
110
103
chunks .append (chunk )
111
104
print (chunk , end = "|" , flush = True ) # chunk size can be big
112
105
resp += chunk ["response" ]+ " "
113
106
await websocket .send (json .dumps ({"type" : "result" , "answer" : resp }))
114
107
115
- # resp = ""
116
- # for ci in chunks_in:
117
- # async for chunk in chain.astream("分段加标点改错别字。 "+ ci.page_content):
118
- # chunks.append(chunk)
119
- # print(chunk, end="|", flush=True) # chunk size can be big
120
- # resp += chunk["response"]+" "
121
- # await websocket.send(json.dumps({"type": "result", "answer": resp}))
122
-
123
108
elif "query" in event ["input" ]:
109
+ memory = ConversationBufferMemory (return_messages = False )
110
+ if "history" in event ["input" ]:
111
+ # user server history if history key is not present in user request
112
+ memory .clear () # do not use memory on serverside. Add chat history kept by client.
113
+ hlen = 0
114
+ for c in event ["input" ]["history" ]:
115
+ hlen += len (c ["Q" ]) + len (c ["A" ])
116
+ if hlen > MAX_TOKEN / 2 :
117
+ break
118
+ else :
119
+ memory .chat_memory .add_messages ([HumanMessage (content = c ["Q" ]), AIMessage (content = c ["A" ])])
120
+
121
+ chain = ConversationChain (llm = CHAT_LLM , memory = memory , verbose = True , output_parser = StrOutputParser ())
122
+ chunks = []
124
123
async for chunk in chain .astream (event ["input" ]["query" ]):
125
124
chunks .append (chunk )
126
125
print (chunk , end = "|" , flush = True ) # chunk size can be big
0 commit comments