-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #61 from Capstone-Projects-2024-Spring/prompting-a…
…idan Prompting Testing
- Loading branch information
Showing
8 changed files
with
204 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,9 +9,7 @@ | |
.env.production.local | ||
__pycache__ | ||
.pytest_cache | ||
yarn.lock | ||
|
||
#API Keys/Tokens | ||
.env | ||
|
||
#virtual enviornment | ||
venv | ||
.env |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
2024-03-25 12:20:55,949 INFO: Flask application startup [in /Users/ayejay/Downloads/CapstoneProject/project-phillygpt/flask-backend/server.py:37] | ||
2024-03-25 12:20:56,350 INFO: Flask application startup [in /Users/ayejay/Downloads/CapstoneProject/project-phillygpt/flask-backend/server.py:37] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import os | ||
from dotenv import load_dotenv | ||
import mysql.connector | ||
from mysql.connector import Error | ||
from sshtunnel import SSHTunnelForwarder | ||
|
||
load_dotenv() | ||
|
||
def get_database_uri(): | ||
"""Constructs MySQL database URI and retrieves schema information""" | ||
ssh_host = os.getenv('SSH_HOST') | ||
ssh_port = int(os.getenv('SSH_PORT', 22)) # Default to port 22 if not specified | ||
ssh_username = os.getenv('SSH_USERNAME') | ||
ssh_private_key = os.getenv('SSH_PRIVATE_KEY') | ||
|
||
mysql_host = os.getenv('MYSQL_HOST') | ||
mysql_port = int(os.getenv('MYSQL_PORT', 3306)) # Default to port 3306 if not specified | ||
mysql_username = os.getenv('MYSQL_USERNAME') | ||
mysql_password = os.getenv('MYSQL_PASSWORD') | ||
mysql_db = os.getenv('MYSQL_DB') | ||
|
||
try: | ||
# Establish SSH tunnel | ||
with SSHTunnelForwarder( | ||
(ssh_host, ssh_port), | ||
ssh_username=ssh_username, | ||
ssh_pkey=ssh_private_key, | ||
remote_bind_address=(mysql_host, mysql_port) | ||
) as tunnel: | ||
# Connect to MySQL through the SSH tunnel | ||
connection = mysql.connector.connect( | ||
host='127.0.0.1', | ||
port=tunnel.local_bind_port, | ||
user=mysql_username, | ||
password=mysql_password, | ||
database=mysql_db | ||
) | ||
if connection.is_connected(): | ||
db_info = connection.get_server_info() | ||
uri = f'mysql://{mysql_username}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_db}' | ||
cursor = connection.cursor() | ||
|
||
# Retrieve schema representation | ||
schema_representation = get_schema_representation(cursor) | ||
|
||
cursor.close() | ||
connection.close() | ||
|
||
return schema_representation | ||
|
||
except Error as e: | ||
print("Error connecting to MySQL database:", e) | ||
return None, None | ||
|
||
def get_schema_representation(cursor): | ||
""" Get the database schema in a JSON-like format """ | ||
db_schema = {} | ||
|
||
# Query to get all table names | ||
cursor.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = DATABASE();") | ||
tables = cursor.fetchall() | ||
|
||
for table in tables: | ||
table_name = table[0] | ||
|
||
# Query to get column details for each table | ||
cursor.execute(f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';") | ||
columns = cursor.fetchall() | ||
|
||
column_details = {} | ||
for column in columns: | ||
column_name = column[0] | ||
column_type = column[1] | ||
column_details[column_name] = column_type | ||
|
||
db_schema[table_name] = column_details | ||
|
||
return db_schema | ||
|
||
if __name__ == "__main__": | ||
database_uri, schema_representation = get_database_uri() | ||
if database_uri: | ||
print("Database URI:", database_uri) | ||
print("Schema Representation:", schema_representation) | ||
else: | ||
print("Failed to get the database URI.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
Database URI: mysql://admin:vuwva8-vZgabt-zip5m@phillygptdb1.czgma8k2kt6g.us-east-2.rds.amazonaws.com:3306/phillygpt | ||
Schema Representation: {'Bike_Network': {'OBJECTID': 'int', 'SEG_ID': 'int', 'STREETNAME': 'varchar', 'ST_CODE': 'int', 'ONEWAY': 'varchar', 'CLASS': 'int', 'TYPE': 'varchar', 'Shape__Length': 'float'}, 'City_Landmarks': {'OBJECTID': 'int', 'NAME': 'varchar', 'ADDRESS': 'varchar', 'FEAT_TYPE': 'varchar', 'SUB_TYPE': 'varchar', 'VANITY_NAME': 'varchar', 'SECONDARY_NAME': 'varchar', 'BLDG': 'varchar', 'PARENT_NAME': 'varchar', 'PARENT_TYPE': 'varchar', 'ACREAGE': 'float', 'PARENT_ACREAGE': 'float', 'Shape__Area': 'float', 'Shape__Length': 'float'}, 'citywide_arrests': {'offense_category': 'varchar', 'day': 'datetime', 'defendant_race': 'varchar', 'count': 'int', 'objectid': 'bigint'}, 'covid_vaccine_totals': {'partially_vaccinated': 'int', 'fully_vaccinated': 'int', 'boosted': 'int'}, 'covid_vaccines_by_age': {'age': 'varchar', 'partially_vaccinated': 'int', 'fully_vaccinated': 'int', 'boosted': 'int'}, 'covid_vaccines_by_race': {'racial_identity': 'varchar', 'partially_vaccinated': 'int', 'fully_vaccinated': 'int', 'boosted': 'int'}, 'covid_vaccines_by_sex': {'SEX': 'char', 'Partially_Vaccinated': 'int', 'Fully_Vaccinated': 'int', 'Boosted': 'int'}, 'covid_vaccines_by_zip': {'zip_code': 'int', 'partially_vaccinated': 'int', 'fully_vaccinated': 'int', 'boosted': 'int'}, 'farmers_markets_location': {'objectid': 'int', 'X': 'float', 'Y': 'float', 'name': 'varchar', 'address': 'varchar', 'zip': 'int', 'hours_mon_start': 'time', 'hours_mon_end': 'time', 'hours_tues_start': 'time', 'hours_tues_end': 'time', 'hours_wed_start': 'time', 'hours_wed_end': 'time', 'hours_thurs_start': 'time', 'hours_thurs_end': 'time', 'hours_fri_start': 'time', 'hours_fri_end': 'time', 'hours_sat_start': 'time', 'hours_sat_end': 'time', 'hours_sat_exceptions': 'varchar', 'hours_sun_start': 'time', 'hours_sun_end': 'time', 'season_opening_month': 'varchar', 'season_opening_day': 'int', 'season_closing_month': 'varchar', 'season_closing_day': 'int'}, 'universities_colleges': {'OBJECTID': 'int', 'NAME': 'varchar', 'ADDRESS': 'varchar', 'BUILDING_DESCRIPTION': 'varchar', 'PARCEL_ID': 'varchar', 'BRT_ID': 'varchar', 'TENCODE_ID': 'varchar', 'GROSS_AREA': 'float', 'Shape_Area': 'float', 'Shape_Length': 'float'}} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
SYSTEM_MESSAGE = """You are an AI assistant that is able to convert natural language into a properly formatted SQL query. | ||
The database you will be querying is called "City_Landmarks". Here is the schema of the table: | ||
{schema} | ||
You must always output your answer in JSON format with the following key-value pairs: | ||
- "query": the SQL query that you generated | ||
- "error": an error message if the query is invalid, or null if the query is valid""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import re | ||
|
||
def validate_sql_query(self, sql_query): | ||
""" | ||
Validates the structure of the generated SQL query to mitigate injection vulnerabilities. | ||
Args: | ||
sql_query (string): The generated SQL query to validate. | ||
Returns: | ||
True if the query passes validation, False otherwise. | ||
""" | ||
# Regular expression pattern to match common SQL injection keywords | ||
injection_pattern = re.compile(r'\b(DELETE|DROP|TRUNCATE|UPDATE|INSERT|ALTER|CREATE)\b', re.IGNORECASE) | ||
|
||
# Check if the query contains any SQL injection keywords | ||
if injection_pattern.search(sql_query): | ||
return False # Injection vulnerability detected | ||
else: | ||
return True # Query is safe |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters