Skip to content

Commit 62cccc3

Browse files
authored
fix: Fixing test for literals due to change in sqlalchemy core tests (#384)
* fix: Fixing test for literals due to change in sqlalchemy core tests * tests: remove editable install in tests * One more literal test fix
1 parent 00561f8 commit 62cccc3

File tree

5 files changed

+122
-17
lines changed

5 files changed

+122
-17
lines changed

.github/sync-repo-settings.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ branchProtectionRules:
1111
- 'unit'
1212
- 'compliance_tests_13'
1313
- 'compliance_tests_14'
14+
- 'compliance_tests_20'
1415
- 'migration_tests'
1516
- 'cla/google'
1617
- 'Kokoro'

google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
format_type,
2424
)
2525
from sqlalchemy.exc import NoSuchTableError
26+
from sqlalchemy.sql import elements
2627
from sqlalchemy import ForeignKeyConstraint, types
2728
from sqlalchemy.engine.base import Engine
2829
from sqlalchemy.engine.default import DefaultDialect, DefaultExecutionContext
@@ -314,6 +315,12 @@ def render_literal_value(self, value, type_):
314315
in string. Override the method to add additional escape before using it to
315316
generate a SQL statement.
316317
"""
318+
if value is None and not type_.should_evaluate_none:
319+
# issue #10535 - handle NULL in the compiler without placing
320+
# this onto each type, except for "evaluate None" types
321+
# (e.g. JSON)
322+
return self.process(elements.Null._instance())
323+
317324
raw = ["\\", "'", '"', "\n", "\t", "\r"]
318325
if isinstance(value, str) and any(single in value for single in raw):
319326
value = 'r"""{}"""'.format(value)

noxfile.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def compliance_test_13(session):
145145
)
146146

147147
session.install("mock")
148-
session.install("-e", ".[tracing]")
148+
session.install(".[tracing]")
149149
session.run("pip", "install", "sqlalchemy>=1.1.13,<=1.3.24", "--force-reinstall")
150150
session.run("pip", "install", "opentelemetry-api<=1.10", "--force-reinstall")
151151
session.run("pip", "install", "opentelemetry-sdk<=1.10", "--force-reinstall")
@@ -191,7 +191,7 @@ def compliance_test_14(session):
191191
)
192192

193193
session.install("mock")
194-
session.install("-e", ".[tracing]")
194+
session.install(".[tracing]")
195195
session.run("pip", "install", "sqlalchemy>=1.4,<2.0", "--force-reinstall")
196196
session.run("python", "create_test_database.py")
197197
session.run(
@@ -231,7 +231,7 @@ def compliance_test_20(session):
231231
)
232232

233233
session.install("mock")
234-
session.install("-e", ".[tracing]")
234+
session.install(".[tracing]")
235235
session.run("pip", "install", "opentelemetry-api<=1.10", "--force-reinstall")
236236
session.run("python", "create_test_database.py")
237237

@@ -257,7 +257,7 @@ def unit(session):
257257
# Run SQLAlchemy dialect compliance test suite with OpenTelemetry.
258258
session.install("pytest")
259259
session.install("mock")
260-
session.install("-e", ".")
260+
session.install(".")
261261
session.install("opentelemetry-api==1.1.0")
262262
session.install("opentelemetry-sdk==1.1.0")
263263
session.install("opentelemetry-instrumentation==0.20b0")
@@ -292,7 +292,7 @@ def _migration_test(session):
292292
session.run("pip", "install", "sqlalchemy>=1.3.11,<2.0", "--force-reinstall")
293293

294294
session.install("pytest")
295-
session.install("-e", ".")
295+
session.install(".")
296296
session.install("alembic")
297297

298298
session.run("python", "create_test_database.py")
@@ -360,7 +360,7 @@ def snippets(session):
360360
session.install(
361361
"git+https://github.com/googleapis/python-spanner.git#egg=google-cloud-spanner"
362362
)
363-
session.install("-e", ".")
363+
session.install(".")
364364
session.run("python", "create_test_database.py")
365365
session.run(
366366
"py.test",

test/conftest.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,72 @@
1616

1717
import pytest
1818
from sqlalchemy.dialects import registry
19+
from sqlalchemy.testing.schema import Column
20+
from sqlalchemy.testing.schema import Table
21+
from sqlalchemy.sql.elements import literal
1922

2023
registry.register("spanner", "google.cloud.sqlalchemy_spanner", "SpannerDialect")
2124

2225
pytest.register_assert_rewrite("sqlalchemy.testing.assertions")
2326

2427
from sqlalchemy.testing.plugin.pytestplugin import * # noqa: E402, F401, F403
28+
29+
30+
@pytest.fixture
31+
def literal_round_trip_spanner(metadata, connection):
32+
# for literal, we test the literal render in an INSERT
33+
# into a typed column. we can then SELECT it back as its
34+
# official type;
35+
36+
def run(
37+
type_,
38+
input_,
39+
output,
40+
filter_=None,
41+
compare=None,
42+
support_whereclause=True,
43+
):
44+
t = Table("t", metadata, Column("x", type_))
45+
t.create(connection)
46+
47+
for value in input_:
48+
ins = t.insert().values(x=literal(value, type_, literal_execute=True))
49+
connection.execute(ins)
50+
51+
if support_whereclause:
52+
if compare:
53+
stmt = t.select().where(
54+
t.c.x
55+
== literal(
56+
compare,
57+
type_,
58+
literal_execute=True,
59+
),
60+
t.c.x
61+
== literal(
62+
input_[0],
63+
type_,
64+
literal_execute=True,
65+
),
66+
)
67+
else:
68+
stmt = t.select().where(
69+
t.c.x
70+
== literal(
71+
compare if compare is not None else input_[0],
72+
type_,
73+
literal_execute=True,
74+
)
75+
)
76+
else:
77+
stmt = t.select()
78+
79+
rows = connection.execute(stmt).all()
80+
assert rows, "No rows returned"
81+
for row in rows:
82+
value = row[0]
83+
if filter_ is not None:
84+
value = filter_(value)
85+
assert value in output
86+
87+
return run

test/test_suite_20.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,10 @@
149149
UnicodeTextTest as _UnicodeTextTest,
150150
_UnicodeFixture as __UnicodeFixture,
151151
) # noqa: F401, F403
152-
from test._helpers import get_db_url, get_project
152+
from test._helpers import (
153+
get_db_url,
154+
get_project,
155+
)
153156

154157
config.test_schema = ""
155158

@@ -162,7 +165,7 @@ class BooleanTest(_BooleanTest):
162165
def test_render_literal_bool(self):
163166
pass
164167

165-
def test_render_literal_bool_true(self, literal_round_trip):
168+
def test_render_literal_bool_true(self, literal_round_trip_spanner):
166169
"""
167170
SPANNER OVERRIDE:
168171
@@ -171,9 +174,9 @@ def test_render_literal_bool_true(self, literal_round_trip):
171174
following insertions will fail with `Row [] already exists".
172175
Overriding the test to avoid the same failure.
173176
"""
174-
literal_round_trip(Boolean(), [True], [True])
177+
literal_round_trip_spanner(Boolean(), [True], [True])
175178

176-
def test_render_literal_bool_false(self, literal_round_trip):
179+
def test_render_literal_bool_false(self, literal_round_trip_spanner):
177180
"""
178181
SPANNER OVERRIDE:
179182
@@ -182,7 +185,7 @@ def test_render_literal_bool_false(self, literal_round_trip):
182185
following insertions will fail with `Row [] already exists".
183186
Overriding the test to avoid the same failure.
184187
"""
185-
literal_round_trip(Boolean(), [False], [False])
188+
literal_round_trip_spanner(Boolean(), [False], [False])
186189

187190
@pytest.mark.skip("Not supported by Cloud Spanner")
188191
def test_whereclause(self):
@@ -2003,6 +2006,9 @@ def test_huge_int_auto_accommodation(self, connection, intvalue):
20032006
intvalue,
20042007
)
20052008

2009+
def test_literal(self, literal_round_trip_spanner):
2010+
literal_round_trip_spanner(Integer, [5], [5])
2011+
20062012

20072013
class _UnicodeFixture(__UnicodeFixture):
20082014
@classmethod
@@ -2189,6 +2195,19 @@ def test_dont_truncate_rightside(
21892195
args[1],
21902196
)
21912197

2198+
def test_literal(self, literal_round_trip_spanner):
2199+
# note that in Python 3, this invokes the Unicode
2200+
# datatype for the literal part because all strings are unicode
2201+
literal_round_trip_spanner(String(40), ["some text"], ["some text"])
2202+
2203+
def test_literal_quoting(self, literal_round_trip_spanner):
2204+
data = """some 'text' hey "hi there" that's text"""
2205+
literal_round_trip_spanner(String(40), [data], [data])
2206+
2207+
def test_literal_backslashes(self, literal_round_trip_spanner):
2208+
data = r"backslash one \ backslash two \\ end"
2209+
literal_round_trip_spanner(String(40), [data], [data])
2210+
21922211

