Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multimodality #967

Merged
merged 19 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 215 additions & 45 deletions core/cat/convo/messages.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,46 @@
from typing import List, Literal
from cat.utils import BaseModelDict
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage
from enum import Enum
from pydantic import BaseModel, Field, ConfigDict
import time
from enum import Enum
from typing import List, Optional, Literal, Union

from pydantic import BaseModel, Field, ConfigDict, computed_field

from cat.utils import BaseModelDict, deprecation_warning


class Role(Enum):
"""
Enum representing the roles involved in a conversation.

Attributes
----------
AI : str
Represents an artificial intelligence role.
Human : str
Represents a human role.
"""

AI = "AI"
Human = "Human"


class ModelInteraction(BaseModel):
"""
Base class for interactions with models, capturing essential attributes common to all model interactions.

Attributes
----------
model_type : Literal["llm", "embedder"]
The type of model involved in the interaction, either a large language model (LLM) or an embedder.
source : str
The source from which the interaction originates.
prompt : str
The prompt or input provided to the model.
input_tokens : int
The number of input tokens processed by the model.
started_at : float
The timestamp when the interaction started. Defaults to the current time.
"""

model_type: Literal["llm", "embedder"]
source: str
prompt: str
Expand All @@ -24,26 +53,65 @@ class ModelInteraction(BaseModel):


class LLMModelInteraction(ModelInteraction):
"""
Represents an interaction with a large language model (LLM).

Inherits from ModelInteraction and adds specific attributes related to LLM interactions.

Attributes
----------
model_type : Literal["llm"]
The type of model, which is fixed to "llm".
reply : str
The response generated by the LLM.
output_tokens : int
The number of output tokens generated by the LLM.
ended_at : float
The timestamp when the interaction ended.
"""

model_type: Literal["llm"] = Field(default="llm")
reply: str
output_tokens: int
ended_at: float


class EmbedderModelInteraction(ModelInteraction):
"""
Represents an interaction with an embedding model.

Inherits from ModelInteraction and includes attributes specific to embedding interactions.

Attributes
----------
model_type : Literal["embedder"]
The type of model, which is fixed to "embedder".
source : str
The source of the interaction, defaulting to "recall".
reply : List[float]
The embeddings generated by the embedder.
"""
model_type: Literal["embedder"] = Field(default="embedder")
source: str = Field(default="recall")
reply: List[float]


class MessageWhy(BaseModelDict):
"""Class for wrapping message why
"""
A class for encapsulating the context and reasoning behind a message, providing details on
input, intermediate steps, memory, and interactions with models.

Variables:
input (str): input message
intermediate_steps (List): intermediate steps
memory (dict): memory
model_interactions (List[LLMModelInteraction | EmbedderModelInteraction]): model interactions
Attributes
----------
input : str
The initial input message that triggered the response.
intermediate_steps : List
A list capturing intermediate steps or actions taken as part of processing the message.
memory : dict
A dictionary containing relevant memory information used during the processing of the message.
model_interactions : List[Union[LLMModelInteraction, EmbedderModelInteraction]]
A list of interactions with language or embedding models, detailing how models were used in generating
or understanding the message context.
"""

input: str
Expand All @@ -53,50 +121,152 @@ class MessageWhy(BaseModelDict):


class CatMessage(BaseModelDict):
"""Class for wrapping cat message
"""
Represents a Cat message.

Parameters
----------
user_id : str
Unique identifier for the user associated with the message.
content : Optional[str], default=None
Deprecated. The text content of the message. Use `text` instead.
text : Optional[str], default=None
The text content of the message.
images : Optional[Union[List[str], str]], default=None
List of image file URLs or base64 data URIs that represent images associated with the message. A single string can also be provided and will be converted to a list.
audio : Optional[Union[List[str], str]], default=None
List of audio file URLs or base64 data URIs that represent audio associated with the message. A single string can also be provided and will be converted to a list.
why : Optional[MessageWhy], default=None
Additional contextual information related to the message.
who : str, default="AI"
The name of the message author, by default "AI".

Variables:
content (str): cat message
user_id (str): user id
Attributes
----------
type : str
The type of message. Defaults to "chat".
user_id : str
Unique identifier for the user associated with the message.
who : str
The name of the message author, by default AI.
text : Optional[str]
The text content of the message.
images : Optional[List[str]]
List of image URLs or paths associated with the message, if any.
audio : Optional[List[str]]
List of audio file URLs or paths associated with the message, if any.
why : Optional[MessageWhy]
Additional contextual information related to the message.

Notes
-----
- The `content` parameter and attribute are deprecated. Use `text` instead.
"""

