Skip to content

Commit

Permalink
chore: update lint errors (#453)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuan325 authored Jul 31, 2024
1 parent 8cefed0 commit c82a5b5
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 9 deletions.
11 changes: 9 additions & 2 deletions llm_demo/orchestrator/langgraph/react_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@
from langgraph.managed import IsLastStep

from .tool_node import ToolNode
from .tools import get_confirmation_needing_tools, insert_ticket, validate_ticket
from .tools import (
TicketInfo,
get_confirmation_needing_tools,
insert_ticket,
validate_ticket,
)


class UserState(TypedDict):
Expand Down Expand Up @@ -192,7 +197,9 @@ async def insert_ticket_node(state: UserState, config: RunnableConfig):
# Run insert ticket
if hasattr(last_message, "tool_calls") and len(last_message.tool_calls) > 0:
tool_call = last_message.tool_calls[0]
output = await insert_ticket(client, tool_call.get("args"), user_id_token)
tool_args = tool_call.get("args")
ticket_info = TicketInfo(**tool_args)
output = await insert_ticket(client, ticket_info, user_id_token)
tool_call_id = tool_call.get("id")
tool_message = ToolMessage(
content=output, name="Insert Ticket", tool_call_id=tool_call_id
Expand Down
25 changes: 18 additions & 7 deletions llm_demo/orchestrator/langgraph/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import json
import os
from dataclasses import dataclass
from datetime import date, datetime
from typing import Any, Dict, Optional

Expand Down Expand Up @@ -206,18 +207,28 @@ async def insert_ticket(
return insert_ticket


@dataclass
class TicketInfo:
airline: str
flight_number: str
departure_airport: str
departure_time: str
arrival_airport: str
arrival_time: str


async def insert_ticket(
client: aiohttp.ClientSession, params: dict[str, str], user_id_token: str
client: aiohttp.ClientSession, ticket_info: TicketInfo, user_id_token: str
):
response = await client.post(
url=f"{BASE_URL}/tickets/insert",
params={
"airline": params.get("airline"),
"flight_number": params.get("flight_number"),
"departure_airport": params.get("departure_airport"),
"arrival_airport": params.get("arrival_airport"),
"departure_time": params.get("departure_time"),
"arrival_time": params.get("arrival_time"),
"airline": ticket_info.airline,
"flight_number": ticket_info.flight_number,
"departure_airport": ticket_info.departure_airport,
"arrival_airport": ticket_info.arrival_airport,
"departure_time": ticket_info.departure_time,
"arrival_time": ticket_info.arrival_time,
},
headers=get_headers(client, user_id_token),
)
Expand Down

0 comments on commit c82a5b5

Please sign in to comment.