Skip to content

Commit 9af8e86

Browse files
authored
feat: enable instance-level connection (#316)
* chore: auto-release * feat:enable instance-level connection * lint
1 parent da1af57 commit 9af8e86

File tree

4 files changed

+90
-6
lines changed

4 files changed

+90
-6
lines changed

google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -645,11 +645,14 @@ def create_connect_args(self, url):
645645
646646
The given URL follows the style:
647647
`spanner:///projects/{project-id}/instances/{instance-id}/databases/{database-id}`
648+
or `spanner:///projects/{project-id}/instances/{instance-id}`. For the latter,
649+
database operations will be not be possible and if required a new engine with
650+
database-id set will need to be created.
648651
"""
649652
match = re.match(
650653
(
651654
r"^projects/(?P<project>.+?)/instances/"
652-
"(?P<instance>.+?)/databases/(?P<database>.+?)$"
655+
"(?P<instance>.+?)(/databases/(?P<database>.+)|$)"
653656
),
654657
url.database,
655658
)
@@ -1346,17 +1349,29 @@ def do_rollback(self, dbapi_connection):
13461349
):
13471350
pass
13481351
else:
1349-
trace_attributes = {"db.instance": dbapi_connection.database.name}
1352+
trace_attributes = {
1353+
"db.instance": dbapi_connection.database.name
1354+
if dbapi_connection.database
1355+
else ""
1356+
}
13501357
with trace_call("SpannerSqlAlchemy.Rollback", trace_attributes):
13511358
dbapi_connection.rollback()
13521359

13531360
def do_commit(self, dbapi_connection):
1354-
trace_attributes = {"db.instance": dbapi_connection.database.name}
1361+
trace_attributes = {
1362+
"db.instance": dbapi_connection.database.name
1363+
if dbapi_connection.database
1364+
else ""
1365+
}
13551366
with trace_call("SpannerSqlAlchemy.Commit", trace_attributes):
13561367
dbapi_connection.commit()
13571368

13581369
def do_close(self, dbapi_connection):
1359-
trace_attributes = {"db.instance": dbapi_connection.database.name}
1370+
trace_attributes = {
1371+
"db.instance": dbapi_connection.database.name
1372+
if dbapi_connection.database
1373+
else ""
1374+
}
13601375
with trace_call("SpannerSqlAlchemy.Close", trace_attributes):
13611376
dbapi_connection.close()
13621377

test/test_suite_13.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2045,3 +2045,17 @@ def test_create_engine_w_invalid_client_object(self):
20452045

20462046
with pytest.raises(ValueError):
20472047
engine.connect()
2048+
2049+
2050+
class CreateEngineWithoutDatabaseTest(fixtures.TestBase):
2051+
def test_create_engine_wo_database(self):
2052+
"""
2053+
SPANNER TEST:
2054+
2055+
Check that we can connect to SqlAlchemy
2056+
without passing database id in the
2057+
connection URL.
2058+
"""
2059+
engine = create_engine(get_db_url().split("/database")[0])
2060+
with engine.connect() as connection:
2061+
assert connection.connection.database is None

test/test_suite_14.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2378,3 +2378,17 @@ def test_create_engine_w_invalid_client_object(self):
23782378

23792379
with pytest.raises(ValueError):
23802380
engine.connect()
2381+
2382+
2383+
class CreateEngineWithoutDatabaseTest(fixtures.TestBase):
2384+
def test_create_engine_wo_database(self):
2385+
"""
2386+
SPANNER TEST:
2387+
2388+
Check that we can connect to SqlAlchemy
2389+
without passing database id in the
2390+
connection URL.
2391+
"""
2392+
engine = create_engine(get_db_url().split("/database")[0])
2393+
with engine.connect() as connection:
2394+
assert connection.connection.database is None

test/test_suite_20.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import time
2525
from unittest import mock
2626

27-
from google.cloud.spanner_v1 import RequestOptions
27+
from google.cloud.spanner_v1 import RequestOptions, Client
2828
import sqlalchemy
2929
from sqlalchemy import create_engine
3030
from sqlalchemy.engine import Inspector
@@ -144,7 +144,7 @@
144144
UnicodeTextTest as _UnicodeTextTest,
145145
_UnicodeFixture as __UnicodeFixture,
146146
) # noqa: F401, F403
147-
from test._helpers import get_db_url
147+
from test._helpers import get_db_url, get_project
148148

149149
config.test_schema = ""
150150

@@ -3000,3 +3000,44 @@ def test_request_priority(self):
30003000
engine = create_engine("sqlite:///database")
30013001
with engine.connect() as connection:
30023002
pass
3003+
3004+
3005+
class CreateEngineWithClientObjectTest(fixtures.TestBase):
3006+
def test_create_engine_w_valid_client_object(self):
3007+
"""
3008+
SPANNER TEST:
3009+
3010+
Check that we can connect to SqlAlchemy
3011+
by passing custom Client object.
3012+
"""
3013+
client = Client(project=get_project())
3014+
engine = create_engine(get_db_url(), connect_args={"client": client})
3015+
with engine.connect() as connection:
3016+
assert connection.connection.instance._client == client
3017+
3018+
def test_create_engine_w_invalid_client_object(self):
3019+
"""
3020+
SPANNER TEST:
3021+
3022+
Check that if project id in url and custom Client
3023+
Object passed to enginer mismatch, error is thrown.
3024+
"""
3025+
client = Client(project="project_id")
3026+
engine = create_engine(get_db_url(), connect_args={"client": client})
3027+
3028+
with pytest.raises(ValueError):
3029+
engine.connect()
3030+
3031+
3032+
class CreateEngineWithoutDatabaseTest(fixtures.TestBase):
3033+
def test_create_engine_wo_database(self):
3034+
"""
3035+
SPANNER TEST:
3036+
3037+
Check that we can connect to SqlAlchemy
3038+
without passing database id in the
3039+
connection URL.
3040+
"""
3041+
engine = create_engine(get_db_url().split("/database")[0])
3042+
with engine.connect() as connection:
3043+
assert connection.connection.database is None

0 commit comments

Comments
 (0)