-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Jonathan Talmi
committed
Mar 17, 2023
0 parents
commit fb08864
Showing
5 changed files
with
1,891 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import streamlit as st | ||
from streamlit_ace import st_ace | ||
|
||
from util import RULE_MAPPING, SAMPLE_QUERY, apply_optimizations, format_sql_with_sqlfmt | ||
|
||
st.set_page_config(layout="wide") | ||
|
||
|
||
# Set custom CSS | ||
st.markdown( | ||
""" | ||
# Optimize and lint SQL using [sqlglot](https://github.com/tobymao/sqlglot) and [sqlfmt](http://sqlfmt.com/) | ||
<style> | ||
body { | ||
background-color: black; | ||
color: white; | ||
} | ||
</style> | ||
""", | ||
unsafe_allow_html=True, | ||
) | ||
|
||
left, right = st.columns(2) | ||
|
||
# Add rule selector | ||
selected_rules = st.multiselect( | ||
'Optimization [rules](https://github.com/tobymao/sqlglot/blob/main/sqlglot/optimizer/optimizer.py). Drop "canonicalize" rule to prevent quoting.', | ||
list(RULE_MAPPING.keys()), | ||
default=list(RULE_MAPPING.keys()), | ||
) | ||
|
||
# Add checkboxes and button | ||
cols = [col for col in st.columns(12)] | ||
remove_ctes = cols[0].checkbox("Remove CTEs", on_change=None) | ||
format_with_sqlfmt = cols[1].checkbox("Lint w/ sqlfmt", on_change=None) | ||
optimize_button = st.button("Optimize SQL") | ||
|
||
# Initialize session state | ||
if "new_query" not in st.session_state: | ||
st.session_state.new_query = "" | ||
|
||
if "state" not in st.session_state: | ||
st.session_state.state = 0 | ||
|
||
# Add input editor | ||
with left: | ||
sql_input = st_ace( | ||
value=SAMPLE_QUERY, | ||
height=320, | ||
theme="monokai", | ||
language="sql", | ||
keybinding="vscode", | ||
font_size=14, | ||
wrap=True, | ||
min_lines=10, | ||
auto_update=True, | ||
) | ||
|
||
# Optimize and lint query | ||
if optimize_button: | ||
try: | ||
rules = [RULE_MAPPING[rule] for rule in selected_rules] | ||
new_query = apply_optimizations(sql_input, rules, remove_ctes).sql(pretty=True) | ||
if format_with_sqlfmt: | ||
new_query = format_sql_with_sqlfmt(new_query) | ||
st.session_state.new_query = new_query | ||
st.session_state.state += 1 | ||
except Exception as e: | ||
st.error(f"Error: {e}") | ||
|
||
# Add output editor | ||
with right: | ||
view_editor = st_ace( | ||
value=st.session_state.new_query, | ||
height=320, | ||
theme="monokai", | ||
language="sql", | ||
keybinding="vscode", | ||
font_size=14, | ||
wrap=True, | ||
readonly=True, | ||
min_lines=10, | ||
auto_update=True, | ||
key=f"ace-{st.session_state.state}", | ||
) |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
[tool.poetry] | ||
name = "automodel" | ||
version = "0.1.0" | ||
description = "" | ||
authors = ["Jonathan Talmi <jtalmi@gmail.com>"] | ||
|
||
[tool.poetry.dependencies] | ||
python = ">=3.8.0,<3.9.7 || >3.9.7,<4.0" | ||
sqlglot = "11.3.7" | ||
streamlit = "^1.20.0" | ||
shandy-sqlfmt = "^0.17.0" | ||
streamlit-ace = "^0.1.1" | ||
|
||
[tool.poetry.group.dev.dependencies] | ||
pytest-snapshot = "^0.9.0" | ||
black = "^23.1.0" | ||
isort = "^5.12.0" | ||
pytest = "^5.2" | ||
ipdb = "^0.13.9" | ||
|
||
[build-system] | ||
requires = ["poetry-core>=1.0.0"] | ||
build-backend = "poetry.core.masonry.api" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from typing import Callable, Dict, Sequence | ||
|
||
from sqlfmt.api import Mode, format_string | ||
from sqlfmt.exception import SqlfmtError | ||
from sqlglot import parse_one | ||
from sqlglot.expressions import Select | ||
from sqlglot.optimizer import RULES, optimize | ||
|
||
RULE_MAPPING: Dict[str, Callable] = {rule.__name__: rule for rule in RULES} | ||
SAMPLE_QUERY: str = """WITH users AS ( | ||
SELECT * | ||
FROM users_table), | ||
orders AS ( | ||
SELECT * | ||
FROM orders_table), | ||
combined AS ( | ||
SELECT users.id, users.name, orders.order_id, orders.total | ||
FROM users | ||
JOIN orders ON users.id = orders.user_id) | ||
SELECT combined.id, combined.name, combined.order_id, combined.total | ||
FROM combined | ||
""" | ||
|
||
|
||
def _generate_ast(query: str) -> Select: | ||
""" | ||
Generate an AST from a query. | ||
""" | ||
ast = parse_one(query) | ||
return ast | ||
|
||
|
||
def apply_optimizations( | ||
query: str, rules: Sequence[Callable] = RULES, remove_ctes: bool = False | ||
) -> Select: | ||
""" | ||
Apply optimizations to an AST. | ||
""" | ||
ast = _generate_ast(query) | ||
if remove_ctes: | ||
return optimize(ast, rules=rules) | ||
else: | ||
return optimize(ast, rules=rules, leave_tables_isolated=True) | ||
|
||
|
||
def format_sql_with_sqlfmt(query: str) -> str: | ||
""" | ||
Format a query using sqlfmt. | ||
""" | ||
mode = Mode() | ||
return format_string(query, mode) |