Skip to content

Commit 1af1b84

Browse files
committed
Update openChat.py
1 parent 9d538ad commit 1af1b84

File tree

1 file changed

+30
-31
lines changed

1 file changed

+30
-31
lines changed

openChat.py

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import asyncio, websockets, os, sys, json, ssl
22
from typing import Any
33
from langchain_openai import ChatOpenAI
4-
from langchain_community.vectorstores import Chroma
54
from langchain.memory import ConversationBufferWindowMemory, ConversationBufferMemory
6-
from langchain.chains import ConversationChain
5+
from langchain.chains import ConversationChain, LLMChain
6+
from langchain.prompts import PromptTemplate
77
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
88
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
99
from langchain_core.messages import get_buffer_string
1010
from langchain_core.messages.ai import AIMessage
1111
from langchain_core.messages.human import HumanMessage
12-
from langchain_text_splitters import RecursiveCharacterTextSplitter
12+
from langchain.text_splitter import RecursiveCharacterTextSplitter
1313
from dotenv import load_dotenv
1414

1515
load_dotenv()
@@ -52,12 +52,6 @@ async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
5252
# ai_prefix=self.ai_prefix,
5353
# )
5454

55-
CHAT_LLM = ChatOpenAI(temperature=0, model="gpt-4", streaming=True,
56-
callbacks=[MyStreamingHandler()]) # ChatOpenAI cannot have max_token=-1
57-
memory = ConversationBufferMemory(return_messages=False)
58-
chain = ConversationChain(llm=CHAT_LLM, memory=memory, verbose=True)
59-
chain.output_parser=StrOutputParser()
60-
6155
###########################################
6256
### Format of input
6357
# {
@@ -73,6 +67,8 @@ async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
7367
# }
7468
############################################
7569

70+
CHAT_LLM = ChatOpenAI(temperature=0, model="gpt-4", streaming=True,
71+
callbacks=[MyStreamingHandler()]) # ChatOpenAI cannot have max_token=-1
7672
while True:
7773
try:
7874
async for message in websocket:
@@ -81,46 +77,49 @@ async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
8177
params = event["parameters"]
8278
if params["llm"] == "openai":
8379
CHAT_LLM.temperature = float(params["temperature"])
84-
CHAT_LLM.model = params["model"]
80+
CHAT_LLM.model_name = params["model"]
8581
elif params["llm"] == "qianfan":
8682
pass
8783

8884
# if params["client"] == "mobile":
8985
# CHAT_LLM.streaming = False
9086

91-
hlen = 0
92-
if "history" in event["input"]:
93-
# user server history if history key is not present in user request
94-
memory.clear() # do not use memory on serverside. Add chat history kept by client.
95-
for c in event["input"]["history"]:
96-
hlen += len(c["Q"]) + len(c["A"])
97-
if hlen > MAX_TOKEN/2:
98-
break
99-
else:
100-
memory.chat_memory.add_messages([HumanMessage(content=c["Q"]), AIMessage(content=c["A"])])
101-
chunks = []
102-
10387
if "secretary" in event["input"]:
10488
# the request is from secretary APP. If it is too long, seperate it.
10589
splitter = RecursiveCharacterTextSplitter(chunk_size=3072, chunk_overlap=200)
10690
chunks_in = splitter.create_documents([event["input"]["query"]])
91+
92+
# prompt is sent from client, so that it can be customized.
93+
prompt = PromptTemplate(input_variables=["text"],
94+
prompt=event["input"]["prompt"] + """
95+
{text}
96+
SUMMARY:
97+
""")
98+
chain = LLMChain(llm=CHAT_LLM, verbose=True, prompt=prompt, output_parser=StrOutputParser())
10799
resp = ""
108100
for ci in chunks_in:
109-
async for chunk in chain.astream(event["input"]["prompt"] + ci.page_content):
101+
chunks = []
102+
async for chunk in chain.astream({"text": ci.page_content}):
110103
chunks.append(chunk)
111104
print(chunk, end="|", flush=True) # chunk size can be big
112105
resp += chunk["response"]+" "
113106
await websocket.send(json.dumps({"type": "result", "answer": resp}))
114107

115-
# resp = ""
116-
# for ci in chunks_in:
117-
# async for chunk in chain.astream("分段加标点改错别字。 "+ ci.page_content):
118-
# chunks.append(chunk)
119-
# print(chunk, end="|", flush=True) # chunk size can be big
120-
# resp += chunk["response"]+" "
121-
# await websocket.send(json.dumps({"type": "result", "answer": resp}))
122-
123108
elif "query" in event["input"]:
109+
memory = ConversationBufferMemory(return_messages=False)
110+
if "history" in event["input"]:
111+
# user server history if history key is not present in user request
112+
memory.clear() # do not use memory on serverside. Add chat history kept by client.
113+
hlen = 0
114+
for c in event["input"]["history"]:
115+
hlen += len(c["Q"]) + len(c["A"])
116+
if hlen > MAX_TOKEN/2:
117+
break
118+
else:
119+
memory.chat_memory.add_messages([HumanMessage(content=c["Q"]), AIMessage(content=c["A"])])
120+
121+
chain = ConversationChain(llm=CHAT_LLM, memory=memory, verbose=True, output_parser=StrOutputParser())
122+
chunks = []
124123
async for chunk in chain.astream(event["input"]["query"]):
125124
chunks.append(chunk)
126125
print(chunk, end="|", flush=True) # chunk size can be big

0 commit comments

Comments
 (0)