Skip to content

Commit

Permalink
chore: migrate run/idp-sql to sqlalchemy 2 (#9121)
Browse files Browse the repository at this point in the history
* chore: migrate run/idp-sql to sqlalchemy 2

* linting

* temp remove cleanup for debugging

* switch to db.begin() for commits

* remove cleanup for debugging

* fix index parameters

* add back cleanup

---------

Co-authored-by: Averi Kitsch <akitsch@google.com>
  • Loading branch information
kweinmeister and averikitsch authored Feb 28, 2023
1 parent 7af5f61 commit f4cd664
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 25 deletions.
5 changes: 2 additions & 3 deletions run/idp-sql/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,20 @@ connection.

[Instructions to launch proxy with TCP](https://github.com/GoogleCloudPlatform/python-docs-samples/tree/main/cloud-sql/postgres/sqlalchemy#launch-proxy-with-tcp)


## Testing

Tests expect the Cloud SQL instance to already be created and environment Variables
to be set.

### Unit tests

```
```sh
pytest test_app.py
```

### System Tests

```
```sh
export GOOGLE_CLOUD_PROJECT=<YOUR_PROJECT_ID>
export CLOUD_SQL_CONNECTION_NAME=<YOUR_CLOUD_SQL_CONNECTION_NAME>
export DB_PASSWORD=<POSTGRESQL_PASSWORD>
Expand Down
39 changes: 18 additions & 21 deletions run/idp-sql/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import datetime
import os
from typing import Dict
from typing import Any, Dict, Type

import sqlalchemy
from sqlalchemy.orm import close_all_sessions
Expand All @@ -30,12 +30,12 @@
db = None


def init_connection_engine() -> Dict[str, int]:
def init_connection_engine() -> sqlalchemy.engine.base.Engine:
if os.getenv("TRAMPOLINE_CI", None):
logger.info("Using NullPool for testing")
db_config = {"poolclass": NullPool}
db_config: Dict[str, Any] = {"poolclass": NullPool}
else:
db_config = {
db_config: Dict[str, Any] = {
# Pool size is the maximum number of permanent connections to keep.
"pool_size": 5,
# Temporarily exceeds the set pool_size if no connections are available.
Expand All @@ -61,7 +61,7 @@ def init_connection_engine() -> Dict[str, int]:


def init_tcp_connection_engine(
db_config: Dict[str, str]
db_config: Dict[str, Type[NullPool]]
) -> sqlalchemy.engine.base.Engine:
creds = credentials.get_cred_config()
db_user = creds["DB_USER"]
Expand Down Expand Up @@ -94,7 +94,7 @@ def init_tcp_connection_engine(

# [START cloudrun_user_auth_sql_connect]
def init_unix_connection_engine(
db_config: Dict[str, str]
db_config: Dict[str, int]
) -> sqlalchemy.engine.base.Engine:
creds = credentials.get_cred_config()
db_user = creds["DB_USER"]
Expand All @@ -113,9 +113,8 @@ def init_unix_connection_engine(
password=db_pass, # e.g. "my-database-password"
database=db_name, # e.g. "my-database-name"
query={
"unix_sock": "{}/{}/.s.PGSQL.5432".format(
db_socket_dir, cloud_sql_connection_name # e.g. "/cloudsql"
) # i.e "<PROJECT-NAME>:<INSTANCE-REGION>:<INSTANCE-NAME>"
"unix_sock": f"{db_socket_dir}/{cloud_sql_connection_name}/.s.PGSQL.5432"
# e.g. "/cloudsql", "<PROJECT-NAME>:<INSTANCE-REGION>:<INSTANCE-NAME>"
},
),
**db_config,
Expand All @@ -136,26 +135,26 @@ def create_tables() -> None:
global db
db = init_connection_engine()
# Create pet_votes table if it doesn't already exist
with db.connect() as conn:
conn.execute(
with db.begin() as conn:
conn.execute(sqlalchemy.text(
"CREATE TABLE IF NOT EXISTS pet_votes"
"( vote_id SERIAL NOT NULL, "
"time_cast timestamp NOT NULL, "
"candidate VARCHAR(6) NOT NULL, "
"uid VARCHAR(128) NOT NULL, "
"PRIMARY KEY (vote_id)"
");"
)
))


def get_index_context() -> Dict:
def get_index_context() -> Dict[str, Any]:
votes = []
with db.connect() as conn:
# Execute the query and fetch all results
recent_votes = conn.execute(
recent_votes = conn.execute(sqlalchemy.text(
"SELECT candidate, time_cast FROM pet_votes "
"ORDER BY time_cast DESC LIMIT 5"
).fetchall()
)).fetchall()
# Convert the results into a list of dicts representing votes
for row in recent_votes:
votes.append(
Expand All @@ -168,11 +167,9 @@ def get_index_context() -> Dict:
"SELECT COUNT(vote_id) FROM pet_votes WHERE candidate=:candidate"
)
# Count number of votes for cats
cats_result = conn.execute(stmt, candidate="CATS").fetchone()
cats_count = cats_result[0]
cats_count = conn.execute(stmt, parameters={"candidate": "CATS"}).scalar()
# Count number of votes for dogs
dogs_result = conn.execute(stmt, candidate="DOGS").fetchone()
dogs_count = dogs_result[0]
dogs_count = conn.execute(stmt, parameters={"candidate": "DOGS"}).scalar()
return {
"dogs_count": dogs_count,
"recent_votes": votes,
Expand All @@ -189,8 +186,8 @@ def save_vote(team: str, uid: str, time_cast: datetime.datetime) -> None:

# Using a with statement ensures that the connection is always released
# back into the pool at the end of statement (even if an error occurs)
with db.connect() as conn:
conn.execute(stmt, time_cast=time_cast, candidate=team, uid=uid)
with db.begin() as conn:
conn.execute(stmt, parameters={"time_cast": time_cast, "candidate": team, "uid": uid})
logger.info("Vote for %s saved.", team)


Expand Down
2 changes: 1 addition & 1 deletion run/idp-sql/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Flask==2.1.0
SQLAlchemy==1.4.38
SQLAlchemy==2.0.3
pg8000==1.24.2
gunicorn==20.1.0
firebase-admin==6.0.0
Expand Down

0 comments on commit f4cd664

Please sign in to comment.