Skip to content

Commit d9e60f6

Browse files
authored
add image support to the chat api (ollama#1490)
1 parent 4251b34 commit d9e60f6

File tree

3 files changed

+14
-10
lines changed

3 files changed

+14
-10
lines changed

api/types.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,9 @@ type ChatRequest struct {
5757
}
5858

5959
type Message struct {
60-
Role string `json:"role"` // one of ["system", "user", "assistant"]
61-
Content string `json:"content"`
60+
Role string `json:"role"` // one of ["system", "user", "assistant"]
61+
Content string `json:"content"`
62+
Images []ImageData `json:"images, omitempty"`
6263
}
6364

6465
type ChatResponse struct {

server/images.go

+9-7
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,10 @@ func (m *Model) Prompt(p PromptVars) (string, error) {
8686
return prompt.String(), nil
8787
}
8888

89-
func (m *Model) ChatPrompt(msgs []api.Message) (string, error) {
89+
func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error) {
9090
// build the prompt from the list of messages
9191
var prompt strings.Builder
92+
var currentImages []api.ImageData
9293
currentVars := PromptVars{
9394
First: true,
9495
}
@@ -108,35 +109,36 @@ func (m *Model) ChatPrompt(msgs []api.Message) (string, error) {
108109
case "system":
109110
if currentVars.System != "" {
110111
if err := writePrompt(); err != nil {
111-
return "", err
112+
return "", nil, err
112113
}
113114
}
114115
currentVars.System = msg.Content
115116
case "user":
116117
if currentVars.Prompt != "" {
117118
if err := writePrompt(); err != nil {
118-
return "", err
119+
return "", nil, err
119120
}
120121
}
121122
currentVars.Prompt = msg.Content
123+
currentImages = msg.Images
122124
case "assistant":
123125
currentVars.Response = msg.Content
124126
if err := writePrompt(); err != nil {
125-
return "", err
127+
return "", nil, err
126128
}
127129
default:
128-
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
130+
return "", nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
129131
}
130132
}
131133

132134
// Append the last set of vars if they are non-empty
133135
if currentVars.Prompt != "" || currentVars.System != "" {
134136
if err := writePrompt(); err != nil {
135-
return "", err
137+
return "", nil, err
136138
}
137139
}
138140

139-
return prompt.String(), nil
141+
return prompt.String(), currentImages, nil
140142
}
141143

142144
type ManifestV2 struct {

server/routes.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -994,7 +994,7 @@ func ChatHandler(c *gin.Context) {
994994

995995
checkpointLoaded := time.Now()
996996

997-
prompt, err := model.ChatPrompt(req.Messages)
997+
prompt, images, err := model.ChatPrompt(req.Messages)
998998
if err != nil {
999999
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
10001000
return
@@ -1037,6 +1037,7 @@ func ChatHandler(c *gin.Context) {
10371037
Format: req.Format,
10381038
CheckpointStart: checkpointStart,
10391039
CheckpointLoaded: checkpointLoaded,
1040+
Images: images,
10401041
}
10411042
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
10421043
ch <- gin.H{"error": err.Error()}

0 commit comments

Comments
 (0)