21932212
class TextTest(_TextTest):
21942213
@classmethod
@@ -2224,6 +2243,21 @@ def test_text_empty_strings(self, connection):
22242243
def test_text_null_strings(self, connection):
22252244
pass
22262245

2246+
def test_literal(self, literal_round_trip_spanner):
2247+
literal_round_trip_spanner(Text, ["some text"], ["some text"])
2248+
2249+
def test_literal_quoting(self, literal_round_trip_spanner):
2250+
data = """some 'text' hey "hi there" that's text"""
2251+
literal_round_trip_spanner(Text, [data], [data])
2252+
2253+
def test_literal_backslashes(self, literal_round_trip_spanner):
2254+
data = r"backslash one \ backslash two \\ end"
2255+
literal_round_trip_spanner(Text, [data], [data])
2256+
2257+
def test_literal_percentsigns(self, literal_round_trip_spanner):
2258+
data = r"percent % signs %% percent"
2259+
literal_round_trip_spanner(Text, [data], [data])
2260+
22272261

22282262
class NumericTest(_NumericTest):
22292263
@testing.fixture
@@ -2254,7 +2288,7 @@ def run(type_, input_, output, filter_=None, check_scale=False):
22542288
return run
22552289

22562290
@emits_warning(r".*does \*not\* support Decimal objects natively")
2257-
def test_render_literal_numeric(self, literal_round_trip):
2291+
def test_render_literal_numeric(self, literal_round_trip_spanner):
22582292
"""
22592293
SPANNER OVERRIDE:
22602294
@@ -2263,14 +2297,14 @@ def test_render_literal_numeric(self, literal_round_trip):
22632297
following insertions will fail with `Row [] already exists".
22642298
Overriding the test to avoid the same failure.
22652299
"""
2266-
literal_round_trip(
2300+
literal_round_trip_spanner(
22672301
Numeric(precision=8, scale=4),
22682302
[decimal.Decimal("15.7563")],
22692303
[decimal.Decimal("15.7563")],
22702304
)
22712305

22722306
@emits_warning(r".*does \*not\* support Decimal objects natively")
2273-
def test_render_literal_numeric_asfloat(self, literal_round_trip):
2307+
def test_render_literal_numeric_asfloat(self, literal_round_trip_spanner):
22742308
"""
22752309
SPANNER OVERRIDE:
22762310
@@ -2279,13 +2313,13 @@ def test_render_literal_numeric_asfloat(self, literal_round_trip):
22792313
following insertions will fail with `Row [] already exists".
22802314
Overriding the test to avoid the same failure.
22812315
"""
2282-
literal_round_trip(
2316+
literal_round_trip_spanner(
22832317
Numeric(precision=8, scale=4, asdecimal=False),
22842318
[decimal.Decimal("15.7563")],
22852319
[15.7563],
22862320
)
22872321

2288-
def test_render_literal_float(self, literal_round_trip):
2322+
def test_render_literal_float(self, literal_round_trip_spanner):
22892323
"""
22902324
SPANNER OVERRIDE:
22912325
@@ -2294,7 +2328,7 @@ def test_render_literal_float(self, literal_round_trip):
22942328
following insertions will fail with `Row [] already exists".
22952329
Overriding the test to avoid the same failure.
22962330
"""
2297-
literal_round_trip(
2331+
literal_round_trip_spanner(
22982332
Float(4),
22992333
[decimal.Decimal("15.7563")],
23002334
[15.7563],

0 commit comments

Comments
 (0)