Skip to content

Commit ab5d836

Browse files
committed
added support for expandable parameters
1 parent 508937a commit ab5d836

File tree

1 file changed

+77
-43
lines changed

1 file changed

+77
-43
lines changed

cs50/sql.py

Lines changed: 77 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import re
1+
import datetime
22
import sqlalchemy
3-
import sqlparse
3+
import sys
44

55
class SQL(object):
66
"""Wrap SQLAlchemy to provide a simple SQL API."""
@@ -16,58 +16,91 @@ def __init__(self, url):
1616
try:
1717
self.engine = sqlalchemy.create_engine(url)
1818
except Exception as e:
19+
e.__context__ = None
1920
raise RuntimeError(e)
2021

21-
def execute(self, text, *multiparams, **params):
22+
def execute(self, text, **params):
2223
"""
2324
Execute a SQL statement.
2425
"""
2526

26-
# parse text
27-
parsed = sqlparse.parse(text)
28-
if len(parsed) == 0:
29-
raise RuntimeError("missing statement")
30-
elif len(parsed) > 1:
31-
raise RuntimeError("too many statements")
32-
statement = parsed[0]
33-
if statement.get_type() == "UNKNOWN":
34-
raise RuntimeError("unknown type of statement")
35-
36-
# infer paramstyle
37-
# https://www.python.org/dev/peps/pep-0249/#paramstyle
38-
paramstyle = None
39-
for token in statement.flatten():
40-
if sqlparse.utils.imt(token.ttype, t=sqlparse.tokens.Token.Name.Placeholder):
41-
_paramstyle = None
42-
if re.search(r"^\?$", token.value):
43-
_paramstyle = "qmark"
44-
elif re.search(r"^:\d+$", token.value):
45-
_paramstyle = "numeric"
46-
elif re.search(r"^:\w+$", token.value):
47-
_paramstyle = "named"
48-
elif re.search(r"^%s$", token.value):
49-
_paramstyle = "format"
50-
elif re.search(r"^%\(\w+\)s$", token.value):
51-
_paramstyle = "pyformat"
52-
else:
53-
raise RuntimeError("unknown paramstyle")
54-
if paramstyle and paramstyle != _paramstyle:
55-
raise RuntimeError("inconsistent paramstyle")
56-
paramstyle = _paramstyle
27+
class UserDefinedType(sqlalchemy.TypeDecorator):
28+
"""
29+
Add support for expandable values, a la https://bitbucket.org/zzzeek/sqlalchemy/issues/3953/expanding-parameter.
30+
"""
31+
impl = sqlalchemy.types.UserDefinedType
32+
def process_literal_param(self, value, dialect):
33+
"""Receive a literal parameter value to be rendered inline within a statement."""
34+
def process(value):
35+
"""Render a literal value, escaping as needed."""
36+
37+
# bool
38+
if isinstance(value, bool):
39+
return sqlalchemy.types.Boolean().literal_processor(dialect)(value)
40+
41+
# datetime.date
42+
elif isinstance(value, datetime.date):
43+
return sqlalchemy.types.String().literal_processor(dialect)(value.strftime("%Y-%m-%d"))
44+
45+
# datetime.datetime
46+
elif isinstance(value, datetime.datetime):
47+
return sqlalchemy.types.String().literal_processor(dialect)(value.strftime("%Y-%m-%d %H:%M:%S"))
48+
49+
# datetime.time
50+
elif isinstance(value, datetime.time):
51+
return sqlalchemy.types.String().literal_processor(dialect)(value.strftime("%H:%M:%S"))
52+
53+
# float
54+
elif isinstance(value, float):
55+
return sqlalchemy.types.Float().literal_processor(dialect)(value)
56+
57+
# int
58+
elif isinstance(value, int):
59+
return sqlalchemy.types.Integer().literal_processor(dialect)(value)
60+
61+
# long
62+
elif sys.version_info.major != 3 and isinstance(value, long):
63+
return sqlalchemy.types.Integer().literal_processor(dialect)(value)
64+
65+
# str
66+
elif isinstance(value, str):
67+
return sqlalchemy.types.String().literal_processor(dialect)(value)
68+
69+
# None
70+
elif isinstance(value, sqlalchemy.sql.elements.Null):
71+
return sqlalchemy.types.NullType().literal_processor(dialect)(value)
72+
73+
# unsupported value
74+
raise RuntimeError("unsupported value")
75+
76+
# process value(s), separating with commas as needed
77+
if type(value) is list:
78+
return ", ".join([process(v) for v in value])
79+
else:
80+
return process(value)
5781

5882
try:
5983

60-
parsed = sqlparse.split("SELECT * FROM cs50 WHERE id IN (SELECT id FROM cs50); SELECT 1; CREATE TABLE foo")
61-
print(parsed)
62-
return 0
84+
# construct a new TextClause clause
85+
statement = sqlalchemy.text(text)
86+
87+
# iterate over parameters
88+
for key, value in params.items():
6389

64-
# bind parameters before statement reaches database, so that bound parameters appear in exceptions
65-
# http://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.text
66-
# https://groups.google.com/forum/#!topic/sqlalchemy/FfLwKT1yQlg
67-
# http://docs.sqlalchemy.org/en/latest/core/connections.html#sqlalchemy.engine.Engine.execute
90+
# translate None to NULL
91+
if value is None:
92+
value = sqlalchemy.sql.null()
93+
94+
# bind parameters before statement reaches database, so that bound parameters appear in exceptions
95+
# http://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.text
96+
statement = statement.bindparams(sqlalchemy.bindparam(key, value=value, type_=UserDefinedType()))
97+
98+
# stringify bound parameters
6899
# http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html#how-do-i-render-sql-expressions-as-strings-possibly-with-bound-parameters-inlined
69-
statement = sqlalchemy.text(text).bindparams(*multiparams, **params)
70-
result = self.engine.execute(str(statement.compile(compile_kwargs={"literal_binds": True})))
100+
self.statement = str(statement.compile(compile_kwargs={"literal_binds": True}))
101+
102+
# execute statement
103+
result = self.engine.execute(self.statement)
71104

72105
# if SELECT (or INSERT with RETURNING), return result set as list of dict objects
73106
if result.returns_rows:
@@ -88,4 +121,5 @@ def execute(self, text, *multiparams, **params):
88121

89122
# else raise error
90123
except Exception as e:
124+
e.__context__ = None
91125
raise RuntimeError(e)

0 commit comments

Comments
 (0)