Skip to content

Commit

Permalink
support function calling API, add custom schema support
Browse files Browse the repository at this point in the history
  • Loading branch information
nyanp committed Jun 24, 2023
1 parent efce7ca commit d8fc902
Show file tree
Hide file tree
Showing 5 changed files with 252 additions and 74 deletions.
171 changes: 138 additions & 33 deletions chat2plot/chat2plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import traceback
from dataclasses import dataclass
from logging import getLogger
from typing import Any
from typing import Any, Callable, Literal, Type, TypeVar

import altair as alt
import commentjson
Expand All @@ -12,17 +12,25 @@
import pydantic
from langchain.chat_models import ChatOpenAI
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage, HumanMessage, SystemMessage
from langchain.schema import BaseMessage, FunctionMessage, HumanMessage, SystemMessage
from plotly.graph_objs import Figure

from chat2plot.dataset_description import description
from chat2plot.dictionary_helper import delete_null_field
from chat2plot.prompt import JSON_TAG, error_correction_prompt, system_prompt
from chat2plot.prompt import (
JSON_TAG,
error_correction_prompt,
explanation_prompt,
system_prompt,
)
from chat2plot.render import draw_altair, draw_plotly
from chat2plot.schema import PlotConfig, ResponseType
from chat2plot.schema import PlotConfig, ResponseType, get_schema_of_chart_config

_logger = getLogger(__name__)

T = TypeVar("T", bound=pydantic.BaseModel)
ModelDeserializer = Callable[[dict[str, Any]], T]

