Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: consolidate Postgres providers #494

Merged
merged 15 commits into from
Sep 22, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 24 additions & 15 deletions retrieval_service/datastore/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

import sqlparse


def format_sql(sql: str, params):
def format_sql(sql: str, params: Union[tuple, dict]):
anubhav756 marked this conversation as resolved.
Show resolved Hide resolved
"""
Format postgres sql to human readable text
Format Postgres SQL to human readable text by replacing placeholders.
Handles tuple-based ($1, $2, ...) and dict-based (:key) formats.
"""
for i in range(len(params)):
sql = sql.replace(f"${i+1}", f"{params[i]}")
# format the SQL
formatted_sql = (
sqlparse.format(
sql,
reindent=True,
keyword_case="upper",
use_space_around_operators=True,
strip_whitespace=True,
)
.replace("\n", "<br/>")
.replace(" ", '<div class="indent"></div>')
if isinstance(params, tuple):
for i, value in enumerate(params):
sql = sql.replace(f"${i+1}", f"{value}")
elif isinstance(params, dict):
for key, value in params.items():
sql = sql.replace(f":{key}", f"{value}")
else:
raise ValueError("params must be a tuple or dict")
# format the SQL
formatted_sql = (
sqlparse.format(
sql,
reindent=True,
keyword_case="upper",
use_space_around_operators=True,
strip_whitespace=True,
)
.replace("\n", "<br/>")
.replace(" ", '<div class="indent"></div>')
)
return formatted_sql.replace("<br/>", "", 1)
8 changes: 3 additions & 5 deletions retrieval_service/datastore/providers/alloydb.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,7 @@ async def initialize_data(
flights: list[models.Flight],
policies: list[models.Policy],
) -> None:
return await self.__pg_ds.initialize_data(
airports, amenities, flights, policies
)
await self.__pg_ds.initialize_data(airports, amenities, flights, policies)

async def export_data(
self,
Expand Down Expand Up @@ -172,7 +170,7 @@ async def insert_ticket(
departure_time: str,
arrival_time: str,
):
return await self.__pg_ds.insert_ticket(
await self.__pg_ds.insert_ticket(
user_id,
user_name,
user_email,
Expand All @@ -198,4 +196,4 @@ async def policies_search(
)

async def close(self):
return await self.__pg_ds.close()
await self.__pg_ds.close()
6 changes: 2 additions & 4 deletions retrieval_service/datastore/providers/cloudsql_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ async def initialize_data(
flights: list[models.Flight],
policies: list[models.Policy],
) -> None:
return await self.__pg_ds.initialize_data(
airports, amenities, flights, policies
)
await self.__pg_ds.initialize_data(airports, amenities, flights, policies)

async def export_data(
self,
Expand Down Expand Up @@ -174,7 +172,7 @@ async def insert_ticket(
departure_time: str,
arrival_time: str,
):
return await self.__pg_ds.insert_ticket(
await self.__pg_ds.insert_ticket(
user_id,
user_name,
user_email,
Expand Down
24 changes: 13 additions & 11 deletions retrieval_service/datastore/providers/postgres_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import models

from ..helpers import format_sql


class PostgresDatastore:
def __init__(self, pool: AsyncEngine):
Expand Down Expand Up @@ -289,7 +291,7 @@ async def get_airport_by_id(
return None, None

res = models.Airport.model_validate(result)
return res, sql
return res, format_sql(sql, params)

async def get_airport_by_iata(
self, iata: str
Expand All @@ -304,7 +306,7 @@ async def get_airport_by_iata(
return None, None

res = models.Airport.model_validate(result)
return res, sql
return res, format_sql(sql, params)

async def search_airports(
self,
Expand All @@ -329,7 +331,7 @@ async def search_airports(
results = (await conn.execute(s, params)).mappings().fetchall()

res = [models.Airport.model_validate(r) for r in results]
return res, sql
return res, format_sql(sql, params)

async def get_amenity(
self, id: int
Expand All @@ -347,7 +349,7 @@ async def get_amenity(
return None, None

res = models.Amenity.model_validate(result)
return res, sql
return res, format_sql(sql, params)

async def amenities_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
Expand All @@ -369,7 +371,7 @@ async def amenities_search(
results = (await conn.execute(s, params)).mappings().fetchall()

res = [r for r in results]
return res, sql
return res, format_sql(sql, params)

async def get_flight(
self, flight_id: int
Expand All @@ -387,7 +389,7 @@ async def get_flight(
return None, None

res = models.Flight.model_validate(result)
return res, sql
return res, format_sql(sql, params)

async def search_flights_by_number(
self,
Expand All @@ -409,7 +411,7 @@ async def search_flights_by_number(
results = (await conn.execute(s, params)).mappings().fetchall()

res = [models.Flight.model_validate(r) for r in results]
return res, sql
return res, format_sql(sql, params)

async def search_flights_by_airports(
self,
Expand All @@ -436,7 +438,7 @@ async def search_flights_by_airports(
results = (await conn.execute(s, params)).mappings().fetchall()

res = [models.Flight.model_validate(r) for r in results]
return res, sql
return res, format_sql(sql, params)

async def validate_ticket(
self,
Expand Down Expand Up @@ -466,7 +468,7 @@ async def validate_ticket(
if result is None:
return None, None
res = models.Flight.model_validate(result)
return res, sql
return res, format_sql(sql, params)

async def insert_ticket(
self,
Expand Down Expand Up @@ -541,7 +543,7 @@ async def list_tickets(
results = (await conn.execute(s, params)).mappings().fetchall()

res = [r for r in results]
return res, sql
return res, format_sql(sql, params)

async def policies_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
Expand All @@ -563,7 +565,7 @@ async def policies_search(
results = (await conn.execute(s, params)).mappings().fetchall()

res = [r["content"] for r in results]
return res, sql
return res, format_sql(sql, params)

async def close(self):
await self.__pool.dispose()