Skip to content

Commit

Permalink
improve accuracy
Browse files Browse the repository at this point in the history
  • Loading branch information
nyanp committed May 13, 2023
1 parent bd5e740 commit 7df295e
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 79 deletions.
15 changes: 12 additions & 3 deletions chat2plot/chat2plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any

import altair as alt
import commentjson
import jsonschema
import pandas as pd
from langchain.chat_models import ChatOpenAI
Expand Down Expand Up @@ -152,8 +153,15 @@ def _parse_response(self, content: str, config_only: bool, show_plot: bool) -> P
return Plot(None, None, ResponseType.NOT_RELATED, content, content)

explanation, json_data = parse_json(content)
jsonschema.validate(json_data, PlotConfig.schema())
config = PlotConfig.from_json(json_data)

try:
config = PlotConfig.from_json(json_data)
except Exception:
# To reduce the number of failure cases as much as possible,
# only check against the json schema when instantiation fails.
jsonschema.validate(json_data, PlotConfig.schema())
raise

if self._verbose:
_logger.info(config)

Expand Down Expand Up @@ -244,4 +252,5 @@ def parse_json(content: str) -> tuple[str, dict[str, Any]]:
if not explanation_part:
explanation_part = _extract_tag_content(content, "explanation")

return explanation_part.strip(), delete_null_field(json.loads(json_part))
# LLM rarely generates JSON with comments, so use the commentjson package instead of json
return explanation_part.strip(), delete_null_field(commentjson.loads(json_part))
6 changes: 2 additions & 4 deletions chat2plot/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ def draw_plotly(df: pd.DataFrame, config: PlotConfig, show: bool = True) -> Figu
chart_type = config.chart_type

if chart_type == ChartType.BAR:
if config.horizontal:
config = config.transpose()
agg = groupby_agg(df_filtered, config)
x = agg.columns[0]
y = agg.columns[-1]
Expand Down Expand Up @@ -128,7 +126,7 @@ def groupby_agg(df: pd.DataFrame, config: PlotConfig) -> pd.DataFrame:
}

y = config.y
aggregation = y.transform.aggregation or AggregationType.AVG
aggregation = y.aggregation or AggregationType.AVG

