Skip to content

Commit

Permalink
Work-around langchain's SERP non-thread-safe code for sys.stdout =. N…
Browse files Browse the repository at this point in the history
…ote others
  • Loading branch information
pseudotensor committed Oct 23, 2023
1 parent 0b84624 commit 3d1e740
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 8 deletions.
11 changes: 11 additions & 0 deletions src/gpt4all_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ def get_llm_gpt4all(model_name,
max_seq_len=max_seq_len,
)
if model_name == 'llama':
# FIXME: streaming not thread safe due to:
# llama_cpp/utils.py: sys.stdout = self.outnull_file
# llama_cpp/utils.py: sys.stdout = self.old_stdout
cls = H2OLlamaCpp
if model is None:
llamacpp_dict = llamacpp_dict.copy()
Expand Down Expand Up @@ -159,6 +162,10 @@ def get_llm_gpt4all(model_name,
llm.client.verbose = verbose
inner_model = llm.client
elif model_name == 'gpt4all_llama':
# FIXME: streaming not thread safe due to:
# gpt4all/pyllmodel.py: sys.stdout = stream_processor
# gpt4all/pyllmodel.py: sys.stdout = old_stdout

cls = H2OGPT4All
if model is None:
llamacpp_dict = llamacpp_dict.copy()
Expand All @@ -177,6 +184,10 @@ def get_llm_gpt4all(model_name,
llm = cls(**model_kwargs)
inner_model = llm.client
elif model_name == 'gptj':
# FIXME: streaming not thread safe due to:
# gpt4all/pyllmodel.py: sys.stdout = stream_processor
# gpt4all/pyllmodel.py: sys.stdout = old_stdout

cls = H2OGPT4All
if model is None:
llamacpp_dict = llamacpp_dict.copy()
Expand Down
14 changes: 6 additions & 8 deletions src/gpt_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3675,13 +3675,7 @@ def run_qa_db(**kwargs):
# only keep actual used
kwargs = {k: v for k, v in kwargs.items() if k in func_names}
try:
if kwargs.get('verbose', False):
# maybe helps avoid sys.stdout getting closed
from contextlib import redirect_stdout
with redirect_stdout(None):
return _run_qa_db(**kwargs)
else:
return _run_qa_db(**kwargs)
return _run_qa_db(**kwargs)
finally:
clear_torch_cache()

Expand Down Expand Up @@ -4518,7 +4512,11 @@ def get_chain(query=None,
llm, model_name, streamer, prompt_type_out, async_output, only_new_text

if LangChainAgent.PYTHON.value in langchain_agents:
if does_support_functiontools(inference_server, model_name):
# FIXME: not thread-safe due to sys.stdout = assignments in worker
# langchain/utilities/python.py: sys.stdout = mystdout = StringIO()
# langchain/utilities/python.py: sys.stdout = old_stdout
# langchain/utilities/python.py: sys.stdout = old_stdout
if does_support_functiontools(inference_server, model_name) and False:
chain = create_python_agent(
llm=llm,
tool=PythonREPLTool(),
Expand Down
8 changes: 8 additions & 0 deletions src/serpapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,11 @@ def __process_response(res: dict, query: str, headsize: int) -> list:
add_meta(docs, query)

return docs

def results(self, query: str) -> dict:
# Fix non-thread-safe langchain swapping out sys directly.
"""Run query through SerpAPI and return the raw result."""
params = self.get_params(query)
search = self.search_engine(params)
res = search.get_dict()
return res

0 comments on commit 3d1e740

Please sign in to comment.