Skip to content

Commit

Permalink
samples: add code samples for PostgreSql dialect (googleapis#836)
Browse files Browse the repository at this point in the history
* samples: add code samples for PostgreSql dialect

* linting

* fix: remove unnecessary imports

* remove unused import

* fix: change method doc references in parser

* add another command

* test: add samples tests for PG

* fix: linting

* feat: sample tests config changes

* refactor

* refactor

* refactor

* refactor

* add database dialect

* database dialect fixture change

* fix ddl

* yield operation as well

* skip backup tests

* config changes

* fix

* minor lint fix

* some tests were getting skipped. fixing it.

* fix test

* fix test and skip few tests for faster testing

* re-enable tests

Co-authored-by: Astha Mohta <asthamohta@google.com>
Co-authored-by: Astha Mohta <35952883+asthamohta@users.noreply.github.com>
  • Loading branch information
3 people committed Oct 26, 2022
1 parent fb1948d commit fbb1440
Show file tree
Hide file tree
Showing 6 changed files with 2,214 additions and 125 deletions.
129 changes: 88 additions & 41 deletions samples/samples/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
import uuid

from google.api_core import exceptions

from google.cloud import spanner_admin_database_v1
from google.cloud.spanner_admin_database_v1.types.common import DatabaseDialect
from google.cloud.spanner_v1 import backup
from google.cloud.spanner_v1 import client
from google.cloud.spanner_v1 import database
Expand All @@ -26,17 +29,32 @@

INSTANCE_CREATION_TIMEOUT = 560 # seconds

OPERATION_TIMEOUT_SECONDS = 120 # seconds

retry_429 = retry.RetryErrors(exceptions.ResourceExhausted, delay=15)


@pytest.fixture(scope="module")
def sample_name():
"""Sample testcase modules must define this fixture.
The name is used to label the instance created by the sample, to
aid in debugging leaked instances.
"""
raise NotImplementedError("Define 'sample_name' fixture in sample test driver")
The name is used to label the instance created by the sample, to
aid in debugging leaked instances.
"""
raise NotImplementedError(
"Define 'sample_name' fixture in sample test driver")


@pytest.fixture(scope="module")
def database_dialect():
"""Database dialect to be used for this sample.
The dialect is used to initialize the dialect for the database.
It can either be GoogleStandardSql or PostgreSql.
"""
# By default, we consider GOOGLE_STANDARD_SQL dialect. Other specific tests
# can override this if required.
return DatabaseDialect.GOOGLE_STANDARD_SQL


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -87,7 +105,7 @@ def multi_region_instance_id():
@pytest.fixture(scope="module")
def instance_config(spanner_client):
return "{}/instanceConfigs/{}".format(
spanner_client.project_name, "regional-us-central1"
spanner_client.project_name, "regional-us-central1"
)


Expand All @@ -98,20 +116,20 @@ def multi_region_instance_config(spanner_client):

@pytest.fixture(scope="module")
def sample_instance(
spanner_client,
cleanup_old_instances,
instance_id,
instance_config,
sample_name,
spanner_client,
cleanup_old_instances,
instance_id,
instance_config,
sample_name,
):
sample_instance = spanner_client.instance(
instance_id,
instance_config,
labels={
"cloud_spanner_samples": "true",
"sample_name": sample_name,
"created": str(int(time.time())),
},
instance_id,
instance_config,
labels={
"cloud_spanner_samples": "true",
"sample_name": sample_name,
"created": str(int(time.time())),
},
)
op = retry_429(sample_instance.create)()
op.result(INSTANCE_CREATION_TIMEOUT) # block until completion
Expand All @@ -133,20 +151,20 @@ def sample_instance(

@pytest.fixture(scope="module")
def multi_region_instance(
spanner_client,
cleanup_old_instances,
multi_region_instance_id,
multi_region_instance_config,
sample_name,
spanner_client,
cleanup_old_instances,
multi_region_instance_id,
multi_region_instance_config,
sample_name,
):
multi_region_instance = spanner_client.instance(
multi_region_instance_id,
multi_region_instance_config,
labels={
"cloud_spanner_samples": "true",
"sample_name": sample_name,
"created": str(int(time.time())),
},
multi_region_instance_id,
multi_region_instance_config,
labels={
"cloud_spanner_samples": "true",
"sample_name": sample_name,
"created": str(int(time.time())),
},
)
op = retry_429(multi_region_instance.create)()
op.result(INSTANCE_CREATION_TIMEOUT) # block until completion
Expand All @@ -170,30 +188,59 @@ def multi_region_instance(
def database_id():
"""Id for the database used in samples.
Sample testcase modules can override as needed.
"""
Sample testcase modules can override as needed.
"""
return "my-database-id"


@pytest.fixture(scope="module")
def database_ddl():
"""Sequence of DDL statements used to set up the database.
Sample testcase modules can override as needed.
"""
Sample testcase modules can override as needed.
"""
return []


@pytest.fixture(scope="module")
def sample_database(sample_instance, database_id, database_ddl):
def sample_database(
spanner_client,
sample_instance,
database_id,
database_ddl,
database_dialect):
if database_dialect == DatabaseDialect.POSTGRESQL:
sample_database = sample_instance.database(
database_id,
database_dialect=DatabaseDialect.POSTGRESQL,
)

if not sample_database.exists():
operation = sample_database.create()
operation.result(OPERATION_TIMEOUT_SECONDS)

request = spanner_admin_database_v1.UpdateDatabaseDdlRequest(
database=sample_database.name,
statements=database_ddl,
)

operation =\
spanner_client.database_admin_api.update_database_ddl(request)
operation.result(OPERATION_TIMEOUT_SECONDS)

yield sample_database

sample_database.drop()
return

sample_database = sample_instance.database(
database_id,
ddl_statements=database_ddl,
database_id,
ddl_statements=database_ddl,
)

if not sample_database.exists():
sample_database.create()
operation = sample_database.create()
operation.result(OPERATION_TIMEOUT_SECONDS)

yield sample_database

Expand All @@ -203,8 +250,8 @@ def sample_database(sample_instance, database_id, database_ddl):
@pytest.fixture(scope="module")
def kms_key_name(spanner_client):
return "projects/{}/locations/{}/keyRings/{}/cryptoKeys/{}".format(
spanner_client.project,
"us-central1",
"spanner-test-keyring",
"spanner-test-cmek",
spanner_client.project,
"us-central1",
"spanner-test-keyring",
"spanner-test-cmek",
)
15 changes: 8 additions & 7 deletions samples/samples/noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def blacken(session: nox.sessions.Session) -> None:
# format = isort + black
#


@nox.session
def format(session: nox.sessions.Session) -> None:
"""
Expand Down Expand Up @@ -207,7 +208,9 @@ def _session_tests(
session: nox.sessions.Session, post_install: Callable = None
) -> None:
# check for presence of tests
test_list = glob.glob("**/*_test.py", recursive=True) + glob.glob("**/test_*.py", recursive=True)
test_list = glob.glob("**/*_test.py", recursive=True) + glob.glob(
"**/test_*.py", recursive=True
)
test_list.extend(glob.glob("**/tests", recursive=True))

if len(test_list) == 0:
Expand All @@ -229,9 +232,7 @@ def _session_tests(

if os.path.exists("requirements-test.txt"):
if os.path.exists("constraints-test.txt"):
session.install(
"-r", "requirements-test.txt", "-c", "constraints-test.txt"
)
session.install("-r", "requirements-test.txt", "-c", "constraints-test.txt")
else:
session.install("-r", "requirements-test.txt")
with open("requirements-test.txt") as rtfile:
Expand All @@ -244,9 +245,9 @@ def _session_tests(
post_install(session)

if "pytest-parallel" in packages:
concurrent_args.extend(['--workers', 'auto', '--tests-per-worker', 'auto'])
concurrent_args.extend(["--workers", "auto", "--tests-per-worker", "auto"])
elif "pytest-xdist" in packages:
concurrent_args.extend(['-n', 'auto'])
concurrent_args.extend(["-n", "auto"])

session.run(
"pytest",
Expand Down Expand Up @@ -276,7 +277,7 @@ def py(session: nox.sessions.Session) -> None:


def _get_repo_root() -> Optional[str]:
""" Returns the root folder of the project. """
"""Returns the root folder of the project."""
# Get root of this repository. Assume we don't have directories nested deeper than 10 items.
p = Path(os.getcwd())
for i in range(10):
Expand Down
Loading

0 comments on commit fbb1440

Please sign in to comment.