Skip to content

Commit

Permalink
Expand joins when calculating PostgreSQL "WITH FOR UPDATE OF"
Browse files Browse the repository at this point in the history
Modified the :paramref:`.Select.with_for_update.of` parameter so that if a
join or other composed selectable is passed, the individual :class:`.Table`
objects will be filtered from it, allowing one to pass a join() object to
the parameter, as occurs normally when using joined table inheritance with
the ORM.  Pull request courtesy Raymond Lu.

Fixes: sqlalchemy#4550
Co-authored-by: Mike Bayer <mike_mp@zzzcomputing.com>
Closes: sqlalchemy#4551
Pull-request: sqlalchemy#4551
Pull-request-sha: 452da77

Change-Id: If4b7c231f7b71190d7245543959fb5c3351125a1
  • Loading branch information
raylu authored and zzzeek committed Mar 21, 2019
1 parent 8eaccf1 commit 8acbc26
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 4 deletions.
10 changes: 10 additions & 0 deletions doc/build/changelog/unreleased_13/4550.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
.. change::
:tags: bug, postgresql
:tickets: 4550

Modified the :paramref:`.Select.with_for_update.of` parameter so that if a
join or other composed selectable is passed, the individual :class:`.Table`
objects will be filtered from it, allowing one to pass a join() object to
the parameter, as occurs normally when using joined table inheritance with
the ORM. Pull request courtesy Raymond Lu.

10 changes: 6 additions & 4 deletions lib/sqlalchemy/dialects/postgresql/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,7 @@ def bind_expression(self, bindvalue):
from ...sql import elements
from ...sql import expression
from ...sql import sqltypes
from ...sql import util as sql_util
from ...types import BIGINT
from ...types import BOOLEAN
from ...types import CHAR
Expand Down Expand Up @@ -1681,10 +1682,11 @@ def for_update_clause(self, select, **kw):
tmp = " FOR UPDATE"

if select._for_update_arg.of:
tables = util.OrderedSet(
c.table if isinstance(c, expression.ColumnClause) else c
for c in select._for_update_arg.of
)

tables = util.OrderedSet()
for c in select._for_update_arg.of:
tables.update(sql_util.surface_selectables_only(c))

tmp += " OF " + ", ".join(
self.process(table, ashint=True, use_schema=False, **kw)
for table in tables
Expand Down
16 changes: 16 additions & 0 deletions lib/sqlalchemy/sql/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@
from .elements import Null
from .elements import UnaryExpression
from .schema import Column
from .selectable import Alias
from .selectable import FromClause
from .selectable import FromGrouping
from .selectable import Join
from .selectable import ScalarSelect
from .selectable import SelectBase
from .selectable import TableClause
from .. import exc
from .. import util

Expand Down Expand Up @@ -339,6 +341,20 @@ def surface_selectables(clause):
stack.append(elem.element)


def surface_selectables_only(clause):
stack = [clause]
while stack:
elem = stack.pop()
if isinstance(elem, (TableClause, Alias)):
yield elem
if isinstance(elem, Join):
stack.extend((elem.left, elem.right))
elif isinstance(elem, FromGrouping):
stack.append(elem.element)
elif isinstance(elem, ColumnClause):
stack.append(elem.table)


def surface_column_elements(clause, include_scalar_selects=True):
"""traverse and yield only outer-exposed column elements, such as would
be addressable in the WHERE clause of a SELECT if this element were
Expand Down
24 changes: 24 additions & 0 deletions test/dialect/postgresql/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,30 @@ def test_for_update(self):
"WHERE mytable_1.myid = %(myid_1)s FOR UPDATE OF mytable_1",
)

table2 = table("table2", column("mytable_id"))
join = table2.join(table1, table2.c.mytable_id == table1.c.myid)
self.assert_compile(
join.select(table2.c.mytable_id == 7).with_for_update(of=[join]),
"SELECT table2.mytable_id, "
"mytable.myid, mytable.name, mytable.description "
"FROM table2 "
"JOIN mytable ON table2.mytable_id = mytable.myid "
"WHERE table2.mytable_id = %(mytable_id_1)s "
"FOR UPDATE OF mytable, table2",
)

join = table2.join(ta, table2.c.mytable_id == ta.c.myid)
self.assert_compile(
join.select(table2.c.mytable_id == 7).with_for_update(of=[join]),
"SELECT table2.mytable_id, "
"mytable_1.myid, mytable_1.name, mytable_1.description "
"FROM table2 "
"JOIN mytable AS mytable_1 "
"ON table2.mytable_id = mytable_1.myid "
"WHERE table2.mytable_id = %(mytable_id_1)s "
"FOR UPDATE OF mytable_1, table2",
)

def test_for_update_with_schema(self):
m = MetaData()
table1 = Table(
Expand Down

0 comments on commit 8acbc26

Please sign in to comment.