From 3d1e7401876cdb568d32283324ef32ef3a036b22 Mon Sep 17 00:00:00 2001 From: "Jonathan C. McKinney" Date: Mon, 23 Oct 2023 01:36:11 -0700 Subject: [PATCH] Work-around langchain's SERP non-thread-safe code for sys.stdout =. Note others --- src/gpt4all_llm.py | 11 +++++++++++ src/gpt_langchain.py | 14 ++++++-------- src/serpapi.py | 8 ++++++++ 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/src/gpt4all_llm.py b/src/gpt4all_llm.py index e214fc55e..dacb36496 100644 --- a/src/gpt4all_llm.py +++ b/src/gpt4all_llm.py @@ -132,6 +132,9 @@ def get_llm_gpt4all(model_name, max_seq_len=max_seq_len, ) if model_name == 'llama': + # FIXME: streaming not thread safe due to: + # llama_cpp/utils.py: sys.stdout = self.outnull_file + # llama_cpp/utils.py: sys.stdout = self.old_stdout cls = H2OLlamaCpp if model is None: llamacpp_dict = llamacpp_dict.copy() @@ -159,6 +162,10 @@ def get_llm_gpt4all(model_name, llm.client.verbose = verbose inner_model = llm.client elif model_name == 'gpt4all_llama': + # FIXME: streaming not thread safe due to: + # gpt4all/pyllmodel.py: sys.stdout = stream_processor + # gpt4all/pyllmodel.py: sys.stdout = old_stdout + cls = H2OGPT4All if model is None: llamacpp_dict = llamacpp_dict.copy() @@ -177,6 +184,10 @@ def get_llm_gpt4all(model_name, llm = cls(**model_kwargs) inner_model = llm.client elif model_name == 'gptj': + # FIXME: streaming not thread safe due to: + # gpt4all/pyllmodel.py: sys.stdout = stream_processor + # gpt4all/pyllmodel.py: sys.stdout = old_stdout + cls = H2OGPT4All if model is None: llamacpp_dict = llamacpp_dict.copy() diff --git a/src/gpt_langchain.py b/src/gpt_langchain.py index 54dfde50e..b8759f406 100644 --- a/src/gpt_langchain.py +++ b/src/gpt_langchain.py @@ -3675,13 +3675,7 @@ def run_qa_db(**kwargs): # only keep actual used kwargs = {k: v for k, v in kwargs.items() if k in func_names} try: - if kwargs.get('verbose', False): - # maybe helps avoid sys.stdout getting closed - from contextlib import redirect_stdout - with redirect_stdout(None): - return _run_qa_db(**kwargs) - else: - return _run_qa_db(**kwargs) + return _run_qa_db(**kwargs) finally: clear_torch_cache() @@ -4518,7 +4512,11 @@ def get_chain(query=None, llm, model_name, streamer, prompt_type_out, async_output, only_new_text if LangChainAgent.PYTHON.value in langchain_agents: - if does_support_functiontools(inference_server, model_name): + # FIXME: not thread-safe due to sys.stdout = assignments in worker + # langchain/utilities/python.py: sys.stdout = mystdout = StringIO() + # langchain/utilities/python.py: sys.stdout = old_stdout + # langchain/utilities/python.py: sys.stdout = old_stdout + if does_support_functiontools(inference_server, model_name) and False: chain = create_python_agent( llm=llm, tool=PythonREPLTool(), diff --git a/src/serpapi.py b/src/serpapi.py index f7ed7f066..8d5ae49a3 100644 --- a/src/serpapi.py +++ b/src/serpapi.py @@ -165,3 +165,11 @@ def __process_response(res: dict, query: str, headsize: int) -> list: add_meta(docs, query) return docs + + def results(self, query: str) -> dict: + # Fix non-thread-safe langchain swapping out sys directly. + """Run query through SerpAPI and return the raw result.""" + params = self.get_params(query) + search = self.search_engine(params) + res = search.get_dict() + return res