Skip to content

Commit

Permalink
refactor HistoryEntry; fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pieroit committed Dec 20, 2024
1 parent 888bfe9 commit e2dcd27
Show file tree
Hide file tree
Showing 12 changed files with 161 additions and 337 deletions.
214 changes: 43 additions & 171 deletions core/cat/convo/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import base64
from io import BytesIO
from enum import Enum
from typing import List, Optional, Literal
from typing import List, Optional, Literal, Union
import requests
from PIL import Image

from pydantic import BaseModel, Field, ConfigDict, computed_field
from langchain_core.messages import AIMessage, HumanMessage

from cat.convo.model_interactions import LLMModelInteraction, EmbedderModelInteraction
from cat.utils import BaseModelDict, deprecation_warning
from cat.log import log

Expand All @@ -29,79 +30,6 @@ class Role(Enum):
Human = "Human"


class ModelInteraction(BaseModel):
"""
Base class for interactions with models, capturing essential attributes common to all model interactions.
Attributes
----------
model_type : Literal["llm", "embedder"]
The type of model involved in the interaction, either a large language model (LLM) or an embedder.
source : str
The source from which the interaction originates.
prompt : str
The prompt or input provided to the model.
input_tokens : int
The number of input tokens processed by the model.
started_at : float
The timestamp when the interaction started. Defaults to the current time.
"""

model_type: Literal["llm", "embedder"]
source: str
prompt: str
input_tokens: int
started_at: float = Field(default_factory=lambda: time.time())

model_config = ConfigDict(
protected_namespaces=()
)


class LLMModelInteraction(ModelInteraction):
"""
Represents an interaction with a large language model (LLM).
Inherits from ModelInteraction and adds specific attributes related to LLM interactions.
Attributes
----------
model_type : Literal["llm"]
The type of model, which is fixed to "llm".
reply : str
The response generated by the LLM.
output_tokens : int
The number of output tokens generated by the LLM.
ended_at : float
The timestamp when the interaction ended.
"""

model_type: Literal["llm"] = Field(default="llm")
reply: str
output_tokens: int
ended_at: float


class EmbedderModelInteraction(ModelInteraction):
"""
Represents an interaction with an embedding model.
Inherits from ModelInteraction and includes attributes specific to embedding interactions.
Attributes
----------
model_type : Literal["embedder"]
The type of model, which is fixed to "embedder".
source : str
The source of the interaction, defaulting to "recall".
reply : List[float]
The embeddings generated by the embedder.
"""
model_type: Literal["embedder"] = Field(default="embedder")
source: str = Field(default="recall")
reply: List[float]


