Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 10 additions & 30 deletions sql_migration_assistant/app/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@


class LLMCalls:
def __init__(self, foundation_llm_name):
self.w = WorkspaceClient()
def __init__(self, openai_client, foundation_llm_name):
self.o = openai_client
self.foundation_llm_name = foundation_llm_name

def call_llm(self, messages, max_tokens, temperature):
Expand All @@ -26,27 +26,15 @@ def call_llm(self, messages, max_tokens, temperature):
# check to make sure temperature is between 0.0 and 1.0
if temperature < 0.0 or temperature > 1.0:
raise gr.Error("Temperature must be between 0.0 and 1.0")
response = self.w.serving_endpoints.query(
name=self.foundation_llm_name,
response = self.o.chat.completions.create(
model=self.foundation_llm_name,
max_tokens=max_tokens,
messages=messages,
temperature=temperature,
)
message = response.choices[0].message.content
return message

def convert_chat_to_llm_input(self, system_prompt, chat):
# Convert the chat list of lists to the required format for the LLM
messages = [ChatMessage(role=ChatMessageRole.SYSTEM, content=system_prompt)]
for q, a in chat:
messages.extend(
[
ChatMessage(role=ChatMessageRole.USER, content=q),
ChatMessage(role=ChatMessageRole.ASSISTANT, content=a),
]
)
return messages

################################################################################
# FUNCTION FOR TRANSLATING CODE
################################################################################
Expand All @@ -55,30 +43,22 @@ def convert_chat_to_llm_input(self, system_prompt, chat):

def llm_translate(self, system_prompt, input_code, max_tokens, temperature):
messages = [
ChatMessage(role=ChatMessageRole.SYSTEM, content=system_prompt),
ChatMessage(role=ChatMessageRole.USER, content=input_code),
{"role": "system", "content": system_prompt},
{"role": "user", "content": input_code}
]

# call the LLM end point.
llm_answer = self.call_llm(
messages=messages, max_tokens=max_tokens, temperature=temperature
)
# Extract the code from in between the triple backticks (```), since LLM often prints the code like this.
# Also removes the 'sql' prefix always added by the LLM.
translation = llm_answer # .split("Final answer:\n")[1].replace(">>", "").replace("<<", "")
return translation

def llm_chat(self, system_prompt, query, chat_history):
messages = self.convert_chat_to_llm_input(system_prompt, chat_history)
messages.append(ChatMessage(role=ChatMessageRole.USER, content=query))
# call the LLM end point.
llm_answer = self.call_llm(messages=messages)
return llm_answer
translation = llm_answer
return translation

def llm_intent(self, system_prompt, input_code, max_tokens, temperature):
messages = [
ChatMessage(role=ChatMessageRole.SYSTEM, content=system_prompt),
ChatMessage(role=ChatMessageRole.USER, content=input_code),
{"role": "system", "content": system_prompt},
{"role": "user", "content": input_code}
]

# call the LLM end point.
Expand Down
25 changes: 25 additions & 0 deletions sql_migration_assistant/app/prompt_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import gradio as gr
class PromptHelper:
def __init__(self, see, catalog, schema, prompt_table):
self.see = see
self.CATALOG = catalog
self.SCHEMA = schema
self.PROMPT_TABLE = prompt_table

def get_prompts(self, agent):
gr.Info("Retrieving Prompts...")
response = self.see.execute(
f"SELECT id, prompt, temperature, token_limit, save_time FROM {self.CATALOG}.{self.SCHEMA}.{self.PROMPT_TABLE} "
f"WHERE agent = '{agent}' "
f"ORDER BY save_time DESC "
)
return response.result.data_array

def save_prompt(self, agent, prompt, temperature, token_limit):
gr.Info("Saving prompt...")
self.see.execute(
f"INSERT INTO {self.CATALOG}.{self.SCHEMA}.{self.PROMPT_TABLE} "
f"(agent, prompt, temperature, token_limit, save_time) "
f"VALUES ('{agent}', '{prompt}',{temperature}, {token_limit}, CURRENT_TIMESTAMP())"
)
gr.Info("Prompt saved")
Loading