Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: consolidate Postgres providers #494

Merged
merged 15 commits into from
Sep 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: Add tracing for AlloyDB and CloudSQL Postgres providers
The `format_sql` helper now supports both `tuple` and `dict` for its `params` argument. This change accommodates the AlloyDB and CloudSQL Postgres providers, which pass `params` as `dict` through `postgres_datastore.py`, unlike the non-cloud Postgres provider that uses `tuples`.  The updated `format_sql` helper is used in `postgres_datastore.py` query handlers to format SQL queries, populate them with variables, and return them as traces.
  • Loading branch information
anubhav756 committed Sep 19, 2024
commit a82e09631c4e8c0dd5054fcdb676f3e0c31a8c0f
36 changes: 22 additions & 14 deletions retrieval_service/datastore/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

import sqlparse


def format_sql(sql: str, params):
def format_sql(sql: str, params: Union[tuple, dict]):
anubhav756 marked this conversation as resolved.
Show resolved Hide resolved
"""
Format postgres sql to human readable text
anubhav756 marked this conversation as resolved.
Show resolved Hide resolved
"""
for i in range(len(params)):
sql = sql.replace(f"${i+1}", f"{params[i]}")
# format the SQL
formatted_sql = (
sqlparse.format(
sql,
reindent=True,
keyword_case="upper",
use_space_around_operators=True,
strip_whitespace=True,
)
.replace("\n", "<br/>")
.replace(" ", '<div class="indent"></div>')
if isinstance(params, tuple):
for i, value in enumerate(params):
sql = sql.replace(f"${i+1}", f"{value}")
elif isinstance(params, dict):
for key, value in params.items():
sql = sql.replace(f":{key}", f"{value}")
else:
raise ValueError("params must be a tuple or dict")
# format the SQL
formatted_sql = (
sqlparse.format(
sql,
reindent=True,
keyword_case="upper",
use_space_around_operators=True,
strip_whitespace=True,
)
.replace("\n", "<br/>")
.replace(" ", '<div class="indent"></div>')
)
return formatted_sql.replace("<br/>", "", 1)
24 changes: 13 additions & 11 deletions retrieval_service/datastore/providers/postgres_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import models

from ..helpers import format_sql


class PostgresDatastore:
def __init__(self, pool: AsyncEngine):
Expand Down Expand Up @@ -289,7 +291,7 @@ async def get_airport_by_id(
return None, None

res = models.Airport.model_validate(result)
return res, sql
return res, format_sql(sql, params)

async def get_airport_by_iata(
self, iata: str
Expand All @@ -304,7 +306,7 @@ async def get_airport_by_iata(
return None, None

res = models.Airport.model_validate(result)
return res, sql
return res, format_sql(sql, params)

async def search_airports(
self,
Expand All @@ -329,7 +331,7 @@ async def search_airports(
results = (await conn.execute(s, params)).mappings().fetchall()

res = [models.Airport.model_validate(r) for r in results]
return res, sql
return res, format_sql(sql, params)

async def get_amenity(
self, id: int
Expand All @@ -347,7 +349,7 @@ async def get_amenity(
return None, None

res = models.Amenity.model_validate(result)
return res, sql
return res, format_sql(sql, params)

async def amenities_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
Expand All @@ -369,7 +371,7 @@ async def amenities_search(
results = (await conn.execute(s, params)).mappings().fetchall()

res = [r for r in results]
return res, sql
return res, format_sql(sql, params)

async def get_flight(
self, flight_id: int
Expand All @@ -387,7 +389,7 @@ async def get_flight(
return None, None

res = models.Flight.model_validate(result)
return res, sql
return res, format_sql(sql, params)

async def search_flights_by_number(
self,
Expand All @@ -409,7 +411,7 @@ async def search_flights_by_number(
results = (await conn.execute(s, params)).mappings().fetchall()

res = [models.Flight.model_validate(r) for r in results]
return res, sql
return res, format_sql(sql, params)

async def search_flights_by_airports(
self,
Expand All @@ -436,7 +438,7 @@ async def search_flights_by_airports(
results = (await conn.execute(s, params)).mappings().fetchall()

res = [models.Flight.model_validate(r) for r in results]
return res, sql
return res, format_sql(sql, params)

async def validate_ticket(
self,
Expand Down Expand Up @@ -466,7 +468,7 @@ async def validate_ticket(
if result is None:
return None, None
res = models.Flight.model_validate(result)
return res, sql
return res, format_sql(sql, params)

async def insert_ticket(
self,
Expand Down Expand Up @@ -541,7 +543,7 @@ async def list_tickets(
results = (await conn.execute(s, params)).mappings().fetchall()

res = [r for r in results]
return res, sql
return res, format_sql(sql, params)

async def policies_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
Expand All @@ -563,7 +565,7 @@ async def policies_search(
results = (await conn.execute(s, params)).mappings().fetchall()

res = [r["content"] for r in results]
return res, sql
return res, format_sql(sql, params)

async def close(self):
await self.__pool.dispose()