Skip to content

Commit

Permalink
feat: add transpile_sql and execute_last_sql for demo (#58)
Browse files Browse the repository at this point in the history
* change CodeAttachment attribute language to dialect

* change CodeAttachment attribute language to dialect

* add transplile and execute_last_sql

* add sqlglot

* change code attachment to sql

* get_attachment_by_type

* resolve comments

* multiple table attachments in one message
  • Loading branch information
jitingxu1 authored Mar 22, 2024
1 parent 76e42f9 commit ed0ca32
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 60 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dependencies = [
'ibis-framework[duckdb,examples]', # @ git+https://github.com/ibis-project/ibis',
'plotly',
'streamlit',
'sqlglot',
]

[project.urls]
Expand Down
36 changes: 24 additions & 12 deletions src/ibis_birdbrain/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ibis_birdbrain.bot import Bot
from ibis_birdbrain.attachments import (
CodeAttachment,
SQLAttachment,
ErrorAttachment,
TextAttachment,
WebpageAttachment,
Expand Down Expand Up @@ -51,11 +51,23 @@ def process_message(message, include_attachments=False):
results = []
results.append(st.markdown(message.body))
if include_attachments:
for attachment in message.attachments:
a = message.attachments[attachment] # TODO: hack
if isinstance(a, CodeAttachment):
expander = st.expander(label=a.language, expanded=False)
results.append(expander.markdown(f"```{a.language}\n{a.open()}"))
if sql_attachment := message.attachments.get_attachment_by_type(SQLAttachment):
expander = st.expander(label=a.dialect, expanded=False)
results.append(expander.markdown(f"```{sql_attachment.dialect}\n{sql_attachment.open()}"))

if table_attachments := message.attachments.get_attachment_by_type(TableAttachment):
# only have 1 table
results.append(
st.dataframe(
table_attachments[0].open().limit(1000).to_pandas(), use_container_width=True
)
)

# for attachment in message.attachments:
# a = message.attachments[attachment] # TODO: hack
# if isinstance(a, SQLAttachment):
# expander = st.expander(label=a.dialect, expanded=False)
# results.append(expander.markdown(f"```{a.dialect}\n{a.open()}"))
# elif isinstance(a, TextAttachment):
# results.append(st.markdown(a.open()))
# elif isinstance(a, ErrorAttachment):
Expand All @@ -64,12 +76,12 @@ def process_message(message, include_attachments=False):
# results.append(st.markdown(a.open())) # TODO: better?
# elif isinstance(a, DataAttachment):
# results.append(st.markdown(a.open()))
elif isinstance(a, TableAttachment):
results.append(
st.dataframe(
a.open().limit(1000).to_pandas(), use_container_width=True
)
)
# elif isinstance(a, TableAttachment):
# results.append(
# st.dataframe(
# a.open().limit(1000).to_pandas(), use_container_width=True
# )
# )
# elif isinstance(a, ChartAttachment):
# results.append(st.plotly_chart(a.open(), use_container_width=True))
# else:
Expand Down
29 changes: 26 additions & 3 deletions src/ibis_birdbrain/attachments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# imports
from collections import defaultdict
from uuid import uuid4
from typing import Any, Union, List
from typing import Any, Union, List, Type
from datetime import datetime

from ibis.expr.types.relations import Table
Expand Down Expand Up @@ -53,14 +54,20 @@ class Attachments:
"""Ibis Birdbrain attachments."""

attachments: dict[str, Attachment]
type_id_map: dict[Type[Attachment], List[str]]

def __init__(self, attachments: list[Attachment] = []) -> None:
"""Initialize the attachments."""
self.attachments = {a.id: a for a in attachments}
self.type_id_map = defaultdict(list)
for a in attachments:
self.type_id_map[type(a)].append(a.id)


def add_attachment(self, attachment: Attachment):
"""Add an attachment to the collection."""
self.attachments[attachment.id] = attachment
self.type_id_map[type(attachment)].append(attachment.id)

def append(self, attachment: Attachment):
"""Alias for add_attachment."""
Expand All @@ -75,6 +82,21 @@ def extend(self, attachments: Union[List[Attachment], "Attachments"]):

return self

def get_attachment_by_type(self, attachment_type: Type[Attachment]):
"""Get attachments of a specific type."""
if attachment_type not in self.type_id_map:
return None

ids = self.type_id_map[attachment_type]
if not isinstance(self.attachments[ids[0]], TableAttachment):
return self.attachments[ids[0]]

# One messages may have multiple table attachment
attachments = Attachments()
for id in ids:
attachments.append(self.attachments[id])
return attachments

def __getitem__(self, id: str | int):
"""Get an attachment from the collection."""
if isinstance(id, int):
Expand All @@ -84,6 +106,7 @@ def __getitem__(self, id: str | int):
def __setitem__(self, id: str, attachment: Attachment):
"""Set an attachment in the collection."""
self.attachments[id] = attachment
self.type_id_map[type(attachment)].append(id)

def __len__(self) -> int:
"""Get the length of the collection."""
Expand All @@ -109,7 +132,7 @@ def __repr__(self):
)
from ibis_birdbrain.attachments.text import (
TextAttachment,
CodeAttachment,
SQLAttachment,
ErrorAttachment,
WebpageAttachment,
)
Expand All @@ -122,7 +145,7 @@ def __repr__(self):
"TableAttachment",
"ChartAttachment",
"TextAttachment",
"CodeAttachment",
"SQLAttachment",
"ErrorAttachment",
"WebpageAttachment",
]
11 changes: 6 additions & 5 deletions src/ibis_birdbrain/attachments/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,16 @@ def open(self, browser=False):
return self.url


