From d8fc9021657979487cbb998ecbaee588d8656860 Mon Sep 17 00:00:00 2001 From: Taiga Noumi Date: Sat, 24 Jun 2023 18:40:04 +0900 Subject: [PATCH] support function calling API, add custom schema support --- chat2plot/chat2plot.py | 171 +++++++++++++++++++++++++++------- chat2plot/prompt.py | 115 +++++++++++++++++------ chat2plot/schema.py | 26 +++++- example/streamlit_app/main.py | 12 +-- requirements.txt | 2 +- 5 files changed, 252 insertions(+), 74 deletions(-) diff --git a/chat2plot/chat2plot.py b/chat2plot/chat2plot.py index 4e06118..177937c 100644 --- a/chat2plot/chat2plot.py +++ b/chat2plot/chat2plot.py @@ -3,7 +3,7 @@ import traceback from dataclasses import dataclass from logging import getLogger -from typing import Any +from typing import Any, Callable, Literal, Type, TypeVar import altair as alt import commentjson @@ -12,17 +12,25 @@ import pydantic from langchain.chat_models import ChatOpenAI from langchain.chat_models.base import BaseChatModel -from langchain.schema import BaseMessage, HumanMessage, SystemMessage +from langchain.schema import BaseMessage, FunctionMessage, HumanMessage, SystemMessage from plotly.graph_objs import Figure from chat2plot.dataset_description import description from chat2plot.dictionary_helper import delete_null_field -from chat2plot.prompt import JSON_TAG, error_correction_prompt, system_prompt +from chat2plot.prompt import ( + JSON_TAG, + error_correction_prompt, + explanation_prompt, + system_prompt, +) from chat2plot.render import draw_altair, draw_plotly -from chat2plot.schema import PlotConfig, ResponseType +from chat2plot.schema import PlotConfig, ResponseType, get_schema_of_chart_config _logger = getLogger(__name__) +T = TypeVar("T", bound=pydantic.BaseModel) +ModelDeserializer = Callable[[dict[str, Any]], T] + # These errors are caught within the application. # Other errors (e.g. openai.error.RateLimitError) are propagated to user code. _APPLICATION_ERRORS = ( @@ -37,7 +45,7 @@ @dataclass(frozen=True) class Plot: figure: alt.Chart | Figure | None - config: PlotConfig | dict[str, Any] | None + config: PlotConfig | dict[str, Any] | pydantic.BaseModel | None response_type: ResponseType explanation: str conversation_history: list[BaseMessage] | None @@ -48,15 +56,16 @@ class ChatSession: def __init__( self, + chat: BaseChatModel, df: pd.DataFrame, system_prompt_template: str, user_prompt_template: str, description_strategy: str = "head", - chat: BaseChatModel | None = None, + functions: list[dict[str, Any]] | None = None, ): self._system_prompt_template = system_prompt_template self._user_prompt_template = user_prompt_template - self._chat = chat or ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo") # type: ignore + self._chat = chat self._conversation_history: list[BaseMessage] = [ SystemMessage( content=system_prompt_template.format( @@ -64,6 +73,7 @@ def __init__( ) ) ] + self._functions = functions @property def history(self) -> list[BaseMessage]: @@ -72,19 +82,30 @@ def history(self) -> list[BaseMessage]: def set_chatmodel(self, chat: BaseChatModel) -> None: self._chat = chat - def query_without_history(self, q: str) -> str: + def query_without_history(self, q: str) -> BaseMessage: response = self._chat([HumanMessage(content=q)]) - return response.content + return response - def query(self, q: str) -> str: - prompt = self._user_prompt_template.format(text=q) + def query(self, q: str, raw: bool = False) -> BaseMessage: + prompt = q if raw else self._user_prompt_template.format(text=q) response = self._query(prompt) - return response.content + return response def _query(self, prompt: str) -> BaseMessage: self._conversation_history.append(HumanMessage(content=prompt)) - response = self._chat(self._conversation_history) + kwargs = {} + if self._functions: + kwargs["functions"] = self._functions + response = self._chat(self._conversation_history, **kwargs) # type: ignore self._conversation_history.append(response) + + if response.additional_kwargs.get("function_call"): + name = response.additional_kwargs["function_call"]["name"] + arguments = response.additional_kwargs["function_call"]["arguments"] + self._conversation_history.append( + FunctionMessage(name=name, content=arguments) + ) + return response def last_response(self) -> str: @@ -109,20 +130,43 @@ class Chat2Plot(Chat2PlotBase): def __init__( self, df: pd.DataFrame, + chart_schema: Literal["simple"] | Type[pydantic.BaseModel], + *, chat: BaseChatModel | None = None, + function_call: bool | Literal["auto"] = False, language: str | None = None, description_strategy: str = "head", verbose: bool = False, + custom_deserializer: ModelDeserializer | None = None, ): + self._target_schema: Type[pydantic.BaseModel] = ( + PlotConfig if chart_schema == "simple" else chart_schema # type: ignore + ) + + chat_model = _get_or_default_chat_model(chat) + + self._function_call = ( + _has_function_call_capability(chat_model) + if function_call == "auto" + else function_call + ) + self._session = ChatSession( + chat_model, df, - system_prompt("simple", language), + system_prompt("simple", self._function_call, language, self._target_schema), "<{text}>", description_strategy, - chat, + functions=[ + get_schema_of_chart_config(self._target_schema, as_function=True) + ] + if self._function_call + else None, ) self._df = df self._verbose = verbose + self._custom_deserializer = custom_deserializer + self._language = language @property def session(self) -> ChatSession: @@ -135,12 +179,12 @@ def query(self, q: str, config_only: bool = False, show_plot: bool = False) -> P if self._verbose: _logger.info(f"request: {q}") _logger.info(f"first response: {raw_response}") - return self._parse_response(raw_response, config_only, show_plot) + return self._parse_response(q, raw_response, config_only, show_plot) except _APPLICATION_ERRORS as e: if self._verbose: _logger.warning(traceback.format_exc()) msg = e.message if isinstance(e, jsonschema.ValidationError) else str(e) - error_correction = error_correction_prompt().format( + error_correction = error_correction_prompt(self._function_call).format( error_message=msg, ) corrected_response = self._session.query(error_correction) @@ -148,7 +192,9 @@ def query(self, q: str, config_only: bool = False, show_plot: bool = False) -> P _logger.info(f"retry response: {corrected_response}") try: - return self._parse_response(corrected_response, config_only, show_plot) + return self._parse_response( + q, corrected_response, config_only, show_plot + ) except _APPLICATION_ERRORS as e: if self._verbose: _logger.warning(e) @@ -166,22 +212,38 @@ def __call__( ) -> Plot: return self.query(q, config_only, show_plot) - def _parse_response(self, content: str, config_only: bool, show_plot: bool) -> Plot: - explanation, json_data = parse_json(content) + def _parse_response( + self, q: str, response: BaseMessage, config_only: bool, show_plot: bool + ) -> Plot: + if self._function_call: + if not response.additional_kwargs.get("function_call"): + raise ValueError("Function should be called") + function_call = response.additional_kwargs["function_call"] + json_data = commentjson.loads(function_call["arguments"]) + + explanation = self._session.query( + explanation_prompt(self._language, q), raw=True + ).content + else: + explanation, json_data = parse_json(response.content) try: - config = PlotConfig.from_json(json_data) + if self._custom_deserializer: + config = self._custom_deserializer(json_data) + else: + config = pydantic.parse_obj_as(self._target_schema, json_data) + # config = self._target_schema.from_json(json_data) except _APPLICATION_ERRORS: _logger.warning(traceback.format_exc()) # To reduce the number of failure cases as much as possible, # only check against the json schema when instantiation fails. - jsonschema.validate(json_data, PlotConfig.schema()) + jsonschema.validate(json_data, self._target_schema.schema()) raise if self._verbose: _logger.info(config) - if config_only: + if config_only or not isinstance(config, PlotConfig): return Plot( None, config, ResponseType.SUCCESS, explanation, self._session.history ) @@ -202,7 +264,11 @@ def __init__( verbose: bool = False, ): self._session = ChatSession( - df, system_prompt("vega", language), "<{text}>", description_strategy, chat + _get_or_default_chat_model(chat), + df, + system_prompt("vega", False, language, None), + "<{text}>", + description_strategy, ) self._df = df self._verbose = verbose @@ -215,7 +281,7 @@ def query(self, q: str, config_only: bool = False, show_plot: bool = False) -> P res = self._session.query(q) try: - explanation, config = parse_json(res) + explanation, config = parse_json(res.content) if "data" in config: del config["data"] if self._verbose: @@ -223,7 +289,9 @@ def query(self, q: str, config_only: bool = False, show_plot: bool = False) -> P except _APPLICATION_ERRORS: _logger.warning(f"failed to parse LLM response: {res}") _logger.warning(traceback.format_exc()) - return Plot(None, None, ResponseType.UNKNOWN, res, self._session.history) + return Plot( + None, None, ResponseType.UNKNOWN, res.content, self._session.history + ) if config_only: return Plot( @@ -253,19 +321,21 @@ def __call__( def chat2plot( df: pd.DataFrame, - model_type: str = "simple", + schema_definition: Literal["simple", "vega"] | Type[pydantic.BaseModel] = "simple", chat: BaseChatModel | None = None, + function_call: bool | Literal["auto"] = "auto", language: str | None = None, description_strategy: str = "head", verbose: bool = False, + custom_deserializer: ModelDeserializer | None = None, ) -> Chat2PlotBase: """Create Chat2Plot instance. Args: df: Data source for visualization. - model_type: Type of json format. "vega" for a vega-lite compliant format, or "simple" or a simpler format. + schema_definition: Type of json format. "vega" for a vega-lite compliant format, or "simple" or a simpler format. chat: The chat instance for interaction with LLMs. - If omitted, `ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo")` will be used. + If omitted, `ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-0613")` will be used. language: Language of explanations. If not specified, it will be automatically inferred from user prompts. description_strategy: Type of how the information in the dataset is embedded in the prompt. Defaults to "head" which embeds the contents of df.head(5) in the prompt. @@ -277,13 +347,33 @@ def chat2plot( Chat instance. """ - if model_type == "simple": - return Chat2Plot(df, chat, language, description_strategy, verbose) - elif model_type == "vega": + if schema_definition == "simple": + return Chat2Plot( + df, + "simple", + chat=chat, + language=language, + description_strategy=description_strategy, + verbose=verbose, + custom_deserializer=custom_deserializer or PlotConfig.from_json, + function_call=function_call, + ) + if schema_definition == "vega": return Chat2Vega(df, chat, language, description_strategy, verbose) + elif issubclass(schema_definition, pydantic.BaseModel): + return Chat2Plot( + df, + schema_definition, + chat=chat, + language=language, + description_strategy=description_strategy, + verbose=verbose, + custom_deserializer=custom_deserializer, + function_call=function_call, + ) else: raise ValueError( - f"model_type should be one of [default, vega] (given: {model_type})" + f"schema_definition should be one of [simple, vega] or pydantic.BaseClass (given: {schema_definition})" ) @@ -310,3 +400,18 @@ def parse_json(content: str) -> tuple[str, dict[str, Any]]: # LLM rarely generates JSON with comments, so use the commentjson package instead of json return explanation_part.strip(), delete_null_field(commentjson.loads(json_part)) + + +def _get_or_default_chat_model(chat: BaseChatModel | None) -> BaseChatModel: + if chat is None: + return ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-0613") # type: ignore + return chat + + +def _has_function_call_capability(chat: BaseChatModel) -> bool: + if not isinstance(chat, ChatOpenAI): + return False + return any( + chat.model_name.startswith(prefix) + for prefix in ["gpt-4-0613", "gpt-3.5-turbo-0613"] + ) diff --git a/chat2plot/prompt.py b/chat2plot/prompt.py index e27ee87..a3ce832 100644 --- a/chat2plot/prompt.py +++ b/chat2plot/prompt.py @@ -1,5 +1,8 @@ import json from textwrap import dedent +from typing import Type + +import pydantic from chat2plot.schema import get_schema_of_chart_config @@ -7,39 +10,84 @@ EXPLANATION_TAG = ["", ""] -def system_prompt(model_type: str = "simple", language: str | None = None) -> str: +def system_prompt( + model_type: str, + function_call: bool, + language: str | None, + target_schema: Type[pydantic.BaseModel] | None, +) -> str: return ( - _task_definition_part(model_type) + _task_definition_part(model_type, function_call, target_schema) + "\n" - + _data_and_detailed_instruction_part(language) + + _data_and_detailed_instruction_part(language, function_call) ) -def error_correction_prompt() -> str: - return dedent( +def error_correction_prompt(function_call: bool) -> str: + if function_call: + return dedent( + """ + Your function call fails with the following error: + {error_message} + + Correct the format and retry calling function that fixes the above mentioned error. + Do not generate the same arguments again. + """ + ) + else: + return dedent( + """ + Your response fails with the following error: + {error_message} + + Correct the json and return a new explanation and json that fixes the above mentioned error. + Do not generate the same json again. """ - Your response fails with the following error: - {error_message} + ) + - Correct the json and return a new explanation and json that fixes the above mentioned error. - Do not generate the same json again. +def explanation_prompt(language: str | None, user_original_query: str) -> str: + language_spec = ( + language + or f'the same language as the user\'s original question (question: "{user_original_query}")' + ) + prompt = dedent( + f""" + For the graph setting you have just output, + please explain why you have output this graph setting in response to the user's question. + The response MUST be in {language_spec}. """ ) + return prompt -def _task_definition_part(model_type: str) -> str: +def _task_definition_part( + model_type: str, function_call: bool, target_schema: Type[pydantic.BaseModel] | None +) -> str: if model_type == "simple": + if function_call: + return dedent( + """ + Call the chart generation function for the given dataset and user question delimited by <>. + """ + ) + + assert target_schema is not None + schema_json = json.dumps( - get_schema_of_chart_config(inlining_refs=True, remove_title=True), indent=2 + get_schema_of_chart_config( + target_schema, inlining_refs=True, remove_title=True + ), + indent=2, ) return ( dedent( """ - Your task is to generate chart configuration for the given dataset and user question delimited by <>. - Responses should be in JSON format compliant to the following JSON Schema. + Your task is to generate chart configuration for the given dataset and user question delimited by <>. + Responses should be in JSON format compliant to the following JSON Schema. - """ + """ ) + schema_json.replace("{", "{{").replace("}", "}}") ) @@ -54,24 +102,35 @@ def _task_definition_part(model_type: str) -> str: ) -def _data_and_detailed_instruction_part(language: str | None = None) -> str: +def _data_and_detailed_instruction_part( + language: str | None, function_call: bool +) -> str: language_spec = language or "the same language as the user" - return dedent( - f""" + + dataset_description_part = dedent( + """ Note that the user may want to refine the chart by asking a follow-up question to a previous request, or may want to create a new chart in a completely new context. In the latter case, be careful not to use the context used for the previous chart. - {{dataset}} - - You should do the following step by step, and your response should include both 1 and 2: - 1. Explain whether filters should be applied to the data, which chart_type and columns should be used, - and what transformations are necessary to fulfill the user's request. - The explanation MUST be in {language_spec}, - and be understandable to someone who does not know the JSON schema definition. - This text should be enclosed with {EXPLANATION_TAG[0]} and {EXPLANATION_TAG[1]} tag. - 2. Generate schema-compliant JSON that represents 1. - This text should be enclosed with {JSON_TAG[0]} and {JSON_TAG[1]} tag. - + {dataset} """ ) + + if function_call: + instruction_part = "" + else: + instruction_part = dedent( + f""" + You should do the following step by step, and your response should include both 1 and 2: + 1. Explain whether filters should be applied to the data, which chart_type and columns should be used, + and what transformations are necessary to fulfill the user's request. + The explanation MUST be in {language_spec}, + and be understandable to someone who does not know the JSON schema definition. + This text should be enclosed with {EXPLANATION_TAG[0]} and {EXPLANATION_TAG[1]} tag. + 2. Generate schema-compliant JSON that represents 1. + This text should be enclosed with {JSON_TAG[0]} and {JSON_TAG[1]} tag. + """ + ) + + return dataset_description_part + instruction_part diff --git a/chat2plot/schema.py b/chat2plot/schema.py index 068c722..35703c4 100644 --- a/chat2plot/schema.py +++ b/chat2plot/schema.py @@ -1,7 +1,7 @@ import copy import re from enum import Enum -from typing import Any +from typing import Any, Type import jsonref import pydantic @@ -87,7 +87,9 @@ def parse_from_llm(cls, f: str) -> "Filter": class XAxis(pydantic.BaseModel): - column: str = pydantic.Field(description="column in datasets used for the x-axis") + column: str = pydantic.Field( + description="name of the column in the df used for the x-axis" + ) bin_size: int | None = pydantic.Field( None, description="Integer value as the number of bins used to discretizes numeric values into a set of bins", @@ -120,7 +122,9 @@ def parse_from_llm(cls, d: dict[str, str | float | dict[str, str]]) -> "XAxis": class YAxis(pydantic.BaseModel): - column: str = pydantic.Field(description="column in datasets used for the y-axis") + column: str = pydantic.Field( + description="name of the column in the df used for the y-axis" + ) aggregation: AggregationType | None = pydantic.Field( None, description="Type of aggregation. Required for all chart types but scatter plots.", @@ -239,13 +243,25 @@ def wrap_if_not_list(value: str | list[str]) -> list[str]: def get_schema_of_chart_config( - inlining_refs: bool = False, remove_title: bool = True + target_schema: Type[pydantic.BaseModel], + inlining_refs: bool = True, + remove_title: bool = True, + as_function: bool = False, ) -> dict[str, Any]: - defs = jsonref.loads(PlotConfig.schema_json()) if inlining_refs else PlotConfig.schema() # type: ignore + defs = jsonref.loads(target_schema.schema_json()) if inlining_refs else target_schema.schema() # type: ignore if remove_title: defs = remove_field_recursively(defs, "title") defs = flatten_single_element_allof(defs) + defs = copy.deepcopy(defs) + + if as_function: + return { + "name": "generate_chart", + "description": "Generate the chart with given parameters", + "parameters": defs, + } + return defs # type: ignore diff --git a/example/streamlit_app/main.py b/example/streamlit_app/main.py index 61e2e09..d4f6a94 100644 --- a/example/streamlit_app/main.py +++ b/example/streamlit_app/main.py @@ -56,12 +56,10 @@ def initialize_logger(): model_name = st.selectbox( "Model type", ( - "gpt-3.5-turbo", - "gpt-3.5-turbo-0301", - "gpt-4", - "gpt-4-0314", - "gpt-4-32k", - "gpt-4-32k-0314", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", + "gpt-4-0613", + "gpt-4-32k-0613", ), index=0, ) @@ -89,7 +87,7 @@ def initialize_c2p(): df, st.session_state["chart_format"], verbose=True, - description_strategy="dtypes", + description_strategy="head", ) def reset_history(): diff --git a/requirements.txt b/requirements.txt index b979fa6..ff6b5fe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ altair>=4.2.0 commentjson==0.9.0 jsonschema jsonref -langchain>=0.0.127 +langchain>=0.0.208 openai>=0.27.0 pandas>=1.5.0 plotly>=5.0.0