Skip to content

Commit

Permalink
remove unknown handler
Browse files Browse the repository at this point in the history
  • Loading branch information
Gunther Hagleitner committed Sep 26, 2024
1 parent 9ca4511 commit d75232a
Showing 1 changed file with 27 additions and 20 deletions.
47 changes: 27 additions & 20 deletions conversational-analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,16 @@ def create_workflow(self) -> StateGraph:
workflow.add_node("Chart Generator", self.chart_gen)
workflow.add_node("Insight Generator", self.insight_generator)
workflow.add_node("Result Synthesizer", self.result_synthesizer)
workflow.add_node("Unknown Handler", self.unknown_handler)

# Define edges to control workflow execution
workflow.set_entry_point("Intent Classifier")
workflow.add_conditional_edges(
"Intent Classifier",
lambda state: state.path_decision,
{
"sql": "SQL Generator",
"data_visualization": "Chart Generator",
"insight": "Insight Generator",
"unknown": "Unknown Handler"
"database": "SQL Generator",
"visualization": "Chart Generator",
"insight": "Insight Generator"
}
)
workflow.add_edge("SQL Generator", "SQL Executor")
Expand All @@ -77,7 +75,6 @@ def create_workflow(self) -> StateGraph:

# Loop through the workflow
workflow.add_edge("Result Synthesizer", "Intent Classifier")
workflow.add_edge("Unknown Handler", "Intent Classifier")

return workflow

Expand All @@ -95,14 +92,9 @@ def intent_classifier(self, state: State) -> State:
# Classify the question to one of sql, insight, data_visualization, or unknown
intent = self.waii_intent_classification(query=state.query)

if intent in ["sql", "insight", "data_visualization", "unknown"]:
if intent in ["database", "insight", "visualization"]:
return state.model_copy(update={"path_decision": intent, "error": None})

def unknown_handler(self, state: State) -> State:
print(f"Unable to classify your question. Please enter a valid question.")
# reset the state and get back to the intent classifier again
return State()

def sql_generator(self, state: State) -> State:
print(f"Generating SQL for query: {state.query}")
if state.simulate_error_sql_gen:
Expand Down Expand Up @@ -136,7 +128,7 @@ def chart_gen(self, state: State) -> State:
return state.model_copy(update={"error": str(e)})

def insight_generator(self, state: State) -> dict:
print(f"Generating insight for data: {state.data}")
print(f"Generating insight for data: {state.query}")
if state.error:
return {}
# TODO: Need to fix this for integration with WAII
Expand All @@ -163,9 +155,9 @@ def waii_intent_classification(self, query: str) -> str:

# Create the chat prompt template
prompt = ChatPromptTemplate.from_messages([
("system", "You are an expert in classifying questions into 'sql', 'data_visualization', 'insight', or 'others'."),
("system", "You are an expert in classifying questions into 'database', 'visualization', or 'insight'. Use 'database' if the question can be answered from the movie and tv database, 'visualization' if the user would be best served by a graph, 'insight' if it's a general question you can answer from memory. Prefer 'database' if multiple apply."),
("human", "Can you classify the following question into one of these categories? Question: '{query}'. "
"Output: Strictly respond with either 'sql', 'data_visualization', 'insight', or 'unknown'. No additional text.")
"Output: Strictly respond with either 'database', 'visualization', or 'insight'. No additional text.")
])

# Create the chain
Expand All @@ -176,10 +168,10 @@ def waii_intent_classification(self, query: str) -> str:
classification = chain.invoke({"query": query}).strip().lower()

# Return the classification, mapping 'others' to 'unknown'
if classification in ["sql", "data_visualization", "insight"]:
if classification in ["database", "visualization", "insight"]:
return classification
else:
return "unknown"
return "insight"

def waii_sql_generator(self, question: str) -> str:
try:
Expand Down Expand Up @@ -210,9 +202,24 @@ def waii_chart_generator(self, data: List[Dict[str, Any]]) -> str:
print(f"Error generating chart: {e}")
raise e

def waii_insight_generator(param: List[str]) -> str:
# TODO: Need to integrate with WAII for generating query
return "Insight: These are the top 5 directors."
def waii_insight_generator(self, query: str) -> str:
# Create the language model
model = ChatOpenAI()

# Create the chat prompt template
prompt = ChatPromptTemplate.from_messages([
("system", "You are an AI assistant that generates insightful responses to any query. Provide a concise, relevant insight based on the user's question."),
("human", "Please provide an insightful response to the following question: '{query}'. "
"Your response should be informative and directly address the query.")
])

# Create the chain
chain = prompt | model | StrOutputParser()

# Generate the insight
insight = chain.invoke({"query": query})

return insight.strip()

def run_workflow(self):
while True:
Expand Down

0 comments on commit d75232a

Please sign in to comment.