Skip to content

Commit 6f336da

Browse files
committed
feat(templates): use a single template for multimodals messages
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
1 parent 015835d commit 6f336da

File tree

4 files changed

+140
-29
lines changed

4 files changed

+140
-29
lines changed

core/config/backend_config.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,7 @@ type TemplateConfig struct {
197197
// It defaults to \n
198198
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character"`
199199

200-
Video string `yaml:"video"`
201-
Image string `yaml:"image"`
202-
Audio string `yaml:"audio"`
200+
Multimodal string `yaml:"multimodal"`
203201
}
204202

205203
func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error {

core/http/endpoints/openai/request.go

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -149,18 +149,27 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
149149
// Decode each request's message content
150150
imgIndex, vidIndex, audioIndex := 0, 0, 0
151151
for i, m := range input.Messages {
152+
nrOfImgsInMessage := 0
153+
nrOfVideosInMessage := 0
154+
nrOfAudiosInMessage := 0
155+
152156
switch content := m.Content.(type) {
153157
case string:
154158
input.Messages[i].StringContent = content
155159
case []interface{}:
156160
dat, _ := json.Marshal(content)
157161
c := []schema.Content{}
158162
json.Unmarshal(dat, &c)
163+
164+
textContent := ""
165+
// we will template this at the end
166+
159167
CONTENT:
160168
for _, pp := range c {
161169
switch pp.Type {
162170
case "text":
163-
input.Messages[i].StringContent = pp.Text
171+
textContent += pp.Text
172+
//input.Messages[i].StringContent = pp.Text
164173
case "video", "video_url":
165174
// Decode content as base64 either if it's an URL or base64 text
166175
base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL)
@@ -169,14 +178,8 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
169178
continue CONTENT
170179
}
171180
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
172-
173-
t := "[vid-{{.ID}}]{{.Text}}"
174-
if config.TemplateConfig.Video != "" {
175-
t = config.TemplateConfig.Video
176-
}
177-
// set a placeholder for each image
178-
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, vidIndex, input.Messages[i].StringContent)
179181
vidIndex++
182+
nrOfVideosInMessage++
180183
case "audio_url", "audio":
181184
// Decode content as base64 either if it's an URL or base64 text
182185
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
@@ -185,13 +188,8 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
185188
continue CONTENT
186189
}
187190
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
188-
// set a placeholder for each image
189-
t := "[audio-{{.ID}}]{{.Text}}"
190-
if config.TemplateConfig.Audio != "" {
191-
t = config.TemplateConfig.Audio
192-
}
193-
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, audioIndex, input.Messages[i].StringContent)
194191
audioIndex++
192+
nrOfAudiosInMessage++
195193
case "image_url", "image":
196194
// Decode content as base64 either if it's an URL or base64 text
197195
base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL)
@@ -200,16 +198,21 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
200198
continue CONTENT
201199
}
202200

203-
t := "[img-{{.ID}}]{{.Text}}"
204-
if config.TemplateConfig.Image != "" {
205-
t = config.TemplateConfig.Image
206-
}
207201
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
208-
// set a placeholder for each image
209-
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, imgIndex, input.Messages[i].StringContent)
202+
210203
imgIndex++
204+
nrOfImgsInMessage++
211205
}
212206
}
207+
208+
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{
209+
TotalImages: imgIndex,
210+
TotalVideos: vidIndex,
211+
TotalAudios: audioIndex,
212+
ImagesInMessage: nrOfImgsInMessage,
213+
VideosInMessage: nrOfVideosInMessage,
214+
AudiosInMessage: nrOfAudiosInMessage,
215+
}, textContent)
213216
}
214217
}
215218

pkg/templates/multimodal.go

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,60 @@ import (
77
"github.com/Masterminds/sprig/v3"
88
)
99