class CodeAttachment(TextAttachment):
# TODO: add CodeAttachment
class SQLAttachment(TextAttachment):
"""A code attachment."""

content: str
language: str
dialect: str

def __init__(self, language="python", *args, **kwargs):
def __init__(self, dialect="duckdb", *args, **kwargs):
super().__init__(*args, **kwargs)
self.language = language
self.dialect = dialect

def encode(self):
...
Expand All @@ -94,7 +95,7 @@ def __str__(self):
return (
super().__str__()
+ f"""
**language**: {self.language}
**dialect**: {self.dialect}
**code**:\n{self.content}"""
)

Expand Down
42 changes: 36 additions & 6 deletions src/ibis_birdbrain/bot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# imports
import ibis
import sqlglot as sg

from uuid import uuid4
from typing import Any
Expand All @@ -12,12 +13,14 @@
Attachments,
TableAttachment,
DatabaseAttachment,
SQLAttachment,
)
from ibis_birdbrain.flows import Flows
from ibis_birdbrain.strings import bot_description
from ibis_birdbrain.messages import Message, Messages, Email
from ibis_birdbrain.utils.strings import shorten_str
from ibis_birdbrain.utils.attachments import to_attachments
from ibis_birdbrain.tasks.sql import ExecuteSQLTask

from ibis_birdbrain.flows.data import DataFlow

Expand Down Expand Up @@ -172,12 +175,39 @@ def respond(self, messages: Messages) -> Message:
"""Respond to the messages."""
...

# TODO: for demo
def translate_sql(self, sql: str, dialect_to: str, dialect_from: str) -> str:
def transpile_sql(self, sql: str, dialect_from: str, dialect_to: str) -> str:
"""Translate SQL from one dialect to another."""
...

return sg.transpile(
sql=sql,
read=dialect_from,
write=dialect_to,
identity=False,
pretty=True,
)[0]

# TODO: for demo
def execute_last_sql(self, con: BaseBackend) -> Message:
"""Execute the last SQL statement."""
...
"""Execute the last successfully executed SQL statement."""

sql_attachment = None
for m in reversed(self.messages):
if m.attachments.get_attachment_by_type(SQLAttachment) and m.attachments.get_attachment_by_type(TableAttachment):
sql_attachment = m.attachments.get_attachment_by_type(SQLAttachment)
break

if sql_attachment:
database_attachment = DatabaseAttachment(con)
task_message = Email(
body="execute this SQL on the {con.name}",
attachments=[database_attachment, sql_attachment],
to_address="execute-SQL",
from_address=self.name,
)
return ExecuteSQLTask("execute-SQL")(task_message)


return Email(
body=f"No Sql query executed",
to_address=self.messages[-1].from_address,
from_address=self.name,
)
14 changes: 5 additions & 9 deletions src/ibis_birdbrain/flows/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ibis_birdbrain.messages import Email, Messages
from ibis_birdbrain.attachments import (
TableAttachment,
CodeAttachment,
SQLAttachment,
ErrorAttachment,
Attachments,
)
Expand Down Expand Up @@ -66,8 +66,8 @@ def __call__(self, messages: Messages) -> Messages:

# check the response
assert len(task_response.attachments) == 1
assert isinstance(task_response.attachments[0], CodeAttachment)
assert task_response.attachments[0].language == "sql"
assert isinstance(task_response.attachments[0], SQLAttachment)
# assert task_response.attachments[0].language == "sql"

