Skip to content

Commit

Permalink
Merge pull request #61 from Capstone-Projects-2024-Spring/prompting-a…
Browse files Browse the repository at this point in the history
…idan

Prompting Testing
  • Loading branch information
alishahidd authored Apr 10, 2024
2 parents 8fb5355 + a76ab9d commit ae66900
Show file tree
Hide file tree
Showing 8 changed files with 204 additions and 21 deletions.
6 changes: 2 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
.env.production.local
__pycache__
.pytest_cache
yarn.lock

#API Keys/Tokens
.env

#virtual enviornment
venv
.env
2 changes: 2 additions & 0 deletions flask-backend/flask.log
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
2024-03-25 12:20:55,949 INFO: Flask application startup [in /Users/ayejay/Downloads/CapstoneProject/project-phillygpt/flask-backend/server.py:37]
2024-03-25 12:20:56,350 INFO: Flask application startup [in /Users/ayejay/Downloads/CapstoneProject/project-phillygpt/flask-backend/server.py:37]
86 changes: 86 additions & 0 deletions flask-backend/resources/database_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import os
from dotenv import load_dotenv
import mysql.connector
from mysql.connector import Error
from sshtunnel import SSHTunnelForwarder

load_dotenv()

def get_database_uri():
"""Constructs MySQL database URI and retrieves schema information"""
ssh_host = os.getenv('SSH_HOST')
ssh_port = int(os.getenv('SSH_PORT', 22)) # Default to port 22 if not specified
ssh_username = os.getenv('SSH_USERNAME')
ssh_private_key = os.getenv('SSH_PRIVATE_KEY')

mysql_host = os.getenv('MYSQL_HOST')
mysql_port = int(os.getenv('MYSQL_PORT', 3306)) # Default to port 3306 if not specified
mysql_username = os.getenv('MYSQL_USERNAME')
mysql_password = os.getenv('MYSQL_PASSWORD')
mysql_db = os.getenv('MYSQL_DB')

try:
# Establish SSH tunnel
with SSHTunnelForwarder(
(ssh_host, ssh_port),
ssh_username=ssh_username,
ssh_pkey=ssh_private_key,
remote_bind_address=(mysql_host, mysql_port)
) as tunnel:
# Connect to MySQL through the SSH tunnel
connection = mysql.connector.connect(
host='127.0.0.1',
port=tunnel.local_bind_port,
user=mysql_username,
password=mysql_password,
database=mysql_db
)
if connection.is_connected():
db_info = connection.get_server_info()
uri = f'mysql://{mysql_username}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_db}'
cursor = connection.cursor()

# Retrieve schema representation
schema_representation = get_schema_representation(cursor)

cursor.close()
connection.close()

return schema_representation

except Error as e:
print("Error connecting to MySQL database:", e)
return None, None

def get_schema_representation(cursor):
""" Get the database schema in a JSON-like format """
db_schema = {}

# Query to get all table names
cursor.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = DATABASE();")
tables = cursor.fetchall()

for table in tables:
table_name = table[0]

