Skip to content

Commit 184a7d5

Browse files
Ilya Gurovlarkeeskuruppu
authored
feat: support JSON data type (googleapis#135)
* feat: support JSON data type * fix type * bug fixes * erase excess test override * erase excess override * fix errors Co-authored-by: larkee <31196561+larkee@users.noreply.github.com> Co-authored-by: skuruppu <skuruppu@google.com>
1 parent f02a2c0 commit 184a7d5

File tree

3 files changed

+196
-4
lines changed

3 files changed

+196
-4
lines changed

google/cloud/sqlalchemy_spanner/requirements.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717

1818

1919
class Requirements(SuiteRequirements): # pragma: no cover
20+
@property
21+
def json_type(self):
22+
return exclusions.open()
23+
2024
@property
2125
def computed_columns(self):
2226
return exclusions.open()

google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,13 @@
3434
GenericTypeCompiler,
3535
IdentifierPreparer,
3636
SQLCompiler,
37+
OPERATORS,
3738
RESERVED_WORDS,
3839
)
40+
from sqlalchemy.sql.default_comparator import operator_lookup
41+
from sqlalchemy.sql.operators import json_getitem_op
3942

43+
from google.cloud.spanner_v1.data_types import JsonObject
4044
from google.cloud import spanner_dbapi
4145
from google.cloud.sqlalchemy_spanner._opentelemetry_tracing import trace_call
4246

@@ -47,6 +51,10 @@ def reset_connection(dbapi_conn, connection_record):
4751
dbapi_conn.connection.staleness = None
4852

4953

54+
# register a method to get a single value of a JSON object
55+
OPERATORS[json_getitem_op] = operator_lookup["json_getitem_op"]
56+
57+
5058
# Spanner-to-SQLAlchemy types map
5159
_type_map = {
5260
"BOOL": types.Boolean,
@@ -60,8 +68,10 @@ def reset_connection(dbapi_conn, connection_record):
6068
"TIME": types.TIME,
6169
"TIMESTAMP": types.TIMESTAMP,
6270
"ARRAY": types.ARRAY,
71+
"JSON": types.JSON,
6372
}
6473

74+
6575
_type_map_inv = {
6676
types.Boolean: "BOOL",
6777
types.BINARY: "BYTES(MAX)",
@@ -210,6 +220,53 @@ def visit_like_op_binary(self, binary, operator, **kw):
210220
binary.right._compiler_dispatch(self, **kw),
211221
)
212222

223+
def _generate_generic_binary(self, binary, opstring, eager_grouping=False, **kw):
224+
"""The method is overriden to process JSON data type cases."""
225+
_in_binary = kw.get("_in_binary", False)
226+
227+
kw["_in_binary"] = True
228+
229+
if isinstance(opstring, str):
230+
text = (
231+
binary.left._compiler_dispatch(
232+
self, eager_grouping=eager_grouping, **kw
233+
)
234+
+ opstring
235+
+ binary.right._compiler_dispatch(
236+
self, eager_grouping=eager_grouping, **kw
237+
)
238+
)
239+
if _in_binary and eager_grouping:
240+
text = "(%s)" % text
241+
else:
242+
# got JSON data
243+
right_value = getattr(
244+
binary.right, "value", None
245+
) or binary.right._compiler_dispatch(
246+
self, eager_grouping=eager_grouping, **kw
247+
)
248+
249+
text = (
250+
binary.left._compiler_dispatch(
251+
self, eager_grouping=eager_grouping, **kw
252+
)
253+
+ """, "$."""
254+
+ str(right_value)
255+
+ '"'
256+
)
257+
text = "JSON_VALUE(%s)" % text
258+
259+
return text
260+
261+
def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
262+
"""Build a JSON_VALUE() function call."""
263+
expr = """JSON_VALUE(%s, "$.%s")"""
264+
265+
return expr % (
266+
self.process(binary.left, **kw),
267+
self.process(binary.right, **kw),
268+
)
269+
213270
def render_literal_value(self, value, type_):
214271
"""Render the value of a bind parameter as a quoted literal.
215272
@@ -404,6 +461,9 @@ def visit_NUMERIC(self, type_, **kw):
404461
def visit_BIGINT(self, type_, **kw):
405462
return "INT64"
406463

464+
def visit_JSON(self, type_, **kw):
465+
return "JSON"
466+
407467

408468
class SpannerDialect(DefaultDialect):
409469
"""Cloud Spanner dialect.
@@ -434,6 +494,8 @@ class SpannerDialect(DefaultDialect):
434494
statement_compiler = SpannerSQLCompiler
435495
type_compiler = SpannerTypeCompiler
436496
execution_ctx_cls = SpannerExecutionContext
497+
_json_serializer = JsonObject
498+
_json_deserializer = JsonObject
437499

438500
@classmethod
439501
def dbapi(cls):

test/test_suite.py

Lines changed: 130 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import os
2121
import pkg_resources
2222
import pytest
23+
import random
2324
import unittest
2425
from unittest import mock
2526

@@ -61,7 +62,6 @@
6162
)
6263

6364
from google.api_core.datetime_helpers import DatetimeWithNanoseconds
64-
6565
from google.cloud import spanner_dbapi
6666