content: str
user_id: str
type: str = "chat"
why: MessageWhy | None = None
user_id: str
who: str = "AI"
text: Optional[str] = None
images: Optional[List[str]] = None
audio: Optional[List[str]] = None
why: Optional[MessageWhy] = None

def __init__(
self,
user_id: str,
content: Optional[str] = None,
text: Optional[str] = None,
images: Optional[Union[List[str], str]] = None,
audio: Optional[Union[List[str], str]] = None,
why: Optional[MessageWhy] = None,
who: str = "AI",
**kwargs,
):
if isinstance(images, str):
images = [images]

if isinstance(audio, str):
audio = [audio]

if content:
deprecation_warning("The `content` parameter is deprecated. Use `text` instead.")
text = content # Map 'content' to 'text'

super().__init__(user_id=user_id, who=who, content=content, text=text, images=images, audio=audio, why=why, **kwargs)

@computed_field
@property
def content(self) -> str:
"""
This attribute is deprecated. Use `text` instead.

The text content of the message. Use `text` instead.

Returns
-------
str
The text content of the message.
"""
deprecation_warning("The `content` attribute is deprecated. Use `text` instead.")
return self.text

@content.setter
def content(self, value):
deprecation_warning("The `content` attribute is deprecated. Use `text` instead.")
self.text = value


class UserMessage(BaseModelDict):
"""Class for wrapping user message
"""
Represents a message from a user, containing text and optional multimedia content such as images and audio.

Variables:
text (str): user message
user_id (str): user id
Parameters
----------
user_id : str
Unique identifier for the user sending the message.
text : Optional[str], default=None
The text content of the message. Can be `None` if no text is provided.
images : Optional[Union[List[str], str]], default=None
List of image file URLs or base64 data URIs that represent images associated with the message. A single string can also be provided and will be converted to a list.
audio : Optional[Union[List[str], str]], default=None
List of audio file URLs or base64 data URIs that represent audio associated with the message. A single string can also be provided and will be converted to a list.
who : str, default="Human"
The name of the message author, by default “Human”.

Attributes
----------
user_id : str
Unique identifier for the user sending the message.
who : str
The name of the message author, by default “Human”.
text : Optional[str]
The text content of the message.
images : Optional[List[str]]
List of images associated with the message, if any.
audio : Optional[List[str]]
List of audio files associated with the message, if any.
"""

text: str
user_id: str
who: str = "Human"
text: Optional[str] = None
images: Optional[List[str]] = None
audio: Optional[List[str]] = None

def __init__(
self,
user_id: str,
text: Optional[str] = None,
images: Optional[Union[List[str], str]] = None,
audio: Optional[Union[List[str], str]] = None,
who: str = "Human",
**kwargs,
):
if isinstance(images, str):
images = [images]

if isinstance(audio, str):
audio = [audio]

super().__init__(user_id=user_id, who=who, text=text, images=images, audio=audio, **kwargs)

def convert_to_Langchain_message(
messages: List[UserMessage | CatMessage],
) -> List[BaseMessage]:
messages = []
for m in messages:
if isinstance(m, CatMessage):
messages.append(
HumanMessage(content=m.content, response_metadata={"userId": m.user_id})
)
else:
messages.append(
AIMessage(content=m.text, response_metadata={"userId": m.user_id})
)
return messages


def convert_to_Cat_message(cat_message: AIMessage, why: MessageWhy) -> CatMessage:
return CatMessage(
content=cat_message.content,
user_id=cat_message.response_metadata["userId"],
why=why,
)
Loading
Loading