Skip to content

Commit

Permalink
feat: Add tracing for AlloyDB and CloudSQL Postgres providers (#494)
Browse files Browse the repository at this point in the history
The `format_sql` helper now supports both `tuple` and `dict` for its
`params` argument. This change accommodates the AlloyDB and CloudSQL
Postgres providers, which pass `params` as `dict` through
`postgres_datastore.py`, unlike the non-cloud Postgres provider that
uses `tuples`. The updated `format_sql` helper is used in
`postgres_datastore.py` query handlers to format SQL queries, populate
them with variables, and return them as traces.

---------

Co-authored-by: Vishwaraj Anand <vishwaraj.anand00@gmail.com>
  • Loading branch information
anubhav756 and vishwarajanand authored Sep 22, 2024
1 parent 62776bb commit 2fa03bc
Show file tree
Hide file tree
Showing 5 changed files with 452 additions and 982 deletions.
33 changes: 18 additions & 15 deletions retrieval_service/datastore/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

import sqlparse


def format_sql(sql: str, params):
def format_sql(sql: str, params: dict):
"""
Format postgres sql to human readable text
Format Postgres SQL to human readable text by replacing placeholders.
Handles dict-based (:key) formats.
"""
for i in range(len(params)):
sql = sql.replace(f"${i+1}", f"{params[i]}")
# format the SQL
formatted_sql = (
sqlparse.format(
sql,
reindent=True,
keyword_case="upper",
use_space_around_operators=True,
strip_whitespace=True,
)
.replace("\n", "<br/>")
.replace(" ", '<div class="indent"></div>')
for key, value in params.items():
sql = sql.replace(f":{key}", f"{value}")
# format the SQL
formatted_sql = (
sqlparse.format(
sql,
reindent=True,
keyword_case="upper",
use_space_around_operators=True,
strip_whitespace=True,
)
.replace("\n", "<br/>")
.replace(" ", '<div class="indent"></div>')
)
return formatted_sql.replace("<br/>", "", 1)
48 changes: 23 additions & 25 deletions retrieval_service/datastore/providers/alloydb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import models

from .. import datastore
from .postgres_datastore import PostgresDatastore
from .postgres import Client as PostgresClient

ALLOYDB_PG_IDENTIFIER = "alloydb-postgres"

Expand All @@ -41,14 +41,14 @@ class Config(BaseModel, datastore.AbstractConfig):

class Client(datastore.Client[Config]):
__connector: Optional[AsyncConnector] = None
__pg_ds: PostgresDatastore
__pg_client: PostgresClient

@datastore.classproperty
def kind(cls):
return ALLOYDB_PG_IDENTIFIER

def __init__(self, pool: AsyncEngine):
self.__pg_ds = PostgresDatastore(pool)
def __init__(self, async_engine: AsyncEngine):
self.__pg_client = PostgresClient(async_engine)

@classmethod
async def create(cls, config: Config) -> "Client":
Expand All @@ -68,13 +68,13 @@ async def getconn() -> asyncpg.Connection:
await register_vector(conn)
return conn

pool = create_async_engine(
async_engine = create_async_engine(
"postgresql+asyncpg://",
async_creator=getconn,
)
if pool is None:
raise TypeError("pool not instantiated")
return cls(pool)
if async_engine is None:
raise TypeError("async_engine not instantiated")
return cls(async_engine)

async def initialize_data(
self,
Expand All @@ -83,9 +83,7 @@ async def initialize_data(
flights: list[models.Flight],
policies: list[models.Policy],
) -> None:
return await self.__pg_ds.initialize_data(
airports, amenities, flights, policies
)
await self.__pg_client.initialize_data(airports, amenities, flights, policies)

async def export_data(
self,
Expand All @@ -95,57 +93,57 @@ async def export_data(
list[models.Flight],
list[models.Policy],
]:
return await self.__pg_ds.export_data()
return await self.__pg_client.export_data()

async def get_airport_by_id(
self, id: int
) -> tuple[Optional[models.Airport], Optional[str]]:
return await self.__pg_ds.get_airport_by_id(id)
return await self.__pg_client.get_airport_by_id(id)

async def get_airport_by_iata(
self, iata: str
) -> tuple[Optional[models.Airport], Optional[str]]:
return await self.__pg_ds.get_airport_by_iata(iata)
return await self.__pg_client.get_airport_by_iata(iata)

async def search_airports(
self,
country: Optional[str] = None,
city: Optional[str] = None,
name: Optional[str] = None,
) -> tuple[list[models.Airport], Optional[str]]:
return await self.__pg_ds.search_airports(country, city, name)
return await self.__pg_client.search_airports(country, city, name)

async def get_amenity(
self, id: int
) -> tuple[Optional[models.Amenity], Optional[str]]:
return await self.__pg_ds.get_amenity(id)
return await self.__pg_client.get_amenity(id)

async def amenities_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> tuple[list[Any], Optional[str]]:
return await self.__pg_ds.amenities_search(
return await self.__pg_client.amenities_search(
query_embedding, similarity_threshold, top_k
)

async def get_flight(
self, flight_id: int
) -> tuple[Optional[models.Flight], Optional[str]]:
return await self.__pg_ds.get_flight(flight_id)
return await self.__pg_client.get_flight(flight_id)

async def search_flights_by_number(
self,
airline: str,
number: str,
) -> tuple[list[models.Flight], Optional[str]]:
return await self.__pg_ds.search_flights_by_number(airline, number)
return await self.__pg_client.search_flights_by_number(airline, number)

async def search_flights_by_airports(
self,
date: str,
departure_airport: Optional[str] = None,
arrival_airport: Optional[str] = None,
) -> tuple[list[models.Flight], Optional[str]]:
return await self.__pg_ds.search_flights_by_airports(
return await self.__pg_client.search_flights_by_airports(
date, departure_airport, arrival_airport
)

Expand All @@ -156,7 +154,7 @@ async def validate_ticket(
departure_airport: str,
departure_time: str,
) -> tuple[Optional[models.Flight], Optional[str]]:
return await self.__pg_ds.validate_ticket(
return await self.__pg_client.validate_ticket(
airline, flight_number, departure_airport, departure_time
)

Expand All @@ -172,7 +170,7 @@ async def insert_ticket(
departure_time: str,
arrival_time: str,
):
return await self.__pg_ds.insert_ticket(
await self.__pg_client.insert_ticket(
user_id,
user_name,
user_email,
Expand All @@ -188,14 +186,14 @@ async def list_tickets(
self,
user_id: str,
) -> tuple[list[Any], Optional[str]]:
return await self.__pg_ds.list_tickets(user_id)
return await self.__pg_client.list_tickets(user_id)

async def policies_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> tuple[list[str], Optional[str]]:
return await self.__pg_ds.policies_search(
return await self.__pg_client.policies_search(
query_embedding, similarity_threshold, top_k
)

async def close(self):
return await self.__pg_ds.close()
await self.__pg_client.close()
48 changes: 23 additions & 25 deletions retrieval_service/datastore/providers/cloudsql_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import models

from .. import datastore
from .postgres_datastore import PostgresDatastore
from .postgres import Client as PostgresClient

CLOUD_SQL_PG_IDENTIFIER = "cloudsql-postgres"

Expand All @@ -40,15 +40,15 @@ class Config(BaseModel, datastore.AbstractConfig):


class Client(datastore.Client[Config]):
__pg_ds: PostgresDatastore
__pg_client: PostgresClient
__connector: Optional[Connector] = None

@datastore.classproperty
def kind(cls):
return CLOUD_SQL_PG_IDENTIFIER

def __init__(self, pool: AsyncEngine):
self.__pg_ds = PostgresDatastore(pool)
def __init__(self, async_engine: AsyncEngine):
self.__pg_client = PostgresClient(async_engine)

@classmethod
async def create(cls, config: Config) -> "Client":
Expand All @@ -70,13 +70,13 @@ async def getconn() -> asyncpg.Connection:
await register_vector(conn)
return conn

pool = create_async_engine(
async_engine = create_async_engine(
"postgresql+asyncpg://",
async_creator=getconn,
)
if pool is None:
raise TypeError("pool not instantiated")
return cls(pool)
if async_engine is None:
raise TypeError("async_engine not instantiated")
return cls(async_engine)

async def initialize_data(
self,
Expand All @@ -85,9 +85,7 @@ async def initialize_data(
flights: list[models.Flight],
policies: list[models.Policy],
) -> None:
return await self.__pg_ds.initialize_data(
airports, amenities, flights, policies
)
await self.__pg_client.initialize_data(airports, amenities, flights, policies)

async def export_data(
self,
Expand All @@ -97,57 +95,57 @@ async def export_data(
list[models.Flight],
list[models.Policy],
]:
return await self.__pg_ds.export_data()
return await self.__pg_client.export_data()

async def get_airport_by_id(
self, id: int
) -> tuple[Optional[models.Airport], Optional[str]]:
return await self.__pg_ds.get_airport_by_id(id)
return await self.__pg_client.get_airport_by_id(id)

async def get_airport_by_iata(
self, iata: str
) -> tuple[Optional[models.Airport], Optional[str]]:
return await self.__pg_ds.get_airport_by_iata(iata)
return await self.__pg_client.get_airport_by_iata(iata)

async def search_airports(
self,
country: Optional[str] = None,
city: Optional[str] = None,
name: Optional[str] = None,
) -> tuple[list[models.Airport], Optional[str]]:
return await self.__pg_ds.search_airports(country, city, name)
return await self.__pg_client.search_airports(country, city, name)

async def get_amenity(
self, id: int
) -> tuple[Optional[models.Amenity], Optional[str]]:
return await self.__pg_ds.get_amenity(id)
return await self.__pg_client.get_amenity(id)

async def amenities_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> tuple[list[Any], Optional[str]]:
return await self.__pg_ds.amenities_search(
return await self.__pg_client.amenities_search(
query_embedding, similarity_threshold, top_k
)

async def get_flight(
self, flight_id: int
) -> tuple[Optional[models.Flight], Optional[str]]:
return await self.__pg_ds.get_flight(flight_id)
return await self.__pg_client.get_flight(flight_id)

async def search_flights_by_number(
self,
airline: str,
number: str,
) -> tuple[list[models.Flight], Optional[str]]:
return await self.__pg_ds.search_flights_by_number(airline, number)
return await self.__pg_client.search_flights_by_number(airline, number)

async def search_flights_by_airports(
self,
date: str,
departure_airport: Optional[str] = None,
arrival_airport: Optional[str] = None,
) -> tuple[list[models.Flight], Optional[str]]:
return await self.__pg_ds.search_flights_by_airports(
return await self.__pg_client.search_flights_by_airports(
date, departure_airport, arrival_airport
)

Expand All @@ -158,7 +156,7 @@ async def validate_ticket(
departure_airport: str,
departure_time: str,
) -> tuple[Optional[models.Flight], Optional[str]]:
return await self.__pg_ds.validate_ticket(
return await self.__pg_client.validate_ticket(
airline, flight_number, departure_airport, departure_time
)

Expand All @@ -174,7 +172,7 @@ async def insert_ticket(
departure_time: str,
arrival_time: str,
):
return await self.__pg_ds.insert_ticket(
await self.__pg_client.insert_ticket(
user_id,
user_name,
user_email,
Expand All @@ -190,14 +188,14 @@ async def list_tickets(
self,
user_id: str,
) -> tuple[list[Any], Optional[str]]:
return await self.__pg_ds.list_tickets(user_id)
return await self.__pg_client.list_tickets(user_id)

async def policies_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> tuple[list[str], Optional[str]]:
return await self.__pg_ds.policies_search(
return await self.__pg_client.policies_search(
query_embedding, similarity_threshold, top_k
)

async def close(self):
await self.__pg_ds.close()
await self.__pg_client.close()
Loading

0 comments on commit 2fa03bc

Please sign in to comment.