Skip to content

Commit

Permalink
Handle asyncio chain.arun, but for 4 chunks it's performing no better…
Browse files Browse the repository at this point in the history
… than one at a time.
  • Loading branch information
pseudotensor committed Jul 28, 2023
1 parent cf6b579 commit cc3331d
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 46 deletions.
2 changes: 2 additions & 0 deletions src/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,13 +1696,15 @@ def evaluate(
num_return_sequences=num_return_sequences,
)
t_generate = time.time()
async_output = True # if not streaming, will async over tasks if summarization and be faster result
for r in run_qa_db(query=instruction,
iinput=iinput,
context=context,
model_name=base_model, model=model, tokenizer=tokenizer,
inference_server=inference_server,
langchain_only_model=langchain_only_model,
stream_output=stream_output,
async_output=async_output,
prompter=prompter,
use_llm_if_no_docs=use_llm_if_no_docs,
load_db_if_exists=load_db_if_exists,
Expand Down
134 changes: 88 additions & 46 deletions src/gpt_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def get_answer_from_sources(chain, sources, question):

from pydantic import Extra, Field, root_validator

from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks, AsyncCallbackManagerForLLMRun
from langchain.llms.base import LLM


Expand Down Expand Up @@ -468,27 +468,7 @@ class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
context: Any = ''
iinput: Any = ''
tokenizer: Any = None
client: Any = None

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that python package exists in environment."""

try:
if values['client'] is None:
import text_generation

values["client"] = text_generation.Client(
values["inference_server_url"],
timeout=values["timeout"],
headers=values["headers"],
)
except ImportError:
raise ImportError(
"Could not import text_generation python package. "
"Please install it with `pip install text_generation`."
)
return values
#client: Any = None

def _call(
self,
Expand All @@ -498,9 +478,12 @@ def _call(
**kwargs: Any,
) -> str:
if stop is None:
stop = self.stop_sequences
stop = self.stop_sequences.copy()
else:
stop += self.stop_sequences
stop += self.stop_sequences.copy()
stop_tmp = stop.copy()
stop = []
[stop.append(x) for x in stop_tmp if x not in stop]

# HF inference server needs control over input tokens
assert self.tokenizer is not None
Expand Down Expand Up @@ -572,6 +555,41 @@ def _call(
text_callback(text_chunk)
return text

async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if stop is None:
stop = self.stop_sequences.copy()
else:
stop += self.stop_sequences.copy()
stop_tmp = stop.copy()
stop = []
[stop.append(x) for x in stop_tmp if x not in stop]

# HF inference server needs control over input tokens
assert self.tokenizer is not None
from h2oai_pipeline import H2OTextGenerationPipeline
prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)

# NOTE: TGI server does not add prompting, so must do here
data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
prompt = self.prompter.generate_prompt(data_point)

gen_text = await super()._acall(prompt, stop=stop, run_manager=run_manager, **kwargs)

# remove stop sequences from the end of the generated text
for stop_seq in stop:
if stop_seq in gen_text:
gen_text = gen_text[:gen_text.index(stop_seq)]
text = prompt + gen_text
text = self.prompter.get_response(text, prompt=prompt,
sanitize_bot_response=self.sanitize_bot_response)
return text


from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI
Expand Down Expand Up @@ -790,12 +808,13 @@ def get_llm(use_openai_model=False,
sanitize_bot_response=sanitize_bot_response,
)
elif hf_client:
# no need to pass original client, no state and fast, so can use same validate_environment from base class
llm = H2OHuggingFaceTextGenInference(
inference_server_url=inference_server,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=True,
return_full_text=False,
seed=SEED,

stop_sequences=prompter.stop_sequences,
Expand All @@ -809,7 +828,6 @@ def get_llm(use_openai_model=False,
context=context,
iinput=iinput,
tokenizer=tokenizer,
client=hf_client,
timeout=max_time,
sanitize_bot_response=sanitize_bot_response,
)
Expand Down Expand Up @@ -1969,6 +1987,7 @@ def _run_qa_db(query=None,
langchain_only_model=False,
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
stream_output=False,
async_output=True,
prompter=None,
prompt_type=None,
prompt_dict=None,
Expand Down Expand Up @@ -2142,20 +2161,23 @@ def _run_qa_db(query=None,
# in case no exception and didn't join with thread yet, then join
if not thread.exc:
answer = thread.join()
answer = answer['output_text']
# in case raise StopIteration or broke queue loop in streamer, but still have exception
if thread.exc:
raise thread.exc
# FIXME: answer is not string outputs from streamer. How to get actual final output?
# answer = outputs
else:
answer = chain()
if not stream_output and async_output:
import asyncio
answer = asyncio.run(chain())
else:
answer = chain()

if not use_docs_planned:
ret = answer['output_text']
ret = answer
extra = ''
yield ret, extra
elif answer is not None:
ret, extra = get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=verbose)
ret, extra = get_sources_answer(query, docs, answer, scores, show_rank, answer_with_sources, verbose=verbose)
yield ret, extra
return

Expand Down Expand Up @@ -2213,6 +2235,8 @@ def get_chain(query=None,
tokenizer=None,
verbose=False,
reverse_docs=True,
stream_output=True,
async_output=True,

# local
auto_reduce_chunks=True,
Expand Down Expand Up @@ -2506,21 +2530,40 @@ def get_chain(query=None,
elif langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
LangChainAction.SUMMARIZE_REFINE,
LangChainAction.SUMMARIZE_ALL.value]:
if not stream_output and async_output:
return_intermediate_steps = False
else:
return_intermediate_steps = True
from langchain.chains.summarize import load_summarize_chain
if langchain_action == LangChainAction.SUMMARIZE_MAP.value:
prompt = PromptTemplate(input_variables=["text"], template=template)
chain = load_summarize_chain(llm, chain_type="map_reduce",
map_prompt=prompt, combine_prompt=prompt, return_intermediate_steps=True,
map_prompt=prompt, combine_prompt=prompt,
return_intermediate_steps=return_intermediate_steps,
token_max=max_input_tokens)
target = wrapped_partial(chain, {"input_documents": docs}) # , return_only_outputs=True)
if not stream_output and async_output:
chain_func = chain.arun
else:
chain_func = chain
target = wrapped_partial(chain_func, {"input_documents": docs}) # , return_only_outputs=True)
elif langchain_action == LangChainAction.SUMMARIZE_ALL.value:
assert use_template
prompt = PromptTemplate(input_variables=["text"], template=template)
chain = load_summarize_chain(llm, chain_type="stuff", prompt=prompt, return_intermediate_steps=True)
target = wrapped_partial(chain)
chain = load_summarize_chain(llm, chain_type="stuff", prompt=prompt,
return_intermediate_steps=return_intermediate_steps)
if not stream_output and async_output:
chain_func = chain.arun
else:
chain_func = chain
target = wrapped_partial(chain_func)
elif langchain_action == LangChainAction.SUMMARIZE_REFINE.value:
chain = load_summarize_chain(llm, chain_type="refine", return_intermediate_steps=True)
target = wrapped_partial(chain)
chain = load_summarize_chain(llm, chain_type="refine",
return_intermediate_steps=return_intermediate_steps)
if not stream_output and async_output:
chain_func = chain.arun
else:
chain_func = chain
target = wrapped_partial(chain_func)
else:
raise RuntimeError("No such langchain_action=%s" % langchain_action)
else:
Expand All @@ -2529,19 +2572,18 @@ def get_chain(query=None,
return docs, target, scores, use_docs_planned, have_any_docs


def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=False):
def get_sources_answer(query, docs, answer, scores, show_rank, answer_with_sources, verbose=False):
if verbose:
print("query: %s" % query, flush=True)
print("answer: %s" % answer['output_text'], flush=True)
print("answer: %s" % answer, flush=True)

if len(answer['input_documents']) == 0:
if len(docs) == 0:
extra = ''
ret = answer['output_text'] + extra
ret = answer + extra
return ret, extra

# link
answer_sources = [(max(0.0, 1.5 - score) / 1.5, get_url(doc)) for score, doc in
zip(scores, answer['input_documents'])]
answer_sources = [(max(0.0, 1.5 - score) / 1.5, get_url(doc)) for score, doc in zip(scores, docs)]
answer_sources_dict = defaultdict(list)
[answer_sources_dict[url].append(score) for score, url in answer_sources]
answers_dict = {}
Expand All @@ -2559,14 +2601,14 @@ def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, ve
sorted_sources_urls = f"{source_prefix}<p><ul>" + "<p>".join(answer_sources)
sorted_sources_urls += f"</ul></p>{source_postfix}"

if not answer['output_text'].endswith('\n'):
answer['output_text'] += '\n'
if not answer.endswith('\n'):
answer += '\n'

if answer_with_sources:
extra = '\n' + sorted_sources_urls
else:
extra = ''
ret = answer['output_text'] + extra
ret = answer + extra
return ret, extra


Expand Down

0 comments on commit cc3331d

Please sign in to comment.