From 287f4b7781d5b01fc84bb3d71bff9f56faed9262 Mon Sep 17 00:00:00 2001 From: Taiga Noumi Date: Fri, 12 May 2023 23:24:16 +0900 Subject: [PATCH] small improvements --- .gitignore | 1 + chat2plot/chat2plot.py | 18 ++---------------- chat2plot/render.py | 36 +++++++++++++++++++++++------------- chat2plot/schema.py | 23 ++++++++++++++++------- chat2plot/transform.py | 4 ++++ 5 files changed, 46 insertions(+), 36 deletions(-) diff --git a/.gitignore b/.gitignore index 38c7201..42a31e2 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ .ipynb_checkpoints .idea notebooks/ +datasets/ apps/ **/__pycache__/ *.egg-info/ diff --git a/chat2plot/chat2plot.py b/chat2plot/chat2plot.py index 5cbd121..3abfe6c 100644 --- a/chat2plot/chat2plot.py +++ b/chat2plot/chat2plot.py @@ -241,19 +241,5 @@ def parse_json(content: str) -> tuple[str, dict[str, Any]]: if not explanation_part: explanation_part = _extract_tag_content(content, "explanation") - return explanation_part.strip(), json.loads(json_part) - - -def _parse_json(content: str) -> tuple[str, dict[str, Any]]: - ptn = r"```json(.*)```" if "```json" in content else r"```(.*)```" - s = re.search(ptn, content, re.MULTILINE | re.DOTALL) - if s: - json_part = json.loads(s.group(1)) # type: ignore - non_json_part = content.replace(s.group(0), "") - return non_json_part, delete_null_field(json_part) - - try: - json_part = json.loads(content) - return "", json_part - except Exception: - raise ValueError("failed to find start(```) and end(```) marker") + return explanation_part.strip(), delete_null_field(json.loads(json_part)) + diff --git a/chat2plot/render.py b/chat2plot/render.py index 116c53b..fb70654 100644 --- a/chat2plot/render.py +++ b/chat2plot/render.py @@ -10,6 +10,7 @@ from chat2plot.schema import ( AggregationType, + BarMode, ChartType, Filter, PlotConfig, @@ -47,6 +48,7 @@ def draw_plotly(df: pd.DataFrame, config: PlotConfig, show: bool = True) -> Figu x = agg.columns[0] y = agg.columns[-1] orientation = "v" + bar_mode = "group" if config.bar_mode == BarMode.GROUP else "relative" if chart_type == ChartType.HORIZONTAL_BAR: x, y = y, x @@ -54,15 +56,16 @@ def draw_plotly(df: pd.DataFrame, config: PlotConfig, show: bool = True) -> Figu fig = px.bar( agg, - color=config.hue or None, + color=config.color or None, orientation=orientation, + barmode=bar_mode, **_ax_config(config, x, y), ) elif chart_type == ChartType.SCATTER: assert config.x is not None fig = px.scatter( df_filtered, - color=config.hue or None, + color=config.color or None, **_ax_config(config, config.x.column, config.y.column), ) elif chart_type == ChartType.PIE: @@ -80,7 +83,7 @@ def draw_plotly(df: pd.DataFrame, config: PlotConfig, show: bool = True) -> Figu assert config.x is not None fig = func_table[chart_type]( df_filtered, - color=config.hue or None, + color=config.color or None, **_ax_config(config, config.x.column, config.y.column), ) else: @@ -112,8 +115,8 @@ def draw_altair( def groupby_agg(df: pd.DataFrame, config: PlotConfig) -> pd.DataFrame: group_by = [config.x.column] if config.x is not None else [] - if config.hue and (not config.x or (config.hue != config.x.column)): - group_by.append(config.hue) + if config.color and (not config.x or (config.color != config.x.column)): + group_by.append(config.color) agg_method = { AggregationType.AVG: "mean", @@ -155,12 +158,19 @@ def filter_data(df: pd.DataFrame, filters: list[str]) -> pd.DataFrame: if not filters: return df - elements = [] - for f in filters: - try: - e = Filter.parse_from_llm(f).escaped() - except Exception: - e = f - elements.append(f"({e})") + def _filter_data(df: pd.DataFrame, filters: list[str], with_escape: bool) -> pd.DataFrame: + if with_escape: + return df.query(" and ".join([Filter.parse_from_llm(f).escaped() for f in filters])) + else: + return df.query(" and ".join(filters)) + + # 1. LLM sometimes forgets to escape column names when necessary. + # In this case, adding escaping will handle it correctly. + # 2. LLM sometimes writes multiple OR conditions in one filter. + # In this case, adding escapes leads to errors. + # Since both cases exist, add escapes and retry only when an error occurs. + try: + return _filter_data(df, filters, False) + except Exception: + return _filter_data(df, filters, True) - return df.query(" and ".join(elements)) diff --git a/chat2plot/schema.py b/chat2plot/schema.py index 5aaef73..5a1b92e 100644 --- a/chat2plot/schema.py +++ b/chat2plot/schema.py @@ -52,6 +52,12 @@ class TimeUnit(str, Enum): DAY = "day" +class BarMode(str, Enum): + STACK = "stacked" + GROUP = "group" + + + class Transform(pydantic.BaseModel): aggregation: AggregationType | None = pydantic.Field( None, @@ -114,8 +120,8 @@ def parse_from_llm(cls, f: str) -> "Filter": class Axis(pydantic.BaseModel): - column: str = pydantic.Field(None, description="column in datasets used for the axis") - transform: Transform | None = pydantic.Field(None, description="transformation applied to column") + column: str = pydantic.Field(description="column in datasets used for the axis") + transform: Transform = pydantic.Field(None, description="transformation applied to column") min_value: float | None max_value: float | None label: str | None @@ -135,9 +141,8 @@ def parse_from_llm(cls, d: dict[str, str | float | dict[str, str]]) -> "Axis": class PlotConfig(pydantic.BaseModel): - chart_type: ChartType = pydantic.Field(None, description="the type of the chart") + chart_type: ChartType = pydantic.Field(description="the type of the chart") filters: list[str] = pydantic.Field( - None, description="List of filter conditions, where each filter must be a legal string that can be passed to df.query()," ' such as "x >= 0". Filters will be calculated before transforming axis.', ) @@ -145,13 +150,16 @@ class PlotConfig(pydantic.BaseModel): None, description="X-axis for the chart, or label column for pie chart" ) y: Axis = pydantic.Field( - None, description="Y-axis or measure value for the chart, or the wedge sizes for pie chart.", ) - hue: str | None = pydantic.Field( + color: str | None = pydantic.Field( None, description="Column name used as grouping variables that will produce different colors.", ) + bar_mode: BarMode | None = pydantic.Field( + None, + description="If 'stacked', bars are stacked. In 'group' mode, bars are placed beside each other." + ) sort_criteria: SortingCriteria | None = pydantic.Field( None, description="The sorting criteria for x-axis" ) @@ -190,7 +198,8 @@ def wrap_if_not_list(value: str | list[str]) -> list[str]: x=Axis.parse_from_llm(json_data["x"]) if json_data.get("x") else None, y=Axis.parse_from_llm(json_data["y"]), filters=wrap_if_not_list(json_data.get("filters", [])), - hue=json_data.get("hue") or None, + color=json_data.get("color") or None, + bar_mode=BarMode(json_data["bar_mode"]) if json_data.get("bar_mode") else None, sort_criteria=SortingCriteria(json_data["sort_criteria"]) if json_data.get("sort_criteria") else None, diff --git a/chat2plot/transform.py b/chat2plot/transform.py index 930ee31..06e0bbf 100644 --- a/chat2plot/transform.py +++ b/chat2plot/transform.py @@ -52,6 +52,10 @@ def binning(series: pd.Series, interval: int) -> pd.Series: def round_datetime(series: pd.Series, period: TimeUnit) -> pd.Series: + if is_integer_dtype(series) and period == TimeUnit.YEAR: + # assuming that it is year column, so no transform is needed + return series + series = pd.to_datetime(series) period_map = {TimeUnit.DAY: "D", TimeUnit.WEEK: "W", TimeUnit.MONTH: "M", TimeUnit.QUARTER: "Q", TimeUnit.YEAR: "Y"}