if not group_by:
return pd.DataFrame(
Expand All @@ -151,7 +149,7 @@ def groupby_agg(df: pd.DataFrame, config: PlotConfig) -> pd.DataFrame:


def is_aggregation(config: PlotConfig) -> bool:
return config.y.transform and config.y.transform.aggregation is not None
return config.y.aggregation is not None


def filter_data(df: pd.DataFrame, filters: list[str]) -> pd.DataFrame:
Expand Down
122 changes: 66 additions & 56 deletions chat2plot/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,40 +59,6 @@ class BarMode(str, Enum):
GROUP = "group"


class Transform(pydantic.BaseModel):
aggregation: AggregationType | None = pydantic.Field(
None,
description=f"Type of aggregation. It will be ignored when it is scatter plot",
)
bin_size: int | None = pydantic.Field(
None,
description="Integer value as the number of bins used to discretizes numeric values into a set of bins",
)
time_unit: TimeUnit | None = pydantic.Field(
None, description="The time unit used to descretize date/datetime values"
)

def transformed_name(self, col: str) -> str:
dst = col
if self.time_unit:
dst = f"UNIT({col}, {self.time_unit.value})"
if self.bin_size:
dst = f"BINNING({col}, {self.bin_size})"
if self.aggregation:
dst = f"{self.aggregation.value}({col})"
return dst

@classmethod
def parse_from_llm(cls, d: dict[str, str | int]) -> "Transform":
return Transform(
aggregation=AggregationType(d["aggregation"].upper())
if d.get("aggregation")
else None,
bin_size=d.get("bin_size") or None, # type: ignore
time_unit=d.get("time_unit") or None, # type: ignore
)


class Filter(pydantic.BaseModel):
lhs: str
rhs: str
Expand All @@ -119,26 +85,66 @@ def parse_from_llm(cls, f: str) -> "Filter":
raise ValueError(f"Unsupported op or failed to parse: {f}")


class Axis(pydantic.BaseModel):
column: str = pydantic.Field(description="column in datasets used for the axis")
transform: Transform = pydantic.Field(
None, description="transformation applied to column"
class XAxis(pydantic.BaseModel):
column: str = pydantic.Field(description="column in datasets used for the x-axis")
bin_size: int | None = pydantic.Field(
None,
description="Integer value as the number of bins used to discretizes numeric values into a set of bins",
)
time_unit: TimeUnit | None = pydantic.Field(
None, description="The time unit used to descretize date/datetime values"
)
min_value: float | None
max_value: float | None
label: str | None

def transformed_name(self) -> str:
return (
self.transform.transformed_name(self.column)
if self.transform
else self.column
dst = self.column
if self.time_unit:
dst = f"UNIT({dst}, {self.time_unit.value})"
if self.bin_size:
dst = f"BINNING({dst}, {self.bin_size})"
return dst

@classmethod
def parse_from_llm(cls, d: dict[str, str | float | dict[str, str]]) -> "XAxis":
return XAxis(
column=d.get("column") or None, # type: ignore
min_value=d.get("min_value"), # type: ignore
max_value=d.get("max_value"), # type: ignore
label=d.get("label") or None, # type: ignore
bin_size=d.get("bin_size") or None, # type: ignore
time_unit=TimeUnit(d["time_unit"]) if d.get("time_unit") else None, # type: ignore
)


class YAxis(pydantic.BaseModel):
column: str = pydantic.Field(description="column in datasets used for the y-axis")
aggregation: AggregationType | None = pydantic.Field(
None,
description=f"Type of aggregation. It will be ignored when it is scatter plot",
)
min_value: float | None
max_value: float | None
label: str | None

def transformed_name(self) -> str:
dst = self.column
if self.aggregation:
dst = f"{self.aggregation.value}({dst})"
return dst

@classmethod
def parse_from_llm(cls, d: dict[str, str | float | dict[str, str]]) -> "Axis":
return Axis(
def parse_from_llm(
cls, d: dict[str, str | float | dict[str, str]], needs_aggregation: bool = False
) -> "YAxis":
agg = d.get("aggregation")
if needs_aggregation and not agg:
agg = "AVG"

return YAxis(
column=d.get("column") or None, # type: ignore
aggregation=AggregationType(agg) if agg else None,
transform=Transform.parse_from_llm(d["transform"]) if "transform" in d else None, # type: ignore
min_value=d.get("min_value"), # type: ignore
max_value=d.get("max_value"), # type: ignore
Expand All @@ -152,10 +158,10 @@ class PlotConfig(pydantic.BaseModel):
description="List of filter conditions, where each filter must be a legal string that can be passed to df.query(),"
' such as "x >= 0". Filters will be calculated before transforming axis.',
)
x: Axis | None = pydantic.Field(
x: XAxis | None = pydantic.Field(
None, description="X-axis for the chart, or label column for pie chart"
)
y: Axis = pydantic.Field(
y: YAxis = pydantic.Field(
description="Y-axis or measure value for the chart, or the wedge sizes for pie chart.",
)
color: str | None = pydantic.Field(
Expand All @@ -176,12 +182,6 @@ class PlotConfig(pydantic.BaseModel):
None, description="If true, the chart is drawn in a horizontal orientation"
)

def transpose(self) -> "PlotConfig":
transposed = copy.deepcopy(self)
transposed.y = self.x
transposed.x = self.y
return transposed

@property
def required_columns(self) -> list[str]:
columns = [self.y.column]
Expand All @@ -200,12 +200,22 @@ def wrap_if_not_list(value: str | list[str]) -> list[str]:
else:
return value

chart_type = ChartType(json_data["chart_type"])
if not json_data.get("x") or json_data["x"] == "none":
# treat chart as bar if x-axis does not exist
chart_type = ChartType.BAR
else:
chart_type = ChartType(json_data["chart_type"])

if not json_data.get("x"):
# treat chart as bar if x-axis does not exist
chart_type = ChartType.BAR

return cls(
chart_type=chart_type,
x=Axis.parse_from_llm(json_data["x"]) if json_data.get("x") else None,
y=Axis.parse_from_llm(json_data["y"]),
x=XAxis.parse_from_llm(json_data["x"]) if json_data.get("x") else None,
y=YAxis.parse_from_llm(
json_data["y"], needs_aggregation=chart_type != ChartType.SCATTER
),
filters=wrap_if_not_list(json_data.get("filters", [])),
color=json_data.get("color") or None,
bar_mode=BarMode(json_data["bar_mode"])
Expand All @@ -231,4 +241,4 @@ def get_schema_of_chart_config(

defs = flatten_single_element_allof(defs)

return defs
return defs # type: ignore
24 changes: 8 additions & 16 deletions chat2plot/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,28 @@
import pandas as pd
from pandas.api.types import is_integer_dtype

from chat2plot.schema import Axis, PlotConfig, TimeUnit
from chat2plot.schema import PlotConfig, TimeUnit, XAxis


def transform(df: pd.DataFrame, config: PlotConfig) -> tuple[pd.DataFrame, PlotConfig]:
config = copy.deepcopy(config)

if config.x and config.x.transform:
x_trans = _transform(df, config.x)
if config.x and (config.x.bin_size or config.x.time_unit):
x_trans = _transform_x(df, config.x)
df[x_trans.name] = x_trans
config.x.column = x_trans.name

if config.y.transform:
y_trans = _transform(df, config.y)
df[y_trans.name] = y_trans
config.y.column = y_trans.name

return df, config


def _transform(df: pd.DataFrame, ax: Axis) -> pd.Series:
if not ax.transform:
return df[ax.column]

def _transform_x(df: pd.DataFrame, ax: XAxis) -> pd.Series:
dst = df[ax.column].copy()

if ax.transform.bin_size:
dst = binning(dst, ax.transform.bin_size)
if ax.bin_size:
dst = binning(dst, ax.bin_size)

if ax.transform.time_unit:
dst = round_datetime(dst, ax.transform.time_unit)
if ax.time_unit:
dst = round_datetime(dst, ax.time_unit)

return pd.Series(dst.values, name=ax.transformed_name())

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
altair>=4.2.0
commentjson==0.9.0
jsonschema
jsonref
langchain>=0.0.127
Expand Down

0 comments on commit 7df295e

Please sign in to comment.