# extract the SQL attachment
sql_attachment = task_response.attachments[0]
Expand All @@ -83,7 +83,7 @@ def __call__(self, messages: Messages) -> Messages:
task_response = self.tasks["execute-SQL"](task_message)
response_messages.append(task_response)

assert len(task_response.attachments) == 1
assert len(task_response.attachments) == 2

# check the response
if isinstance(task_response.attachments[0], TableAttachment):
Expand All @@ -106,11 +106,7 @@ def __call__(self, messages: Messages) -> Messages:
response_messages.append(task_response)

# get the new sql_attachment
for attachment in task_response.attachments:
if isinstance(
task_response.attachments[attachment], CodeAttachment
):
sql_attachment = task_response.attachments[attachment]
sql_attachment = task_response.attachments.get_attachment_by_type(SQLAttachment)

# try executing
task_response = self.tasks["execute-SQL"](
Expand Down
39 changes: 14 additions & 25 deletions src/ibis_birdbrain/tasks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ibis_birdbrain.attachments import (
Attachments,
TableAttachment,
CodeAttachment,
SQLAttachment,
ErrorAttachment,
DatabaseAttachment,
)
Expand All @@ -32,9 +32,11 @@ def __call__(self, message: Message) -> Message:
# TODO: add proper methods for this
table_attachments = Attachments()
database_attachment = None
dialect = "duckdb"
for attachment in message.attachments:
if isinstance(message.attachments[attachment], DatabaseAttachment):
database_attachment = message.attachments[attachment]
dialect = database_attachment.open().name
elif isinstance(message.attachments[attachment], TableAttachment):
table_attachments.append(message.attachments[attachment])

Expand All @@ -46,8 +48,9 @@ def __call__(self, message: Message) -> Message:
text=message.body,
tables=table_attachments,
data_description=database_attachment.description,
dialect=dialect,
)
code_attachment = CodeAttachment(language="sql", content=sql)
code_attachment = SQLAttachment(dialect=dialect, content=sql)

# generate the response message
response_message = Email(
Expand Down Expand Up @@ -111,12 +114,8 @@ def __call__(self, message: Message) -> Message:
log.info("Executing the SQL task")

# get the database attachment and sql attachments
# TODO: add proper methods for this
for attachment in message.attachments:
if isinstance(message.attachments[attachment], CodeAttachment):
sql_attachment = message.attachments[attachment]
elif isinstance(message.attachments[attachment], DatabaseAttachment):
database_attachment = message.attachments[attachment]
sql_attachment = message.attachments.get_attachment_by_type(SQLAttachment)
database_attachment = message.attachments.get_attachment_by_type(DatabaseAttachment)

con = database_attachment.open()
sql = sql_attachment.open()
Expand All @@ -133,7 +132,7 @@ def __call__(self, message: Message) -> Message:

response_message = Email(
body="execute SQL called",
attachments=[attachment],
attachments=[attachment, sql_attachment],
to_address=message.from_address,
from_address=self.name,
)
Expand All @@ -153,21 +152,10 @@ def __call__(self, message: Message) -> Message:
"""Fix the SQL task."""
log.info("Fixing the SQL task")

# hackily get the database attachment, table attachments, sql attachment, and error attachment

database_attachment = None
table_attachments = None
sql_attachment = None
error_attachment = None
for attachment in message.attachments:
if isinstance(message.attachments[attachment], DatabaseAttachment):
database_attachment = message.attachments[attachment]
elif isinstance(message.attachments[attachment], TableAttachment):
table_attachments = message.attachments[attachment]
elif isinstance(message.attachments[attachment], CodeAttachment):
sql_attachment = message.attachments[attachment]
elif isinstance(message.attachments[attachment], ErrorAttachment):
error_attachment = message.attachments[attachment]
database_attachment = message.attachments.get_attachment_by_type(DatabaseAttachment)
table_attachments = message.attachments.get_attachment_by_type(TableAttachment)
sql_attachment = message.attachments.get_attachment_by_type(SQLAttachment)
error_attachment = message.attachments.get_attachment_by_type(ErrorAttachment)

assert database_attachment is not None, "No database attachment found"
assert table_attachments is not None, "No table attachments found"
Expand All @@ -180,11 +168,12 @@ def __call__(self, message: Message) -> Message:
error=error_attachment.open(),
tables=table_attachments,
data_description=database_attachment.description,
dialect=sql_attachment.dialect
)

response_message = Email(
body="fix SQL called",
attachments=[CodeAttachment(language="sql", content=sql)],
attachments=[SQLAttachment(dialect=sql_attachment.dialect, content=sql)],
to_address=message.from_address,
from_address=self.name,
)
Expand Down

0 comments on commit ed0ca32

Please sign in to comment.