Skip to content

Commit

Permalink
more cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Gunther Hagleitner committed Sep 27, 2024
1 parent 8b39b9a commit c37d010
Showing 1 changed file with 28 additions and 46 deletions.
74 changes: 28 additions & 46 deletions conversational-analytics.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import os
import sys
import uuid
from typing import List, Optional, Dict, Any

import pandas as pd
from pydantic import BaseModel
from langgraph.graph import StateGraph
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser
from pydantic import BaseModel
from waii_sdk_py import WAII
from waii_sdk_py.query import QueryGenerationRequest, RunQueryRequest

Expand All @@ -24,20 +24,9 @@ class State(BaseModel):

class LanggraphWorkflowManager:

def __init__(self):
self.workflow = StateGraph(State)
self.workflow = self.create_workflow()
self.app = self.workflow.compile()
self.init_waii()
print(self.app.get_graph().draw_ascii())

def init_waii(self):
url = 'http://localhost:9859/api/'
api_key = ''
db_connection_str = 'snowflake://WAII_USER@gqobxjv-bhb91428/MOVIE_DB?role=WAII_USER_ROLE&warehouse=COMPUTE_WH'
WAII.initialize(url=url, api_key=api_key)
WAII.Database.activate_connection(db_connection_str)
print(f"Initialized WAII with connection: {db_connection_str}")
WAII.initialize(url=os.getenv("WAII_URL"), api_key=os.getenv("WAII_API_KEY"))
WAII.Database.activate_connection(os.getenv("DB_CONNECTION"))

def create_workflow(self) -> StateGraph:
workflow = StateGraph(State)
Expand All @@ -64,37 +53,19 @@ def create_workflow(self) -> StateGraph:
workflow.add_edge("SQL Executor", "Result Synthesizer")
workflow.add_edge("Chart Generator", "Result Synthesizer")
workflow.add_edge("Insight Generator", "Result Synthesizer")

workflow.add_edge("Result Synthesizer", "Question Classifier")

return workflow

def format_catalog_info(self, catalogs):
formatted_info = []

for catalog in catalogs.catalogs:
catalog_name = catalog.name
formatted_info.append(f"Database: {catalog_name}")

for schema in catalog.schemas:
schema_name = schema.name.schema_name
schema_description = schema.description

formatted_info.append(f" Schema: {schema_name}")
formatted_info.append(f" Description: {schema_description}")

formatted_info.append("")

return "\n".join(formatted_info)

def question_classifier(self, state: State) -> State:

state.database_description = self.format_catalog_info(WAII.Database.get_catalogs())
state.query = input("Enter your question: ")
state.query = input("Question: ")

prompt = ChatPromptTemplate.from_messages([
("system", "You are an expert in classifying questions into 'database', 'visualization', or 'insight'. Use 'database' if the question can be answered from the movie and tv database, 'visualization' if the user would be best served by a graph, 'insight' if it's a general question you can answer from memory. Prefer 'database' if multiple apply. Here is a description of what's in the database: '\n---\n{database_description}\n---\n'"),
("human", "Can you classify the following question into one of these categories? Question: '{query}'. "
("human",
"Database info: \n---\n{database_description}\n---\n"
"Can you classify the following question into one of 'database' or 'insight'? Question: '{query}'. "
"Output: Strictly respond with either 'database', 'visualization', or 'insight'. No additional text.")
])
chain = prompt | ChatOpenAI() | StrOutputParser()
Expand All @@ -103,6 +74,7 @@ def question_classifier(self, state: State) -> State:
return state.model_copy(update={"path_decision": classification, "error": None})

def sql_generator(self, state: State) -> State:
insight = ''
sql = WAII.Query.generate(QueryGenerationRequest(ask=state.query)).query
return state.model_copy(update={"sql": sql})

Expand All @@ -121,13 +93,6 @@ def insight_generator(self, state: State) -> dict:
insight = chain.invoke({"query": state.query})
return state.model_copy(update={"insight": insight, "error": None}, deep=True)

def format_data(self, data: List[Dict[str, Any]]) -> str:
formatted_data = ""
for row in data:
formatted_data += " | ".join([f"{key}: {value}" for key, value in row.items()])
formatted_data += "\n"
return formatted_data

def result_synthesizer(self, state: State) -> State:

model = ChatOpenAI()
Expand All @@ -139,10 +104,27 @@ def result_synthesizer(self, state: State) -> State:
"\n\n Instructions: Answer the user with this information.")
])
chain = prompt | model | StrOutputParser()
output = chain.invoke({"query": state.query, "data": self.format_data(state.data), "insight": state.insight}).strip().lower()
print(output)
data = "\n".join(" | ".join(f"{key}: {value}" for key, value in row.items()) for row in state.data)
output = chain.invoke({"query": state.query, "data": data, "insight": state.insight}).strip().lower()
print('Answer: '+output)
return state.model_copy(update={"response": output}, deep=True)

def __init__(self):
self.workflow = self.create_workflow()
self.app = self.workflow.compile()
self.init_waii()
print(self.app.get_graph().draw_ascii())

def format_catalog_info(self, catalogs):
return "\n".join([
f"Database: {catalog.name}\n" +
"\n".join([
f" Schema: {schema.name.schema_name}\n Description: {schema.description}"
for schema in catalog.schemas
]) + "\n"
for catalog in catalogs.catalogs
])

def run_workflow(self):
while True:
try:
Expand Down

0 comments on commit c37d010

Please sign in to comment.