Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
"metadata": {},
"outputs": [],
"source": [
"agentic_text_2_sql = AutoGenText2Sql(engine_specific_rules=\"\", use_case=\"Analysing sales data\")"
"agentic_text_2_sql = AutoGenText2Sql(use_case=\"Analysing sales data\")"
]
},
{
Expand All @@ -100,9 +100,16 @@
"metadata": {},
"outputs": [],
"source": [
"async for message in agentic_text_2_sql.process_question(QuestionPayload(question=\"What total number of orders in June 2008?\")):\n",
"async for message in agentic_text_2_sql.process_question(QuestionPayload(question=\"What is the total number of sales?\")):\n",
" logging.info(\"Received %s Message from Text2SQL System\", message)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@


class AutoGenText2Sql:
def __init__(self, engine_specific_rules: str, **kwargs: dict):
def __init__(self, **kwargs: dict):
self.target_engine = os.environ["Text2Sql__DatabaseEngine"].upper()
self.engine_specific_rules = engine_specific_rules
self.kwargs = kwargs

def get_all_agents(self):
Expand All @@ -45,9 +44,7 @@ def get_all_agents(self):
"question_rewrite_agent", current_datetime=current_datetime
)

self.parallel_query_solving_agent = ParallelQuerySolvingAgent(
engine_specific_rules=self.engine_specific_rules, **self.kwargs
)
self.parallel_query_solving_agent = ParallelQuerySolvingAgent(**self.kwargs)

self.answer_agent = LLMAgentCreator.create("answer_agent")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from text_2_sql_core.prompts.load import load
from autogen_text_2_sql.creators.llm_model_creator import LLMModelCreator
from jinja2 import Template
import logging


class LLMAgentCreator:
Expand Down Expand Up @@ -89,6 +90,17 @@ def create(cls, name: str, **kwargs) -> AssistantAgent:

sql_helper = ConnectorFactory.get_database_connector()

# Handle engine specific rules
if "engine_specific_rules" not in kwargs:
if sql_helper.engine_specific_rules is not None:
kwargs["engine_specific_fields"] = sql_helper.engine_specific_rules
logging.info(
"Engine specific fields pulled from in-built: %s",
kwargs["engine_specific_fields"],
)
else:
kwargs["engine_specific_fields"] = ""

tools = []
if "tools" in agent_file and len(agent_file["tools"]) > 0:
for tool in agent_file["tools"]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@


class ParallelQuerySolvingAgent(BaseChatAgent):
def __init__(self, engine_specific_rules: str, **kwargs: dict):
def __init__(self, **kwargs: dict):
super().__init__(
"parallel_query_solving_agent",
"An agent that solves each query in parallel.",
)

self.engine_specific_rules = engine_specific_rules
self.kwargs = kwargs

@property
Expand Down Expand Up @@ -177,9 +176,7 @@ async def consume_inner_messages_from_agentic_flow(
for question_rewrite in question_rewrites["sub_questions"]:
logging.info(f"Processing sub-query: {question_rewrite}")
# Create an instance of the InnerAutoGenText2Sql class
inner_autogen_text_2_sql = InnerAutoGenText2Sql(
self.engine_specific_rules, **self.kwargs
)
inner_autogen_text_2_sql = InnerAutoGenText2Sql(**self.kwargs)

identifier = ", ".join(question_rewrite)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,9 @@ async def on_messages_stream(self, messages, sender=None, config=None):


class InnerAutoGenText2Sql:
def __init__(self, engine_specific_rules: str, **kwargs: dict):
def __init__(self, **kwargs: dict):
self.pre_run_query_cache = False
self.target_engine = os.environ["Text2Sql__DatabaseEngine"].upper()
self.engine_specific_rules = engine_specific_rules
self.kwargs = kwargs
self.set_mode()

Expand Down Expand Up @@ -73,21 +72,18 @@ def get_all_agents(self):

self.sql_schema_selection_agent = SqlSchemaSelectionAgent(
target_engine=self.target_engine,
engine_specific_rules=self.engine_specific_rules,
**self.kwargs,
)

self.sql_query_correction_agent = LLMAgentCreator.create(
"sql_query_correction_agent",
target_engine=self.target_engine,
engine_specific_rules=self.engine_specific_rules,
**self.kwargs,
)

self.disambiguation_and_sql_query_generation_agent = LLMAgentCreator.create(
"disambiguation_and_sql_query_generation_agent",
target_engine=self.target_engine,
engine_specific_rules=self.engine_specific_rules,
**self.kwargs,
)
agents = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ def __init__(self):

self.database_engine = DatabaseEngine.DATABRICKS

@property
def engine_specific_rules(self) -> str:
"""Get the engine specific rules."""
return

@property
def engine_specific_fields(self) -> list[str]:
"""Get the engine specific fields."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ def __init__(self):

self.database_engine = DatabaseEngine.SNOWFLAKE

@property
def engine_specific_rules(self) -> str:
"""Get the engine specific rules."""
return """When an ORDER BY clause is included in the SQL query, always append the ORDER BY clause with 'NULLS LAST' to ensure that NULL values are at the end of the result set. e.g. 'ORDER BY column_name DESC NULLS LAST'."""

@property
def engine_specific_fields(self) -> list[str]:
"""Get the engine specific fields."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ def __init__(self):

self.database_engine = None

@property
@abstractmethod
def engine_specific_rules(self) -> str:
"""Get the engine specific rules."""
pass

@property
@abstractmethod
def invalid_identifiers(self) -> list[str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ def __init__(self):

self.database_engine = DatabaseEngine.TSQL

@property
def engine_specific_rules(self) -> str:
"""Get the engine specific rules."""
return """Use TOP X instead of LIMIT X to limit the number of rows returned."""

@property
def engine_specific_fields(self) -> list[str]:
"""Get the engine specific fields."""
Expand Down
Loading