1- import logging
1+ import gradio as gr
22
33from databricks .sdk import WorkspaceClient
44from databricks .sdk .service .serving import ChatMessage , ChatMessageRole
55
6- w = WorkspaceClient ()
7- foundation_llm_name = "databricks-meta-llama-3-1-405b-instruct"
8- max_token = 4096
9- messages = [
10- ChatMessage (role = ChatMessageRole .SYSTEM , content = "You are an unhelpful assistant" ),
11- ChatMessage (role = ChatMessageRole .USER , content = "What is RAG?" ),
12- ]
13-
146
157class LLMCalls :
16- def __init__ (self , foundation_llm_name , max_tokens ):
8+ def __init__ (self , foundation_llm_name ):
179 self .w = WorkspaceClient ()
1810 self .foundation_llm_name = foundation_llm_name
19- self .max_tokens = int (max_tokens )
2011
21- def call_llm (self , messages ):
12+ def call_llm (self , messages , max_tokens , temperature ):
2213 """
2314 Function to call the LLM model and return the response.
2415 :param messages: list of messages like
@@ -29,8 +20,17 @@ def call_llm(self, messages):
2920 ]
3021 :return: the response from the model
3122 """
23+
24+ max_tokens = int (max_tokens )
25+ temperature = float (temperature )
26+ # check to make sure temperature is between 0.0 and 1.0
27+ if temperature < 0.0 or temperature > 1.0 :
28+ raise gr .Error ("Temperature must be between 0.0 and 1.0" )
3229 response = self .w .serving_endpoints .query (
33- name = foundation_llm_name , max_tokens = max_token , messages = messages
30+ name = self .foundation_llm_name ,
31+ max_tokens = max_tokens ,
32+ messages = messages ,
33+ temperature = temperature ,
3434 )
3535 message = response .choices [0 ].message .content
3636 return message
@@ -53,14 +53,16 @@ def convert_chat_to_llm_input(self, system_prompt, chat):
5353
5454 # this is called to actually send a request and receive response from the llm endpoint.
5555
56- def llm_translate (self , system_prompt , input_code ):
56+ def llm_translate (self , system_prompt , input_code , max_tokens , temperature ):
5757 messages = [
5858 ChatMessage (role = ChatMessageRole .SYSTEM , content = system_prompt ),
5959 ChatMessage (role = ChatMessageRole .USER , content = input_code ),
6060 ]
6161
6262 # call the LLM end point.
63- llm_answer = self .call_llm (messages = messages )
63+ llm_answer = self .call_llm (
64+ messages = messages , max_tokens = max_tokens , temperature = temperature
65+ )
6466 # Extract the code from in between the triple backticks (```), since LLM often prints the code like this.
6567 # Also removes the 'sql' prefix always added by the LLM.
6668 translation = llm_answer # .split("Final answer:\n")[1].replace(">>", "").replace("<<", "")
@@ -73,12 +75,14 @@ def llm_chat(self, system_prompt, query, chat_history):
7375 llm_answer = self .call_llm (messages = messages )
7476 return llm_answer
7577
76- def llm_intent (self , system_prompt , input_code ):
78+ def llm_intent (self , system_prompt , input_code , max_tokens , temperature ):
7779 messages = [
7880 ChatMessage (role = ChatMessageRole .SYSTEM , content = system_prompt ),
7981 ChatMessage (role = ChatMessageRole .USER , content = input_code ),
8082 ]
8183
8284 # call the LLM end point.
83- llm_answer = self .call_llm (messages = messages )
85+ llm_answer = self .call_llm (
86+ messages = messages , max_tokens = max_tokens , temperature = temperature
87+ )
8488 return llm_answer
0 commit comments