Skip to content

Commit

Permalink
small improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
nyanp committed May 12, 2023
1 parent e932186 commit 287f4b7
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 36 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
.ipynb_checkpoints
.idea
notebooks/
datasets/
apps/
**/__pycache__/
*.egg-info/
18 changes: 2 additions & 16 deletions chat2plot/chat2plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,19 +241,5 @@ def parse_json(content: str) -> tuple[str, dict[str, Any]]:
if not explanation_part:
explanation_part = _extract_tag_content(content, "explanation")

return explanation_part.strip(), json.loads(json_part)


def _parse_json(content: str) -> tuple[str, dict[str, Any]]:
ptn = r"```json(.*)```" if "```json" in content else r"```(.*)```"
s = re.search(ptn, content, re.MULTILINE | re.DOTALL)
if s:
json_part = json.loads(s.group(1)) # type: ignore
non_json_part = content.replace(s.group(0), "")
return non_json_part, delete_null_field(json_part)

try:
json_part = json.loads(content)
return "", json_part
except Exception:
raise ValueError("failed to find start(```) and end(```) marker")
return explanation_part.strip(), delete_null_field(json.loads(json_part))

36 changes: 23 additions & 13 deletions chat2plot/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from chat2plot.schema import (
AggregationType,
BarMode,
ChartType,
Filter,
PlotConfig,
Expand Down Expand Up @@ -47,22 +48,24 @@ def draw_plotly(df: pd.DataFrame, config: PlotConfig, show: bool = True) -> Figu
x = agg.columns[0]
y = agg.columns[-1]
orientation = "v"
bar_mode = "group" if config.bar_mode == BarMode.GROUP else "relative"

if chart_type == ChartType.HORIZONTAL_BAR:
x, y = y, x
orientation = "h"

fig = px.bar(
agg,
color=config.hue or None,
color=config.color or None,
orientation=orientation,
barmode=bar_mode,
**_ax_config(config, x, y),
)
elif chart_type == ChartType.SCATTER:
assert config.x is not None
fig = px.scatter(
df_filtered,
color=config.hue or None,
color=config.color or None,
**_ax_config(config, config.x.column, config.y.column),
)
elif chart_type == ChartType.PIE:
Expand All @@ -80,7 +83,7 @@ def draw_plotly(df: pd.DataFrame, config: PlotConfig, show: bool = True) -> Figu
assert config.x is not None
fig = func_table[chart_type](
df_filtered,
color=config.hue or None,
color=config.color or None,
**_ax_config(config, config.x.column, config.y.column),
)
else:
Expand Down Expand Up @@ -112,8 +115,8 @@ def draw_altair(
def groupby_agg(df: pd.DataFrame, config: PlotConfig) -> pd.DataFrame:
group_by = [config.x.column] if config.x is not None else []

if config.hue and (not config.x or (config.hue != config.x.column)):
group_by.append(config.hue)
if config.color and (not config.x or (config.color != config.x.column)):
group_by.append(config.color)

agg_method = {
AggregationType.AVG: "mean",
Expand Down Expand Up @@ -155,12 +158,19 @@ def filter_data(df: pd.DataFrame, filters: list[str]) -> pd.DataFrame:
if not filters:
return df

elements = []
for f in filters:
try:
e = Filter.parse_from_llm(f).escaped()
except Exception:
e = f
elements.append(f"({e})")
def _filter_data(df: pd.DataFrame, filters: list[str], with_escape: bool) -> pd.DataFrame:
if with_escape:
return df.query(" and ".join([Filter.parse_from_llm(f).escaped() for f in filters]))
else:
return df.query(" and ".join(filters))

# 1. LLM sometimes forgets to escape column names when necessary.
# In this case, adding escaping will handle it correctly.
# 2. LLM sometimes writes multiple OR conditions in one filter.
# In this case, adding escapes leads to errors.
# Since both cases exist, add escapes and retry only when an error occurs.
try:
return _filter_data(df, filters, False)
except Exception:
return _filter_data(df, filters, True)

return df.query(" and ".join(elements))
23 changes: 16 additions & 7 deletions chat2plot/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ class TimeUnit(str, Enum):
DAY = "day"


class BarMode(str, Enum):
STACK = "stacked"
GROUP = "group"



class Transform(pydantic.BaseModel):
aggregation: AggregationType | None = pydantic.Field(
None,
Expand Down Expand Up @@ -114,8 +120,8 @@ def parse_from_llm(cls, f: str) -> "Filter":


class Axis(pydantic.BaseModel):
column: str = pydantic.Field(None, description="column in datasets used for the axis")
transform: Transform | None = pydantic.Field(None, description="transformation applied to column")
column: str = pydantic.Field(description="column in datasets used for the axis")
transform: Transform = pydantic.Field(None, description="transformation applied to column")
min_value: float | None
max_value: float | None
label: str | None
Expand All @@ -135,23 +141,25 @@ def parse_from_llm(cls, d: dict[str, str | float | dict[str, str]]) -> "Axis":


class PlotConfig(pydantic.BaseModel):
chart_type: ChartType = pydantic.Field(None, description="the type of the chart")
chart_type: ChartType = pydantic.Field(description="the type of the chart")
filters: list[str] = pydantic.Field(
None,
description="List of filter conditions, where each filter must be a legal string that can be passed to df.query(),"
' such as "x >= 0". Filters will be calculated before transforming axis.',
)
x: Axis | None = pydantic.Field(
None, description="X-axis for the chart, or label column for pie chart"
)
y: Axis = pydantic.Field(
None,
description="Y-axis or measure value for the chart, or the wedge sizes for pie chart.",
)
hue: str | None = pydantic.Field(
color: str | None = pydantic.Field(
None,
description="Column name used as grouping variables that will produce different colors.",
)
bar_mode: BarMode | None = pydantic.Field(
None,
description="If 'stacked', bars are stacked. In 'group' mode, bars are placed beside each other."
)
sort_criteria: SortingCriteria | None = pydantic.Field(
None, description="The sorting criteria for x-axis"
)
Expand Down Expand Up @@ -190,7 +198,8 @@ def wrap_if_not_list(value: str | list[str]) -> list[str]:
x=Axis.parse_from_llm(json_data["x"]) if json_data.get("x") else None,
y=Axis.parse_from_llm(json_data["y"]),
filters=wrap_if_not_list(json_data.get("filters", [])),
hue=json_data.get("hue") or None,
color=json_data.get("color") or None,
bar_mode=BarMode(json_data["bar_mode"]) if json_data.get("bar_mode") else None,
sort_criteria=SortingCriteria(json_data["sort_criteria"])
if json_data.get("sort_criteria")
else None,
Expand Down
4 changes: 4 additions & 0 deletions chat2plot/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def binning(series: pd.Series, interval: int) -> pd.Series:


def round_datetime(series: pd.Series, period: TimeUnit) -> pd.Series:
if is_integer_dtype(series) and period == TimeUnit.YEAR:
# assuming that it is year column, so no transform is needed
return series

series = pd.to_datetime(series)

period_map = {TimeUnit.DAY: "D", TimeUnit.WEEK: "W", TimeUnit.MONTH: "M", TimeUnit.QUARTER: "Q", TimeUnit.YEAR: "Y"}
Expand Down

0 comments on commit 287f4b7

Please sign in to comment.