55
66
77class LLMCalls :
8- def __init__ (self , foundation_llm_name ):
9- self .w = WorkspaceClient ()
8+ def __init__ (self , openai_client , foundation_llm_name ):
9+ self .o = openai_client
1010 self .foundation_llm_name = foundation_llm_name
1111
1212 def call_llm (self , messages , max_tokens , temperature ):
@@ -26,27 +26,15 @@ def call_llm(self, messages, max_tokens, temperature):
2626 # check to make sure temperature is between 0.0 and 1.0
2727 if temperature < 0.0 or temperature > 1.0 :
2828 raise gr .Error ("Temperature must be between 0.0 and 1.0" )
29- response = self .w . serving_endpoints . query (
30- name = self .foundation_llm_name ,
29+ response = self .o . chat . completions . create (
30+ model = self .foundation_llm_name ,
3131 max_tokens = max_tokens ,
3232 messages = messages ,
3333 temperature = temperature ,
3434 )
3535 message = response .choices [0 ].message .content
3636 return message
3737
38- def convert_chat_to_llm_input (self , system_prompt , chat ):
39- # Convert the chat list of lists to the required format for the LLM
40- messages = [ChatMessage (role = ChatMessageRole .SYSTEM , content = system_prompt )]
41- for q , a in chat :
42- messages .extend (
43- [
44- ChatMessage (role = ChatMessageRole .USER , content = q ),
45- ChatMessage (role = ChatMessageRole .ASSISTANT , content = a ),
46- ]
47- )
48- return messages
49-
5038 ################################################################################
5139 # FUNCTION FOR TRANSLATING CODE
5240 ################################################################################
@@ -55,30 +43,22 @@ def convert_chat_to_llm_input(self, system_prompt, chat):
5543
5644 def llm_translate (self , system_prompt , input_code , max_tokens , temperature ):
5745 messages = [
58- ChatMessage ( role = ChatMessageRole . SYSTEM , content = system_prompt ) ,
59- ChatMessage ( role = ChatMessageRole . USER , content = input_code ),
46+ { " role" : "system" , " content" : system_prompt } ,
47+ { " role" : "user" , " content" : input_code }
6048 ]
6149
6250 # call the LLM end point.
6351 llm_answer = self .call_llm (
6452 messages = messages , max_tokens = max_tokens , temperature = temperature
6553 )
66- # Extract the code from in between the triple backticks (```), since LLM often prints the code like this.
67- # Also removes the 'sql' prefix always added by the LLM.
68- translation = llm_answer # .split("Final answer:\n")[1].replace(">>", "").replace("<<", "")
69- return translation
7054
71- def llm_chat (self , system_prompt , query , chat_history ):
72- messages = self .convert_chat_to_llm_input (system_prompt , chat_history )
73- messages .append (ChatMessage (role = ChatMessageRole .USER , content = query ))
74- # call the LLM end point.
75- llm_answer = self .call_llm (messages = messages )
76- return llm_answer
55+ translation = llm_answer
56+ return translation
7757
7858 def llm_intent (self , system_prompt , input_code , max_tokens , temperature ):
7959 messages = [
80- ChatMessage ( role = ChatMessageRole . SYSTEM , content = system_prompt ) ,
81- ChatMessage ( role = ChatMessageRole . USER , content = input_code ),
60+ { " role" : "system" , " content" : system_prompt } ,
61+ { " role" : "user" , " content" : input_code }
8262 ]
8363
8464 # call the LLM end point.
0 commit comments