Skip to content

Commit

Permalink
feat: add ticket validation and insertion to cloudsql postgres (#437)
Browse files Browse the repository at this point in the history
Also updated alloydb's sqlalchemy syntax for ticket insertion.
  • Loading branch information
Yuan325 authored Jul 17, 2024
1 parent 56e3c0b commit a4480fa
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 7 deletions.
9 changes: 5 additions & 4 deletions retrieval_service/datastore/providers/alloydb.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,11 +565,12 @@ async def insert_ticket(
"flight_number": flight_number,
"departure_airport": departure_airport,
"arrival_airport": arrival_airport,
"departure_time": departure_time,
"arrival_time": arrival_time,
"departure_time": departure_time_datetime,
"arrival_time": arrival_time_datetime,
}
results = (await conn.execute(s, params)).mappings().fetchall()
if results != "INSERT 0 1":
result = (await conn.execute(s, params)).mappings()
await conn.commit()
if not result:
raise Exception("Ticket Insertion failure")

async def list_tickets(
Expand Down
84 changes: 81 additions & 3 deletions retrieval_service/datastore/providers/cloudsql_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,29 @@ async def validate_ticket(
departure_airport: str,
departure_time: str,
) -> Optional[models.Flight]:
raise NotImplementedError("Not Implemented")
departure_time_datetime = datetime.strptime(departure_time, "%Y-%m-%d %H:%M:%S")
async with self.__pool.connect() as conn:
s = text(
"""
SELECT * FROM flights
WHERE airline ILIKE :airline
AND flight_number ILIKE :flight_number
AND departure_airport ILIKE :departure_airport
AND departure_time = :departure_time
"""
)
params = {
"airline": airline,
"flight_number": flight_number,
"departure_airport": departure_airport,
"departure_time": departure_time_datetime,
}
result = (await conn.execute(s, params)).mappings().fetchone()

if result is None:
return None
res = models.Flight.model_validate(result)
return res

async def insert_ticket(
self,
Expand All @@ -506,13 +528,69 @@ async def insert_ticket(
departure_time: str,
arrival_time: str,
):
raise NotImplementedError("Not Implemented")
departure_time_datetime = datetime.strptime(departure_time, "%Y-%m-%d %H:%M:%S")
arrival_time_datetime = datetime.strptime(arrival_time, "%Y-%m-%d %H:%M:%S")

async with self.__pool.connect() as conn:
s = text(
"""
INSERT INTO tickets (
user_id,
user_name,
user_email,
airline,
flight_number,
departure_airport,
arrival_airport,
departure_time,
arrival_time
) VALUES (
:user_id,
:user_name,
:user_email,
:airline,
:flight_number,
:departure_airport,
:arrival_airport,
:departure_time,
:arrival_time
);
"""
)
params = {
"user_id": user_id,
"user_name": user_name,
"user_email": user_email,
"airline": airline,
"flight_number": flight_number,
"departure_airport": departure_airport,
"arrival_airport": arrival_airport,
"departure_time": departure_time_datetime,
"arrival_time": arrival_time_datetime,
}
result = (await conn.execute(s, params)).mappings()
await conn.commit()
if not result:
raise Exception("Ticket Insertion failure")

async def list_tickets(
self,
user_id: str,
) -> list[models.Ticket]:
raise NotImplementedError("Not Implemented")
async with self.__pool.connect() as conn:
s = text(
"""
SELECT * FROM tickets
WHERE user_id = :user_id
"""
)
params = {
"user_id": user_id,
}
results = (await conn.execute(s, params)).mappings().fetchall()

res = [models.Ticket.model_validate(r) for r in results]
return res

async def policies_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
Expand Down
46 changes: 46 additions & 0 deletions retrieval_service/datastore/providers/cloudsql_postgres_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,52 @@ async def test_search_flights_by_airports(
assert res == expected


async def test_insert_ticket(ds: cloudsql_postgres.Client):
await ds.insert_ticket(
"1",
"test",
"test",
"UA",
"1532",
"SFO",
"DEN",
"2024-01-01 05:50:00",
"2024-01-01 09:23:00",
)


async def test_list_tickets(ds: cloudsql_postgres.Client):
res = await ds.list_tickets("1")
expected = models.Ticket(
user_id=1,
user_name="test",
user_email="test",
airline="UA",
flight_number="1532",
departure_airport="SFO",
arrival_airport="DEN",
departure_time=datetime.strptime("2024-01-01 05:50:00", "%Y-%m-%d %H:%M:%S"),
arrival_time=datetime.strptime("2024-01-01 09:23:00", "%Y-%m-%d %H:%M:%S"),
)
assert res == [expected]


async def test_validate_ticket(ds: cloudsql_postgres.Client):
res = await ds.validate_ticket("UA", "1532", "SFO", "2024-01-01 05:50:00")
expected = models.Flight(
id=0,
airline="UA",
flight_number="1532",
departure_airport="SFO",
arrival_airport="DEN",
departure_time=datetime.strptime("2024-01-01 05:50:00", "%Y-%m-%d %H:%M:%S"),
arrival_time=datetime.strptime("2024-01-01 09:23:00", "%Y-%m-%d %H:%M:%S"),
departure_gate="E49",
arrival_gate="D6",
)
assert res == expected


policies_search_test_data = [
pytest.param(
# "What is the fee for extra baggage?"
Expand Down

0 comments on commit a4480fa

Please sign in to comment.