Skip to content

Commit 84a7e72

Browse files
UI, llms, prompt saving (#285)
This PR makes 3 main changes to the Legion migration assistant - Replace the Databricks SDK with Open AI for making LLM calls. This is necessary to get around the 5 min timeout in the Databricks SDK. Creating a workspace client with a config which had timeout set to 10 minutes didn't fix this, so reverting to the Open AI client. - Add functionality for saving and loading user prompts in the UI, - Streamline the UI making it easier to use
1 parent cac167b commit 84a7e72

File tree

9 files changed

+204
-434
lines changed

9 files changed

+204
-434
lines changed

sql_migration_assistant/app/llm.py

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66

77
class LLMCalls:
8-
def __init__(self, foundation_llm_name):
9-
self.w = WorkspaceClient()
8+
def __init__(self, openai_client, foundation_llm_name):
9+
self.o = openai_client
1010
self.foundation_llm_name = foundation_llm_name
1111

1212
def call_llm(self, messages, max_tokens, temperature):
@@ -26,27 +26,15 @@ def call_llm(self, messages, max_tokens, temperature):
2626
# check to make sure temperature is between 0.0 and 1.0
2727
if temperature < 0.0 or temperature > 1.0:
2828
raise gr.Error("Temperature must be between 0.0 and 1.0")
29-
response = self.w.serving_endpoints.query(
30-
name=self.foundation_llm_name,
29+
response = self.o.chat.completions.create(
30+
model=self.foundation_llm_name,
3131
max_tokens=max_tokens,
3232
messages=messages,
3333
temperature=temperature,
3434
)
3535
message = response.choices[0].message.content
3636
return message
3737

38-
def convert_chat_to_llm_input(self, system_prompt, chat):
39-
# Convert the chat list of lists to the required format for the LLM
40-
messages = [ChatMessage(role=ChatMessageRole.SYSTEM, content=system_prompt)]
41-
for q, a in chat:
42-
messages.extend(
43-
[
44-
ChatMessage(role=ChatMessageRole.USER, content=q),
45-
ChatMessage(role=ChatMessageRole.ASSISTANT, content=a),
46-
]
47-
)
48-
return messages
49-
5038
################################################################################
5139
# FUNCTION FOR TRANSLATING CODE
5240
################################################################################
@@ -55,30 +43,22 @@ def convert_chat_to_llm_input(self, system_prompt, chat):
5543

5644
def llm_translate(self, system_prompt, input_code, max_tokens, temperature):
5745
messages = [
58-
ChatMessage(role=ChatMessageRole.SYSTEM, content=system_prompt),
59-
ChatMessage(role=ChatMessageRole.USER, content=input_code),
46+
{"role": "system", "content": system_prompt},
47+
{"role": "user", "content": input_code}
6048
]
6149

6250
# call the LLM end point.
6351
llm_answer = self.call_llm(
6452
messages=messages, max_tokens=max_tokens, temperature=temperature
6553
)
66-
# Extract the code from in between the triple backticks (```), since LLM often prints the code like this.
67-
# Also removes the 'sql' prefix always added by the LLM.
68-
translation = llm_answer # .split("Final answer:\n")[1].replace(">>", "").replace("<<", "")
69-
return translation
7054

71-
def llm_chat(self, system_prompt, query, chat_history):
72-
messages = self.convert_chat_to_llm_input(system_prompt, chat_history)
73-
messages.append(ChatMessage(role=ChatMessageRole.USER, content=query))
74-
# call the LLM end point.
75-
llm_answer = self.call_llm(messages=messages)
76-
return llm_answer
55+
translation = llm_answer
56+
return translation
7757

7858
def llm_intent(self, system_prompt, input_code, max_tokens, temperature):
7959
messages = [
80-
ChatMessage(role=ChatMessageRole.SYSTEM, content=system_prompt),
81-
ChatMessage(role=ChatMessageRole.USER, content=input_code),
60+
{"role": "system", "content": system_prompt},
61+
{"role": "user", "content": input_code}
8262
]
8363

8464
# call the LLM end point.
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import gradio as gr
2+
class PromptHelper:
3+
def __init__(self, see, catalog, schema, prompt_table):
4+
self.see = see
5+
self.CATALOG = catalog
6+
self.SCHEMA = schema
7+
self.PROMPT_TABLE = prompt_table
8+
9+
def get_prompts(self, agent):
10+
gr.Info("Retrieving Prompts...")
11+
response = self.see.execute(
12+
f"SELECT id, prompt, temperature, token_limit, save_time FROM {self.CATALOG}.{self.SCHEMA}.{self.PROMPT_TABLE} "
13+
f"WHERE agent = '{agent}' "
14+
f"ORDER BY save_time DESC "
15+
)
16+
return response.result.data_array
17+
18+
def save_prompt(self, agent, prompt, temperature, token_limit):
19+
gr.Info("Saving prompt...")
20+
self.see.execute(
21+
f"INSERT INTO {self.CATALOG}.{self.SCHEMA}.{self.PROMPT_TABLE} "
22+
f"(agent, prompt, temperature, token_limit, save_time) "
23+
f"VALUES ('{agent}', '{prompt}',{temperature}, {token_limit}, CURRENT_TIMESTAMP())"
24+
)
25+
gr.Info("Prompt saved")

0 commit comments

Comments
 (0)