Skip to content

Commit

Permalink
remove direct openai access
Browse files Browse the repository at this point in the history
  • Loading branch information
Gunther Hagleitner committed Sep 26, 2024
1 parent 9cae755 commit 9ca4511
Showing 1 changed file with 21 additions and 11 deletions.
32 changes: 21 additions & 11 deletions conversational-analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import pandas as pd
from langgraph.graph import StateGraph
from langgraph.types import RetryPolicy
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser
from pydantic import BaseModel
from waii_sdk_py import WAII
from waii_sdk_py.query import QueryGenerationRequest, RunQueryRequest

import open_ai_utils


class State(BaseModel):
query: str = ''
sql: str = ''
Expand All @@ -24,7 +24,6 @@ class State(BaseModel):
simulate_error_sql_gen: bool = False
simulate_error_sql_exec: bool = False


class LanggraphWorkflowManager:

def __init__(self):
Expand Down Expand Up @@ -149,6 +148,7 @@ def result_synthesizer(self, state: State) -> State:
if state.error:
print(f"Error in previous step: {state.error}")
return state

# Create a response based on the data
response = "Here are the results of your query:\n"
for row in state.data:
Expand All @@ -158,12 +158,24 @@ def result_synthesizer(self, state: State) -> State:
return state.model_copy(update={"response": response}, deep=True)

def waii_intent_classification(self, query: str) -> str:
system_message = """You are an expert in classifying questions into 'sql', 'data_visualization', 'insight', or 'others'."""
question = f"Can you classify the following question into one of these categories? Question: '{query}'. " \
f"Output: Strictly respond with either 'sql', 'data_visualization', 'insight', or 'unknown'. No additional text."
# Create the language model
model = ChatOpenAI()

classification = open_ai_utils.run_prompt(system_message=system_message, question=question)
# Create the chat prompt template
prompt = ChatPromptTemplate.from_messages([
("system", "You are an expert in classifying questions into 'sql', 'data_visualization', 'insight', or 'others'."),
("human", "Can you classify the following question into one of these categories? Question: '{query}'. "
"Output: Strictly respond with either 'sql', 'data_visualization', 'insight', or 'unknown'. No additional text.")
])

# Create the chain
chain = prompt | model | StrOutputParser()

print(query)
# Invoke the chain and get the classification
classification = chain.invoke({"query": query}).strip().lower()

# Return the classification, mapping 'others' to 'unknown'
if classification in ["sql", "data_visualization", "insight"]:
return classification
else:
Expand Down Expand Up @@ -191,7 +203,7 @@ def waii_chart_generator(self, data: List[Dict[str, Any]]) -> str:
try:
df_data = pd.DataFrame(data)
response = WAII.Chart.generate_chart(df=df_data)
# TODO: Remove this later (may be dump the chart into some JPG?)

print(f"Chart spec: {response.chart_spec}")
return response.chart_spec
except Exception as e:
Expand All @@ -213,8 +225,6 @@ def run_workflow(self):
print(f"Error in workflow: {e}. Will restart.")


# Example usage
if __name__ == "__main__":
# Who are the top 5 directors with the highest number of titles?
workflow_manager = LanggraphWorkflowManager()
workflow_manager.run_workflow()

0 comments on commit 9ca4511

Please sign in to comment.