Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,15 @@
import logging
import json

from text_2_sql_core.utils.database import DatabaseEngine


class DatabricksSqlConnector(SqlConnector):
def __init__(self):
super().__init__()

self.database_engine = DatabaseEngine.DATABRICKS

async def query_execution(
self,
sql_query: Annotated[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,15 @@
import logging
import json

from text_2_sql_core.utils.database import DatabaseEngine


class SnowflakeSqlConnector(SqlConnector):
def __init__(self):
super().__init__()

self.database_engine = DatabaseEngine.SNOWFLAKE

async def query_execution(
self,
sql_query: Annotated[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def __init__(self):

self.ai_search_connector = ConnectorFactory.get_ai_search_connector()

self.database_engine = None

def get_current_datetime(self) -> str:
"""Get the current datetime."""
return datetime.now().strftime("%d/%m/%Y, %H:%M:%S")
Expand Down Expand Up @@ -138,7 +140,11 @@ async def query_validation(
"""Validate the SQL query."""
try:
logging.info("Validating SQL Query: %s", sql_query)
sqlglot.transpile(sql_query)
sqlglot.transpile(
sql_query,
read=self.database_engine.value.lower(),
error_level=sqlglot.ErrorLevel.ERROR,
)
except sqlglot.errors.ParseError as e:
logging.error("SQL Query is invalid: %s", e.errors)
return e.errors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@
import logging
import json

from text_2_sql_core.utils.database import DatabaseEngine


class TSQLSqlConnector(SqlConnector):
def __init__(self):
super().__init__()

self.database_engine = DatabaseEngine.TSQL

async def query_execution(
self,
sql_query: Annotated[
Expand Down
Loading