# Query to get column details for each table
cursor.execute(f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';")
columns = cursor.fetchall()

column_details = {}
for column in columns:
column_name = column[0]
column_type = column[1]
column_details[column_name] = column_type

db_schema[table_name] = column_details

return db_schema

if __name__ == "__main__":
database_uri, schema_representation = get_database_uri()
if database_uri:
print("Database URI:", database_uri)
print("Schema Representation:", schema_representation)
else:
print("Failed to get the database URI.")
2 changes: 2 additions & 0 deletions flask-backend/resources/output.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Database URI: mysql://admin:vuwva8-vZgabt-zip5m@phillygptdb1.czgma8k2kt6g.us-east-2.rds.amazonaws.com:3306/phillygpt
Schema Representation: {'Bike_Network': {'OBJECTID': 'int', 'SEG_ID': 'int', 'STREETNAME': 'varchar', 'ST_CODE': 'int', 'ONEWAY': 'varchar', 'CLASS': 'int', 'TYPE': 'varchar', 'Shape__Length': 'float'}, 'City_Landmarks': {'OBJECTID': 'int', 'NAME': 'varchar', 'ADDRESS': 'varchar', 'FEAT_TYPE': 'varchar', 'SUB_TYPE': 'varchar', 'VANITY_NAME': 'varchar', 'SECONDARY_NAME': 'varchar', 'BLDG': 'varchar', 'PARENT_NAME': 'varchar', 'PARENT_TYPE': 'varchar', 'ACREAGE': 'float', 'PARENT_ACREAGE': 'float', 'Shape__Area': 'float', 'Shape__Length': 'float'}, 'citywide_arrests': {'offense_category': 'varchar', 'day': 'datetime', 'defendant_race': 'varchar', 'count': 'int', 'objectid': 'bigint'}, 'covid_vaccine_totals': {'partially_vaccinated': 'int', 'fully_vaccinated': 'int', 'boosted': 'int'}, 'covid_vaccines_by_age': {'age': 'varchar', 'partially_vaccinated': 'int', 'fully_vaccinated': 'int', 'boosted': 'int'}, 'covid_vaccines_by_race': {'racial_identity': 'varchar', 'partially_vaccinated': 'int', 'fully_vaccinated': 'int', 'boosted': 'int'}, 'covid_vaccines_by_sex': {'SEX': 'char', 'Partially_Vaccinated': 'int', 'Fully_Vaccinated': 'int', 'Boosted': 'int'}, 'covid_vaccines_by_zip': {'zip_code': 'int', 'partially_vaccinated': 'int', 'fully_vaccinated': 'int', 'boosted': 'int'}, 'farmers_markets_location': {'objectid': 'int', 'X': 'float', 'Y': 'float', 'name': 'varchar', 'address': 'varchar', 'zip': 'int', 'hours_mon_start': 'time', 'hours_mon_end': 'time', 'hours_tues_start': 'time', 'hours_tues_end': 'time', 'hours_wed_start': 'time', 'hours_wed_end': 'time', 'hours_thurs_start': 'time', 'hours_thurs_end': 'time', 'hours_fri_start': 'time', 'hours_fri_end': 'time', 'hours_sat_start': 'time', 'hours_sat_end': 'time', 'hours_sat_exceptions': 'varchar', 'hours_sun_start': 'time', 'hours_sun_end': 'time', 'season_opening_month': 'varchar', 'season_opening_day': 'int', 'season_closing_month': 'varchar', 'season_closing_day': 'int'}, 'universities_colleges': {'OBJECTID': 'int', 'NAME': 'varchar', 'ADDRESS': 'varchar', 'BUILDING_DESCRIPTION': 'varchar', 'PARCEL_ID': 'varchar', 'BRT_ID': 'varchar', 'TENCODE_ID': 'varchar', 'GROSS_AREA': 'float', 'Shape_Area': 'float', 'Shape_Length': 'float'}}
89 changes: 74 additions & 15 deletions flask-backend/resources/process_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,55 +2,114 @@
from flask_restful import Resource
from openai import OpenAI
import os
import json
from dotenv import load_dotenv
from resources.database_connection import get_database_uri
from resources.prompts import SYSTEM_MESSAGE
from resources.validate_sql_query import validate_sql_query

#load .env file

# Load .env file
load_dotenv()

schemas = get_database_uri()

client = OpenAI(
api_key=os.environ.get("OPENAI_API_KEY"),
)

class ProcessInput(Resource):
"""
ProcessInput(Resource)
This resouce handles user input and generates a SQL query using OpenAI
This resource handles user input and generates a SQL query using OpenAI
"""

def post(self):
"""
Handles POST HTTP requests to the '/process_input/ endpoint.
Returns:
JSON of user input and OpenAI response.
"""

#get user input
# Get user input
data = request.get_json()
user_input = data.get("user_input")
response = self.openai_request(user_input)

return jsonify({"USER_INPUT" : user_input,
"OPENAI_RESPONSE" : response})


if response:
return jsonify({"status": "success", "USER_INPUT": user_input, "OPENAI_RESPONSE": response})
else:
return jsonify({"status": "error", "message": "An error occurred during processing."})

def openai_request(self, user_input):
"""
Sends a request to OpenAI to generate a SQL Query based on the user input.
Args:
user_input (string): User's prompt for generating SQL query.
user_input (string): User's prompt for generating SQL query.
Returns: Generated SQL query as a String.
"""
if not user_input:
print(user_input)
return "User input is empty."
try:
response = client.chat.completions.create(
# Determine the proper table based on the schema and user input
table_name = self.determine_proper_table(user_input)
#should only run this code if table name is properly found
if table_name is None:
#maybe reprompt here??
print("Table not found")
else:
formatted_system_message = SYSTEM_MESSAGE.format(schema=schemas[table_name])

response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": formatted_system_message},
{"role": "user", "content": user_input},
],
temperature=1
)
generated_query = response.choices[0].message.content

