Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: consolidate Postgres providers #494

Merged
merged 15 commits into from
Sep 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions retrieval_service/datastore/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

import sqlparse


def format_sql(sql: str, params):
def format_sql(sql: str, params: dict):
"""
Format postgres sql to human readable text
Format Postgres SQL to human readable text by replacing placeholders.
Handles dict-based (:key) formats.
"""
for i in range(len(params)):
sql = sql.replace(f"${i+1}", f"{params[i]}")
# format the SQL
formatted_sql = (
sqlparse.format(
sql,
reindent=True,
keyword_case="upper",
use_space_around_operators=True,
strip_whitespace=True,
)
.replace("\n", "<br/>")
.replace(" ", '<div class="indent"></div>')
for key, value in params.items():
sql = sql.replace(f":{key}", f"{value}")
# format the SQL
formatted_sql = (
sqlparse.format(
sql,
reindent=True,
keyword_case="upper",
use_space_around_operators=True,
strip_whitespace=True,
)
.replace("\n", "<br/>")
.replace(" ", '<div class="indent"></div>')
)
return formatted_sql.replace("<br/>", "", 1)
48 changes: 23 additions & 25 deletions retrieval_service/datastore/providers/alloydb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import models

from .. import datastore
from .postgres_datastore import PostgresDatastore
from .postgres import Client as PostgresClient

ALLOYDB_PG_IDENTIFIER = "alloydb-postgres"

Expand All @@ -41,14 +41,14 @@ class Config(BaseModel, datastore.AbstractConfig):

class Client(datastore.Client[Config]):
__connector: Optional[AsyncConnector] = None
__pg_ds: PostgresDatastore
__pg_client: PostgresClient

@datastore.classproperty
def kind(cls):
return ALLOYDB_PG_IDENTIFIER

def __init__(self, pool: AsyncEngine):
self.__pg_ds = PostgresDatastore(pool)
def __init__(self, async_engine: AsyncEngine):
self.__pg_client = PostgresClient(async_engine)

@classmethod
async def create(cls, config: Config) -> "Client":
Expand All @@ -68,13 +68,13 @@ async def getconn() -> asyncpg.Connection:
await register_vector(conn)
return conn

pool = create_async_engine(
async_engine = create_async_engine(
"postgresql+asyncpg://",
async_creator=getconn,
)
if pool is None:
raise TypeError("pool not instantiated")
return cls(pool)
if async_engine is None:
raise TypeError("async_engine not instantiated")
return cls(async_engine)

async def initialize_data(
self,
Expand All @@ -83,9 +83,7 @@ async def initialize_data(
flights: list[models.Flight],
policies: list[models.Policy],
) -> None:
return await self.__pg_ds.initialize_data(
airports, amenities, flights, policies
)
await self.__pg_client.initialize_data(airports, amenities, flights, policies)

async def export_data(
self,
Expand All @@ -95,57 +93,57 @@ async def export_data(
list[models.Flight],
list[models.Policy],
]:
return await self.__pg_ds.export_data()
return await self.__pg_client.export_data()

async def get_airport_by_id(
self, id: int
) -> tuple[Optional[models.Airport], Optional[str]]:
return await self.__pg_ds.get_airport_by_id(id)
return await self.__pg_client.get_airport_by_id(id)

async def get_airport_by_iata(
self, iata: str
) -> tuple[Optional[models.Airport], Optional[str]]:
return await self.__pg_ds.get_airport_by_iata(iata)
return await self.__pg_client.get_airport_by_iata(iata)

async def search_airports(
self,
country: Optional[str] = None,
city: Optional[str] = None,
name: Optional[str] = None,
) -> tuple[list[models.Airport], Optional[str]]:
return await self.__pg_ds.search_airports(country, city, name)
return await self.__pg_client.search_airports(country, city, name)

async def get_amenity(
self, id: int
) -> tuple[Optional[models.Amenity], Optional[str]]:
return await self.__pg_ds.get_amenity(id)
return await self.__pg_client.get_amenity(id)

async def amenities_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> tuple[list[Any], Optional[str]]:
return await self.__pg_ds.amenities_search(
return await self.__pg_client.amenities_search(
query_embedding, similarity_threshold, top_k
)

async def get_flight(
self, flight_id: int
) -> tuple[Optional[models.Flight], Optional[str]]:
return await self.__pg_ds.get_flight(flight_id)
return await self.__pg_client.get_flight(flight_id)

async def search_flights_by_number(
self,
airline: str,
number: str,
) -> tuple[list[models.Flight], Optional[str]]:
return await self.__pg_ds.search_flights_by_number(airline, number)
return await self.__pg_client.search_flights_by_number(airline, number)

async def search_flights_by_airports(
self,
date: str,
departure_airport: Optional[str] = None,
arrival_airport: Optional[str] = None,
) -> tuple[list[models.Flight], Optional[str]]:
return await self.__pg_ds.search_flights_by_airports(
return await self.__pg_client.search_flights_by_airports(
date, departure_airport, arrival_airport
)

Expand All @@ -156,7 +154,7 @@ async def validate_ticket(
departure_airport: str,
departure_time: str,
) -> tuple[Optional[models.Flight], Optional[str]]:
return await self.__pg_ds.validate_ticket(
return await self.__pg_client.validate_ticket(
airline, flight_number, departure_airport, departure_time
)

Expand All @@ -172,7 +170,7 @@ async def insert_ticket(
departure_time: str,
arrival_time: str,
):
return await self.__pg_ds.insert_ticket(
await self.__pg_client.insert_ticket(
user_id,
user_name,
user_email,
Expand All @@ -188,14 +186,14 @@ async def list_tickets(
self,
user_id: str,
) -> tuple[list[Any], Optional[str]]:
return await self.__pg_ds.list_tickets(user_id)
return await self.__pg_client.list_tickets(user_id)

async def policies_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> tuple[list[str], Optional[str]]:
return await self.__pg_ds.policies_search(
return await self.__pg_client.policies_search(
query_embedding, similarity_threshold, top_k
)

async def close(self):
return await self.__pg_ds.close()
await self.__pg_client.close()
48 changes: 23 additions & 25 deletions retrieval_service/datastore/providers/cloudsql_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import models

from .. import datastore
from .postgres_datastore import PostgresDatastore
from .postgres import Client as PostgresClient

CLOUD_SQL_PG_IDENTIFIER = "cloudsql-postgres"

Expand All @@ -40,15 +40,15 @@ class Config(BaseModel, datastore.AbstractConfig):


class Client(datastore.Client[Config]):
__pg_ds: PostgresDatastore
__pg_client: PostgresClient
__connector: Optional[Connector] = None

@datastore.classproperty
def kind(cls):
return CLOUD_SQL_PG_IDENTIFIER

def __init__(self, pool: AsyncEngine):
self.__pg_ds = PostgresDatastore(pool)
def __init__(self, async_engine: AsyncEngine):
self.__pg_client = PostgresClient(async_engine)

@classmethod
async def create(cls, config: Config) -> "Client":
Expand All @@ -70,13 +70,13 @@ async def getconn() -> asyncpg.Connection:
await register_vector(conn)
return conn

pool = create_async_engine(
async_engine = create_async_engine(
"postgresql+asyncpg://",
async_creator=getconn,
)
if pool is None:
raise TypeError("pool not instantiated")
return cls(pool)
if async_engine is None:
raise TypeError("async_engine not instantiated")
return cls(async_engine)

async def initialize_data(
self,
Expand All @@ -85,9 +85,7 @@ async def initialize_data(
flights: list[models.Flight],
policies: list[models.Policy],
) -> None:
return await self.__pg_ds.initialize_data(
airports, amenities, flights, policies
)
await self.__pg_client.initialize_data(airports, amenities, flights, policies)

async def export_data(
self,
Expand All @@ -97,57 +95,57 @@ async def export_data(
list[models.Flight],
list[models.Policy],
]:
return await self.__pg_ds.export_data()
return await self.__pg_client.export_data()

async def get_airport_by_id(
self, id: int
) -> tuple[Optional[models.Airport], Optional[str]]:
return await self.__pg_ds.get_airport_by_id(id)
return await self.__pg_client.get_airport_by_id(id)

async def get_airport_by_iata(
self, iata: str
) -> tuple[Optional[models.Airport], Optional[str]]:
return await self.__pg_ds.get_airport_by_iata(iata)
return await self.__pg_client.get_airport_by_iata(iata)

async def search_airports(
self,
country: Optional[str] = None,
city: Optional[str] = None,
name: Optional[str] = None,
) -> tuple[list[models.Airport], Optional[str]]:
return await self.__pg_ds.search_airports(country, city, name)
return await self.__pg_client.search_airports(country, city, name)

async def get_amenity(
self, id: int
) -> tuple[Optional[models.Amenity], Optional[str]]:
return await self.__pg_ds.get_amenity(id)
return await self.__pg_client.get_amenity(id)

async def amenities_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> tuple[list[Any], Optional[str]]:
return await self.__pg_ds.amenities_search(
return await self.__pg_client.amenities_search(
query_embedding, similarity_threshold, top_k
)

async def get_flight(
self, flight_id: int
) -> tuple[Optional[models.Flight], Optional[str]]:
return await self.__pg_ds.get_flight(flight_id)
return await self.__pg_client.get_flight(flight_id)

async def search_flights_by_number(
self,
airline: str,
number: str,
) -> tuple[list[models.Flight], Optional[str]]:
return await self.__pg_ds.search_flights_by_number(airline, number)
return await self.__pg_client.search_flights_by_number(airline, number)

async def search_flights_by_airports(
self,
date: str,
departure_airport: Optional[str] = None,
arrival_airport: Optional[str] = None,
) -> tuple[list[models.Flight], Optional[str]]:
return await self.__pg_ds.search_flights_by_airports(
return await self.__pg_client.search_flights_by_airports(
date, departure_airport, arrival_airport
)

Expand All @@ -158,7 +156,7 @@ async def validate_ticket(
departure_airport: str,
departure_time: str,
) -> tuple[Optional[models.Flight], Optional[str]]:
return await self.__pg_ds.validate_ticket(
return await self.__pg_client.validate_ticket(
airline, flight_number, departure_airport, departure_time
)

Expand All @@ -174,7 +172,7 @@ async def insert_ticket(
departure_time: str,
arrival_time: str,
):
return await self.__pg_ds.insert_ticket(
await self.__pg_client.insert_ticket(
user_id,
user_name,
user_email,
Expand All @@ -190,14 +188,14 @@ async def list_tickets(
self,
user_id: str,
) -> tuple[list[Any], Optional[str]]:
return await self.__pg_ds.list_tickets(user_id)
return await self.__pg_client.list_tickets(user_id)

async def policies_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> tuple[list[str], Optional[str]]:
return await self.__pg_ds.policies_search(
return await self.__pg_client.policies_search(
query_embedding, similarity_threshold, top_k
)

async def close(self):
await self.__pg_ds.close()
await self.__pg_client.close()
Loading