From 0af613abbde0c7760f118017a515a7617f32af73 Mon Sep 17 00:00:00 2001 From: Jonathan Hecl Date: Sat, 3 Aug 2024 01:00:32 -0300 Subject: [PATCH] SystemPrompt --- v2/ollamaclient.go | 42 +++++++++++++++++++++++++----------------- v2/stream.go | 2 ++ v2/tools_test.go | 1 + 3 files changed, 28 insertions(+), 17 deletions(-) diff --git a/v2/ollamaclient.go b/v2/ollamaclient.go index 766ea16..4470106 100644 --- a/v2/ollamaclient.go +++ b/v2/ollamaclient.go @@ -33,6 +33,7 @@ type RequestOptions struct { // GenerateRequest represents the request payload for generating output type GenerateRequest struct { Model string `json:"model"` + System string `json:"system,omitempty"` Prompt string `json:"prompt,omitempty"` Images []string `json:"images,omitempty"` // base64 encoded images Stream bool `json:"stream,omitempty"` @@ -85,6 +86,7 @@ type Config struct { TrimSpace bool Verbose bool ContextLength int64 + SystemPrompt string Tools []json.RawMessage } @@ -161,6 +163,11 @@ func (oc *Config) SetReproducible(optionalSeed ...int) { oc.SeedOrNegative = defaultFixedSeed } +// SetSystemPrompt sets the system prompt for this Ollama config +func (oc *Config) SetSystemPrompt(prompt string) { + oc.SystemPrompt = prompt +} + // SetRandom configures the generated output to not be reproducible func (oc *Config) SetRandom() { oc.SeedOrNegative = -1 @@ -193,18 +200,24 @@ func (oc *Config) GetOutputChat(promptAndOptionalImages ...string) (OutputChat, if seed < 0 { temperature = oc.TemperatureIfNegativeSeed } + messages := []Message{} + if oc.SystemPrompt != "" { + messages = append(messages, Message{ + Role: "system", + Content: oc.SystemPrompt, + }) + } + messages = append(messages, Message{ + Role: "user", + Content: prompt, + }) var reqBody GenerateChatRequest if len(images) > 0 { reqBody = GenerateChatRequest{ - Model: oc.ModelName, - Messages: []Message{ - { - Role: "user", - Content: prompt, - }, - }, - Images: images, - Tools: oc.Tools, + Model: oc.ModelName, + Messages: messages, + Images: images, + Tools: oc.Tools, Options: RequestOptions{ Seed: seed, // set to -1 to make it random Temperature: temperature, // set to 0 together with a specific seed to make output reproducible @@ -212,14 +225,9 @@ func (oc *Config) GetOutputChat(promptAndOptionalImages ...string) (OutputChat, } } else { reqBody = GenerateChatRequest{ - Model: oc.ModelName, - Messages: []Message{ - { - Role: "user", - Content: prompt, - }, - }, - Tools: oc.Tools, + Model: oc.ModelName, + Messages: messages, + Tools: oc.Tools, Options: RequestOptions{ Seed: seed, // set to -1 to make it random Temperature: temperature, // set to 0 together with a specific seed to make output reproducible diff --git a/v2/stream.go b/v2/stream.go index 0cba35d..cb632ef 100644 --- a/v2/stream.go +++ b/v2/stream.go @@ -69,6 +69,7 @@ func (oc *Config) StreamOutput(callbackFunction func(string, bool), promptAndOpt if len(images) > 0 { reqBody = GenerateRequest{ Model: oc.ModelName, + System: oc.SystemPrompt, Prompt: prompt, Images: images, Stream: true, @@ -80,6 +81,7 @@ func (oc *Config) StreamOutput(callbackFunction func(string, bool), promptAndOpt } else { reqBody = GenerateRequest{ Model: oc.ModelName, + System: oc.SystemPrompt, Prompt: prompt, Stream: true, Options: RequestOptions{ diff --git a/v2/tools_test.go b/v2/tools_test.go index fd092fc..ef8e453 100644 --- a/v2/tools_test.go +++ b/v2/tools_test.go @@ -18,6 +18,7 @@ func TestTools(t *testing.T) { t.Error("Expected to have 'llama3.1' model downloaded, but it's not present") } + oc.SetSystemPrompt("You are a helpful assistant.") oc.SetRandom() oc.SetTool(json.RawMessage(`{ "type": "function",