if not validate_sql_query(self, generated_query):
print("Potential SQL injection detected!")
return None # Do not return the query

return generated_query
except Exception as e:
return f"Error from OpenAI: {e}"

def determine_proper_table(self, user_input):
"""
Determines the proper table that would best answer the user input based on the schema.
Args:
user_input (string): User's prompt for generating SQL query.
Returns:
The name of the table that best answers the user input.
"""
# Create a prompt including the database schema
prompt = f"Determine which table has information that can best answer the user's question. You must always output your answer in JSON format with the following key-value pairs:- table: the table you found based on user_input - error: an error message if the query is invalid, or null if the query is valid. Available tables are: "
for table_name in schemas:
prompt += f"{table_name}, "
prompt = prompt[:-2] # Remove the trailing comma and space

# Call OpenAI API to determine the proper table
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "Generate a SQL Query for a database based on the user input. Only give the SQL Statement and nothing more."}, #THIS WILL NEED TO BE MODIFIED.
{"role": "system", "content": prompt},
{"role": "user", "content": user_input},
],
temperature=1
)
return response.choices[0].message.content
except Exception as e:
return f"Error from OpenAI: {e}"
)

json_response = json.loads(response.choices[0].message.content)

error_message = json_response.get("error")

if error_message is None:
table_name = json_response.get("table")
return table_name
else:
# If there's an error, print the error message
print(f"Error determining proper table: {error_message}")
return None
8 changes: 8 additions & 0 deletions flask-backend/resources/prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
SYSTEM_MESSAGE = """You are an AI assistant that is able to convert natural language into a properly formatted SQL query.
The database you will be querying is called "City_Landmarks". Here is the schema of the table:
{schema}
You must always output your answer in JSON format with the following key-value pairs:
- "query": the SQL query that you generated
- "error": an error message if the query is invalid, or null if the query is valid"""
18 changes: 18 additions & 0 deletions flask-backend/resources/validate_sql_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import re

def validate_sql_query(self, sql_query):
"""
Validates the structure of the generated SQL query to mitigate injection vulnerabilities.
Args:
sql_query (string): The generated SQL query to validate.
Returns:
True if the query passes validation, False otherwise.
"""
# Regular expression pattern to match common SQL injection keywords
injection_pattern = re.compile(r'\b(DELETE|DROP|TRUNCATE|UPDATE|INSERT|ALTER|CREATE)\b', re.IGNORECASE)

# Check if the query contains any SQL injection keywords
if injection_pattern.search(sql_query):
return False # Injection vulnerability detected
else:
return True # Query is safe
14 changes: 12 additions & 2 deletions phillygpt/src/components/searchbar.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ const SearchBar = () => {
setUserInput(event.target.value);
};

const setStatus = (message) => {
console.log('Status:', message);
};

const handleButtonClick = () => {
console.log(userInput);
if (userInput.trim() === 'ERROR') { //ENGINEER THE PROMPT TO RESPOND ONLY WITH 'ERROR' WHEN DATA IS NOT FOUND IN THE ANY TABLES.
Expand All @@ -36,12 +40,18 @@ const SearchBar = () => {
axios.post('http://127.0.0.1:5000/process_input', {user_input : userInput})
.then(response => {
console.log(response.data);
setResponseDataSQL(response.data.OPENAI_RESPONSE);
navigate(`/response?input=${encodeURIComponent(userInput)}`);
if (response.data.status === "success") {
setResponseDataSQL(response.data.OPENAI_RESPONSE);
navigate(`/response?input=${encodeURIComponent(userInput)}`);
} else {
setStatus('An error occurred during processing.');
navigate('/reprompt')
}
setLoading(false);
})
.catch(error => {
console.error('Error: ', error);
setStatus('An error occurred during processing.');
setLoading(false);
})
.finally(() => {
Expand Down

0 comments on commit ae66900

Please sign in to comment.