# These errors are caught within the application.
# Other errors (e.g. openai.error.RateLimitError) are propagated to user code.
_APPLICATION_ERRORS = (
Expand All @@ -37,7 +45,7 @@
@dataclass(frozen=True)
class Plot:
figure: alt.Chart | Figure | None
config: PlotConfig | dict[str, Any] | None
config: PlotConfig | dict[str, Any] | pydantic.BaseModel | None
response_type: ResponseType
explanation: str
conversation_history: list[BaseMessage] | None
Expand All @@ -48,22 +56,24 @@ class ChatSession:

def __init__(
self,
chat: BaseChatModel,
df: pd.DataFrame,
system_prompt_template: str,
user_prompt_template: str,
description_strategy: str = "head",
chat: BaseChatModel | None = None,
functions: list[dict[str, Any]] | None = None,
):
self._system_prompt_template = system_prompt_template
self._user_prompt_template = user_prompt_template
self._chat = chat or ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo") # type: ignore
self._chat = chat
self._conversation_history: list[BaseMessage] = [
SystemMessage(
content=system_prompt_template.format(
dataset=description(df, description_strategy)
)
)
]
self._functions = functions

@property
def history(self) -> list[BaseMessage]:
Expand All @@ -72,19 +82,30 @@ def history(self) -> list[BaseMessage]:
def set_chatmodel(self, chat: BaseChatModel) -> None:
self._chat = chat

def query_without_history(self, q: str) -> str:
def query_without_history(self, q: str) -> BaseMessage:
response = self._chat([HumanMessage(content=q)])
return response.content
return response

def query(self, q: str) -> str:
prompt = self._user_prompt_template.format(text=q)
def query(self, q: str, raw: bool = False) -> BaseMessage:
prompt = q if raw else self._user_prompt_template.format(text=q)
response = self._query(prompt)
return response.content
return response

def _query(self, prompt: str) -> BaseMessage:
self._conversation_history.append(HumanMessage(content=prompt))
response = self._chat(self._conversation_history)
kwargs = {}
if self._functions:
kwargs["functions"] = self._functions
response = self._chat(self._conversation_history, **kwargs) # type: ignore
self._conversation_history.append(response)

if response.additional_kwargs.get("function_call"):
name = response.additional_kwargs["function_call"]["name"]
arguments = response.additional_kwargs["function_call"]["arguments"]
self._conversation_history.append(
FunctionMessage(name=name, content=arguments)
)

return response

def last_response(self) -> str:
Expand All @@ -109,20 +130,43 @@ class Chat2Plot(Chat2PlotBase):
def __init__(
self,
df: pd.DataFrame,
chart_schema: Literal["simple"] | Type[pydantic.BaseModel],
*,
chat: BaseChatModel | None = None,
function_call: bool | Literal["auto"] = False,
language: str | None = None,
description_strategy: str = "head",
verbose: bool = False,
custom_deserializer: ModelDeserializer | None = None,
):
self._target_schema: Type[pydantic.BaseModel] = (
PlotConfig if chart_schema == "simple" else chart_schema # type: ignore
)

chat_model = _get_or_default_chat_model(chat)

self._function_call = (
_has_function_call_capability(chat_model)
if function_call == "auto"
else function_call
)

self._session = ChatSession(
chat_model,
df,
system_prompt("simple", language),
system_prompt("simple", self._function_call, language, self._target_schema),
"<{text}>",
description_strategy,
chat,
functions=[
get_schema_of_chart_config(self._target_schema, as_function=True)
]
if self._function_call
else None,
)
self._df = df
self._verbose = verbose
self._custom_deserializer = custom_deserializer
self._language = language

@property
def session(self) -> ChatSession:
Expand All @@ -135,20 +179,22 @@ def query(self, q: str, config_only: bool = False, show_plot: bool = False) -> P
if self._verbose:
_logger.info(f"request: {q}")
_logger.info(f"first response: {raw_response}")
return self._parse_response(raw_response, config_only, show_plot)
return self._parse_response(q, raw_response, config_only, show_plot)
except _APPLICATION_ERRORS as e:
if self._verbose:
_logger.warning(traceback.format_exc())
msg = e.message if isinstance(e, jsonschema.ValidationError) else str(e)
error_correction = error_correction_prompt().format(
error_correction = error_correction_prompt(self._function_call).format(
error_message=msg,
)
corrected_response = self._session.query(error_correction)
if self._verbose:
_logger.info(f"retry response: {corrected_response}")

try:
return self._parse_response(corrected_response, config_only, show_plot)
return self._parse_response(
q, corrected_response, config_only, show_plot
)
except _APPLICATION_ERRORS as e:
if self._verbose:
_logger.warning(e)
Expand All @@ -166,22 +212,38 @@ def __call__(
) -> Plot:
return self.query(q, config_only, show_plot)

def _parse_response(self, content: str, config_only: bool, show_plot: bool) -> Plot:
explanation, json_data = parse_json(content)
def _parse_response(
self, q: str, response: BaseMessage, config_only: bool, show_plot: bool
) -> Plot:
if self._function_call:
if not response.additional_kwargs.get("function_call"):
raise ValueError("Function should be called")
function_call = response.additional_kwargs["function_call"]
json_data = commentjson.loads(function_call["arguments"])

explanation = self._session.query(
explanation_prompt(self._language, q), raw=True
).content
else:
explanation, json_data = parse_json(response.content)

try:
config = PlotConfig.from_json(json_data)
if self._custom_deserializer:
config = self._custom_deserializer(json_data)
else:
config = pydantic.parse_obj_as(self._target_schema, json_data)
# config = self._target_schema.from_json(json_data)
except _APPLICATION_ERRORS:
_logger.warning(traceback.format_exc())
# To reduce the number of failure cases as much as possible,
# only check against the json schema when instantiation fails.
jsonschema.validate(json_data, PlotConfig.schema())
jsonschema.validate(json_data, self._target_schema.schema())
raise

if self._verbose:
_logger.info(config)

if config_only:
if config_only or not isinstance(config, PlotConfig):
return Plot(
None, config, ResponseType.SUCCESS, explanation, self._session.history
)
Expand All @@ -202,7 +264,11 @@ def __init__(
verbose: bool = False,
):
self._session = ChatSession(
df, system_prompt("vega", language), "<{text}>", description_strategy, chat
_get_or_default_chat_model(chat),
df,
system_prompt("vega", False, language, None),
"<{text}>",
description_strategy,
)
self._df = df
self._verbose = verbose
Expand All @@ -215,15 +281,17 @@ def query(self, q: str, config_only: bool = False, show_plot: bool = False) -> P
res = self._session.query(q)

try:
explanation, config = parse_json(res)
explanation, config = parse_json(res.content)
if "data" in config:
del config["data"]
if self._verbose:
_logger.info(config)
except _APPLICATION_ERRORS:
_logger.warning(f"failed to parse LLM response: {res}")
_logger.warning(traceback.format_exc())
return Plot(None, None, ResponseType.UNKNOWN, res, self._session.history)
return Plot(
None, None, ResponseType.UNKNOWN, res.content, self._session.history
)

if config_only:
return Plot(
Expand Down Expand Up @@ -253,19 +321,21 @@ def __call__(

def chat2plot(
df: pd.DataFrame,
model_type: str = "simple",
schema_definition: Literal["simple", "vega"] | Type[pydantic.BaseModel] = "simple",
chat: BaseChatModel | None = None,
function_call: bool | Literal["auto"] = "auto",
language: str | None = None,
description_strategy: str = "head",
verbose: bool = False,
custom_deserializer: ModelDeserializer | None = None,
) -> Chat2PlotBase:
"""Create Chat2Plot instance.
Args:
df: Data source for visualization.
model_type: Type of json format. "vega" for a vega-lite compliant format, or "simple" or a simpler format.
schema_definition: Type of json format. "vega" for a vega-lite compliant format, or "simple" or a simpler format.
chat: The chat instance for interaction with LLMs.
If omitted, `ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo")` will be used.
If omitted, `ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-0613")` will be used.
language: Language of explanations. If not specified, it will be automatically inferred from user prompts.
description_strategy: Type of how the information in the dataset is embedded in the prompt.
Defaults to "head" which embeds the contents of df.head(5) in the prompt.
Expand All @@ -277,13 +347,33 @@ def chat2plot(
Chat instance.
"""

if model_type == "simple":
return Chat2Plot(df, chat, language, description_strategy, verbose)
elif model_type == "vega":
if schema_definition == "simple":
return Chat2Plot(
df,
"simple",
chat=chat,
language=language,
description_strategy=description_strategy,
verbose=verbose,
custom_deserializer=custom_deserializer or PlotConfig.from_json,
function_call=function_call,
)
if schema_definition == "vega":
return Chat2Vega(df, chat, language, description_strategy, verbose)
elif issubclass(schema_definition, pydantic.BaseModel):
return Chat2Plot(
df,
schema_definition,
chat=chat,
language=language,
description_strategy=description_strategy,
verbose=verbose,
custom_deserializer=custom_deserializer,
function_call=function_call,
)
else:
raise ValueError(
f"model_type should be one of [default, vega] (given: {model_type})"
f"schema_definition should be one of [simple, vega] or pydantic.BaseClass (given: {schema_definition})"
)


Expand All @@ -310,3 +400,18 @@ def parse_json(content: str) -> tuple[str, dict[str, Any]]:

# LLM rarely generates JSON with comments, so use the commentjson package instead of json
return explanation_part.strip(), delete_null_field(commentjson.loads(json_part))


def _get_or_default_chat_model(chat: BaseChatModel | None) -> BaseChatModel:
if chat is None:
return ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-0613") # type: ignore
return chat


def _has_function_call_capability(chat: BaseChatModel) -> bool:
if not isinstance(chat, ChatOpenAI):
return False
return any(
chat.model_name.startswith(prefix)
for prefix in ["gpt-4-0613", "gpt-3.5-turbo-0613"]
)
Loading

0 comments on commit d8fc902

Please sign in to comment.