From a9801f2915217e1b0cc6667be0432343577af699 Mon Sep 17 00:00:00 2001 From: "@Draichi" Date: Wed, 30 Oct 2024 22:10:31 -0300 Subject: [PATCH] feat: upgrade system prompt and add telemetry tool --- agent_prompt.txt | 106 ++++++++++++++++++++++++++ app.py | 81 +++++--------------- tools/__init__.py | 82 ++++++++++++++++++++ tools/telemetry_and_weather_query.sql | 25 ++++++ 4 files changed, 230 insertions(+), 64 deletions(-) create mode 100644 agent_prompt.txt create mode 100644 tools/__init__.py create mode 100644 tools/telemetry_and_weather_query.sql diff --git a/agent_prompt.txt b/agent_prompt.txt new file mode 100644 index 0000000..db21599 --- /dev/null +++ b/agent_prompt.txt @@ -0,0 +1,106 @@ +You are a specialized SQL Analysis Agent with deep knowledge of Formula 1 racing and access +to comprehensive data from the 2023 Bahrain Grand Prix qualifying session. Your purpose is to +assist users in analyzing and understanding qualifying performance data through SQL queries +and data analysis. + +## Available Data Tables + +1. SESSION_INFO + - Basic session information (start time, end time, track temperature, air temperature) + - Track conditions and weather data + - Session status changes + +2. DRIVER_LAPS + - Complete lap times for all drivers + - Lap types (Out lap, Flying lap, In lap) + - Lap validity status + - Sector times + - Tire compound used + +3. TELEMETRY_DATA + - Detailed car telemetry per lap + - Speed, throttle, brake, RPM, and gear data + - DRS usage + - Position data (X, Y coordinates) + +## Available Tools + +1. `get_session_weather(timestamp)` + - Returns weather conditions at specific timestamp + - Parameters: timestamp (ISO format) + - Returns: temperature, humidity, wind speed, track temperature + +2. `get_driver_laps(driver_id)` + - Returns all laps for specified driver + - Parameters: driver_id (string) + - Returns: lap times, sectors, tire info + +3. `get_lap_telemetry(driver_id, lap_number)` + - Returns detailed telemetry for specific lap + - Parameters: driver_id (string), lap_number (int) + - Returns: All telemetry data points for the lap + +4. `get_qualifying_results()` + - Returns final qualifying results + - Returns: Position, driver, best lap time, Q1/Q2/Q3 times + +## Your Capabilities + +1. Data Analysis: + - Analyze lap times and sector performance + - Compare driver performances + - Evaluate tire strategies + - Analyze telemetry data for technical insights + +2. Query Creation: + - When available tools are insufficient, create custom SQL queries + - Optimize queries for performance + - Join multiple data sources when needed + - Handle complex aggregations and calculations + +3. Result Presentation: + - Present data in clear, structured format + - Provide context and insights with results + - Highlight significant findings + - Explain technical terms when needed + +## Response Format + +For each query, provide: + +1. Understanding of the request +2. Analysis approach +3. Tool usage or SQL query +4. Results with explanation +5. Additional insights or recommendations + +## Guidelines + +1. Always verify data consistency before analysis +2. Consider track conditions when analyzing performance +3. Account for different qualifying sessions (Q1, Q2, Q3) +4. Note any anomalies or unusual patterns +5. Provide context for technical measurements +6. Consider tire compound impact on performance +7. Account for track evolution during session + +## Error Handling + +1. If data is missing or incomplete: + - Acknowledge the limitation + - Suggest alternative analysis approaches + - Explain impact on results + +2. If query is ambiguous: + - Ask for clarification + - Provide examples of possible interpretations + - Suggest refined query options + +## Example Queries + +1. "Show me Max Verstappen's fastest Q3 lap telemetry" +2. "Compare sector times between Hamilton and Russell" +3. "Which driver had the best middle sector in Q2?" +4. "How did track temperature affect lap times throughout qualifying?" + +Remember to maintain F1 technical accuracy while making insights accessible to users with varying levels of F1 knowledge. \ No newline at end of file diff --git a/app.py b/app.py index 7fec93e..bb3dfca 100644 --- a/app.py +++ b/app.py @@ -1,7 +1,6 @@ import os import gradio as gr from dotenv import load_dotenv -from langchain_openai import ChatOpenAI from langchain_community.utilities import SQLDatabase from langchain_community.agent_toolkits import SQLDatabaseToolkit from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage @@ -9,30 +8,18 @@ from langchain.schema import AIMessage from rich.console import Console from langchain_google_genai import ChatGoogleGenerativeAI -from langchain.agents.agent_toolkits import create_retriever_tool -from langchain_community.vectorstores import FAISS -from langchain_openai import OpenAIEmbeddings -import ast from gradio import ChatMessage -import re - -console = Console(style="chartreuse1 on grey7") - -os.environ['LANGCHAIN_PROJECT'] = 'gradio-test' - -# Load environment variables +import textwrap +from tools import GetTelemetry load_dotenv() +os.environ['LANGCHAIN_PROJECT'] = 'gradio-test' -# # Ensure required environment variables are set -# if not os.environ.get("OPENAI_API_KEY"): -# raise EnvironmentError( -# "OPENAI_API_KEY not found in environment variables.") +console = Console(style="chartreuse1 on grey7") -# Initialize database connection +# * Initialize database db = SQLDatabase.from_uri("sqlite:///db/Bahrain_2023_Q.db") -# Initialize LLM -# llm = ChatOpenAI(model="gpt-4-0125-preview") +# * Initialize LLM llm = ChatGoogleGenerativeAI( model="gemini-1.5-flash", temperature=0.7, @@ -41,55 +28,21 @@ max_retries=2, ) +# * Initialize tools toolkit = SQLDatabaseToolkit(db=db, llm=llm) tools = toolkit.get_tools() +get_telemetry_tool = GetTelemetry() +tools.append(get_telemetry_tool) -def query_as_list(db, query): - res = db.run(query) - res = [el for sub in ast.literal_eval(res) for el in sub if el] - res = [re.sub(r"\b\d+\b", "", string).strip() for string in res] - return list(set(res)) - - -drivers = query_as_list(db, "SELECT driver_name FROM Drivers") - -vector_db = FAISS.from_texts(drivers, OpenAIEmbeddings()) -retriever = vector_db.as_retriever(search_kwargs={"k": 5}) -description = """Use to look up values to filter on. Input is an approximate spelling of the proper noun, output is \ -valid proper nouns. Use the noun most similar to the search.""" +# * Initialize agent +agent_prompt = open("agent_prompt.txt", "r") +system_prompt = textwrap.dedent(agent_prompt.read()) +agent_prompt.close() +state_modifier = SystemMessage(content=system_prompt) +agent = create_react_agent(llm, tools, state_modifier=state_modifier) -retriever_tool = create_retriever_tool( - retriever, - name="search_proper_nouns", - description=description, -) -tools.append(retriever_tool) - -# Define system message -system = """You are an agent designed to interact with a SQL database. -Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer. -Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results. -You can order the results by a relevant column to return the most interesting examples in the database. -Never query for all the columns from a specific table, only ask for the relevant columns given the question. -You have access to tools for interacting with the database. -Only use the given tools. Only use the information returned by the tools to construct your final answer. -You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. - -DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. - -You have access to the following tables: {table_names} - -If you need to filter on a proper noun, you must ALWAYS first look up the filter value using the "search_proper_nouns" tool! -Do not try to guess at the proper name - use this function to find similar ones.""".format( - table_names=db.get_usable_table_names() -) - - -system_message = SystemMessage(content=system) - -# Create agent -agent = create_react_agent(llm, tools, state_modifier=system_message) +# * Interact with agent async def interact_with_agent(message, history): @@ -118,8 +71,8 @@ async def interact_with_agent(message, history): role="assistant", content=msg.content, metadata={"title": "💬 Assistant"})) yield history +# * Initialize Gradio theme = gr.themes.Ocean() - with gr.Blocks(theme=theme, fill_height=True) as demo: gr.Markdown("# Formula 1 Briefing Generator") chatbot = gr.Chatbot( diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000..a2e3690 --- /dev/null +++ b/tools/__init__.py @@ -0,0 +1,82 @@ +from pydantic import BaseModel, Field +from typing import Type +from langchain_core.tools import BaseTool +from langchain_community.utilities import SQLDatabase +from rich.console import Console + +console = Console(style="chartreuse1 on grey7") + +db = SQLDatabase.from_uri("sqlite:///db/Bahrain_2023_Q.db") + + +class GetTelemetryAndWeatherInput(BaseModel): + """Input for the get_telemetry_and_weather tool""" + driver_name: str = Field( + description="Name of the driver to analyze (e.g., 'VER', 'HAM', 'LEC', etc.)") + lap_number: int = Field(description="Lap number to analyze") + + +class GetTelemetryAndWeatherOutput(BaseModel): + """Output for the get_telemetry_and_weather tool""" + lap_id: int = Field(description="Lap ID") + lap_number: int = Field(description="Lap number") + lap_time_in_seconds: float | None = Field( + description="Lap time in seconds") + avg_speed: float = Field(description="Average speed in km/h") + max_speed: float = Field(description="Maximum speed in km/h") + avg_RPM: float = Field(description="Average RPM") + max_RPM: float = Field(description="Maximum RPM") + avg_throttle: float = Field(description="Average throttle") + brake_percentage: float = Field(description="Brake percentage") + drs_usage_percentage: float = Field(description="Drs usage percentage") + off_track_percentage: float = Field(description="Off track percentage") + avg_air_temp: float | None = Field( + description="Average air temperature in celsius") + avg_track_temp: float | None = Field( + description="Average track temperature in celsius") + avg_wind_speed: float | None = Field( + description="Average wind speed in meters per second") + + +class GetTelemetry(BaseTool): + name: str = "get_telemetry" + description: str = "useful for when you need to answer questions about telemetry for a given driver and lap" + args_schema: Type[BaseModel] = GetTelemetryAndWeatherInput + + def _run( + self, driver_name: str, lap_number: int + ) -> GetTelemetryAndWeatherOutput: + # """Use the tool.""" + sql_file = open("tools/telemetry_and_weather_query.sql", "r") + sql_query = sql_file.read() + sql_file.close() + console.print("getting telemetry") + response = db.run(sql_query, parameters={ + "driver_name": driver_name, + "lap_number": lap_number}) + + if not isinstance(response, str): + response = str(response) + + clean_response = response.strip('[]()').split(',') + # Convert to appropriate types and create dictionary + return GetTelemetryAndWeatherOutput( + lap_id=int(float(clean_response[0])), + lap_number=int(float(clean_response[1])), + lap_time_in_seconds=float( + clean_response[2]) if clean_response[2].strip() != 'None' else None, + avg_speed=float(clean_response[3]), + max_speed=float(clean_response[4]), + avg_RPM=float(clean_response[5]), + max_RPM=float(clean_response[6]), + avg_throttle=float(clean_response[7]), + brake_percentage=float(clean_response[8]), + drs_usage_percentage=float(clean_response[9]), + off_track_percentage=float(clean_response[10]), + avg_air_temp=float( + clean_response[11]) if clean_response[11].strip() != 'None' else None, + avg_track_temp=float( + clean_response[12]) if clean_response[12].strip() != 'None' else None, + avg_wind_speed=float( + clean_response[13]) if clean_response[13].strip() != 'None' else None + ) diff --git a/tools/telemetry_and_weather_query.sql b/tools/telemetry_and_weather_query.sql new file mode 100644 index 0000000..7e25059 --- /dev/null +++ b/tools/telemetry_and_weather_query.sql @@ -0,0 +1,25 @@ +SELECT + l.lap_id, + l.lap_number, + l.lap_time_in_seconds, + AVG(tel.speed_in_km) AS avg_speed, + MAX(tel.speed_in_km) AS max_speed, + AVG(tel.RPM) AS avg_RPM, + MAX(tel.RPM) AS max_RPM, + AVG(tel.throttle_input) AS avg_throttle, + SUM(CASE WHEN tel.is_brake_pressed THEN 1 ELSE 0 END) * 100.0 / COUNT(*) AS brake_percentage, + SUM(CASE WHEN tel.is_DRS_open THEN 1 ELSE 0 END) * 100.0 / COUNT(*) AS drs_usage_percentage, + SUM(CASE WHEN tel.is_off_track THEN 1 ELSE 0 END) * 100.0 / COUNT(*) AS off_track_percentage, + AVG(w.air_temperature_in_celsius) AS avg_air_temp, + AVG(w.track_temperature_in_celsius) AS avg_track_temp, + AVG(w.wind_speed_in_meters_per_seconds) AS avg_wind_speed +FROM Laps l +JOIN Sessions s ON l.session_id = s.session_id +JOIN Tracks t ON s.track_id = t.track_id +JOIN Event e ON s.event_id = e.event_id +JOIN Telemetry tel ON l.lap_id = tel.lap_id +LEFT JOIN Weather w ON s.session_id = w.session_id + AND tel.datetime BETWEEN w.datetime AND datetime(w.datetime, '+1 minutes') +WHERE l.driver_name = :driver_name + AND l.lap_number = :lap_number +GROUP BY l.lap_id; \ No newline at end of file