class MessageWhy(BaseModelDict):
"""
A class for encapsulating the context and reasoning behind a message, providing details on
Expand All @@ -126,68 +54,55 @@ class MessageWhy(BaseModelDict):
model_interactions: List[LLMModelInteraction | EmbedderModelInteraction]


class BaseMessage(BaseModelDict):
class Message(BaseModelDict):
"""
Base class for messages, containing common attributes shared by all message types.
Base class for working memory history entries.
Is subclassed by `ConversationMessage`, which in turns is subclassed by `CatMessage` and `UserMessage`.
Attributes
----------
user_id : str
Unique identifier for the user associated with the message.
who : str
The name of the message author.
text : Optional[str]
The text content of the message.
image : Optional[str]
Image file URLs or base64 data URIs that represent image associated with the message.
audio : Optional[str]
Audio file URLs or base64 data URIs that represent audio associated with the message.
why : Optional[MessageWhy]
Additional contextual information related to the message.
when : Optional[float]
The timestamp when the message was sent.
"""

user_id: str
who: str
text: Optional[str] = None
image: Optional[str] = None
audio: Optional[str] = None
when: float = Field(default_factory=time.time)


class CatMessage(BaseMessage):
class ConversationMessage(Message):
"""
Represents a Cat message.
Base class for conversation messages, containing common attributes shared by all message types.
Subclassed by `CatMessage` and `UserMessage`.
Parameters
Attributes
----------
user_id : str
Unique identifier for the user associated with the message.
text : Optional[str], default=None
The text content of the message.
image : Optional[str], default=None
Image file URLs or base64 data URIs that represent image associated with the message.
audio : Optional[str], default=None
Audio file URLs or base64 data URIs that represent audio associated with the message.
why : Optional[MessageWhy], default=None
Additional contextual information related to the message.
who : str, default="AI"
The name of the message author, by default "AI".
content : Optional[str], default=None
Deprecated. The text content of the message. Use `text` instead.
"""

text: Optional[str] = None
image: Optional[str] = None
audio: Optional[str] = None


class CatMessage(ConversationMessage):
"""
Represents a Cat message.
Attributes
----------
type : str
The type of message. Defaults to "chat".
user_id : str
Unique identifier for the user associated with the message.
who : str
The name of the message author, by default AI.
text : Optional[str]
The text content of the message.
image : Optional[str]
Image file URLs or base64 data URIs that represent image associated with the message.
audio : Optional[str]
Audio file URLs or base64 data URIs that represent audio associated with the message.
why : Optional[MessageWhy]
Additional contextual information related to the message.
Expand All @@ -196,26 +111,24 @@ class CatMessage(BaseMessage):
- The `content` parameter and attribute are deprecated. Use `text` instead.
"""

type: str = "chat" # For now is always "chat" and is not used
who: str = "AI"
why: Optional[MessageWhy]

def __init__(
self,
user_id: str,
text: Optional[str] = None,
image: Optional[str] = None,
audio: Optional[str] = None,
why: Optional[MessageWhy] = None,
who: str = "AI",
content: Optional[str] = None,
**kwargs,
):
if content:
deprecation_warning("The `content` parameter is deprecated. Use `text` instead.")
text = content # Map 'content' to 'text'

super().__init__(user_id=user_id, who=who, text=text, image=image, audio=audio, why=why, **kwargs)
type: str = "chat" # For now is always "chat" and is not used
why: Optional[MessageWhy] = None

def langchainfy(self) -> AIMessage:
"""
Convert the internal CatMessage to a LangChain AIMessage.
Returns
-------
AIMessage
The LangChain AIMessage converted from the internal CatMessage.
"""

return AIMessage(
name=self.who,
content=self.text
)

@computed_field
@property
Expand All @@ -238,56 +151,15 @@ def content(self, value):
deprecation_warning("The `content` attribute is deprecated. Use `text` instead.")
self.text = value

def langchainfy(self) -> AIMessage:
"""
Convert the internal CatMessage to a LangChain AIMessage.
Returns
-------
AIMessage
The LangChain AIMessage converted from the internal CatMessage.
"""

return AIMessage(
name=self.who,
content=self.text
)



class UserMessage(BaseMessage):
class UserMessage(ConversationMessage):
"""
Represents a message from a user, containing text and optional multimedia content such as image and audio.
This class is used to encapsulate the details of a message sent by a user, including the user's identifier,
the text content of the message, and any associated multimedia content such as image or audio files.
Parameters
----------
user_id : str
Unique identifier for the user sending the message.
text : Optional[str], default=None
The text content of the message.
image : Optional[str], default=None
Image file URLs or base64 data URIs that represent image associated with the message.
audio : Optional[str], default=None
Audio file URLs or base64 data URIs that represent audio associated with the message.
who : str, default="Human"
The name of the message author, by default “Human”.
Attributes
----------
user_id : str
Unique identifier for the user sending the message.
who : str
The name of the message author, by default “Human”.
text : Optional[str]
The text content of the message.
image : Optional[str]
Image file URLs or base64 data URIs that represent image associated with the message.
audio : Optional[str]
Audio file URLs or base64 data URIs that represent audio associated with the message.
"""

who: str = "Human"

def langchainfy(self) -> HumanMessage:
Expand Down
77 changes: 77 additions & 0 deletions core/cat/convo/model_interactions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import time
from io import BytesIO
from enum import Enum
from typing import List, Optional, Literal
from pydantic import BaseModel, Field, ConfigDict

class ModelInteraction(BaseModel):
"""
Base class for interactions with models, capturing essential attributes common to all model interactions.
Attributes
----------
model_type : Literal["llm", "embedder"]
The type of model involved in the interaction, either a large language model (LLM) or an embedder.
source : str
The source from which the interaction originates.
prompt : str
The prompt or input provided to the model.
input_tokens : int
The number of input tokens processed by the model.
started_at : float
The timestamp when the interaction started. Defaults to the current time.
"""

model_type: Literal["llm", "embedder"]
source: str
prompt: str
input_tokens: int
started_at: float = Field(default_factory=lambda: time.time())

model_config = ConfigDict(
protected_namespaces=()
)


class LLMModelInteraction(ModelInteraction):
"""
Represents an interaction with a large language model (LLM).
Inherits from ModelInteraction and adds specific attributes related to LLM interactions.
Attributes
----------
model_type : Literal["llm"]
The type of model, which is fixed to "llm".
reply : str
The response generated by the LLM.
output_tokens : int
The number of output tokens generated by the LLM.
ended_at : float
The timestamp when the interaction ended.
"""

model_type: Literal["llm"] = Field(default="llm")
reply: str
output_tokens: int
ended_at: float


class EmbedderModelInteraction(ModelInteraction):
"""
Represents an interaction with an embedding model.
Inherits from ModelInteraction and includes attributes specific to embedding interactions.
Attributes
----------
model_type : Literal["embedder"]
The type of model, which is fixed to "embedder".
source : str
The source of the interaction, defaulting to "recall".
reply : List[float]
The embeddings generated by the embedder.
"""
model_type: Literal["embedder"] = Field(default="embedder")
source: str = Field(default="recall")
reply: List[float]
2 changes: 1 addition & 1 deletion core/cat/looking_glass/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from langchain.callbacks.base import BaseCallbackHandler
from langchain_core.outputs.llm_result import LLMResult
import tiktoken
from cat.convo.messages import LLMModelInteraction
from cat.convo.model_interactions import LLMModelInteraction
import time


Expand Down
Loading

0 comments on commit e2dcd27

Please sign in to comment.