Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add warm-up function for provided prompt #301

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
add warm-up for provided prompt
  • Loading branch information
amakropoulos committed Jan 21, 2025
commit 6f58e064734c89fb715d59667f7a2821e7674898
67 changes: 46 additions & 21 deletions Runtime/LLMCharacter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,24 @@ protected virtual async Task<string> CompletionRequest(string json, Callback<str
return result;
}

protected async Task<ChatRequest> PromptWithQuery(string query)
{
ChatRequest result = default;
await chatLock.WaitAsync();
try
{
AddPlayerMessage(query);
string prompt = template.ComputePrompt(chat, playerName, AIName);
result = GenerateRequest(prompt);
chat.RemoveAt(chat.Count - 1);
}
finally
{
chatLock.Release();
}
return result;
}

/// <summary>
/// Chat functionality of the LLM.
/// It calls the LLM completion based on the provided query including the previous chat history.
Expand All @@ -436,20 +454,7 @@ public virtual async Task<string> Chat(string query, Callback<string> callback =
if (!CheckTemplate()) return null;
if (!await InitNKeep()) return null;

string json;
await chatLock.WaitAsync();
try
{
AddPlayerMessage(query);
string prompt = template.ComputePrompt(chat, playerName, AIName);
json = JsonUtility.ToJson(GenerateRequest(prompt));
chat.RemoveAt(chat.Count - 1);
}
finally
{
chatLock.Release();
}

string json = JsonUtility.ToJson(await PromptWithQuery(query));
string result = await CompletionRequest(json, callback);

if (addToHistory && result != null)
Expand Down Expand Up @@ -494,23 +499,43 @@ public virtual async Task<string> Complete(string prompt, Callback<string> callb
}

/// <summary>
/// Allow to warm-up a model by processing the prompt.
/// Allow to warm-up a model by processing the system prompt.
/// The prompt processing will be cached (if cachePrompt=true) allowing for faster initialisation.
/// The function allows callback for when the prompt is processed and the response received.
///
/// The function calls the Chat function with a predefined query without adding it to history.
/// The function allows a callback function for when the prompt is processed and the response received.
/// </summary>
/// <param name="completionCallback">callback function called when the full response has been received</param>
/// <param name="query">user prompt used during the initialisation (not added to history)</param>
/// <returns>the LLM response</returns>
public virtual async Task Warmup(EmptyCallback completionCallback = null)
{
await Warmup(null, completionCallback);
}

/// <summary>
/// Allow to warm-up a model by processing the provided prompt without adding it to history.
/// The prompt processing will be cached (if cachePrompt=true) allowing for faster initialisation.
/// The function allows a callback function for when the prompt is processed and the response received.
///
/// </summary>
/// <param name="query">user prompt used during the initialisation (not added to history)</param>
/// <param name="completionCallback">callback function called when the full response has been received</param>
/// <returns>the LLM response</returns>
public virtual async Task Warmup(string query, EmptyCallback completionCallback = null)
{
await LoadTemplate();
if (!CheckTemplate()) return;
if (!await InitNKeep()) return;

string prompt = template.ComputePrompt(chat, playerName, AIName);
ChatRequest request = GenerateRequest(prompt);
ChatRequest request;
if (String.IsNullOrEmpty(query))
{
string prompt = template.ComputePrompt(chat, playerName, AIName);
request = GenerateRequest(prompt);
}
else
{
request = await PromptWithQuery(query);
}

request.n_predict = 0;
string json = JsonUtility.ToJson(request);
await CompletionRequest(json);
Expand Down
Loading