10-
func TemplateMultiModal(templateString string, templateID int, text string) (string, error) {
10+
type MultiModalOptions struct {
11+
TotalImages int
12+
TotalAudios int
13+
TotalVideos int
14+
15+
ImagesInMessage int
16+
AudiosInMessage int
17+
VideosInMessage int
18+
}
19+
20+
type MultimodalContent struct {
21+
ID int
22+
}
23+
24+
const DefaultMultiModalTemplate = "{{ range .Audio }}[audio-{{.ID}}]{{end}}{{ range .Images }}[img-{{.ID}}]{{end}}{{ range .Video }}[vid-{{.ID}}]{{end}}{{.Text}}"
25+
26+
func TemplateMultiModal(templateString string, opts MultiModalOptions, text string) (string, error) {
27+
if templateString == "" {
28+
templateString = DefaultMultiModalTemplate
29+
}
30+
1131
// compile the template
1232
tmpl, err := template.New("template").Funcs(sprig.FuncMap()).Parse(templateString)
1333
if err != nil {
1434
return "", err
1535
}
36+
37+
videos := []MultimodalContent{}
38+
for i := 0; i < opts.VideosInMessage; i++ {
39+
videos = append(videos, MultimodalContent{ID: i + (opts.TotalVideos - opts.VideosInMessage)})
40+
}
41+
42+
audios := []MultimodalContent{}
43+
for i := 0; i < opts.AudiosInMessage; i++ {
44+
audios = append(audios, MultimodalContent{ID: i + (opts.TotalAudios - opts.AudiosInMessage)})
45+
}
46+
47+
images := []MultimodalContent{}
48+
for i := 0; i < opts.ImagesInMessage; i++ {
49+
images = append(images, MultimodalContent{ID: i + (opts.TotalImages - opts.ImagesInMessage)})
50+
}
51+
1652
result := bytes.NewBuffer(nil)
1753
// execute the template
1854
err = tmpl.Execute(result, struct {
19-
ID int
20-
Text string
55+
Audio []MultimodalContent
56+
Images []MultimodalContent
57+
Video []MultimodalContent
58+
Text string
2159
}{
22-
ID: templateID,
23-
Text: text,
60+
Audio: audios,
61+
Images: images,
62+
Video: videos,
63+
Text: text,
2464
})
2565
return result.String(), err
2666
}

pkg/templates/multimodal_test.go

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,77 @@ import (
1111
var _ = Describe("EvaluateTemplate", func() {
1212
Context("templating simple strings for multimodal chat", func() {
1313
It("should template messages correctly", func() {
14-
result, err := TemplateMultiModal("[img-{{.ID}}]{{.Text}}", 1, "bar")
14+
result, err := TemplateMultiModal("", MultiModalOptions{
15+
TotalImages: 1,
16+
TotalAudios: 0,
17+
TotalVideos: 0,
18+
ImagesInMessage: 1,
19+
AudiosInMessage: 0,
20+
VideosInMessage: 0,
21+
}, "bar")
22+
Expect(err).NotTo(HaveOccurred())
23+
Expect(result).To(Equal("[img-0]bar"))
24+
})
25+
26+
It("should handle messages with more images correctly", func() {
27+
result, err := TemplateMultiModal("", MultiModalOptions{
28+
TotalImages: 2,
29+
TotalAudios: 0,
30+
TotalVideos: 0,
31+
ImagesInMessage: 2,
32+
AudiosInMessage: 0,
33+
VideosInMessage: 0,
34+
}, "bar")
35+
Expect(err).NotTo(HaveOccurred())
36+
Expect(result).To(Equal("[img-0][img-1]bar"))
37+
})
38+
It("should handle messages with more images correctly", func() {
39+
result, err := TemplateMultiModal("", MultiModalOptions{
40+
TotalImages: 4,
41+
TotalAudios: 1,
42+
TotalVideos: 0,
43+
ImagesInMessage: 2,
44+
AudiosInMessage: 1,
45+
VideosInMessage: 0,
46+
}, "bar")
47+
Expect(err).NotTo(HaveOccurred())
48+
Expect(result).To(Equal("[audio-0][img-2][img-3]bar"))
49+
})
50+
It("should handle messages with more images correctly", func() {
51+
result, err := TemplateMultiModal("", MultiModalOptions{
52+
TotalImages: 3,
53+
TotalAudios: 1,
54+
TotalVideos: 0,
55+
ImagesInMessage: 1,
56+
AudiosInMessage: 1,
57+
VideosInMessage: 0,
58+
}, "bar")
59+
Expect(err).NotTo(HaveOccurred())
60+
Expect(result).To(Equal("[audio-0][img-2]bar"))
61+
})
62+
It("should handle messages with more images correctly", func() {
63+
result, err := TemplateMultiModal("", MultiModalOptions{
64+
TotalImages: 0,
65+
TotalAudios: 0,
66+
TotalVideos: 0,
67+
ImagesInMessage: 0,
68+
AudiosInMessage: 0,
69+
VideosInMessage: 0,
70+
}, "bar")
71+
Expect(err).NotTo(HaveOccurred())
72+
Expect(result).To(Equal("bar"))
73+
})
74+
})
75+
Context("templating with custom defaults", func() {
76+
It("should handle messages with more images correctly", func() {
77+
result, err := TemplateMultiModal("{{ range .Audio }}[audio-{{ add1 .ID}}]{{end}}{{ range .Images }}[img-{{ add1 .ID}}]{{end}}{{ range .Video }}[vid-{{ add1 .ID}}]{{end}}{{.Text}}", MultiModalOptions{
78+
TotalImages: 1,
79+
TotalAudios: 0,
80+
TotalVideos: 0,
81+
ImagesInMessage: 1,
82+
AudiosInMessage: 0,
83+
VideosInMessage: 0,
84+
}, "bar")
1585
Expect(err).NotTo(HaveOccurred())
1686
Expect(result).To(Equal("[img-1]bar"))
1787
})

0 commit comments

Comments
 (0)