Skip to content

Commit

Permalink
use root validator
Browse files Browse the repository at this point in the history
  • Loading branch information
nyanp committed Jul 4, 2023
1 parent 83483f9 commit 5845192
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 66 deletions.
5 changes: 2 additions & 3 deletions chat2plot/chat2plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,7 @@ def _parse_response(
if self._custom_deserializer:
config = self._custom_deserializer(json_data)
else:
config = pydantic.parse_obj_as(self._target_schema, json_data)
# config = self._target_schema.from_json(json_data)
config = self._target_schema.parse_obj(json_data)
except _APPLICATION_ERRORS:
_logger.warning(traceback.format_exc())
# To reduce the number of failure cases as much as possible,
Expand Down Expand Up @@ -360,7 +359,7 @@ def chat2plot(
language=language,
description_strategy=description_strategy,
verbose=verbose,
custom_deserializer=custom_deserializer or PlotConfig.from_json,
custom_deserializer=custom_deserializer,
function_call=function_call,
)
if schema_definition == "vega":
Expand Down
94 changes: 31 additions & 63 deletions chat2plot/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@ class XAxis(pydantic.BaseModel):
time_unit: TimeUnit | None = pydantic.Field(
None, description="The time unit used to descretize date/datetime values"
)
min_value: float | None
max_value: float | None
label: str | None

def transformed_name(self) -> str:
Expand All @@ -109,17 +107,6 @@ def transformed_name(self) -> str:
dst = f"BINNING({dst}, {self.bin_size})"
return dst

@classmethod
def parse_from_llm(cls, d: dict[str, str | float | dict[str, str]]) -> "XAxis":
return XAxis(
column=d.get("column") or None, # type: ignore
min_value=d.get("min_value"), # type: ignore
max_value=d.get("max_value"), # type: ignore
label=d.get("label") or None, # type: ignore
bin_size=d.get("bin_size") or None, # type: ignore
time_unit=TimeUnit(d["time_unit"]) if d.get("time_unit") else None, # type: ignore
)


class YAxis(pydantic.BaseModel):
column: str = pydantic.Field(
Expand All @@ -129,8 +116,6 @@ class YAxis(pydantic.BaseModel):
None,
description="Type of aggregation. Required for all chart types but scatter plots.",
)
min_value: float | None
max_value: float | None
label: str | None

def transformed_name(self) -> str:
Expand All @@ -139,28 +124,6 @@ def transformed_name(self) -> str:
dst = f"{self.aggregation.value}({dst})"
return dst

@classmethod
def parse_from_llm(
cls, d: dict[str, str | float | dict[str, str]], needs_aggregation: bool = False
) -> "YAxis":
agg = d.get("aggregation")
if needs_aggregation and not agg:
agg = "AVG"

if not d.get("column") and needs_aggregation:
agg = "COUNTROWS"
elif agg == "COUNTROWS":
agg = "COUNT"

return YAxis(
column=d.get("column") or "", # type: ignore
aggregation=AggregationType(agg) if agg else None,
min_value=d.get("min_value"), # type: ignore
max_value=d.get("max_value"), # type: ignore
label=d.get("label") or None, # type: ignore
)


class PlotConfig(pydantic.BaseModel):
chart_type: ChartType = pydantic.Field(
description="The type of the chart. Use scatter plots as little as possible unless explicitly specified by the user. Choose 'scalar' if we need only single scalar."
Expand Down Expand Up @@ -197,11 +160,15 @@ class PlotConfig(pydantic.BaseModel):
None, description="Limit a number of data to top-N items"
)

@classmethod
def from_json(cls, json_data: dict[str, Any]) -> "PlotConfig":
@pydantic.root_validator(pre=True)
def validate(cls, json_data: dict[str, Any]) -> dict[str, Any]:
assert "chart_type" in json_data
assert "y" in json_data

if isinstance(json_data["y"], YAxis):
# use validator only if json_data is raw dictionary
return json_data

json_data = copy.deepcopy(json_data)

def wrap_if_not_list(value: str | list[str]) -> list[str]:
Expand All @@ -213,33 +180,34 @@ def wrap_if_not_list(value: str | list[str]) -> list[str]:
if not json_data["chart_type"] or json_data["chart_type"].lower() == "none":
# treat chart as bar if x-axis does not exist
chart_type = ChartType.BAR
elif json_data["chart_type"] == "histogram":
chart_type = ChartType.BAR
else:
chart_type = ChartType(json_data["chart_type"])

if not json_data.get("x") and chart_type == ChartType.PIE:
# LLM sometimes forget to fill x in pie-chart
json_data["x"] = copy.deepcopy(json_data["y"])

return cls(
chart_type=chart_type,
x=XAxis.parse_from_llm(json_data["x"]) if json_data.get("x") else None,
y=YAxis.parse_from_llm(
json_data["y"], needs_aggregation=chart_type != ChartType.SCATTER
),
filters=wrap_if_not_list(json_data.get("filters", [])),
color=json_data.get("color") or None,
bar_mode=BarMode(json_data["bar_mode"])
if json_data.get("bar_mode")
else None,
sort_criteria=SortingCriteria(json_data["sort_criteria"])
if json_data.get("sort_criteria")
else None,
sort_order=SortOrder(json_data["sort_order"])
if json_data.get("sort_order")
else None,
horizontal=json_data.get("horizontal"),
limit=json_data.get("limit"),
)
if chart_type == ChartType.PIE:
if not json_data.get("x"):
# LLM sometimes forget to fill x in pie-chart
json_data["x"] = copy.deepcopy(json_data["y"])
elif not json_data["y"].get("column") and json_data["x"].get("column"):
# ...and vice versa.
json_data["y"]["column"] = copy.deepcopy(json_data["x"]["column"])

if chart_type == ChartType.SCATTER:
if json_data["y"].get("aggregation"):
del json_data["y"]["aggregation"]
else:
if not json_data["y"].get("column"):
json_data["y"]["aggregation"] = AggregationType.COUNTROWS.value
elif json_data["y"].get("aggregation") == AggregationType.COUNTROWS.value:
json_data["y"]["aggregation"] = AggregationType.COUNT.value

if json_data.get("filters") is None:
json_data["filters"] = []
else:
json_data["filters"] = wrap_if_not_list(json_data["filters"])

return json_data


def get_schema_of_chart_config(
Expand Down

0 comments on commit 5845192

Please sign in to comment.