Skip to content

Commit bb3d1f9

Browse files
author
iurii iurii
committed
add update compiler
1 parent 30712d9 commit bb3d1f9

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

src/drivers/base.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import six
44
from sqlalchemy import schema, types as sqltypes, exc, util as sa_util
55
from sqlalchemy.engine import default, reflection
6-
from sqlalchemy.sql import compiler, expression, type_api
6+
from sqlalchemy.sql import compiler, expression, type_api, crud
77
from sqlalchemy.types import DATE, DATETIME, INTEGER, VARCHAR, FLOAT
88

99
from .. import types
@@ -179,6 +179,26 @@ def visit_join(self, join, asfrom=False, **kwargs):
179179
"ON", self.process(join.onclause, asfrom=True, **kwargs)
180180
))
181181

182+
def visit_update(self, update_stmt, asfrom=False, **kw):
183+
text = 'ALTER TABLE '
184+
table_text = self.update_tables_clause(update_stmt, update_stmt.table, [], **kw)
185+
text += table_text
186+
text += ' UPDATE '
187+
crud_params = crud._setup_crud_params(
188+
self, update_stmt, crud.ISUPDATE, include_table=False, **kw)
189+
190+
text += ', '.join(
191+
c[0]._compiler_dispatch(self,
192+
include_table=False) +
193+
'=' + c[1] for c in crud_params
194+
)
195+
196+
if update_stmt._whereclause is not None:
197+
t = update_stmt._whereclause._compiler_dispatch(self, include_table=False)
198+
if t:
199+
text += " WHERE " + t
200+
return text
201+
182202

183203
class ClickHouseDDLCompiler(compiler.DDLCompiler):
184204
def visit_create_column(self, create, **kw):

tests/orm/test_update.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from sqlalchemy import Column, update
2+
3+
from src import types, Table
4+
from tests.testcase import BaseTestCase
5+
6+
7+
class UpdateTestCase(BaseTestCase):
8+
def create_table(self):
9+
metadata = self.metadata()
10+
11+
return Table(
12+
't1', metadata,
13+
Column('x', types.Int32, primary_key=True),
14+
Column('y', types.Int32, primary_key=True)
15+
)
16+
17+
def test_update(self):
18+
table = self.create_table()
19+
query = table.update().values(x=3).where(table.c.x.in_([1, 2]))
20+
self.assertEqual(
21+
self.compile(query),
22+
'ALTER TABLE t1 UPDATE x=%(x)s WHERE x IN (%(x_1)s, %(x_2)s)'
23+
)
24+
query = update(table).values(x=3).where(table.c.x.in_([1, 2]))
25+
self.assertEqual(
26+
self.compile(query),
27+
'ALTER TABLE t1 UPDATE x=%(x)s WHERE x IN (%(x_1)s, %(x_2)s)'
28+
)
29+
30+
def test_update_from_own_field(self):
31+
table = self.create_table()
32+
query = table.update().values(x=table.c.y).where(table.c.x.in_([1, 2]))
33+
self.assertEqual(
34+
self.compile(query),
35+
'ALTER TABLE t1 UPDATE x=y WHERE x IN (%(x_1)s, %(x_2)s)'
36+
)

0 commit comments

Comments
 (0)