Skip to content

Commit

Permalink
Fix queue.get() blocking agents and entire worker. Add anthropic cach…
Browse files Browse the repository at this point in the history
…ing and add enable_caching API arg
  • Loading branch information
pseudotensor committed Sep 27, 2024
1 parent 56f9644 commit fbdc336
Show file tree
Hide file tree
Showing 13 changed files with 110 additions and 19 deletions.
2 changes: 2 additions & 0 deletions gradio_utils/grclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,7 @@ def query_or_summarize_or_extract(
model: str | int | None = None,
model_lock: dict | None = None,
stream_output: bool = False,
enable_caching: bool = False,
do_sample: bool = False,
seed: int | None = 0,
temperature: float = 0.0,
Expand Down Expand Up @@ -905,6 +906,7 @@ def query_or_summarize_or_extract(
:param max_input_tokens: see src/gen.py
:param max_total_input_tokens: see src/gen.py
:param stream_output: Whether to stream output
:param enable_caching: Whether to enable caching
:param max_time: how long to take
:param add_search_to_context: Whether to do web search and add results to context
Expand Down
12 changes: 9 additions & 3 deletions openai_server/autogen_2agent_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,21 @@ def run_autogen_2agent(query=None,
api_key, model, text_context_list, image_file,
temp_dir, query)

enable_caching = True

code_writer_agent = H2OConversableAgent(
"code_writer_agent",
system_message=system_message,
llm_config={"config_list": [{"model": model,
llm_config={'timeout': autogen_timeout,
'extra_body': dict(enable_caching=enable_caching),
"config_list": [{"model": model,
"api_key": api_key,
"base_url": base_url,
"stream": stream_output,
"cache_seed": autogen_cache_seed,
'max_tokens': max_new_tokens}]},
'max_tokens': max_new_tokens,
'cache_seed': autogen_cache_seed,
}]
},
code_execution_config=False, # Turn off code execution for this agent.
human_input_mode="NEVER",
is_termination_msg=terminate_message_func,
Expand Down
12 changes: 6 additions & 6 deletions openai_server/autogen_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def run_autogen_in_proc(func, output_queue, result_queue, exception_queue, **kwa
result_queue.put(ret_dict)


async def iostream_generator(func, use_process=False, **kwargs) -> typing.Generator[str, None, None]:
async def iostream_generator(func, use_process=False, **kwargs) -> typing.AsyncGenerator[str, None]:
# start capture
custom_stream = CustomIOStream()
IOStream.set_global_default(custom_stream)
Expand Down Expand Up @@ -91,11 +91,11 @@ async def iostream_generator(func, use_process=False, **kwargs) -> typing.Genera
if not exception_queue.empty():
e = exception_queue.get()
raise e

output = output_queue.get()
if output is None: # End of agent execution
break
yield output
if not output_queue.empty():
output = output_queue.get()
if output is None: # End of agent execution
break
yield output
await asyncio.sleep(0.005)

agent_proc.join()
Expand Down
1 change: 1 addition & 0 deletions openai_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class ResponseFormat(BaseModel):
class H2oGPTParams(BaseModel):
# keep in sync with evaluate()
# handled by extra_body passed to OpenAI API
enable_caching: bool | None = None
prompt_type: str | None = None
prompt_dict: Dict | str | None = None
chat_template: str | None = None
Expand Down
2 changes: 1 addition & 1 deletion src/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def run_cli( # for local function:
try_pdf_as_html=None,
# for some evaluate args
load_awq='',
stream_output=None, async_output=None, num_async=None, stream_map=None,
stream_output=None, enable_caching=None, async_output=None, num_async=None, stream_map=None,
prompt_type=None, prompt_dict=None, chat_template=None, system_prompt=None,
temperature=None, top_p=None, top_k=None, penalty_alpha=None, num_beams=None,
max_new_tokens=None, min_new_tokens=None, early_stopping=None, max_time=None, repetition_penalty=None,
Expand Down
5 changes: 4 additions & 1 deletion src/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def get_client(serialize=not is_gradio_version4):
return client


def get_args(prompt, prompt_type=None, chat=False, stream_output=False,
def get_args(prompt, prompt_type=None, chat=False,
stream_output=False,
enable_caching=False,
max_new_tokens=50,
top_k_docs=3,
langchain_mode='Disabled',
Expand Down Expand Up @@ -110,6 +112,7 @@ def get_args(prompt, prompt_type=None, chat=False, stream_output=False,
# streaming output is supported, loops over and outputs each generation in streaming mode
# but leave stream_output=False for simple input/output mode
stream_output=stream_output,
enable_caching=enable_caching,
prompt_type=prompt_type,
prompt_dict=prompt_dict,
chat_template=chat_template,
Expand Down
6 changes: 5 additions & 1 deletion src/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,11 @@ class LangChainAgent(Enum):
"claude-3-haiku-20240307": 4096,
}

anthropic_prompt_caching = ["claude-3-opus-20240229",
"claude-3-5-sonnet-20240620",
"claude-3-haiku-20240307",
]

claude3imagetag = 'claude-3-image'
gpt4imagetag = 'gpt-4-image'
geminiimagetag = 'gemini-image'
Expand Down Expand Up @@ -798,7 +803,6 @@ def gr_to_lg(image_audio_loaders,
json_code_post_prompt_reminder0 = 'Ensure your response satisfies the schema mentioned above and place the response inside JSON code block. Do not just repeat the JSON schema, ensure your response uses that schema to respond by choosing particular values for each type.'
json_code2_post_prompt_reminder0 = 'Ensure your response is inside a JSON code block.'


image_batch_image_prompt0 = """<response_instructions>
- Act as a keen observer with a sharp eye for detail.
- Analyze the content within the images.
Expand Down
2 changes: 1 addition & 1 deletion src/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def run_eval( # for local function:
regenerate_clients=None, regenerate_gradio_clients=None, validate_clients=None, fail_if_invalid_client=None,
prompt_type=None, prompt_dict=None, chat_template=None, system_prompt=None,
debug=None, chat=False,
stream_output=None, async_output=None, num_async=None, stream_map=None,
stream_output=None, enable_caching=None, async_output=None, num_async=None, stream_map=None,
eval_filename=None, eval_prompts_only_num=None, eval_prompts_only_seed=None, eval_as_output=None,
examples=None, memory_restriction_level=None,
# evaluate kwargs
Expand Down
2 changes: 1 addition & 1 deletion src/evaluate_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
]

eval_func_param_names = (
["instruction", "iinput", "context", "stream_output", "prompt_type", "prompt_dict", "chat_template"]
["instruction", "iinput", "context", "stream_output", "enable_caching", "prompt_type", "prompt_dict", "chat_template"]
+ gen_hyper
+ [
"chat",
Expand Down
11 changes: 8 additions & 3 deletions src/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def main(
text_context_list: typing.List[str] = None,

stream_output: bool = True,
enable_caching: bool = False,
async_output: bool = True,
num_async: int = 3,
stream_map: bool = False,
Expand Down Expand Up @@ -844,6 +845,7 @@ def main(
Forces LangChain code path and uses as many entries in list as possible given max_seq_len, with first assumed to be most relevant and to go near prompt.
:param stream_output: whether to stream output
:param enable_caching: whether to enable caching (Only for anthropic)
:param async_output: Whether to do asyncio handling
For summarization
Applicable to HF TGI server
Expand Down Expand Up @@ -1950,7 +1952,7 @@ def main(
inference_server,
llamacpp_dict,
chat,
stream_output, show_examples,
stream_output, enable_caching, show_examples,
prompt_type, prompt_dict, chat_template,
system_prompt,
pre_prompt_query, prompt_query,
Expand Down Expand Up @@ -2460,6 +2462,7 @@ def evaluate(
iinput,
context,
stream_output,
enable_caching,
prompt_type,
prompt_dict,
chat_template,
Expand Down Expand Up @@ -3295,6 +3298,7 @@ def evaluate(
context=context,
stream_output0=stream_output0,
stream_output=stream_output,
enable_caching=enable_caching,
chunk=chunk,
chunk_size=chunk_size,

Expand Down Expand Up @@ -3805,6 +3809,7 @@ def evaluate(
# streaming output is supported, loops over and outputs each generation in streaming mode
# but leave stream_output=False for simple input/output mode
stream_output=stream_output,
enable_caching=enable_caching,

**gen_server_kwargs,

Expand Down Expand Up @@ -4424,7 +4429,7 @@ def get_generate_params(model_lower,
inference_server,
llamacpp_dict,
chat,
stream_output, show_examples,
stream_output, enable_caching, show_examples,
prompt_type, prompt_dict, chat_template,
system_prompt,
pre_prompt_query, prompt_query,
Expand Down Expand Up @@ -4611,7 +4616,7 @@ def get_generate_params(model_lower,
do_sample = False if do_sample is None else do_sample
# doesn't include chat, instruction_nochat, iinput_nochat, added later
params_list = ["",
stream_output,
stream_output, enable_caching,
prompt_type, prompt_dict, chat_template,
temperature, top_p, top_k, penalty_alpha, num_beams,
max_new_tokens, min_new_tokens,
Expand Down
71 changes: 70 additions & 1 deletion src/gpt_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import tabulate

from joblib import delayed
from langchain_anthropic.chat_models import _format_messages
from langchain_core.callbacks import streaming_stdout, AsyncCallbackManager, BaseCallbackHandler, BaseCallbackManager
from langchain.callbacks.base import Callbacks
from langchain_community.document_transformers import Html2TextTransformer, BeautifulSoupTransformer
Expand All @@ -45,6 +46,7 @@
from langchain.schema import LLMResult, Generation, PromptValue
from langchain.schema.output import GenerationChunk
from langchain_core.globals import get_llm_cache
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.llms import aget_prompts, aupdate_cache
from langchain_core.load import dumpd
from langchain_core.messages import BaseMessage
Expand Down Expand Up @@ -87,7 +89,7 @@
noop_prompt_type, unknown_prompt_type, template_prompt_type, none, claude3_image_tokens, gemini_image_tokens, \
gpt4_image_tokens, user_prompt_for_fake_system_prompt0, empty_prompt_type, \
is_gradio_vision_model, is_json_model, anthropic_mapping, gemini15image_num_max, gemini15imagetag, \
openai_supports_functiontools, openai_supports_parallel_functiontools
openai_supports_functiontools, openai_supports_parallel_functiontools, anthropic_prompt_caching
from evaluate_params import gen_hyper, gen_hyper0
from gen import SEED, get_limited_prompt, get_relaxed_max_new_tokens, get_model_retry, gradio_to_llm, \
get_client_from_inference_server
Expand Down Expand Up @@ -930,6 +932,7 @@ class GradioInference(AGenerateStreamFirst, H2Oagenerate, LLM):

return_full_text: bool = False
stream_output: bool = False
enable_caching: bool = False
sanitize_bot_response: bool = False

prompter: Any = None
Expand Down Expand Up @@ -1017,6 +1020,7 @@ def setup_call(self, prompt):
# This is good, so gradio server can also handle stopping.py conditions
# this is different than TGI server that uses prompter to inject prompt_type prompting
stream_output = self.stream_output
enable_caching = self.enable_caching
# don't double-up langchain behavior, already did langchain part
client_langchain_mode = LangChainMode.LLM.value
client_add_chat_history_to_context = self.add_chat_history_to_context
Expand Down Expand Up @@ -1047,6 +1051,7 @@ def setup_call(self, prompt):
# streaming output is supported, loops over and outputs each generation in streaming mode
# but leave stream_output=False for simple input/output mode
stream_output=stream_output,
enable_caching=enable_caching,
prompt_type=prompt_type,
prompt_dict='',
chat_template=self.chat_template,
Expand Down Expand Up @@ -2734,6 +2739,8 @@ class H2OChatAnthropic2(ChatAGenerateStreamFirst, GenerateNormal, ExtraChat, Cha
count_output_tokens: Any = 0
tokenizer: Any = None
prompter: Any = None
supports_caching: bool = False
enable_caching: bool = False

# max_new_tokens0: Any = None # FIXME: Doesn't seem to have same max_tokens == -1 for prompts==1

Expand All @@ -2752,9 +2759,61 @@ class H2OChatAnthropic3(ChatAGenerateStreamFirst, GenerateStream, ExtraChat, Cha
count_output_tokens: Any = 0
tokenizer: Any = None
prompter: Any = None
supports_caching: bool = False
enable_caching: bool = False

# max_new_tokens0: Any = None # FIXME: Doesn't seem to have same max_tokens == -1 for prompts==1

def _get_request_payload(
self,
input_: LanguageModelInput,
*,
stop: Optional[List[str]] = None,
**kwargs: Dict,
) -> Dict:
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
if hasattr(self, 'supports_caching') and self.supports_caching and \
hasattr(self, 'enable_caching') and self.enable_caching:
messages = payload['messages']
system = payload.get('system', '')

# fix system
system_cached = {
"type": "text",
"text": system,
"cache_control": {"type": "ephemeral"}
}

# fix messages
# Prepare the messages list
messages_cached = []

# Process user and assistant messages
user_message_count = sum(1 for msg in messages if msg["role"] == "user")
for i, message in enumerate(messages):
if message["role"] == "user":
content = {
"type": "text",
"text": message["content"]
}

# Add cache control to the last two user messages
if user_message_count - i <= 2:
content["cache_control"] = {"type": "ephemeral"}

messages_cached.append({
"role": "user",
"content": [content]
})
else:
messages_cached.append(message)

# put messages and system back in
payload['messages'] = messages_cached
payload['system'] = system_cached
time.sleep(1000000)
return payload


class H2OChatAnthropic3Sys(H2OChatAnthropic3):
pass
Expand Down Expand Up @@ -2896,6 +2955,7 @@ def get_llm(use_openai_model=False,
langchain_only_model=None,
load_awq='',
stream_output=False,
enable_caching=False,
async_output=True,
num_async=3,
do_sample=False,
Expand Down Expand Up @@ -3317,6 +3377,10 @@ def get_llm(use_openai_model=False,
# FIXME: _AnthropicCommon ignores these and makes no client anyways
kwargs_extra.update(dict(client=model['client'], async_client=model['async_client']))

supports_caching = model_name in anthropic_prompt_caching
if supports_caching:
kwargs_extra.update(extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"})

callbacks = [streaming_callback]
llm = cls(model=model_name,
anthropic_api_key=os.getenv('ANTHROPIC_API_KEY'),
Expand All @@ -3332,6 +3396,8 @@ def get_llm(use_openai_model=False,
tokenizer=tokenizer,
prompter=prompter,
verbose=verbose,
supports_caching=supports_caching,
enable_caching=enable_caching,
**kwargs_extra
)
streamer = callbacks[0] if stream_output else None
Expand Down Expand Up @@ -3705,6 +3771,7 @@ def get_llm(use_openai_model=False,

callbacks=callbacks if stream_output else None,
stream_output=stream_output,
enable_caching=enable_caching,

prompter=prompter,
context=context,
Expand Down Expand Up @@ -6752,6 +6819,7 @@ def _run_qa_db(query=None,
migrate_embedding_model=False,
stream_output0=False,
stream_output=False,
enable_caching=False,
async_output=True,
num_async=3,
prompter=None,
Expand Down Expand Up @@ -7023,6 +7091,7 @@ def _run_qa_db(query=None,
langchain_only_model=langchain_only_model,
load_awq=load_awq,
stream_output=stream_output,
enable_caching=enable_caching,
async_output=async_output,
num_async=num_async,
do_sample=do_sample,
Expand Down
1 change: 1 addition & 0 deletions src/gradio_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1776,6 +1776,7 @@ def show_llava(x):
label="guided_whitespace_pattern, empty string means None",
info="https://github.com/vllm-project/vllm/pull/4305/files",
visible=not is_public)
enable_caching = gr.Checkbox(value=kwargs['enable_caching'], visible=False)
images_num_max = gr.Number(
label='Number of Images per LLM call, -1 is auto mode, 0 is avoid using images',
value=kwargs['images_num_max'] if kwargs['images_num_max'] is not None else -1,
Expand Down
2 changes: 1 addition & 1 deletion src/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "5ea2b0ea086f988e176e433cfc4567ceb330e755"
__version__ = "56f96444402e9ce3ed6bdd78f28bb85cab582b7b"

0 comments on commit fbdc336

Please sign in to comment.