From 2fa03bcec54fd2fe4b463a54df772ae5e6490577 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Mon, 23 Sep 2024 00:11:02 +0530 Subject: [PATCH] feat: Add tracing for AlloyDB and CloudSQL Postgres providers (#494) The `format_sql` helper now supports both `tuple` and `dict` for its `params` argument. This change accommodates the AlloyDB and CloudSQL Postgres providers, which pass `params` as `dict` through `postgres_datastore.py`, unlike the non-cloud Postgres provider that uses `tuples`. The updated `format_sql` helper is used in `postgres_datastore.py` query handlers to format SQL queries, populate them with variables, and return them as traces. --------- Co-authored-by: Vishwaraj Anand --- retrieval_service/datastore/helpers.py | 33 +- .../datastore/providers/alloydb.py | 48 +- .../datastore/providers/cloudsql_postgres.py | 48 +- .../datastore/providers/postgres.py | 736 +++++++++--------- .../datastore/providers/postgres_datastore.py | 569 -------------- 5 files changed, 452 insertions(+), 982 deletions(-) delete mode 100644 retrieval_service/datastore/providers/postgres_datastore.py diff --git a/retrieval_service/datastore/helpers.py b/retrieval_service/datastore/helpers.py index 14d6b1dc..13e025f2 100644 --- a/retrieval_service/datastore/helpers.py +++ b/retrieval_service/datastore/helpers.py @@ -12,25 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union + import sqlparse -def format_sql(sql: str, params): +def format_sql(sql: str, params: dict): """ - Format postgres sql to human readable text + Format Postgres SQL to human readable text by replacing placeholders. + Handles dict-based (:key) formats. """ - for i in range(len(params)): - sql = sql.replace(f"${i+1}", f"{params[i]}") - # format the SQL - formatted_sql = ( - sqlparse.format( - sql, - reindent=True, - keyword_case="upper", - use_space_around_operators=True, - strip_whitespace=True, - ) - .replace("\n", "
") - .replace(" ", '
') + for key, value in params.items(): + sql = sql.replace(f":{key}", f"{value}") + # format the SQL + formatted_sql = ( + sqlparse.format( + sql, + reindent=True, + keyword_case="upper", + use_space_around_operators=True, + strip_whitespace=True, ) + .replace("\n", "
") + .replace(" ", '
') + ) return formatted_sql.replace("
", "", 1) diff --git a/retrieval_service/datastore/providers/alloydb.py b/retrieval_service/datastore/providers/alloydb.py index 910c0248..04fcd34e 100644 --- a/retrieval_service/datastore/providers/alloydb.py +++ b/retrieval_service/datastore/providers/alloydb.py @@ -23,7 +23,7 @@ import models from .. import datastore -from .postgres_datastore import PostgresDatastore +from .postgres import Client as PostgresClient ALLOYDB_PG_IDENTIFIER = "alloydb-postgres" @@ -41,14 +41,14 @@ class Config(BaseModel, datastore.AbstractConfig): class Client(datastore.Client[Config]): __connector: Optional[AsyncConnector] = None - __pg_ds: PostgresDatastore + __pg_client: PostgresClient @datastore.classproperty def kind(cls): return ALLOYDB_PG_IDENTIFIER - def __init__(self, pool: AsyncEngine): - self.__pg_ds = PostgresDatastore(pool) + def __init__(self, async_engine: AsyncEngine): + self.__pg_client = PostgresClient(async_engine) @classmethod async def create(cls, config: Config) -> "Client": @@ -68,13 +68,13 @@ async def getconn() -> asyncpg.Connection: await register_vector(conn) return conn - pool = create_async_engine( + async_engine = create_async_engine( "postgresql+asyncpg://", async_creator=getconn, ) - if pool is None: - raise TypeError("pool not instantiated") - return cls(pool) + if async_engine is None: + raise TypeError("async_engine not instantiated") + return cls(async_engine) async def initialize_data( self, @@ -83,9 +83,7 @@ async def initialize_data( flights: list[models.Flight], policies: list[models.Policy], ) -> None: - return await self.__pg_ds.initialize_data( - airports, amenities, flights, policies - ) + await self.__pg_client.initialize_data(airports, amenities, flights, policies) async def export_data( self, @@ -95,17 +93,17 @@ async def export_data( list[models.Flight], list[models.Policy], ]: - return await self.__pg_ds.export_data() + return await self.__pg_client.export_data() async def get_airport_by_id( self, id: int ) -> tuple[Optional[models.Airport], Optional[str]]: - return await self.__pg_ds.get_airport_by_id(id) + return await self.__pg_client.get_airport_by_id(id) async def get_airport_by_iata( self, iata: str ) -> tuple[Optional[models.Airport], Optional[str]]: - return await self.__pg_ds.get_airport_by_iata(iata) + return await self.__pg_client.get_airport_by_iata(iata) async def search_airports( self, @@ -113,31 +111,31 @@ async def search_airports( city: Optional[str] = None, name: Optional[str] = None, ) -> tuple[list[models.Airport], Optional[str]]: - return await self.__pg_ds.search_airports(country, city, name) + return await self.__pg_client.search_airports(country, city, name) async def get_amenity( self, id: int ) -> tuple[Optional[models.Amenity], Optional[str]]: - return await self.__pg_ds.get_amenity(id) + return await self.__pg_client.get_amenity(id) async def amenities_search( self, query_embedding: list[float], similarity_threshold: float, top_k: int ) -> tuple[list[Any], Optional[str]]: - return await self.__pg_ds.amenities_search( + return await self.__pg_client.amenities_search( query_embedding, similarity_threshold, top_k ) async def get_flight( self, flight_id: int ) -> tuple[Optional[models.Flight], Optional[str]]: - return await self.__pg_ds.get_flight(flight_id) + return await self.__pg_client.get_flight(flight_id) async def search_flights_by_number( self, airline: str, number: str, ) -> tuple[list[models.Flight], Optional[str]]: - return await self.__pg_ds.search_flights_by_number(airline, number) + return await self.__pg_client.search_flights_by_number(airline, number) async def search_flights_by_airports( self, @@ -145,7 +143,7 @@ async def search_flights_by_airports( departure_airport: Optional[str] = None, arrival_airport: Optional[str] = None, ) -> tuple[list[models.Flight], Optional[str]]: - return await self.__pg_ds.search_flights_by_airports( + return await self.__pg_client.search_flights_by_airports( date, departure_airport, arrival_airport ) @@ -156,7 +154,7 @@ async def validate_ticket( departure_airport: str, departure_time: str, ) -> tuple[Optional[models.Flight], Optional[str]]: - return await self.__pg_ds.validate_ticket( + return await self.__pg_client.validate_ticket( airline, flight_number, departure_airport, departure_time ) @@ -172,7 +170,7 @@ async def insert_ticket( departure_time: str, arrival_time: str, ): - return await self.__pg_ds.insert_ticket( + await self.__pg_client.insert_ticket( user_id, user_name, user_email, @@ -188,14 +186,14 @@ async def list_tickets( self, user_id: str, ) -> tuple[list[Any], Optional[str]]: - return await self.__pg_ds.list_tickets(user_id) + return await self.__pg_client.list_tickets(user_id) async def policies_search( self, query_embedding: list[float], similarity_threshold: float, top_k: int ) -> tuple[list[str], Optional[str]]: - return await self.__pg_ds.policies_search( + return await self.__pg_client.policies_search( query_embedding, similarity_threshold, top_k ) async def close(self): - return await self.__pg_ds.close() + await self.__pg_client.close() diff --git a/retrieval_service/datastore/providers/cloudsql_postgres.py b/retrieval_service/datastore/providers/cloudsql_postgres.py index c6fba0e3..364b3fe1 100644 --- a/retrieval_service/datastore/providers/cloudsql_postgres.py +++ b/retrieval_service/datastore/providers/cloudsql_postgres.py @@ -24,7 +24,7 @@ import models from .. import datastore -from .postgres_datastore import PostgresDatastore +from .postgres import Client as PostgresClient CLOUD_SQL_PG_IDENTIFIER = "cloudsql-postgres" @@ -40,15 +40,15 @@ class Config(BaseModel, datastore.AbstractConfig): class Client(datastore.Client[Config]): - __pg_ds: PostgresDatastore + __pg_client: PostgresClient __connector: Optional[Connector] = None @datastore.classproperty def kind(cls): return CLOUD_SQL_PG_IDENTIFIER - def __init__(self, pool: AsyncEngine): - self.__pg_ds = PostgresDatastore(pool) + def __init__(self, async_engine: AsyncEngine): + self.__pg_client = PostgresClient(async_engine) @classmethod async def create(cls, config: Config) -> "Client": @@ -70,13 +70,13 @@ async def getconn() -> asyncpg.Connection: await register_vector(conn) return conn - pool = create_async_engine( + async_engine = create_async_engine( "postgresql+asyncpg://", async_creator=getconn, ) - if pool is None: - raise TypeError("pool not instantiated") - return cls(pool) + if async_engine is None: + raise TypeError("async_engine not instantiated") + return cls(async_engine) async def initialize_data( self, @@ -85,9 +85,7 @@ async def initialize_data( flights: list[models.Flight], policies: list[models.Policy], ) -> None: - return await self.__pg_ds.initialize_data( - airports, amenities, flights, policies - ) + await self.__pg_client.initialize_data(airports, amenities, flights, policies) async def export_data( self, @@ -97,17 +95,17 @@ async def export_data( list[models.Flight], list[models.Policy], ]: - return await self.__pg_ds.export_data() + return await self.__pg_client.export_data() async def get_airport_by_id( self, id: int ) -> tuple[Optional[models.Airport], Optional[str]]: - return await self.__pg_ds.get_airport_by_id(id) + return await self.__pg_client.get_airport_by_id(id) async def get_airport_by_iata( self, iata: str ) -> tuple[Optional[models.Airport], Optional[str]]: - return await self.__pg_ds.get_airport_by_iata(iata) + return await self.__pg_client.get_airport_by_iata(iata) async def search_airports( self, @@ -115,31 +113,31 @@ async def search_airports( city: Optional[str] = None, name: Optional[str] = None, ) -> tuple[list[models.Airport], Optional[str]]: - return await self.__pg_ds.search_airports(country, city, name) + return await self.__pg_client.search_airports(country, city, name) async def get_amenity( self, id: int ) -> tuple[Optional[models.Amenity], Optional[str]]: - return await self.__pg_ds.get_amenity(id) + return await self.__pg_client.get_amenity(id) async def amenities_search( self, query_embedding: list[float], similarity_threshold: float, top_k: int ) -> tuple[list[Any], Optional[str]]: - return await self.__pg_ds.amenities_search( + return await self.__pg_client.amenities_search( query_embedding, similarity_threshold, top_k ) async def get_flight( self, flight_id: int ) -> tuple[Optional[models.Flight], Optional[str]]: - return await self.__pg_ds.get_flight(flight_id) + return await self.__pg_client.get_flight(flight_id) async def search_flights_by_number( self, airline: str, number: str, ) -> tuple[list[models.Flight], Optional[str]]: - return await self.__pg_ds.search_flights_by_number(airline, number) + return await self.__pg_client.search_flights_by_number(airline, number) async def search_flights_by_airports( self, @@ -147,7 +145,7 @@ async def search_flights_by_airports( departure_airport: Optional[str] = None, arrival_airport: Optional[str] = None, ) -> tuple[list[models.Flight], Optional[str]]: - return await self.__pg_ds.search_flights_by_airports( + return await self.__pg_client.search_flights_by_airports( date, departure_airport, arrival_airport ) @@ -158,7 +156,7 @@ async def validate_ticket( departure_airport: str, departure_time: str, ) -> tuple[Optional[models.Flight], Optional[str]]: - return await self.__pg_ds.validate_ticket( + return await self.__pg_client.validate_ticket( airline, flight_number, departure_airport, departure_time ) @@ -174,7 +172,7 @@ async def insert_ticket( departure_time: str, arrival_time: str, ): - return await self.__pg_ds.insert_ticket( + await self.__pg_client.insert_ticket( user_id, user_name, user_email, @@ -190,14 +188,14 @@ async def list_tickets( self, user_id: str, ) -> tuple[list[Any], Optional[str]]: - return await self.__pg_ds.list_tickets(user_id) + return await self.__pg_client.list_tickets(user_id) async def policies_search( self, query_embedding: list[float], similarity_threshold: float, top_k: int ) -> tuple[list[str], Optional[str]]: - return await self.__pg_ds.policies_search( + return await self.__pg_client.policies_search( query_embedding, similarity_threshold, top_k ) async def close(self): - await self.__pg_ds.close() + await self.__pg_client.close() diff --git a/retrieval_service/datastore/providers/postgres.py b/retrieval_service/datastore/providers/postgres.py index 8c9661eb..a806f01b 100644 --- a/retrieval_service/datastore/providers/postgres.py +++ b/retrieval_service/datastore/providers/postgres.py @@ -20,6 +20,8 @@ import asyncpg from pgvector.asyncpg import register_vector from pydantic import BaseModel +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine import models @@ -39,31 +41,35 @@ class Config(BaseModel, datastore.AbstractConfig): class Client(datastore.Client[Config]): - __pool: asyncpg.Pool + __async_engine: AsyncEngine @datastore.classproperty def kind(cls): return POSTGRES_IDENTIFIER - def __init__(self, pool: asyncpg.Pool): - self.__pool = pool + def __init__(self, async_engine: AsyncEngine): + self.__async_engine = async_engine @classmethod async def create(cls, config: Config) -> "Client": - async def init(conn): + async def getconn() -> asyncpg.Connection: + conn: asyncpg.Connection = await asyncpg.connection.connect( + host=str(config.host), + user=config.user, + password=config.password, + database=config.database, + port=config.port, + ) await register_vector(conn) + return conn - pool = await asyncpg.create_pool( - host=str(config.host), - user=config.user, - password=config.password, - database=config.database, - port=config.port, - init=init, + async_engine = create_async_engine( + "postgresql+asyncpg://", + async_creator=getconn, ) - if pool is None: - raise TypeError("pool not instantiated") - return cls(pool) + if async_engine is None: + raise TypeError("async_engine not instantiated") + return cls(async_engine) async def initialize_data( self, @@ -72,182 +78,214 @@ async def initialize_data( flights: list[models.Flight], policies: list[models.Policy], ) -> None: - async with self.__pool.acquire() as conn: - await conn.execute("CREATE EXTENSION IF NOT EXISTS vector") + async with self.__async_engine.connect() as conn: # If the table already exists, drop it to avoid conflicts - await conn.execute("DROP TABLE IF EXISTS airports CASCADE") + await conn.execute(text("DROP TABLE IF EXISTS airports CASCADE")) # Create a new table await conn.execute( - """ - CREATE TABLE airports( - id INT PRIMARY KEY, - iata TEXT, - name TEXT, - city TEXT, - country TEXT + text( + """ + CREATE TABLE airports( + id INT PRIMARY KEY, + iata TEXT, + name TEXT, + city TEXT, + country TEXT + ) + """ ) - """ ) # Insert all the data - await conn.executemany( - """INSERT INTO airports VALUES ($1, $2, $3, $4, $5)""", - [(a.id, a.iata, a.name, a.city, a.country) for a in airports], + await conn.execute( + text( + """INSERT INTO airports VALUES (:id, :iata, :name, :city, :country)""" + ), + [ + { + "id": a.id, + "iata": a.iata, + "name": a.name, + "city": a.city, + "country": a.country, + } + for a in airports + ], ) + await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # If the table already exists, drop it to avoid conflicts - await conn.execute("DROP TABLE IF EXISTS amenities CASCADE") + await conn.execute(text("DROP TABLE IF EXISTS amenities CASCADE")) # Create a new table await conn.execute( - """ - CREATE TABLE amenities( - id INT PRIMARY KEY, - name TEXT, - description TEXT, - location TEXT, - terminal TEXT, - category TEXT, - hour TEXT, - sunday_start_hour TIME, - sunday_end_hour TIME, - monday_start_hour TIME, - monday_end_hour TIME, - tuesday_start_hour TIME, - tuesday_end_hour TIME, - wednesday_start_hour TIME, - wednesday_end_hour TIME, - thursday_start_hour TIME, - thursday_end_hour TIME, - friday_start_hour TIME, - friday_end_hour TIME, - saturday_start_hour TIME, - saturday_end_hour TIME, - content TEXT NOT NULL, - embedding vector(768) NOT NULL + text( + """ + CREATE TABLE amenities( + id INT PRIMARY KEY, + name TEXT, + description TEXT, + location TEXT, + terminal TEXT, + category TEXT, + hour TEXT, + sunday_start_hour TIME, + sunday_end_hour TIME, + monday_start_hour TIME, + monday_end_hour TIME, + tuesday_start_hour TIME, + tuesday_end_hour TIME, + wednesday_start_hour TIME, + wednesday_end_hour TIME, + thursday_start_hour TIME, + thursday_end_hour TIME, + friday_start_hour TIME, + friday_end_hour TIME, + saturday_start_hour TIME, + saturday_end_hour TIME, + content TEXT NOT NULL, + embedding vector(768) NOT NULL + ) + """ ) - """ ) # Insert all the data - await conn.executemany( - """ - INSERT INTO amenities VALUES ( - $1, $2, $3, $4, $5, - $6, $7, $8, $9, $10, - $11, $12, $13, $14, $15, - $16, $17, $18, $19, $20, - $21, $22, $23) - """, + await conn.execute( + text( + """ + INSERT INTO amenities VALUES (:id, :name, :description, :location, + :terminal, :category, :hour, :sunday_start_hour, :sunday_end_hour, + :monday_start_hour, :monday_end_hour, :tuesday_start_hour, + :tuesday_end_hour, :wednesday_start_hour, :wednesday_end_hour, + :thursday_start_hour, :thursday_end_hour, :friday_start_hour, + :friday_end_hour, :saturday_start_hour, :saturday_end_hour, :content, :embedding) + """ + ), [ - ( - a.id, - a.name, - a.description, - a.location, - a.terminal, - a.category, - a.hour, - a.sunday_start_hour, - a.sunday_end_hour, - a.monday_start_hour, - a.monday_end_hour, - a.tuesday_start_hour, - a.tuesday_end_hour, - a.wednesday_start_hour, - a.wednesday_end_hour, - a.thursday_start_hour, - a.thursday_end_hour, - a.friday_start_hour, - a.friday_end_hour, - a.saturday_start_hour, - a.saturday_end_hour, - a.content, - a.embedding, - ) + { + "id": a.id, + "name": a.name, + "description": a.description, + "location": a.location, + "terminal": a.terminal, + "category": a.category, + "hour": a.hour, + "sunday_start_hour": a.sunday_start_hour, + "sunday_end_hour": a.sunday_end_hour, + "monday_start_hour": a.monday_start_hour, + "monday_end_hour": a.monday_end_hour, + "tuesday_start_hour": a.tuesday_start_hour, + "tuesday_end_hour": a.tuesday_end_hour, + "wednesday_start_hour": a.wednesday_start_hour, + "wednesday_end_hour": a.wednesday_end_hour, + "thursday_start_hour": a.thursday_start_hour, + "thursday_end_hour": a.thursday_end_hour, + "friday_start_hour": a.friday_start_hour, + "friday_end_hour": a.friday_end_hour, + "saturday_start_hour": a.saturday_start_hour, + "saturday_end_hour": a.saturday_end_hour, + "content": a.content, + "embedding": a.embedding, + } for a in amenities ], ) # If the table already exists, drop it to avoid conflicts - await conn.execute("DROP TABLE IF EXISTS flights CASCADE") + await conn.execute(text("DROP TABLE IF EXISTS flights CASCADE")) # Create a new table await conn.execute( - """ - CREATE TABLE flights( - id INTEGER PRIMARY KEY, - airline TEXT, - flight_number TEXT, - departure_airport TEXT, - arrival_airport TEXT, - departure_time TIMESTAMP, - arrival_time TIMESTAMP, - departure_gate TEXT, - arrival_gate TEXT + text( + """ + CREATE TABLE flights( + id INTEGER PRIMARY KEY, + airline TEXT, + flight_number TEXT, + departure_airport TEXT, + arrival_airport TEXT, + departure_time TIMESTAMP, + arrival_time TIMESTAMP, + departure_gate TEXT, + arrival_gate TEXT + ) + """ ) - """ ) # Insert all the data - await conn.executemany( - """INSERT INTO flights VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)""", + await conn.execute( + text( + """ + INSERT INTO flights VALUES (:id, :airline, :flight_number, + :departure_airport, :arrival_airport, :departure_time, + :arrival_time, :departure_gate, :arrival_gate) + """ + ), [ - ( - f.id, - f.airline, - f.flight_number, - f.departure_airport, - f.arrival_airport, - f.departure_time, - f.arrival_time, - f.departure_gate, - f.arrival_gate, - ) + { + "id": f.id, + "airline": f.airline, + "flight_number": f.flight_number, + "departure_airport": f.departure_airport, + "arrival_airport": f.arrival_airport, + "departure_time": f.departure_time, + "arrival_time": f.arrival_time, + "departure_gate": f.departure_gate, + "arrival_gate": f.arrival_gate, + } for f in flights ], ) # If the table already exists, drop it to avoid conflicts - await conn.execute("DROP TABLE IF EXISTS tickets CASCADE") + await conn.execute(text("DROP TABLE IF EXISTS tickets CASCADE")) # Create a new table await conn.execute( - """ - CREATE TABLE tickets( - user_id TEXT, - user_name TEXT, - user_email TEXT, - airline TEXT, - flight_number TEXT, - departure_airport TEXT, - arrival_airport TEXT, - departure_time TIMESTAMP, - arrival_time TIMESTAMP + text( + """ + CREATE TABLE tickets( + user_id TEXT, + user_name TEXT, + user_email TEXT, + airline TEXT, + flight_number TEXT, + departure_airport TEXT, + arrival_airport TEXT, + departure_time TIMESTAMP, + arrival_time TIMESTAMP + ) + """ ) - """ ) # If the table already exists, drop it to avoid conflicts - await conn.execute("DROP TABLE IF EXISTS policies CASCADE") + await conn.execute(text("DROP TABLE IF EXISTS policies CASCADE")) # Create a new table await conn.execute( - """ - CREATE TABLE policies( - id INT PRIMARY KEY, - content TEXT NOT NULL, - embedding vector(768) NOT NULL + text( + """ + CREATE TABLE policies( + id INT PRIMARY KEY, + content TEXT NOT NULL, + embedding vector(768) NOT NULL + ) + """ ) - """ ) # Insert all the data - await conn.executemany( - """ - INSERT INTO policies VALUES ($1, $2, $3) - """, + await conn.execute( + text( + """ + INSERT INTO policies VALUES (:id, :content, :embedding) + """ + ), [ - ( - p.id, - p.content, - p.embedding, - ) + { + "id": p.id, + "content": p.content, + "embedding": p.embedding, + } for p in policies ], ) + await conn.commit() async def export_data( self, @@ -257,60 +295,61 @@ async def export_data( list[models.Flight], list[models.Policy], ]: - airport_task = asyncio.create_task( - self.__pool.fetch("""SELECT * FROM airports ORDER BY id ASC""") - ) - amenity_task = asyncio.create_task( - self.__pool.fetch("""SELECT * FROM amenities ORDER BY id ASC""") - ) - flight_task = asyncio.create_task( - self.__pool.fetch("""SELECT * FROM flights ORDER BY id ASC""") - ) - policy_task = asyncio.create_task( - self.__pool.fetch("""SELECT * FROM policies ORDER BY id ASC""") - ) + async with self.__async_engine.connect() as conn: + airport_task = asyncio.create_task( + conn.execute(text("""SELECT * FROM airports ORDER BY id ASC""")) + ) + amenity_task = asyncio.create_task( + conn.execute(text("""SELECT * FROM amenities ORDER BY id ASC""")) + ) + flights_task = asyncio.create_task( + conn.execute(text("""SELECT * FROM flights ORDER BY id ASC""")) + ) + policy_task = asyncio.create_task( + conn.execute(text("""SELECT * FROM policies ORDER BY id ASC""")) + ) + + airport_results = (await airport_task).mappings().fetchall() + amenity_results = (await amenity_task).mappings().fetchall() + flights_results = (await flights_task).mappings().fetchall() + policy_results = (await policy_task).mappings().fetchall() + + airports = [models.Airport.model_validate(a) for a in airport_results] + amenities = [models.Amenity.model_validate(a) for a in amenity_results] + flights = [models.Flight.model_validate(f) for f in flights_results] + policies = [models.Policy.model_validate(p) for p in policy_results] - airports = [models.Airport.model_validate(dict(a)) for a in await airport_task] - amenities = [models.Amenity.model_validate(dict(a)) for a in await amenity_task] - flights = [models.Flight.model_validate(dict(f)) for f in await flight_task] - policies = [models.Policy.model_validate(dict(p)) for p in await policy_task] - return airports, amenities, flights, policies + return airports, amenities, flights, policies async def get_airport_by_id( self, id: int ) -> tuple[Optional[models.Airport], Optional[str]]: - sql = """ - SELECT * FROM airports WHERE id=$1 - """ - params = (id,) - result = await self.__pool.fetchrow( - sql, - *params, - ) + async with self.__async_engine.connect() as conn: + sql = """SELECT * FROM airports WHERE id=:id""" + s = text(sql) + params = {"id": id} + result = (await conn.execute(s, params)).mappings().fetchone() if result is None: return None, None - result = models.Airport.model_validate(dict(result)) - return result, format_sql(sql, params) + res = models.Airport.model_validate(result) + return res, format_sql(sql, params) async def get_airport_by_iata( self, iata: str ) -> tuple[Optional[models.Airport], Optional[str]]: - sql = """ - SELECT * FROM airports WHERE iata ILIKE $1 - """ - params = (iata,) - result = await self.__pool.fetchrow( - sql, - *params, - ) + async with self.__async_engine.connect() as conn: + sql = """SELECT * FROM airports WHERE iata ILIKE :iata""" + s = text(sql) + params = {"iata": iata} + result = (await conn.execute(s, params)).mappings().fetchone() if result is None: return None, None - result = models.Airport.model_validate(dict(result)) - return result, format_sql(sql, params) + res = models.Airport.model_validate(result) + return res, format_sql(sql, params) async def search_airports( self, @@ -318,112 +357,104 @@ async def search_airports( city: Optional[str] = None, name: Optional[str] = None, ) -> tuple[list[models.Airport], Optional[str]]: - sql = """ - SELECT * FROM airports - WHERE ($1::TEXT IS NULL OR country ILIKE $1) - AND ($2::TEXT IS NULL OR city ILIKE $2) - AND ($3::TEXT IS NULL OR name ILIKE '%' || $3 || '%') - LIMIT 10 - """ - params = ( - country, - city, - name, - ) - results = await self.__pool.fetch( - sql, - *params, - timeout=10, - ) - - results = [models.Airport.model_validate(dict(r)) for r in results] - return results, format_sql(sql, params) + async with self.__async_engine.connect() as conn: + sql = """ + SELECT * FROM airports + WHERE (CAST(:country AS TEXT) IS NULL OR country ILIKE :country) + AND (CAST(:city AS TEXT) IS NULL OR city ILIKE :city) + AND (CAST(:name AS TEXT) IS NULL OR name ILIKE '%' || :name || '%') + LIMIT 10 + """ + s = text(sql) + params = { + "country": country, + "city": city, + "name": name, + } + results = (await conn.execute(s, params)).mappings().fetchall() + + res = [models.Airport.model_validate(r) for r in results] + return res, format_sql(sql, params) async def get_amenity( self, id: int ) -> tuple[Optional[models.Amenity], Optional[str]]: - sql = """ - SELECT id, name, description, location, terminal, category, hour - FROM amenities WHERE id=$1 - """ - params = (id,) - result = await self.__pool.fetchrow( - sql, - *params, - ) + async with self.__async_engine.connect() as conn: + sql = """ + SELECT id, name, description, location, terminal, category, hour + FROM amenities WHERE id=:id + """ + s = text(sql) + params = {"id": id} + result = (await conn.execute(s, params)).mappings().fetchone() if result is None: return None, None - result = models.Amenity.model_validate(dict(result)) - return result, format_sql(sql, params) + res = models.Amenity.model_validate(result) + return res, format_sql(sql, params) async def amenities_search( self, query_embedding: list[float], similarity_threshold: float, top_k: int ) -> tuple[list[Any], Optional[str]]: - sql = """ - SELECT name, description, location, terminal, category, hour - FROM amenities - WHERE (embedding <=> $1) < $2 - ORDER BY (embedding <=> $1) - LIMIT $3 - """ - params = ( - query_embedding, - similarity_threshold, - top_k, - ) - results = await self.__pool.fetch( - sql, - *params, - timeout=10, - ) - - results = [dict(r) for r in results] - return results, format_sql(sql, params) + async with self.__async_engine.connect() as conn: + sql = """ + SELECT name, description, location, terminal, category, hour + FROM amenities + WHERE (embedding <=> :query_embedding) < :similarity_threshold + ORDER BY (embedding <=> :query_embedding) + LIMIT :top_k + """ + s = text(sql) + params = { + "query_embedding": query_embedding, + "similarity_threshold": similarity_threshold, + "top_k": top_k, + } + results = (await conn.execute(s, params)).mappings().fetchall() + + res = [r for r in results] + return res, format_sql(sql, params) async def get_flight( self, flight_id: int ) -> tuple[Optional[models.Flight], Optional[str]]: - sql = """ + async with self.__async_engine.connect() as conn: + sql = """ SELECT * FROM flights - WHERE id = $1 - """ - params = (flight_id,) - result = await self.__pool.fetchrow( - sql, - *params, - timeout=10, - ) + WHERE id = :flight_id + """ + s = text(sql) + params = {"flight_id": flight_id} + result = (await conn.execute(s, params)).mappings().fetchone() if result is None: return None, None - result = models.Flight.model_validate(dict(result)) - return result, format_sql(sql, params) + res = models.Flight.model_validate(result) + return res, format_sql(sql, params) async def search_flights_by_number( self, airline: str, number: str, ) -> tuple[list[models.Flight], Optional[str]]: - sql = """ + async with self.__async_engine.connect() as conn: + sql = """ SELECT * FROM flights - WHERE airline = $1 - AND flight_number = $2 - LIMIT 10 - """ - params = ( - airline, - number, - ) - results = await self.__pool.fetch( - sql, - *params, - timeout=10, - ) - results = [models.Flight.model_validate(dict(r)) for r in results] - return results, format_sql(sql, params) + WHERE airline = :airline + AND flight_number = :number + LIMIT 10 + """ + s = text(sql) + params = { + "airline": airline, + "number": number, + } + results = (await conn.execute(s, params)).mappings().fetchall() + + res = [models.Flight.model_validate(r) for r in results] + return res, format_sql(sql, params) async def search_flights_by_airports( self, @@ -431,26 +462,26 @@ async def search_flights_by_airports( departure_airport: Optional[str] = None, arrival_airport: Optional[str] = None, ) -> tuple[list[models.Flight], Optional[str]]: - sql = """ + async with self.__async_engine.connect() as conn: + sql = """ SELECT * FROM flights - WHERE ($1::TEXT IS NULL OR departure_airport ILIKE $1) - AND ($2::TEXT IS NULL OR arrival_airport ILIKE $2) - AND departure_time >= $3::timestamp - AND departure_time < $3::timestamp + interval '1 day' - LIMIT 10 - """ - params = ( - departure_airport, - arrival_airport, - datetime.strptime(date, "%Y-%m-%d"), - ) - results = await self.__pool.fetch( - sql, - *params, - timeout=10, - ) - results = [models.Flight.model_validate(dict(r)) for r in results] - return results, format_sql(sql, params) + WHERE (CAST(:departure_airport AS TEXT) IS NULL OR departure_airport ILIKE :departure_airport) + AND (CAST(:arrival_airport AS TEXT) IS NULL OR arrival_airport ILIKE :arrival_airport) + AND departure_time >= CAST(:datetime AS timestamp) + AND departure_time < CAST(:datetime AS timestamp) + interval '1 day' + LIMIT 10 + """ + s = text(sql) + params = { + "departure_airport": departure_airport, + "arrival_airport": arrival_airport, + "datetime": datetime.strptime(date, "%Y-%m-%d"), + } + + results = (await conn.execute(s, params)).mappings().fetchall() + + res = [models.Flight.model_validate(r) for r in results] + return res, format_sql(sql, params) async def validate_ticket( self, @@ -460,29 +491,26 @@ async def validate_ticket( departure_time: str, ) -> tuple[Optional[models.Flight], Optional[str]]: departure_time_datetime = datetime.strptime(departure_time, "%Y-%m-%d %H:%M:%S") - sql = """ - SELECT * FROM flights - WHERE airline ILIKE $1 - AND flight_number ILIKE $2 - AND departure_airport ILIKE $3 - AND departure_time::date = $4::date - """ - params = ( - airline, - flight_number, - departure_airport, - departure_time_datetime, - ) - result = await self.__pool.fetchrow( - sql, - *params, - timeout=10, - ) + async with self.__async_engine.connect() as conn: + sql = """ + SELECT * FROM flights + WHERE airline ILIKE :airline + AND flight_number ILIKE :flight_number + AND departure_airport ILIKE :departure_airport + AND departure_time = :departure_time + """ + s = text(sql) + params = { + "airline": airline, + "flight_number": flight_number, + "departure_airport": departure_airport, + "departure_time": departure_time_datetime, + } + result = (await conn.execute(s, params)).mappings().fetchone() if result is None: return None, None - - res = models.Flight.model_validate(dict(result)) + res = models.Flight.model_validate(result) return res, format_sql(sql, params) async def insert_ticket( @@ -499,8 +527,10 @@ async def insert_ticket( ): departure_time_datetime = datetime.strptime(departure_time, "%Y-%m-%d %H:%M:%S") arrival_time_datetime = datetime.strptime(arrival_time, "%Y-%m-%d %H:%M:%S") - results = await self.__pool.execute( - """ + + async with self.__async_engine.connect() as conn: + s = text( + """ INSERT INTO tickets ( user_id, user_name, @@ -512,63 +542,73 @@ async def insert_ticket( departure_time, arrival_time ) VALUES ( - $1, $2, $3, $4, $5, $6, $7, $8, $9 + :user_id, + :user_name, + :user_email, + :airline, + :flight_number, + :departure_airport, + :arrival_airport, + :departure_time, + :arrival_time ); - """, - user_id, - user_name, - user_email, - airline, - flight_number, - departure_airport, - arrival_airport, - departure_time_datetime, - arrival_time_datetime, - timeout=10, - ) - if results != "INSERT 0 1": - raise Exception("Ticket Insertion failure") + """ + ) + params = { + "user_id": user_id, + "user_name": user_name, + "user_email": user_email, + "airline": airline, + "flight_number": flight_number, + "departure_airport": departure_airport, + "arrival_airport": arrival_airport, + "departure_time": departure_time_datetime, + "arrival_time": arrival_time_datetime, + } + result = (await conn.execute(s, params)).mappings() + await conn.commit() + if not result: + raise Exception("Ticket Insertion failure") async def list_tickets( self, user_id: str, ) -> tuple[list[Any], Optional[str]]: - sql = """ - SELECT user_name, airline, flight_number, departure_airport, arrival_airport, departure_time, arrival_time FROM tickets - WHERE user_id = $1 - """ - params = (user_id,) - results = await self.__pool.fetch( - sql, - *params, - timeout=10, - ) - results = [r for r in results] - return results, format_sql(sql, params) + async with self.__async_engine.connect() as conn: + sql = """ + SELECT user_name, airline, flight_number, departure_airport, arrival_airport, departure_time, arrival_time FROM tickets + WHERE user_id = :user_id + """ + s = text(sql) + params = { + "user_id": user_id, + } + results = (await conn.execute(s, params)).mappings().fetchall() + + res = [r for r in results] + return res, format_sql(sql, params) async def policies_search( self, query_embedding: list[float], similarity_threshold: float, top_k: int ) -> tuple[list[str], Optional[str]]: - sql = """ - SELECT content - FROM policies - WHERE (embedding <=> $1) < $2 - ORDER BY (embedding <=> $1) - LIMIT $3 - """ - params = ( - query_embedding, - similarity_threshold, - top_k, - ) - results = await self.__pool.fetch( - sql, - *params, - timeout=10, - ) - - results = [r["content"] for r in results] - return results, format_sql(sql, params) + async with self.__async_engine.connect() as conn: + sql = """ + SELECT content + FROM policies + WHERE (embedding <=> :query_embedding) < :similarity_threshold + ORDER BY (embedding <=> :query_embedding) + LIMIT :top_k + """ + s = text(sql) + params = { + "query_embedding": query_embedding, + "similarity_threshold": similarity_threshold, + "top_k": top_k, + } + results = (await conn.execute(s, params)).mappings().fetchall() + + res = [r["content"] for r in results] + return res, format_sql(sql, params) async def close(self): - await self.__pool.close() + await self.__async_engine.dispose() diff --git a/retrieval_service/datastore/providers/postgres_datastore.py b/retrieval_service/datastore/providers/postgres_datastore.py deleted file mode 100644 index f3707f48..00000000 --- a/retrieval_service/datastore/providers/postgres_datastore.py +++ /dev/null @@ -1,569 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -from datetime import datetime -from typing import Any, Optional - -from sqlalchemy import text -from sqlalchemy.ext.asyncio import AsyncEngine - -import models - - -class PostgresDatastore: - def __init__(self, pool: AsyncEngine): - self.__pool = pool - - async def initialize_data( - self, - airports: list[models.Airport], - amenities: list[models.Amenity], - flights: list[models.Flight], - policies: list[models.Policy], - ) -> None: - async with self.__pool.connect() as conn: - # If the table already exists, drop it to avoid conflicts - await conn.execute(text("DROP TABLE IF EXISTS airports CASCADE")) - # Create a new table - await conn.execute( - text( - """ - CREATE TABLE airports( - id INT PRIMARY KEY, - iata TEXT, - name TEXT, - city TEXT, - country TEXT - ) - """ - ) - ) - # Insert all the data - await conn.execute( - text( - """INSERT INTO airports VALUES (:id, :iata, :name, :city, :country)""" - ), - [ - { - "id": a.id, - "iata": a.iata, - "name": a.name, - "city": a.city, - "country": a.country, - } - for a in airports - ], - ) - - await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) - # If the table already exists, drop it to avoid conflicts - await conn.execute(text("DROP TABLE IF EXISTS amenities CASCADE")) - # Create a new table - await conn.execute( - text( - """ - CREATE TABLE amenities( - id INT PRIMARY KEY, - name TEXT, - description TEXT, - location TEXT, - terminal TEXT, - category TEXT, - hour TEXT, - sunday_start_hour TIME, - sunday_end_hour TIME, - monday_start_hour TIME, - monday_end_hour TIME, - tuesday_start_hour TIME, - tuesday_end_hour TIME, - wednesday_start_hour TIME, - wednesday_end_hour TIME, - thursday_start_hour TIME, - thursday_end_hour TIME, - friday_start_hour TIME, - friday_end_hour TIME, - saturday_start_hour TIME, - saturday_end_hour TIME, - content TEXT NOT NULL, - embedding vector(768) NOT NULL - ) - """ - ) - ) - # Insert all the data - await conn.execute( - text( - """ - INSERT INTO amenities VALUES (:id, :name, :description, :location, - :terminal, :category, :hour, :sunday_start_hour, :sunday_end_hour, - :monday_start_hour, :monday_end_hour, :tuesday_start_hour, - :tuesday_end_hour, :wednesday_start_hour, :wednesday_end_hour, - :thursday_start_hour, :thursday_end_hour, :friday_start_hour, - :friday_end_hour, :saturday_start_hour, :saturday_end_hour, :content, :embedding) - """ - ), - [ - { - "id": a.id, - "name": a.name, - "description": a.description, - "location": a.location, - "terminal": a.terminal, - "category": a.category, - "hour": a.hour, - "sunday_start_hour": a.sunday_start_hour, - "sunday_end_hour": a.sunday_end_hour, - "monday_start_hour": a.monday_start_hour, - "monday_end_hour": a.monday_end_hour, - "tuesday_start_hour": a.tuesday_start_hour, - "tuesday_end_hour": a.tuesday_end_hour, - "wednesday_start_hour": a.wednesday_start_hour, - "wednesday_end_hour": a.wednesday_end_hour, - "thursday_start_hour": a.thursday_start_hour, - "thursday_end_hour": a.thursday_end_hour, - "friday_start_hour": a.friday_start_hour, - "friday_end_hour": a.friday_end_hour, - "saturday_start_hour": a.saturday_start_hour, - "saturday_end_hour": a.saturday_end_hour, - "content": a.content, - "embedding": a.embedding, - } - for a in amenities - ], - ) - - # If the table already exists, drop it to avoid conflicts - await conn.execute(text("DROP TABLE IF EXISTS flights CASCADE")) - # Create a new table - await conn.execute( - text( - """ - CREATE TABLE flights( - id INTEGER PRIMARY KEY, - airline TEXT, - flight_number TEXT, - departure_airport TEXT, - arrival_airport TEXT, - departure_time TIMESTAMP, - arrival_time TIMESTAMP, - departure_gate TEXT, - arrival_gate TEXT - ) - """ - ) - ) - # Insert all the data - await conn.execute( - text( - """ - INSERT INTO flights VALUES (:id, :airline, :flight_number, - :departure_airport, :arrival_airport, :departure_time, - :arrival_time, :departure_gate, :arrival_gate) - """ - ), - [ - { - "id": f.id, - "airline": f.airline, - "flight_number": f.flight_number, - "departure_airport": f.departure_airport, - "arrival_airport": f.arrival_airport, - "departure_time": f.departure_time, - "arrival_time": f.arrival_time, - "departure_gate": f.departure_gate, - "arrival_gate": f.arrival_gate, - } - for f in flights - ], - ) - - # If the table already exists, drop it to avoid conflicts - await conn.execute(text("DROP TABLE IF EXISTS tickets CASCADE")) - # Create a new table - await conn.execute( - text( - """ - CREATE TABLE tickets( - user_id TEXT, - user_name TEXT, - user_email TEXT, - airline TEXT, - flight_number TEXT, - departure_airport TEXT, - arrival_airport TEXT, - departure_time TIMESTAMP, - arrival_time TIMESTAMP - ) - """ - ) - ) - - # If the table already exists, drop it to avoid conflicts - await conn.execute(text("DROP TABLE IF EXISTS policies CASCADE")) - # Create a new table - await conn.execute( - text( - """ - CREATE TABLE policies( - id INT PRIMARY KEY, - content TEXT NOT NULL, - embedding vector(768) NOT NULL - ) - """ - ) - ) - # Insert all the data - await conn.execute( - text( - """ - INSERT INTO policies VALUES (:id, :content, :embedding) - """ - ), - [ - { - "id": p.id, - "content": p.content, - "embedding": p.embedding, - } - for p in policies - ], - ) - await conn.commit() - - async def export_data( - self, - ) -> tuple[ - list[models.Airport], - list[models.Amenity], - list[models.Flight], - list[models.Policy], - ]: - async with self.__pool.connect() as conn: - airport_task = asyncio.create_task( - conn.execute(text("""SELECT * FROM airports ORDER BY id ASC""")) - ) - amenity_task = asyncio.create_task( - conn.execute(text("""SELECT * FROM amenities ORDER BY id ASC""")) - ) - flights_task = asyncio.create_task( - conn.execute(text("""SELECT * FROM flights ORDER BY id ASC""")) - ) - policy_task = asyncio.create_task( - conn.execute(text("""SELECT * FROM policies ORDER BY id ASC""")) - ) - - airport_results = (await airport_task).mappings().fetchall() - amenity_results = (await amenity_task).mappings().fetchall() - flights_results = (await flights_task).mappings().fetchall() - policy_results = (await policy_task).mappings().fetchall() - - airports = [models.Airport.model_validate(a) for a in airport_results] - amenities = [models.Amenity.model_validate(a) for a in amenity_results] - flights = [models.Flight.model_validate(f) for f in flights_results] - policies = [models.Policy.model_validate(p) for p in policy_results] - - return airports, amenities, flights, policies - - async def get_airport_by_id( - self, id: int - ) -> tuple[Optional[models.Airport], Optional[str]]: - async with self.__pool.connect() as conn: - sql = """SELECT * FROM airports WHERE id=:id""" - s = text(sql) - params = {"id": id} - result = (await conn.execute(s, params)).mappings().fetchone() - - if result is None: - return None, None - - res = models.Airport.model_validate(result) - return res, sql - - async def get_airport_by_iata( - self, iata: str - ) -> tuple[Optional[models.Airport], Optional[str]]: - async with self.__pool.connect() as conn: - sql = """SELECT * FROM airports WHERE iata ILIKE :iata""" - s = text(sql) - params = {"iata": iata} - result = (await conn.execute(s, params)).mappings().fetchone() - - if result is None: - return None, None - - res = models.Airport.model_validate(result) - return res, sql - - async def search_airports( - self, - country: Optional[str] = None, - city: Optional[str] = None, - name: Optional[str] = None, - ) -> tuple[list[models.Airport], Optional[str]]: - async with self.__pool.connect() as conn: - sql = """ - SELECT * FROM airports - WHERE (CAST(:country AS TEXT) IS NULL OR country ILIKE :country) - AND (CAST(:city AS TEXT) IS NULL OR city ILIKE :city) - AND (CAST(:name AS TEXT) IS NULL OR name ILIKE '%' || :name || '%') - LIMIT 10 - """ - s = text(sql) - params = { - "country": country, - "city": city, - "name": name, - } - results = (await conn.execute(s, params)).mappings().fetchall() - - res = [models.Airport.model_validate(r) for r in results] - return res, sql - - async def get_amenity( - self, id: int - ) -> tuple[Optional[models.Amenity], Optional[str]]: - async with self.__pool.connect() as conn: - sql = """ - SELECT id, name, description, location, terminal, category, hour - FROM amenities WHERE id=:id - """ - s = text(sql) - params = {"id": id} - result = (await conn.execute(s, params)).mappings().fetchone() - - if result is None: - return None, None - - res = models.Amenity.model_validate(result) - return res, sql - - async def amenities_search( - self, query_embedding: list[float], similarity_threshold: float, top_k: int - ) -> tuple[list[Any], Optional[str]]: - async with self.__pool.connect() as conn: - sql = """ - SELECT name, description, location, terminal, category, hour - FROM amenities - WHERE (embedding <=> :query_embedding) < :similarity_threshold - ORDER BY (embedding <=> :query_embedding) - LIMIT :top_k - """ - s = text(sql) - params = { - "query_embedding": query_embedding, - "similarity_threshold": similarity_threshold, - "top_k": top_k, - } - results = (await conn.execute(s, params)).mappings().fetchall() - - res = [r for r in results] - return res, sql - - async def get_flight( - self, flight_id: int - ) -> tuple[Optional[models.Flight], Optional[str]]: - async with self.__pool.connect() as conn: - sql = """ - SELECT * FROM flights - WHERE id = :flight_id - """ - s = text(sql) - params = {"flight_id": flight_id} - result = (await conn.execute(s, params)).mappings().fetchone() - - if result is None: - return None, None - - res = models.Flight.model_validate(result) - return res, sql - - async def search_flights_by_number( - self, - airline: str, - number: str, - ) -> tuple[list[models.Flight], Optional[str]]: - async with self.__pool.connect() as conn: - sql = """ - SELECT * FROM flights - WHERE airline = :airline - AND flight_number = :number - LIMIT 10 - """ - s = text(sql) - params = { - "airline": airline, - "number": number, - } - results = (await conn.execute(s, params)).mappings().fetchall() - - res = [models.Flight.model_validate(r) for r in results] - return res, sql - - async def search_flights_by_airports( - self, - date: str, - departure_airport: Optional[str] = None, - arrival_airport: Optional[str] = None, - ) -> tuple[list[models.Flight], Optional[str]]: - async with self.__pool.connect() as conn: - sql = """ - SELECT * FROM flights - WHERE (CAST(:departure_airport AS TEXT) IS NULL OR departure_airport ILIKE :departure_airport) - AND (CAST(:arrival_airport AS TEXT) IS NULL OR arrival_airport ILIKE :arrival_airport) - AND departure_time >= CAST(:datetime AS timestamp) - AND departure_time < CAST(:datetime AS timestamp) + interval '1 day' - LIMIT 10 - """ - s = text(sql) - params = { - "departure_airport": departure_airport, - "arrival_airport": arrival_airport, - "datetime": datetime.strptime(date, "%Y-%m-%d"), - } - - results = (await conn.execute(s, params)).mappings().fetchall() - - res = [models.Flight.model_validate(r) for r in results] - return res, sql - - async def validate_ticket( - self, - airline: str, - flight_number: str, - departure_airport: str, - departure_time: str, - ) -> tuple[Optional[models.Flight], Optional[str]]: - departure_time_datetime = datetime.strptime(departure_time, "%Y-%m-%d %H:%M:%S") - async with self.__pool.connect() as conn: - sql = """ - SELECT * FROM flights - WHERE airline ILIKE :airline - AND flight_number ILIKE :flight_number - AND departure_airport ILIKE :departure_airport - AND departure_time = :departure_time - """ - s = text(sql) - params = { - "airline": airline, - "flight_number": flight_number, - "departure_airport": departure_airport, - "departure_time": departure_time_datetime, - } - result = (await conn.execute(s, params)).mappings().fetchone() - - if result is None: - return None, None - res = models.Flight.model_validate(result) - return res, sql - - async def insert_ticket( - self, - user_id: str, - user_name: str, - user_email: str, - airline: str, - flight_number: str, - departure_airport: str, - arrival_airport: str, - departure_time: str, - arrival_time: str, - ): - departure_time_datetime = datetime.strptime(departure_time, "%Y-%m-%d %H:%M:%S") - arrival_time_datetime = datetime.strptime(arrival_time, "%Y-%m-%d %H:%M:%S") - - async with self.__pool.connect() as conn: - s = text( - """ - INSERT INTO tickets ( - user_id, - user_name, - user_email, - airline, - flight_number, - departure_airport, - arrival_airport, - departure_time, - arrival_time - ) VALUES ( - :user_id, - :user_name, - :user_email, - :airline, - :flight_number, - :departure_airport, - :arrival_airport, - :departure_time, - :arrival_time - ); - """ - ) - params = { - "user_id": user_id, - "user_name": user_name, - "user_email": user_email, - "airline": airline, - "flight_number": flight_number, - "departure_airport": departure_airport, - "arrival_airport": arrival_airport, - "departure_time": departure_time_datetime, - "arrival_time": arrival_time_datetime, - } - result = (await conn.execute(s, params)).mappings() - await conn.commit() - if not result: - raise Exception("Ticket Insertion failure") - - async def list_tickets( - self, - user_id: str, - ) -> tuple[list[Any], Optional[str]]: - async with self.__pool.connect() as conn: - sql = """ - SELECT user_name, airline, flight_number, departure_airport, arrival_airport, departure_time, arrival_time FROM tickets - WHERE user_id = :user_id - """ - s = text(sql) - params = { - "user_id": user_id, - } - results = (await conn.execute(s, params)).mappings().fetchall() - - res = [r for r in results] - return res, sql - - async def policies_search( - self, query_embedding: list[float], similarity_threshold: float, top_k: int - ) -> tuple[list[str], Optional[str]]: - async with self.__pool.connect() as conn: - sql = """ - SELECT content - FROM policies - WHERE (embedding <=> :query_embedding) < :similarity_threshold - ORDER BY (embedding <=> :query_embedding) - LIMIT :top_k - """ - s = text(sql) - params = { - "query_embedding": query_embedding, - "similarity_threshold": similarity_threshold, - "top_k": top_k, - } - results = (await conn.execute(s, params)).mappings().fetchall() - - res = [r["content"] for r in results] - return res, sql - - async def close(self): - await self.__pool.dispose()