From b590e9f4d2103400da5d87f2ec568ce6f93ae012 Mon Sep 17 00:00:00 2001 From: Andrew Zhu Date: Wed, 17 Apr 2024 14:33:42 -0400 Subject: [PATCH 1/9] feat: WrapperEngine --- docs/community/extensions.rst | 15 +++++++----- docs/engine_reference.rst | 10 +++++--- kani/engines/__init__.py | 1 + kani/engines/base.py | 45 +++++++++++++++++++++++++++++++++++ tests/test_streaming.py | 14 +++++++---- 5 files changed, 71 insertions(+), 14 deletions(-) diff --git a/docs/community/extensions.rst b/docs/community/extensions.rst index 9951ae6..bc1d913 100644 --- a/docs/community/extensions.rst +++ b/docs/community/extensions.rst @@ -38,19 +38,22 @@ your engine *wrap* another engine: """An example showing how to wrap another kani engine.""" - class MyEngineWrapper(BaseEngine): - def __init__(self, inner_engine: BaseEngine): - self.inner_engine = inner_engine - self.max_context_size = inner_engine.max_context_size + from kani.engines import BaseEngine, WrapperEngine + # subclassing WrapperEngine automatically implements passthrough of untouched attributes + # to the wrapped engine! + class MyEngineWrapper(WrapperEngine): def message_len(self, message): # wrap the inner message with the prompt framework... prompted_message = ChatMessage(...) - return self.inner_engine.message_len(prompted_message) + return super().message_len(prompted_message) async def predict(self, messages, functions=None, **hyperparams): # wrap the messages with the prompt framework and pass it to the inner engine - prompted_completion = await self.inner_engine.predict(prompted_messages, ...) + prompted_completion = await super().predict(prompted_messages, ...) # unwrap the resulting message (if necessary) and store the metadata separately completion = self.unwrap(prompted_completion) return completion + +The :class:`kani.engines.WrapperEngine` is a base class that automatically creates a constructor that takes in the +engine to wrap, and passes through any non-overriden attributes to the wrapped engine. diff --git a/docs/engine_reference.rst b/docs/engine_reference.rst index aed303a..5f436e9 100644 --- a/docs/engine_reference.rst +++ b/docs/engine_reference.rst @@ -5,13 +5,17 @@ Engine Reference Base ---- -.. autoclass:: kani.engines.base.BaseEngine +.. autoclass:: kani.engines.BaseEngine :members: -.. autoclass:: kani.engines.base.BaseCompletion +.. autoclass:: kani.engines.Completion :members: -.. autoclass:: kani.engines.base.Completion +.. autoclass:: kani.engines.WrapperEngine + + .. autoattribute:: engine + +.. autoclass:: kani.engines.base.BaseCompletion :members: .. autoclass:: kani.engines.httpclient.BaseClient diff --git a/kani/engines/__init__.py b/kani/engines/__init__.py index e69de29..1fa9801 100644 --- a/kani/engines/__init__.py +++ b/kani/engines/__init__.py @@ -0,0 +1 @@ +from .base import BaseEngine, Completion, WrapperEngine diff --git a/kani/engines/base.py b/kani/engines/base.py index 71b1371..a72af55 100644 --- a/kani/engines/base.py +++ b/kani/engines/base.py @@ -129,3 +129,48 @@ async def stream( async def close(self): """Optional: Clean up any resources the engine might need.""" pass + + +class WrapperEngine(BaseEngine): + """ + A base class for engines that are meant to wrap other engines. By default, this class takes in another engine + as the first parameter in its constructor and will pass through all non-overriden attributes to the wrapped + engine. + """ + + def __init__(self, engine: BaseEngine, *args, **kwargs): + """ + :param engine: The engine to wrap. + """ + super().__init__(*args, **kwargs) + self.engine = engine + """The wrapped engine.""" + + # passthrough attrs + self.max_context_size = engine.max_context_size + self.token_reserve = engine.token_reserve + + # passthrough methods + def message_len(self, message: ChatMessage) -> int: + return self.engine.message_len(message) + + async def predict( + self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams + ) -> BaseCompletion: + return await self.engine.predict(messages, functions, **hyperparams) + + async def stream( + self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams + ) -> AsyncIterable[str | BaseCompletion]: + async for elem in self.engine.stream(messages, functions, **hyperparams): + yield elem + + def function_token_reserve(self, functions: list[AIFunction]) -> int: + return self.engine.function_token_reserve(functions) + + async def close(self): + return await self.engine.close() + + # all other attributes are caught by this default passthrough handler + def __getattr__(self, item): + return getattr(self.engine, item) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 4bd61e8..e85c127 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -2,14 +2,17 @@ from hypothesis import HealthCheck, given, settings, strategies as st from kani import ChatMessage, ChatRole, Kani +from kani.engines.base import WrapperEngine from tests.engine import TestEngine, TestStreamingEngine from tests.utils import flatten_chatmessages engine = TestEngine() streaming_engine = TestStreamingEngine() +wrapped_engine = WrapperEngine(engine) +wrapped_streaming_engine = WrapperEngine(streaming_engine) -@pytest.mark.parametrize("eng", [engine, streaming_engine]) +@pytest.mark.parametrize("eng", [engine, streaming_engine, wrapped_engine, wrapped_streaming_engine]) async def test_chat_round_stream_consume_all(eng): ai = Kani(eng, desired_response_tokens=3) # 5 tokens, no omitting @@ -22,15 +25,16 @@ async def test_chat_round_stream_consume_all(eng): assert flatten_chatmessages(prompt) == "12345a" -async def test_chat_round_stream(): - ai = Kani(streaming_engine, desired_response_tokens=3) +@pytest.mark.parametrize("eng", [streaming_engine, wrapped_streaming_engine]) +async def test_chat_round_stream(eng): + ai = Kani(eng, desired_response_tokens=3) stream = ai.chat_round_stream("12345") async for token in stream: assert token == "a" resp = await stream.message() assert resp.content == "a" - ai = Kani(streaming_engine, desired_response_tokens=3) + ai = Kani(eng, desired_response_tokens=3) stream = ai.chat_round_stream("aaa", test_echo=True) async for token in stream: assert token == "a" @@ -40,7 +44,7 @@ async def test_chat_round_stream(): @settings(suppress_health_check=(HealthCheck.too_slow,), deadline=None) @given(st.data()) -@pytest.mark.parametrize("eng", [engine, streaming_engine]) +@pytest.mark.parametrize("eng", [engine, streaming_engine, wrapped_engine, wrapped_streaming_engine]) async def test_spam_stream(eng, data): # spam the kani with a bunch of random prompts # and make sure it never breaks From 3b6e67010d6ee4016c56e3a9d5a1f0f80c6bc4b2 Mon Sep 17 00:00:00 2001 From: Andrew Zhu Date: Wed, 17 Apr 2024 14:43:51 -0400 Subject: [PATCH 2/9] docs: add kani-ratelimits extension --- docs/community/extensions.rst | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/docs/community/extensions.rst b/docs/community/extensions.rst index bc1d913..b9b0239 100644 --- a/docs/community/extensions.rst +++ b/docs/community/extensions.rst @@ -13,13 +13,17 @@ make your package available on pip. Community Extensions -------------------- -If you've made a cool extension, add it to this table with a PR! +If you've made a cool extension, add it to this list with a PR! -+-------------+----------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------+ -| Name | Description | Links | -+=============+============================================================================+==============================================================================================================+ -| kani-vision | Adds support for multimodal vision-language models, like GPT-4V and LLaVA. | `GitHub `_ `Docs `_ | -+-------------+----------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------+ +* **kani-ratelimits**: Adds a wrapper engine to enforce request-per-minute (RPM), token-per-minute (TPM), and/or + max-concurrency ratelimits before making requests to an underlying engine. + + * `GitHub (kani-ratelimits) `_ + +* **kani-vision**: Adds support for multimodal vision-language models, like GPT-4V and LLaVA. + + * `GitHub (kani-vision) `_ + * `Docs (kani-vision) `_ Design Considerations --------------------- From f563677b548868cfe9db04282534a91955d78ea7 Mon Sep 17 00:00:00 2001 From: Andrew Zhu Date: Thu, 18 Apr 2024 13:41:24 -0400 Subject: [PATCH 3/9] feat: add llama 3 --- README.md | 1 + examples/4_engines_zoo.py | 13 ++++++++++++ kani/prompts/impl/__init__.py | 1 + kani/prompts/impl/llama3.py | 38 +++++++++++++++++++++++++++++++++++ 4 files changed, 53 insertions(+) create mode 100644 kani/prompts/impl/llama3.py diff --git a/README.md b/README.md index a520e5a..d66d5ea 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,7 @@ kani supports every chat model available on Hugging Face through `transformers` In particular, we have reference implementations for the following base models, and their fine-tunes: +- [LLaMA 3](https://huggingface.co/collections/meta-llama/meta-llama-3-66214712577ca38149ebb2b6) (all sizes) - [Command R](https://huggingface.co/CohereForAI/c4ai-command-r-v01) and [Command R+](https://huggingface.co/CohereForAI/c4ai-command-r-plus) - [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) diff --git a/examples/4_engines_zoo.py b/examples/4_engines_zoo.py index 4407115..14a7ff2 100644 --- a/examples/4_engines_zoo.py +++ b/examples/4_engines_zoo.py @@ -18,6 +18,19 @@ engine = AnthropicEngine(api_key=os.getenv("ANTHROPIC_API_KEY"), model="claude-3-opus-20240229") # ========== Hugging Face ========== +# ---- LLaMA v3 (Hugging Face) ---- +import torch +from kani.engines.huggingface import HuggingEngine +from kani.prompts.impl import LLAMA3_PIPELINE +engine = HuggingEngine( + model_id="meta-llama/Meta-Llama-3-8B-Instruct", + prompt_pipeline=LLAMA3_PIPELINE, + use_auth_token=True, # log in with huggingface-cli + # suggested args from the Llama model card + model_load_kwargs={"device_map": "auto", "torch_dtype": torch.bfloat16}, + eos_token_id=[128001, 128009], # [<|end_of_text|>, <|eot_id|>] +) + # ---- LLaMA v2 (Hugging Face) ---- from kani.engines.huggingface.llama2 import LlamaEngine engine = LlamaEngine(model_id="meta-llama/Llama-2-7b-chat-hf", use_auth_token=True) # log in with huggingface-cli diff --git a/kani/prompts/impl/__init__.py b/kani/prompts/impl/__init__.py index 45a3ee5..08beb40 100644 --- a/kani/prompts/impl/__init__.py +++ b/kani/prompts/impl/__init__.py @@ -2,6 +2,7 @@ from .gemma import GEMMA_PIPELINE from .llama2 import LLAMA2_PIPELINE +from .llama3 import LLAMA3_PIPELINE from .vicuna import VICUNA_PIPELINE MISTRAL_PIPELINE = LLAMA2_PIPELINE diff --git a/kani/prompts/impl/llama3.py b/kani/prompts/impl/llama3.py new file mode 100644 index 0000000..6618b37 --- /dev/null +++ b/kani/prompts/impl/llama3.py @@ -0,0 +1,38 @@ +"""Common builder for the LLaMAv3-chat prompt.""" + +from kani.models import ChatRole +from kani.prompts.pipeline import PromptPipeline + +LLAMA3_PIPELINE = ( + PromptPipeline() + .translate_role( + role=ChatRole.FUNCTION, + to=ChatRole.USER, + warn=( + "The Llama 3 prompt format does not natively support the FUNCTION role. These messages will be" + " sent to the model as USER messages." + ), + ) + .conversation_fmt( + prefix="<|begin_of_text|>", + generation_suffix="<|start_header_id|>assistant<|end_header_id|>\n\n", + user_prefix="<|start_header_id|>user<|end_header_id|>\n\n", + user_suffix="<|eot_id|>", + assistant_prefix="<|start_header_id|>assistant<|end_header_id|>\n\n", + assistant_suffix="<|eot_id|>", + assistant_suffix_if_last="", + system_prefix="<|start_header_id|>system<|end_header_id|>\n\n", + system_suffix="<|eot_id|>", + ) +) # fmt: skip + +# from https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json +# {% set loop_messages = messages %} +# {% for message in loop_messages %} +# {% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %} +# {% if loop.index0 == 0 %} +# {% set content = bos_token + content %} +# {% endif %} +# {{ content }} +# {% endfor %} +# {{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }} From 6dc4d1997baf3383126b3e72defa76529f7d2d30 Mon Sep 17 00:00:00 2001 From: Andrew Zhu Date: Thu, 18 Apr 2024 14:52:37 -0400 Subject: [PATCH 4/9] docs: llama 3 engine table --- docs/shared/engine_table.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/shared/engine_table.rst b/docs/shared/engine_table.rst index 1d5ee93..18043b5 100644 --- a/docs/shared/engine_table.rst +++ b/docs/shared/engine_table.rst @@ -7,6 +7,8 @@ +----------------------------------------+------------------------------------+------------------------------+----------------------------------------------------------------------+ | |:hugging:| transformers\ [#runtime]_ | ``huggingface``\ [#torch]_ | (runtime) | :class:`kani.engines.huggingface.HuggingEngine` | +----------------------------------------+------------------------------------+------------------------------+----------------------------------------------------------------------+ +| |:hugging:| |:llama:| LLaMA 3 | ``huggingface, llama``\ [#torch]_ | |oss| |cpu| |gpu| | :class:`kani.engines.huggingface.HuggingEngine`\ [#zoo]_ | ++----------------------------------------+------------------------------------+------------------------------+----------------------------------------------------------------------+ | |:hugging:| Command R, Command R+ | ``huggingface``\ [#torch]_ | |function| |oss| |cpu| |gpu| | :class:`kani.engines.huggingface.cohere.CommandREngine` | +----------------------------------------+------------------------------------+------------------------------+----------------------------------------------------------------------+ | |:hugging:| |:llama:| LLaMA v2 | ``huggingface, llama``\ [#torch]_ | |oss| |cpu| |gpu| | :class:`kani.engines.huggingface.llama2.LlamaEngine` | @@ -36,6 +38,8 @@ models! .. |gpu| replace:: :abbr:`🚀 (runs on local gpu)` .. |api| replace:: :abbr:`📡 (hosted API)` +.. [#zoo] See the `model zoo `_ for a code sample + to initialize this model with the given engine. .. [#torch] You will also need to install `PyTorch `_ manually. .. [#abstract] This is an abstract class of models; kani includes a couple concrete implementations for reference. From be24343996b3683888b8ae4c11d6c5e2c9b3d8ca Mon Sep 17 00:00:00 2001 From: Andrew Zhu Date: Thu, 18 Apr 2024 15:10:38 -0400 Subject: [PATCH 5/9] docs: misc ref fix --- kani/engines/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kani/engines/base.py b/kani/engines/base.py index a72af55..61842d0 100644 --- a/kani/engines/base.py +++ b/kani/engines/base.py @@ -103,7 +103,7 @@ async def stream( """ Optional: Stream a completion from the engine, token-by-token. - This method's signature is the same as :meth:`predict`. + This method's signature is the same as :meth:`.BaseEngine.predict`. This method should yield strings as an asynchronous iterable. From 05fe1acbd66c8c3473611687b64265ed06f807bb Mon Sep 17 00:00:00 2001 From: Andrew Zhu Date: Thu, 18 Apr 2024 16:55:39 -0400 Subject: [PATCH 6/9] refactor(cohere): simplify command-r prompt build --- kani/engines/huggingface/cohere.py | 46 +++++++------ kani/prompts/impl/cohere.py | 105 +++++++++-------------------- 2 files changed, 57 insertions(+), 94 deletions(-) diff --git a/kani/engines/huggingface/cohere.py b/kani/engines/huggingface/cohere.py index 8e2e28e..7b36ad2 100644 --- a/kani/engines/huggingface/cohere.py +++ b/kani/engines/huggingface/cohere.py @@ -9,13 +9,11 @@ from kani.exceptions import MissingModelDependencies from kani.models import ChatMessage, ChatRole, FunctionCall, ToolCall from kani.prompts.impl.cohere import ( - COMMAND_R_PIPELINE, DEFAULT_PREAMBLE, DEFAULT_TASK, DEFAULT_TOOL_INSTRUCTIONS, DEFAULT_TOOL_PROMPT, - build_rag_pipeline, - build_tool_pipeline, + build_pipeline, function_prompt, tool_call_formatter, ) @@ -87,6 +85,7 @@ def __init__( tool_prompt_include_function_calls=True, tool_prompt_include_function_results=True, tool_prompt_instructions=DEFAULT_TOOL_INSTRUCTIONS, + rag_prompt_include_function_calls=True, rag_prompt_include_function_results=True, rag_prompt_instructions=None, **kwargs, @@ -98,13 +97,15 @@ def __init__( :param tokenizer_kwargs: Additional arguments to pass to ``AutoTokenizer.from_pretrained()``. :param model_load_kwargs: Additional arguments to pass to ``AutoModelForCausalLM.from_pretrained()``. :param tool_prompt_include_function_calls: Whether to include previous turns' function calls or just the model's - answers. + answers when it is the model's generation turn and the last message is not FUNCTION. :param tool_prompt_include_function_results: Whether to include the results of previous turns' function calls in - the context. + the context when it is the model's generation turn and the last message is not FUNCTION. :param tool_prompt_instructions: The system prompt to send just before the model's generation turn that includes instructions on the format to generate tool calls in. Generally you shouldn't change this. + :param rag_prompt_include_function_calls: Whether to include previous turns' function calls or just the model's + answers when it is the model's generation turn and the last message is FUNCTION. :param rag_prompt_include_function_results: Whether to include the results of previous turns' function calls in - the context. + the context when it is hte model's generation turn and the last message is FUNCTION. :param rag_prompt_instructions: The system prompt to send just before the model's generation turn that includes instructions on the format to generate the result in. Can be None to only generate a model turn. Defaults to ``None`` to for maximum interoperability between models. Options: @@ -120,14 +121,18 @@ def __init__( self._tool_prompt_include_function_calls = tool_prompt_include_function_calls - self._tool_pipeline = build_tool_pipeline( + self._default_pipeline = build_pipeline() + self._tool_pipeline = build_pipeline( include_function_calls=tool_prompt_include_function_calls, - include_function_results=tool_prompt_include_function_results, - tool_instructions=tool_prompt_instructions, + include_all_function_results=tool_prompt_include_function_results, + include_last_function_result=tool_prompt_include_function_results, + instruction_suffix=tool_prompt_instructions, ) - self._rag_pipeline = build_rag_pipeline( - include_previous_results=rag_prompt_include_function_results, - rag_instructions=rag_prompt_instructions, + self._rag_pipeline = build_pipeline( + include_function_calls=rag_prompt_include_function_calls, + include_all_function_results=rag_prompt_include_function_results, + include_last_function_result=True, + instruction_suffix=rag_prompt_instructions, ) # ==== token counting ==== @@ -163,7 +168,7 @@ def build_prompt( ) -> str | torch.Tensor: # no functions: we can just do the default simple format if not functions: - prompt = COMMAND_R_PIPELINE(messages) + prompt = self._default_pipeline(messages) log.debug(f"PROMPT: {prompt}") return prompt @@ -297,21 +302,22 @@ async def stream( completion = self._generate(input_toks, input_len, hyperparams, functions) tool_calls = completion.message.tool_calls or [] - # if the model generated multiple calls that happen to include a directly_answer, remove the directly_answer - if len(tool_calls) > 1: - completion.message.tool_calls = [ - tc for tc in completion.message.tool_calls if tc.function.name != "directly_answer" - ] + # if tool says directly answer, stream with the rag pipeline (but no result) - elif len(tool_calls) == 1 and tool_calls[0].function.name == "directly_answer": + if len(tool_calls) == 1 and tool_calls[0].function.name == "directly_answer": log.debug("GOT DIRECTLY_ANSWER, REPROMPTING RAG...") prompt = self._build_prompt_rag(messages) log.debug(f"RAG PROMPT: {prompt}") input_toks, input_len, hyperparams = self._get_generate_args(prompt, **hyperparams) async for elem in self._stream(input_toks, hyperparams, streamer_timeout=streamer_timeout): yield elem - # otherwise yield as normal + # if the model generated multiple calls that happen to include a directly_answer, remove the directly_answer + # then yield as normal else: + if completion.message.tool_calls: + completion.message.tool_calls = [ + tc for tc in completion.message.tool_calls if tc.function.name != "directly_answer" + ] if completion.message.text: yield completion.message.text yield completion diff --git a/kani/prompts/impl/cohere.py b/kani/prompts/impl/cohere.py index 9c0acdc..189c0c0 100644 --- a/kani/prompts/impl/cohere.py +++ b/kani/prompts/impl/cohere.py @@ -58,29 +58,11 @@ def directly_answer() -> List[Dict]: Firstly, Decide which of the retrieved documents are relevant to the user's last input by writing 'Relevant Documents:' followed by comma-separated list of document numbers. If none are relevant, you should instead write 'None'. Secondly, Decide which of the retrieved documents contain facts that should be cited in a good answer to the user's last input by writing 'Cited Documents:' followed a comma-separated list of document numbers. If you dont want to cite any of them, you should instead write 'None'. Finally, Write 'Grounded answer:' followed by a response to the user's last input in high quality natural english. Use the symbols and to indicate when a fact comes from a document in the search result, e.g my fact for a fact from document 0.""" -# fmt: on -# ==== no tool calling ==== -COMMAND_R_PIPELINE = ( - PromptPipeline() - .conversation_fmt( - prefix="", - generation_suffix="<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", - user_prefix="<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", - user_suffix="<|END_OF_TURN_TOKEN|>", - assistant_prefix="<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", - assistant_suffix="<|END_OF_TURN_TOKEN|>", - assistant_suffix_if_last="", - system_prefix="<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>", - system_suffix="<|END_OF_TURN_TOKEN|>", - function_prefix="<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>\n\n", - function_suffix="\n<|END_OF_TURN_TOKEN|>", - ) -) # fmt: skip -"""The pipeline to use when interfacing with Command R without tools defined.""" + +# fmt: on -# ==== tool calling (function not last) ==== def function_result_joiner(msgs): contents = [] for idx, msg in enumerate(msgs): @@ -104,17 +86,23 @@ def tool_call_formatter(msg: ChatMessage) -> str: return msg.content -def build_tool_pipeline( - *, include_function_calls=True, include_function_results=True, tool_instructions=DEFAULT_TOOL_INSTRUCTIONS +def build_pipeline( + *, + include_function_calls=True, + include_all_function_results=True, + include_last_function_result=True, + instruction_suffix=None, ): """ - The pipeline to use when interfacing with Command R WITH tools defined. Use this pipeline if the last message is - NOT a FUNCTION message. - :param include_function_calls: Whether to include previous turns' function calls or just the model's answers. - :param include_function_results: Whether to include the results of previous turns' function calls in the context. - :param tool_instructions: The system prompt to send just before the model's generation turn that includes - instructions on the format to generate tool calls in. Generally you shouldn't change this. + :param include_all_function_results: Whether to include the results of all previous turns' function calls in the + context. + :param include_last_function_result: If *include_all_function_results* is False, whether to include just the last + function call's result (useful for RAG). + :param instruction_suffix: The system prompt to send just before the model's generation turn that includes + instructions on the format to generate the result in. Can be None to only generate a model turn. + For tool calling, this should be the DEFAULT_TOOL_INSTRUCTIONS. + For RAG, this should be DEFAULT_RAG_INSTRUCTIONS_ACC or DEFAULT_RAG_INSTRUCTIONS_FAST. """ steps = [] @@ -130,18 +118,30 @@ def apply_tc_format(msg): else: steps.append(Remove(role=ChatRole.ASSISTANT, predicate=lambda msg: msg.content is None)) - # keep function results around as SYSTEM messages - if include_function_results: + # keep/drop function results + if include_all_function_results: + # keep all function results around as SYSTEM messages + steps.append(MergeConsecutive(role=ChatRole.FUNCTION, joiner=function_result_joiner)) + elif include_last_function_result: + # merge consecutive FUNCTION messages then remove all but the last (if it's the last message) + + def remover(m, is_last): + return None if not is_last else m + steps.append(MergeConsecutive(role=ChatRole.FUNCTION, joiner=function_result_joiner)) + steps.append(Apply(remover, role=ChatRole.FUNCTION)) else: + # remove all FUNCTION messages steps.append(Remove(role=ChatRole.FUNCTION)) steps.append( ConversationFmt( prefix="", generation_suffix=( - f"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{tool_instructions}<|END_OF_TURN_TOKEN|>" + f"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{instruction_suffix}<|END_OF_TURN_TOKEN|>" "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" + if instruction_suffix + else "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" ), user_prefix="<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", user_suffix="<|END_OF_TURN_TOKEN|>", @@ -158,49 +158,6 @@ def apply_tc_format(msg): return PromptPipeline(steps) -# ==== tool calling (function last) ==== -def build_rag_pipeline(*, include_previous_results=True, rag_instructions=DEFAULT_RAG_INSTRUCTIONS_ACC): - """ - The pipeline to use when interfacing with Command R WITH tools defined. Use this pipeline if the last message IS a - FUNCTION message. - - :param include_previous_results: Include previous turns' results in the chat history. - :param rag_instructions: The system prompt to send just before the model's generation turn that includes - instructions on the format to generate the result in. Can be None to only generate a model turn. Defaults - to the "accurate" grounded RAG prompt (``from kani.prompts.impl.cohere import DEFAULT_RAG_INSTRUCTIONS_ACC``). - """ - - def remover(m, is_last): - return None if is_last and not include_previous_results else m - - return ( - PromptPipeline() - .merge_consecutive(role=ChatRole.FUNCTION, joiner=function_result_joiner) - # remove all but the last function message - .apply(remover, role=ChatRole.FUNCTION) - # remove asst messages with no content (function calls) - .remove(role=ChatRole.ASSISTANT, predicate=lambda msg: msg.content is None) - .conversation_fmt( - prefix="", - generation_suffix=( - f"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{rag_instructions}<|END_OF_TURN_TOKEN|>" - "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" - if rag_instructions - else "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" - ), - user_prefix="<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", - user_suffix="<|END_OF_TURN_TOKEN|>", - assistant_prefix="<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", - assistant_suffix="<|END_OF_TURN_TOKEN|>", - assistant_suffix_if_last="", - system_prefix="<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>", - system_suffix="<|END_OF_TURN_TOKEN|>", - function_prefix="<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>\n\n", - function_suffix="\n<|END_OF_TURN_TOKEN|>", - ) - ) - - # ==== helpers ==== def function_prompt(f: AIFunction) -> str: params = f.get_params() From b10114ee9a95799410f574ed119ace8a7490d5da Mon Sep 17 00:00:00 2001 From: Andrew Zhu Date: Fri, 19 Apr 2024 12:18:25 -0400 Subject: [PATCH 7/9] refactor: cohere command-r mixin --- kani/engines/huggingface/cohere.py | 133 +++-------------------------- kani/prompts/impl/cohere.py | 129 +++++++++++++++++++++++++++- 2 files changed, 141 insertions(+), 121 deletions(-) diff --git a/kani/engines/huggingface/cohere.py b/kani/engines/huggingface/cohere.py index 7b36ad2..93c519d 100644 --- a/kani/engines/huggingface/cohere.py +++ b/kani/engines/huggingface/cohere.py @@ -1,19 +1,13 @@ import functools -import json import logging -import re from collections.abc import AsyncIterable from threading import Thread from kani.ai_function import AIFunction from kani.exceptions import MissingModelDependencies -from kani.models import ChatMessage, ChatRole, FunctionCall, ToolCall +from kani.models import ChatMessage, ChatRole from kani.prompts.impl.cohere import ( - DEFAULT_PREAMBLE, - DEFAULT_TASK, - DEFAULT_TOOL_INSTRUCTIONS, - DEFAULT_TOOL_PROMPT, - build_pipeline, + CommandRMixin, function_prompt, tool_call_formatter, ) @@ -33,7 +27,7 @@ log = logging.getLogger(__name__) -class CommandREngine(HuggingEngine): +class CommandREngine(CommandRMixin, HuggingEngine): """Implementation of Command R (35B) and Command R+ (104B) using huggingface transformers. Model IDs: @@ -78,18 +72,7 @@ class CommandREngine(HuggingEngine): token_reserve = 200 # generous reserve due to large ctx size and weird 3-mode prompt - def __init__( - self, - model_id: str = "CohereForAI/c4ai-command-r-v01", - *args, - tool_prompt_include_function_calls=True, - tool_prompt_include_function_results=True, - tool_prompt_instructions=DEFAULT_TOOL_INSTRUCTIONS, - rag_prompt_include_function_calls=True, - rag_prompt_include_function_results=True, - rag_prompt_instructions=None, - **kwargs, - ): + def __init__(self, model_id: str = "CohereForAI/c4ai-command-r-v01", *args, **kwargs): """ :param model_id: The ID of the model to load from HuggingFace. :param max_context_size: The context size of the model (defaults to Command R's size of 128k). @@ -119,22 +102,6 @@ def __init__( kwargs.setdefault("max_context_size", 128000) super().__init__(model_id, *args, **kwargs) - self._tool_prompt_include_function_calls = tool_prompt_include_function_calls - - self._default_pipeline = build_pipeline() - self._tool_pipeline = build_pipeline( - include_function_calls=tool_prompt_include_function_calls, - include_all_function_results=tool_prompt_include_function_results, - include_last_function_result=tool_prompt_include_function_results, - instruction_suffix=tool_prompt_instructions, - ) - self._rag_pipeline = build_pipeline( - include_function_calls=rag_prompt_include_function_calls, - include_all_function_results=rag_prompt_include_function_results, - include_last_function_result=True, - instruction_suffix=rag_prompt_instructions, - ) - # ==== token counting ==== def message_len(self, message: ChatMessage) -> int: # prompt str to tokens @@ -162,54 +129,6 @@ def function_token_reserve(self, functions: list[AIFunction]) -> int: function_tokens = len(self.tokenizer.encode(function_text, add_special_tokens=False)) return function_tokens + default_prompt_tokens - # ==== prompt ==== - def build_prompt( - self, messages: list[ChatMessage], functions: list[AIFunction] | None = None - ) -> str | torch.Tensor: - # no functions: we can just do the default simple format - if not functions: - prompt = self._default_pipeline(messages) - log.debug(f"PROMPT: {prompt}") - return prompt - - # if we do have functions things get wacky - # is the last message a FUNCTION? if so, we need to use the RAG template - if messages and messages[-1].role == ChatRole.FUNCTION: - prompt = self._build_prompt_rag(messages) - log.debug(f"RAG PROMPT: {prompt}") - return prompt - - # otherwise use the TOOL template - prompt = self._build_prompt_tools(messages, functions) - log.debug(f"TOOL PROMPT: {prompt}") - return prompt - - def _build_prompt_tools(self, messages: list[ChatMessage], functions: list[AIFunction]): - # get the function definitions - function_text = "\n\n".join(map(function_prompt, functions)) - tool_prompt = DEFAULT_TOOL_PROMPT.format(user_functions=function_text) - - # wrap the initial system message, if any - messages = messages.copy() - if messages and messages[0].role == ChatRole.SYSTEM: - messages[0] = messages[0].copy_with(content=DEFAULT_PREAMBLE + messages[0].text + tool_prompt) - # otherwise add it in - else: - messages.insert(0, ChatMessage.system(DEFAULT_PREAMBLE + DEFAULT_TASK + tool_prompt)) - - return self._tool_pipeline(messages) - - def _build_prompt_rag(self, messages: list[ChatMessage]): - # wrap the initial system message, if any - messages = messages.copy() - if messages and messages[0].role == ChatRole.SYSTEM: - messages[0] = messages[0].copy_with(content=DEFAULT_PREAMBLE + messages[0].text) - # otherwise add it in - else: - messages.insert(0, ChatMessage.system(DEFAULT_PREAMBLE + DEFAULT_TASK)) - - return self._rag_pipeline(messages) - # ==== generate ==== def _generate(self, input_toks, input_len, hyperparams, functions): """Generate and return a completion (may be a directly_answer call).""" @@ -219,28 +138,8 @@ def _generate(self, input_toks, input_len, hyperparams, functions): # the completion shouldn't include the prompt or stop token content = self.tokenizer.decode(output[0][input_len:-1]).strip() completion_tokens = len(output[0]) - (input_len + 1) - log.debug(f"COMPLETION: {content}") - - # if we have tools, possibly parse out the Action - tool_calls = None - if functions and (action_json := re.match(r"Action:\s*```json\n(.+)\n```", content, re.IGNORECASE | re.DOTALL)): - actions = json.loads(action_json.group(1)) - - # translate back to kani spec - tool_calls = [] - for action in actions: - tool_name = action["tool_name"] - tool_args = json.dumps(action["parameters"]) - tool_call = ToolCall.from_function_call(FunctionCall(name=tool_name, arguments=tool_args)) - tool_calls.append(tool_call) - - content = None - log.debug(f"PARSED TOOL CALLS: {tool_calls}") - - return Completion( - ChatMessage.assistant(content, tool_calls=tool_calls), - prompt_tokens=input_len, - completion_tokens=completion_tokens, + return self._parse_completion( + content, functions is not None, prompt_tokens=input_len, completion_tokens=completion_tokens ) async def _stream(self, input_toks, hyperparams, *, streamer_timeout=None) -> AsyncIterable[str | Completion]: @@ -270,14 +169,12 @@ async def predict( input_toks, input_len, hyperparams = self._get_generate_args(prompt, **hyperparams) completion = self._generate(input_toks, input_len, hyperparams, functions) - tool_calls = completion.message.tool_calls or [] + cmd_r_tc_info = self._toolcall_info(completion.message.tool_calls) + # if the model generated multiple calls that happen to include a directly_answer, remove the directly_answer - if len(tool_calls) > 1: - completion.message.tool_calls = [ - tc for tc in completion.message.tool_calls if tc.function.name != "directly_answer" - ] + completion.message.tool_calls = cmd_r_tc_info.filtered_tool_calls # if tool says directly answer, call again with the rag pipeline (but no result) - elif len(tool_calls) == 1 and tool_calls[0].function.name == "directly_answer": + if cmd_r_tc_info.is_directly_answer: log.debug("GOT DIRECTLY_ANSWER, REPROMPTING RAG...") prompt = self._build_prompt_rag(messages) log.debug(f"RAG PROMPT: {prompt}") @@ -301,10 +198,9 @@ async def stream( input_toks, input_len, hyperparams = self._get_generate_args(prompt, **hyperparams) completion = self._generate(input_toks, input_len, hyperparams, functions) - tool_calls = completion.message.tool_calls or [] - + cmd_r_tc_info = self._toolcall_info(completion.message.tool_calls) # if tool says directly answer, stream with the rag pipeline (but no result) - if len(tool_calls) == 1 and tool_calls[0].function.name == "directly_answer": + if cmd_r_tc_info.is_directly_answer: log.debug("GOT DIRECTLY_ANSWER, REPROMPTING RAG...") prompt = self._build_prompt_rag(messages) log.debug(f"RAG PROMPT: {prompt}") @@ -314,10 +210,7 @@ async def stream( # if the model generated multiple calls that happen to include a directly_answer, remove the directly_answer # then yield as normal else: - if completion.message.tool_calls: - completion.message.tool_calls = [ - tc for tc in completion.message.tool_calls if tc.function.name != "directly_answer" - ] + completion.message.tool_calls = cmd_r_tc_info.filtered_tool_calls if completion.message.text: yield completion.message.text yield completion diff --git a/kani/prompts/impl/cohere.py b/kani/prompts/impl/cohere.py index 189c0c0..bbbbce7 100644 --- a/kani/prompts/impl/cohere.py +++ b/kani/prompts/impl/cohere.py @@ -1,11 +1,17 @@ import inspect import json +import logging +import re +from collections import namedtuple from kani import AIFunction -from kani.models import ChatMessage, ChatRole +from kani.engines import Completion +from kani.models import ChatMessage, ChatRole, FunctionCall, ToolCall from kani.prompts.pipeline import PromptPipeline from kani.prompts.steps import Apply, ConversationFmt, MergeConsecutive, Remove +log = logging.getLogger(__name__) + # ==== default prompts ==== # fmt: off DEFAULT_SYSTEM_PROMPT = "You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere." @@ -160,6 +166,7 @@ def remover(m, is_last): # ==== helpers ==== def function_prompt(f: AIFunction) -> str: + """Build the Cohere python signature prompt for a given AIFunction.""" params = f.get_params() # build params @@ -180,3 +187,123 @@ def function_prompt(f: AIFunction) -> str: # return return f'```python\ndef {f.name}({params_str}) -> List[Dict]:\n """{f.desc}{args}\n """\n pass\n```' + + +CommandRToolCallInfo = namedtuple("CommandRToolCallInfo", "is_directly_answer filtered_tool_calls") + + +class CommandRMixin: + """Common Command R functionality to share between engines""" + + def __init__( + self, + *args, + tool_prompt_include_function_calls=True, + tool_prompt_include_function_results=True, + tool_prompt_instructions=DEFAULT_TOOL_INSTRUCTIONS, + rag_prompt_include_function_calls=True, + rag_prompt_include_function_results=True, + rag_prompt_instructions=None, + **kwargs, + ): + super().__init__(*args, **kwargs) + self._tool_prompt_include_function_calls = tool_prompt_include_function_calls + + self._default_pipeline = build_pipeline() + self._tool_pipeline = build_pipeline( + include_function_calls=tool_prompt_include_function_calls, + include_all_function_results=tool_prompt_include_function_results, + include_last_function_result=tool_prompt_include_function_results, + instruction_suffix=tool_prompt_instructions, + ) + self._rag_pipeline = build_pipeline( + include_function_calls=rag_prompt_include_function_calls, + include_all_function_results=rag_prompt_include_function_results, + include_last_function_result=True, + instruction_suffix=rag_prompt_instructions, + ) + + # ==== prompting ==== + def build_prompt(self, messages: list[ChatMessage], functions: list[AIFunction] | None = None) -> str: + # no functions: we can just do the default simple format + if not functions: + prompt = self._default_pipeline(messages) + log.debug(f"PROMPT: {prompt}") + return prompt + + # if we do have functions things get wacky + # is the last message a FUNCTION? if so, we need to use the RAG template + if messages and messages[-1].role == ChatRole.FUNCTION: + prompt = self._build_prompt_rag(messages) + log.debug(f"RAG PROMPT: {prompt}") + return prompt + + # otherwise use the TOOL template + prompt = self._build_prompt_tools(messages, functions) + log.debug(f"TOOL PROMPT: {prompt}") + return prompt + + def _build_prompt_tools(self, messages: list[ChatMessage], functions: list[AIFunction]): + # get the function definitions + function_text = "\n\n".join(map(function_prompt, functions)) + tool_prompt = DEFAULT_TOOL_PROMPT.format(user_functions=function_text) + + # wrap the initial system message, if any + messages = messages.copy() + if messages and messages[0].role == ChatRole.SYSTEM: + messages[0] = messages[0].copy_with(content=DEFAULT_PREAMBLE + messages[0].text + tool_prompt) + # otherwise add it in + else: + messages.insert(0, ChatMessage.system(DEFAULT_PREAMBLE + DEFAULT_TASK + tool_prompt)) + + return self._tool_pipeline(messages) + + def _build_prompt_rag(self, messages: list[ChatMessage]): + # wrap the initial system message, if any + messages = messages.copy() + if messages and messages[0].role == ChatRole.SYSTEM: + messages[0] = messages[0].copy_with(content=DEFAULT_PREAMBLE + messages[0].text) + # otherwise add it in + else: + messages.insert(0, ChatMessage.system(DEFAULT_PREAMBLE + DEFAULT_TASK)) + + return self._rag_pipeline(messages) + + # ==== completions ==== + @staticmethod + def _parse_completion(content: str, parse_functions=True, **kwargs) -> Completion: + """Given the completion string, parse out any function calls.""" + log.debug(f"COMPLETION: {content}") + + # if we have tools, possibly parse out the Action + tool_calls = None + if parse_functions and ( + action_json := re.match(r"Action:\s*```json\n(.+)\n```", content, re.IGNORECASE | re.DOTALL) + ): + actions = json.loads(action_json.group(1)) + + # translate back to kani spec + tool_calls = [] + for action in actions: + tool_name = action["tool_name"] + tool_args = json.dumps(action["parameters"]) + tool_call = ToolCall.from_function_call(FunctionCall(name=tool_name, arguments=tool_args)) + tool_calls.append(tool_call) + + content = None + log.debug(f"PARSED TOOL CALLS: {tool_calls}") + + return Completion(ChatMessage.assistant(content, tool_calls=tool_calls), **kwargs) + + @staticmethod + def _toolcall_info(tool_calls: list[ToolCall]) -> CommandRToolCallInfo: + """Return an info tuple containing Command R-specific metadata (is_directly_answer, filtered_tcs).""" + tool_calls = tool_calls or [] + + # if tool says directly answer, stream with the rag pipeline (but no result) + if len(tool_calls) == 1 and tool_calls[0].function.name == "directly_answer": + return CommandRToolCallInfo(is_directly_answer=True, filtered_tool_calls=[]) + # if the model generated multiple calls that happen to include a directly_answer, remove the directly_answer + # then yield as normal + tool_calls = [tc for tc in tool_calls if tc.function.name != "directly_answer"] + return CommandRToolCallInfo(is_directly_answer=False, filtered_tool_calls=tool_calls) From 1cf6b320f0ee55655feb448375a4eef540dbb73d Mon Sep 17 00:00:00 2001 From: Andrew Zhu Date: Fri, 19 Apr 2024 13:46:59 -0400 Subject: [PATCH 8/9] docs: llama3 default settings --- examples/4_engines_zoo.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/4_engines_zoo.py b/examples/4_engines_zoo.py index 14a7ff2..e88b694 100644 --- a/examples/4_engines_zoo.py +++ b/examples/4_engines_zoo.py @@ -28,9 +28,11 @@ use_auth_token=True, # log in with huggingface-cli # suggested args from the Llama model card model_load_kwargs={"device_map": "auto", "torch_dtype": torch.bfloat16}, - eos_token_id=[128001, 128009], # [<|end_of_text|>, <|eot_id|>] ) +# NOTE: If you're running transformers<4.40 and LLaMA 3 continues generating after the <|eot_id|> token, +# add `eos_token_id=[128001, 128009]` or upgrade transformers + # ---- LLaMA v2 (Hugging Face) ---- from kani.engines.huggingface.llama2 import LlamaEngine engine = LlamaEngine(model_id="meta-llama/Llama-2-7b-chat-hf", use_auth_token=True) # log in with huggingface-cli From 0a3f6d921d8cd4c12226d3b1a171965023c427cd Mon Sep 17 00:00:00 2001 From: Andrew Zhu Date: Fri, 19 Apr 2024 14:21:03 -0400 Subject: [PATCH 9/9] chore: isort --- kani/engines/huggingface/cohere.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/kani/engines/huggingface/cohere.py b/kani/engines/huggingface/cohere.py index 93c519d..3500e25 100644 --- a/kani/engines/huggingface/cohere.py +++ b/kani/engines/huggingface/cohere.py @@ -6,11 +6,7 @@ from kani.ai_function import AIFunction from kani.exceptions import MissingModelDependencies from kani.models import ChatMessage, ChatRole -from kani.prompts.impl.cohere import ( - CommandRMixin, - function_prompt, - tool_call_formatter, -) +from kani.prompts.impl.cohere import CommandRMixin, function_prompt, tool_call_formatter from .base import HuggingEngine from ..base import Completion