Skip to content

Commit

Permalink
Merge pull request #39 from zhudotexe/v1-dev
Browse files Browse the repository at this point in the history
v1: llama3, wrapperengine
  • Loading branch information
zhudotexe committed Apr 22, 2024
2 parents 82f5548 + 0a3f6d9 commit 4b05830
Show file tree
Hide file tree
Showing 12 changed files with 320 additions and 221 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ kani supports every chat model available on Hugging Face through `transformers`

In particular, we have reference implementations for the following base models, and their fine-tunes:

- [LLaMA 3](https://huggingface.co/collections/meta-llama/meta-llama-3-66214712577ca38149ebb2b6) (all sizes)
- [Command R](https://huggingface.co/CohereForAI/c4ai-command-r-v01)
and [Command R+](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
- [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2)
Expand Down
31 changes: 19 additions & 12 deletions docs/community/extensions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@ make your package available on pip.

Community Extensions
--------------------
If you've made a cool extension, add it to this table with a PR!
If you've made a cool extension, add it to this list with a PR!

+-------------+----------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------+
| Name | Description | Links |
+=============+============================================================================+==============================================================================================================+
| kani-vision | Adds support for multimodal vision-language models, like GPT-4V and LLaVA. | `GitHub <https://github.com/zhudotexe/kani-vision>`_ `Docs <https://kani-vision.readthedocs.io/en/latest/>`_ |
+-------------+----------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------+
* **kani-ratelimits**: Adds a wrapper engine to enforce request-per-minute (RPM), token-per-minute (TPM), and/or
max-concurrency ratelimits before making requests to an underlying engine.

* `GitHub (kani-ratelimits) <https://github.com/zhudotexe/kani-ratelimits>`_

* **kani-vision**: Adds support for multimodal vision-language models, like GPT-4V and LLaVA.

* `GitHub (kani-vision) <https://github.com/zhudotexe/kani-vision>`_
* `Docs (kani-vision) <https://kani-vision.readthedocs.io/en/latest/>`_

Design Considerations
---------------------
Expand All @@ -38,19 +42,22 @@ your engine *wrap* another engine:
"""An example showing how to wrap another kani engine."""
class MyEngineWrapper(BaseEngine):
def __init__(self, inner_engine: BaseEngine):
self.inner_engine = inner_engine
self.max_context_size = inner_engine.max_context_size
from kani.engines import BaseEngine, WrapperEngine
# subclassing WrapperEngine automatically implements passthrough of untouched attributes
# to the wrapped engine!
class MyEngineWrapper(WrapperEngine):
def message_len(self, message):
# wrap the inner message with the prompt framework...
prompted_message = ChatMessage(...)
return self.inner_engine.message_len(prompted_message)
return super().message_len(prompted_message)
async def predict(self, messages, functions=None, **hyperparams):
# wrap the messages with the prompt framework and pass it to the inner engine
prompted_completion = await self.inner_engine.predict(prompted_messages, ...)
prompted_completion = await super().predict(prompted_messages, ...)
# unwrap the resulting message (if necessary) and store the metadata separately
completion = self.unwrap(prompted_completion)
return completion
The :class:`kani.engines.WrapperEngine` is a base class that automatically creates a constructor that takes in the
engine to wrap, and passes through any non-overriden attributes to the wrapped engine.
10 changes: 7 additions & 3 deletions docs/engine_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@ Engine Reference

Base
----
.. autoclass:: kani.engines.base.BaseEngine
.. autoclass:: kani.engines.BaseEngine
:members:

.. autoclass:: kani.engines.base.BaseCompletion
.. autoclass:: kani.engines.Completion
:members:

.. autoclass:: kani.engines.base.Completion
.. autoclass:: kani.engines.WrapperEngine

.. autoattribute:: engine

.. autoclass:: kani.engines.base.BaseCompletion
:members:

.. autoclass:: kani.engines.httpclient.BaseClient
Expand Down
4 changes: 4 additions & 0 deletions docs/shared/engine_table.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
+----------------------------------------+------------------------------------+------------------------------+----------------------------------------------------------------------+
| |:hugging:| transformers\ [#runtime]_ | ``huggingface``\ [#torch]_ | (runtime) | :class:`kani.engines.huggingface.HuggingEngine` |
+----------------------------------------+------------------------------------+------------------------------+----------------------------------------------------------------------+
| |:hugging:| |:llama:| LLaMA 3 | ``huggingface, llama``\ [#torch]_ | |oss| |cpu| |gpu| | :class:`kani.engines.huggingface.HuggingEngine`\ [#zoo]_ |
+----------------------------------------+------------------------------------+------------------------------+----------------------------------------------------------------------+
| |:hugging:| Command R, Command R+ | ``huggingface``\ [#torch]_ | |function| |oss| |cpu| |gpu| | :class:`kani.engines.huggingface.cohere.CommandREngine` |
+----------------------------------------+------------------------------------+------------------------------+----------------------------------------------------------------------+
| |:hugging:| |:llama:| LLaMA v2 | ``huggingface, llama``\ [#torch]_ | |oss| |cpu| |gpu| | :class:`kani.engines.huggingface.llama2.LlamaEngine` |
Expand Down Expand Up @@ -36,6 +38,8 @@ models!
.. |gpu| replace:: :abbr:`🚀 (runs on local gpu)`
.. |api| replace:: :abbr:`📡 (hosted API)`

.. [#zoo] See the `model zoo <https://github.com/zhudotexe/kani/blob/main/examples/4_engines_zoo.py>`_ for a code sample
to initialize this model with the given engine.
.. [#torch] You will also need to install `PyTorch <https://pytorch.org/get-started/locally/>`_ manually.
.. [#abstract] This is an abstract class of models; kani includes a couple concrete implementations for
reference.
Expand Down
15 changes: 15 additions & 0 deletions examples/4_engines_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,21 @@
engine = AnthropicEngine(api_key=os.getenv("ANTHROPIC_API_KEY"), model="claude-3-opus-20240229")

# ========== Hugging Face ==========
# ---- LLaMA v3 (Hugging Face) ----
import torch
from kani.engines.huggingface import HuggingEngine
from kani.prompts.impl import LLAMA3_PIPELINE
engine = HuggingEngine(
model_id="meta-llama/Meta-Llama-3-8B-Instruct",
prompt_pipeline=LLAMA3_PIPELINE,
use_auth_token=True, # log in with huggingface-cli
# suggested args from the Llama model card
model_load_kwargs={"device_map": "auto", "torch_dtype": torch.bfloat16},
)

# NOTE: If you're running transformers<4.40 and LLaMA 3 continues generating after the <|eot_id|> token,
# add `eos_token_id=[128001, 128009]` or upgrade transformers

# ---- LLaMA v2 (Hugging Face) ----
from kani.engines.huggingface.llama2 import LlamaEngine
engine = LlamaEngine(model_id="meta-llama/Llama-2-7b-chat-hf", use_auth_token=True) # log in with huggingface-cli
Expand Down
1 change: 1 addition & 0 deletions kani/engines/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .base import BaseEngine, Completion, WrapperEngine
47 changes: 46 additions & 1 deletion kani/engines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ async def stream(
"""
Optional: Stream a completion from the engine, token-by-token.
This method's signature is the same as :meth:`predict`.
This method's signature is the same as :meth:`.BaseEngine.predict`.
This method should yield strings as an asynchronous iterable.
Expand All @@ -129,3 +129,48 @@ async def stream(
async def close(self):
"""Optional: Clean up any resources the engine might need."""
pass


class WrapperEngine(BaseEngine):
"""
A base class for engines that are meant to wrap other engines. By default, this class takes in another engine
as the first parameter in its constructor and will pass through all non-overriden attributes to the wrapped
engine.
"""

def __init__(self, engine: BaseEngine, *args, **kwargs):
"""
:param engine: The engine to wrap.
"""
super().__init__(*args, **kwargs)
self.engine = engine
"""The wrapped engine."""

# passthrough attrs
self.max_context_size = engine.max_context_size
self.token_reserve = engine.token_reserve

# passthrough methods
def message_len(self, message: ChatMessage) -> int:
return self.engine.message_len(message)

async def predict(
self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams
) -> BaseCompletion:
return await self.engine.predict(messages, functions, **hyperparams)

async def stream(
self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams
) -> AsyncIterable[str | BaseCompletion]:
async for elem in self.engine.stream(messages, functions, **hyperparams):
yield elem

def function_token_reserve(self, functions: list[AIFunction]) -> int:
return self.engine.function_token_reserve(functions)

async def close(self):
return await self.engine.close()

# all other attributes are caught by this default passthrough handler
def __getattr__(self, item):
return getattr(self.engine, item)
Loading

0 comments on commit 4b05830

Please sign in to comment.