-
-
Notifications
You must be signed in to change notification settings - Fork 30
/
engine.py
41 lines (31 loc) · 1.44 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from typing import AsyncIterable
from kani import AIFunction, ChatMessage
from kani.engines.base import BaseCompletion, BaseEngine, Completion
class TestEngine(BaseEngine):
"""A mock engine used for testing.
Each message has a token length equal to its str length, and predict always returns a one-token message.
"""
max_context_size = 10
def message_len(self, message: ChatMessage) -> int:
return len(message.text)
async def predict(self, messages, functions=None, test_echo=False, **hyperparams) -> Completion:
"""
:param test_echo: If True, the prediction echoes the last message.
"""
assert sum(len(m.text or "") for m in messages) <= self.max_context_size
if test_echo:
content = messages[-1].text
return Completion(ChatMessage.assistant(content))
return Completion(ChatMessage.assistant("a"))
class TestStreamingEngine(TestEngine):
"""A mock engine used for testing but with streaming (yields one character at a time)."""
async def stream(
self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, test_echo=False, **hyperparams
) -> AsyncIterable[str | BaseCompletion]:
assert sum(len(m.text or "") for m in messages) <= self.max_context_size
if test_echo:
content = messages[-1].text
for char in content:
yield char
else:
yield "a"