Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions deploy_ai_search/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,14 @@ dev = [

[tool.uv.sources]
text_2_sql_core = { workspace = true }

[project.optional-dependencies]
snowflake = [
"text_2_sql_core[snowflake]",
]
databricks = [
"text_2_sql_core[databricks]",
]
postgresql = [
"text_2_sql_core[postgresql]",
]
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)
import os
from text_2_sql_core.utils.database import DatabaseEngine
from text_2_sql_core.connectors.factory import ConnectorFactory


class Text2SqlSchemaStoreAISearch(AISearch):
Expand All @@ -49,29 +50,13 @@ def __init__(
os.environ["Text2Sql__DatabaseEngine"].upper()
]

self.database_connector = ConnectorFactory.get_database_connector()

if single_data_dictionary_file:
self.parsing_mode = BlobIndexerParsingMode.JSON_ARRAY
else:
self.parsing_mode = BlobIndexerParsingMode.JSON

@property
def excluded_fields_for_database_engine(self):
"""A method to get the excluded fields for the database engine."""

all_engine_specific_fields = ["Warehouse", "Database", "Catalog"]
if self.database_engine == DatabaseEngine.SNOWFLAKE:
engine_specific_fields = ["Warehouse", "Database"]
elif self.database_engine == DatabaseEngine.TSQL:
engine_specific_fields = ["Database"]
elif self.database_engine == DatabaseEngine.DATABRICKS:
engine_specific_fields = ["Catalog"]

return [
field
for field in all_engine_specific_fields
if field not in engine_specific_fields
]

def get_index_fields(self) -> list[SearchableField]:
"""This function returns the index fields for sql index.

Expand Down Expand Up @@ -196,7 +181,7 @@ def get_index_fields(self) -> list[SearchableField]:
fields = [
field
for field in fields
if field.name not in self.excluded_fields_for_database_engine
if field.name not in self.database_connector.excluded_engine_specific_fields
]

return fields
Expand Down
13 changes: 12 additions & 1 deletion text_2_sql/autogen/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies = [
"autogen-ext[azure,openai]==0.4.0.dev11",
"grpcio>=1.68.1",
"pyyaml>=6.0.2",
"text_2_sql_core[snowflake,databricks]",
"text_2_sql_core",
]

[dependency-groups]
Expand All @@ -28,3 +28,14 @@ dev = [

[tool.uv.sources]
text_2_sql_core = { workspace = true }

[project.optional-dependencies]
snowflake = [
"text_2_sql_core[snowflake]",
]
databricks = [
"text_2_sql_core[databricks]",
]
postgresql = [
"text_2_sql_core[postgresql]",
]
4 changes: 4 additions & 0 deletions text_2_sql/text_2_sql_core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ databricks = [
"databricks-sql-connector>=3.0.1",
"pyarrow>=14.0.2,<17",
]
postgresql = [
"psycopg>=3.2.3",
]


[build-system]
requires = ["hatchling"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging
import json

from text_2_sql_core.utils.database import DatabaseEngine
from text_2_sql_core.utils.database import DatabaseEngine, DatabaseEngineSpecificFields


class DatabricksSqlConnector(SqlConnector):
Expand All @@ -17,6 +17,11 @@ def __init__(self):

self.database_engine = DatabaseEngine.DATABRICKS

@property
def engine_specific_fields(self) -> list[str]:
"""Get the engine specific fields."""
return [DatabaseEngineSpecificFields.CATALOG]

@property
def invalid_identifiers(self) -> list[str]:
"""Get the invalid identifiers upon which a sql query is rejected."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ def get_database_connector():
from text_2_sql_core.connectors.tsql_sql import TSQLSqlConnector

return TSQLSqlConnector()
elif os.environ["Text2Sql__DatabaseEngine"].upper() == "POSTGRESQL":
from text_2_sql_core.connectors.postgresql_sql import (
PostgresqlSqlConnector,
)