6767
from sqlalchemy.testing.suite.test_cte import * # noqa: F401, F403
@@ -98,15 +98,17 @@
9898
)
9999
from sqlalchemy.testing.suite.test_results import RowFetchTest as _RowFetchTest
100100
from sqlalchemy.testing.suite.test_types import ( # noqa: F401, F403
101+
_DateFixture as _DateFixtureTest,
102+
_LiteralRoundTripFixture,
103+
_UnicodeFixture as _UnicodeFixtureTest,
101104
BooleanTest as _BooleanTest,
102105
DateTest as _DateTest,
103-
_DateFixture as _DateFixtureTest,
104106
DateTimeHistoricTest,
105107
DateTimeCoercedToDateTimeTest as _DateTimeCoercedToDateTimeTest,
106108
DateTimeMicrosecondsTest as _DateTimeMicrosecondsTest,
107109
DateTimeTest as _DateTimeTest,
108110
IntegerTest as _IntegerTest,
109-
_LiteralRoundTripFixture,
111+
JSONTest as _JSONTest,
110112
NumericTest as _NumericTest,
111113
StringTest as _StringTest,
112114
TextTest as _TextTest,
@@ -115,7 +117,6 @@
115117
TimestampMicrosecondsTest,
116118
UnicodeVarcharTest as _UnicodeVarcharTest,
117119
UnicodeTextTest as _UnicodeTextTest,
118-
_UnicodeFixture as _UnicodeFixtureTest,
119120
)
120121
from test._helpers import get_db_url
121122

@@ -1751,3 +1752,128 @@ def test_get_column_returns_computed(self):
17511752
is_true("computed" in compData)
17521753
is_true("sqltext" in compData["computed"])
17531754
eq_(self.normalize(compData["computed"]["sqltext"]), "normal+42")
1755+
1756+
1757+
@pytest.mark.skipif(
1758+
bool(os.environ.get("SPANNER_EMULATOR_HOST")), reason="Skipped on emulator"
1759+
)
1760+
class JSONTest(_JSONTest):
1761+
@pytest.mark.skip("Values without keys are not supported.")
1762+
def test_single_element_round_trip(self, element):
1763+
pass
1764+
1765+
def _test_round_trip(self, data_element):
1766+
data_table = self.tables.data_table
1767+
1768+
config.db.execute(
1769+
data_table.insert(),
1770+
{"id": random.randint(1, 100000000), "name": "row1", "data": data_element},
1771+
)
1772+
1773+
row = config.db.execute(select([data_table.c.data])).first()
1774+
1775+
eq_(row, (data_element,))
1776+
1777+
def test_unicode_round_trip(self):
1778+
# note we include Unicode supplementary characters as well
1779+
with config.db.connect() as conn:
1780+
conn.execute(
1781+
self.tables.data_table.insert(),
1782+
{
1783+
"id": random.randint(1, 100000000),
1784+
"name": "r1",
1785+
"data": {
1786+
util.u("réve🐍 illé"): util.u("réve🐍 illé"),
1787+
"data": {"k1": util.u("drôl🐍e")},
1788+
},
1789+
},
1790+
)
1791+
1792+
eq_(
1793+
conn.scalar(select([self.tables.data_table.c.data])),
1794+
{
1795+
util.u("réve🐍 illé"): util.u("réve🐍 illé"),
1796+
"data": {"k1": util.u("drôl🐍e")},
1797+
},
1798+
)
1799+
1800+
@pytest.mark.skip("Parameterized types are not supported.")
1801+
def test_eval_none_flag_orm(self):
1802+
pass
1803+
1804+
@pytest.mark.skip(
1805+
"Spanner JSON_VALUE() always returns STRING,"
1806+
"thus, this test case can't be executed."
1807+
)
1808+
def test_index_typed_comparison(self):
1809+
pass
1810+
1811+
@pytest.mark.skip(
1812+
"Spanner JSON_VALUE() always returns STRING,"
1813+
"thus, this test case can't be executed."
1814+
)
1815+
def test_path_typed_comparison(self):
1816+
pass
1817+
1818+
@pytest.mark.skip("Custom JSON de-/serializers are not supported.")
1819+
def test_round_trip_custom_json(self):
1820+
pass
1821+
1822+
def _index_fixtures(fn):
1823+
fn = testing.combinations(
1824+
("boolean", True),
1825+
("boolean", False),
1826+
("boolean", None),
1827+
("string", "some string"),
1828+
("string", None),
1829+
("integer", 15),
1830+
("integer", 1),
1831+
("integer", 0),
1832+
("integer", None),
1833+
("float", 28.5),
1834+
("float", None),
1835+
id_="sa",
1836+
)(fn)
1837+
return fn
1838+
1839+
@_index_fixtures
1840+
def test_index_typed_access(self, datatype, value):
1841+
data_table = self.tables.data_table
1842+
data_element = {"key1": value}
1843+
with config.db.connect() as conn:
1844+
conn.execute(
1845+
data_table.insert(),
1846+
{
1847+
"id": random.randint(1, 100000000),
1848+
"name": "row1",
1849+
"data": data_element,
1850+
"nulldata": data_element,
1851+
},
1852+
)
1853+
1854+
expr = data_table.c.data["key1"]
1855+
expr = getattr(expr, "as_%s" % datatype)()
1856+
1857+
roundtrip = conn.scalar(select([expr]))
1858+
if roundtrip in ("true", "false", None):
1859+
roundtrip = str(roundtrip).capitalize()
1860+
1861+
eq_(str(roundtrip), str(value))
1862+
1863+
@pytest.mark.skip(
1864+
"Spanner doesn't support type casts inside JSON_VALUE() function."
1865+
)
1866+
def test_round_trip_json_null_as_json_null(self):
1867+
pass
1868+
1869+
@pytest.mark.skip(
1870+
"Spanner doesn't support type casts inside JSON_VALUE() function."
1871+
)
1872+
def test_round_trip_none_as_json_null(self):
1873+
pass
1874+
1875+
@pytest.mark.skip(
1876+
"Spanner doesn't support type casts inside JSON_VALUE() function."
1877+
)
1878+
def test_round_trip_none_as_sql_null(self):
1879+
pass

0 commit comments

Comments
 (0)