return PostgresqlSqlConnector()
else:
raise ValueError(
f"""Database engine {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from text_2_sql_core.connectors.sql import SqlConnector
import psycopg
from typing import Annotated
import os
import logging
import json

from text_2_sql_core.utils.database import DatabaseEngine, DatabaseEngineSpecificFields


class PostgresqlSqlConnector(SqlConnector):
def __init__(self):
super().__init__()

self.database_engine = DatabaseEngine.POSTGRESQL

@property
def engine_specific_fields(self) -> list[str]:
"""Get the engine specific fields."""
return [DatabaseEngineSpecificFields.DATABASE]

@property
def invalid_identifiers(self) -> list[str]:
"""Get the invalid identifiers upon which a sql query is rejected."""

return [
"CURRENT_USER", # Returns the name of the current user
"SESSION_USER", # Returns the name of the user that initiated the session
"USER", # Returns the name of the current user
"CURRENT_ROLE", # Returns the current role
"CURRENT_DATABASE", # Returns the name of the current database
"CURRENT_SCHEMA()", # Returns the name of the current schema
"CURRENT_SETTING()", # Returns the value of a specified configuration parameter
"PG_CURRENT_XACT_ID()", # Returns the current transaction ID
# (if the extension is enabled) Provides a view of query statistics
"PG_STAT_STATEMENTS()",
"PG_SLEEP()", # Delays execution by the specified number of seconds
"CLIENT_ADDR()", # Returns the IP address of the client (from pg_stat_activity)
"CLIENT_HOSTNAME()", # Returns the hostname of the client (from pg_stat_activity)
"PGP_SYM_DECRYPT()", # (from pgcrypto extension) Symmetric decryption function
"PGP_PUB_DECRYPT()", # (from pgcrypto extension) Asymmetric decryption function
]

async def query_execution(
self,
sql_query: Annotated[str, "The SQL query to run against the database."],
cast_to: any = None,
limit=None,
) -> list[dict]:
"""Run the SQL query against the PostgreSQL database asynchronously.

Args:
----
sql_query (str): The SQL query to run against the database.

Returns:
-------
list[dict]: The results of the SQL query.
"""
logging.info(f"Running query: {sql_query}")
results = []
connection_string = os.environ["Text2Sql__DatabaseConnectionString"]

# Establish an asynchronous connection to the PostgreSQL database
async with psycopg.AsyncConnection.connect(connection_string) as conn:
# Create an asynchronous cursor
async with conn.cursor() as cursor:
await cursor.execute(sql_query)

# Fetch column names
columns = [column[0] for column in cursor.description]

# Fetch rows based on the limit
if limit is not None:
rows = await cursor.fetchmany(limit)
else:
rows = await cursor.fetchall()

# Process the rows
for row in rows:
if cast_to:
results.append(cast_to.from_sql_row(row, columns))
else:
results.append(dict(zip(columns, row)))

logging.debug("Results: %s", results)
return results

async def get_entity_schemas(
self,
text: Annotated[
str,
"The text to run a semantic search against. Relevant entities will be returned.",
],
excluded_entities: Annotated[
list[str],
"The entities to exclude from the search results. Pass the entity property of entities (e.g. 'SalesLT.Address') you already have the schemas for to avoid getting repeated entities.",
] = [],
as_json: bool = True,
) -> str:
"""Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned.

Args:
----
text (str): The text to run the search against.

Returns:
str: The schema of the views or tables in JSON format.
"""

schemas = await self.ai_search_connector.get_entity_schemas(
text, excluded_entities
)

for schema in schemas:
schema["SelectFromEntity"] = ".".join([schema["Schema"], schema["Entity"]])

del schema["Entity"]
del schema["Schema"]

if as_json:
return json.dumps(schemas, default=str)
else:
return schemas
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging
import json

from text_2_sql_core.utils.database import DatabaseEngine
from text_2_sql_core.utils.database import DatabaseEngine, DatabaseEngineSpecificFields


class SnowflakeSqlConnector(SqlConnector):
Expand All @@ -17,6 +17,14 @@ def __init__(self):

self.database_engine = DatabaseEngine.SNOWFLAKE

@property
def engine_specific_fields(self) -> list[str]:
"""Get the engine specific fields."""
return [
DatabaseEngineSpecificFields.WAREHOUSE,
DatabaseEngineSpecificFields.DATABASE,
]

@property
def invalid_identifiers(self) -> list[str]:
"""Get the invalid identifiers upon which a sql query is rejected."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from abc import ABC, abstractmethod
from jinja2 import Template
import json
from text_2_sql_core.utils.database import DatabaseEngineSpecificFields


class SqlConnector(ABC):
Expand All @@ -36,6 +37,22 @@ def invalid_identifiers(self) -> list[str]:
"""Get the invalid identifiers upon which a sql query is rejected."""
pass

@property
@abstractmethod
def engine_specific_fields(self) -> list[str]:
"""Get the engine specific fields."""
pass

@property
def excluded_engine_specific_fields(self):
"""A method to get the excluded fields for the database engine."""

return [
field.value.capitalize()
for field in DatabaseEngineSpecificFields
if field not in self.engine_specific_fields
]

@abstractmethod
async def query_execution(
self,
Expand Down Expand Up @@ -155,7 +172,7 @@ def handle_node(node):

for token in expressions + identifiers:
if isinstance(token, Parameter):
identifier = token.this.this
identifier = str(token.this.this).upper()
else:
identifier = str(token).strip("()").upper()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import logging
import json

from text_2_sql_core.utils.database import DatabaseEngine
from text_2_sql_core.utils.database import DatabaseEngine, DatabaseEngineSpecificFields


class TSQLSqlConnector(SqlConnector):
Expand All @@ -16,6 +16,11 @@ def __init__(self):

self.database_engine = DatabaseEngine.TSQL

@property
def engine_specific_fields(self) -> list[str]:
"""Get the engine specific fields."""
return [DatabaseEngineSpecificFields.DATABASE]

@property
def invalid_identifiers(self) -> list[str]:
"""Get the invalid identifiers upon which a sql query is rejected."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import random
import re
import networkx as nx
from text_2_sql_core.utils.database import DatabaseEngine
from tenacity import retry, stop_after_attempt, wait_exponential
from text_2_sql_core.connectors.open_ai import OpenAIConnector

Expand Down Expand Up @@ -751,21 +750,9 @@ async def build_entity_entry(self, entity: EntityItem) -> EntityItem:
def excluded_fields_for_database_engine(self):
"""A method to get the excluded fields for the database engine."""

all_engine_specific_fields = ["Warehouse", "Database", "Catalog"]
if self.database_engine == DatabaseEngine.SNOWFLAKE:
engine_specific_fields = ["Warehouse", "Database"]
elif self.database_engine == DatabaseEngine.TSQL:
engine_specific_fields = ["Database"]
elif self.database_engine == DatabaseEngine.DATABRICKS:
engine_specific_fields = ["Catalog"]
else:
engine_specific_fields = []

# Determine top-level fields to exclude
filtered_entitiy_specific_fields = {
field.lower(): ...
for field in all_engine_specific_fields
if field not in engine_specific_fields
field.lower(): ... for field in self.excluded_engine_specific_fields
}

if filtered_entitiy_specific_fields:
Expand Down
Loading
Loading