Skip to content

Commit 1a60979

Browse files
committed
Small refactoring and adaptations
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
1 parent 614bb5e commit 1a60979

File tree

6 files changed

+59
-108
lines changed

6 files changed

+59
-108
lines changed

core/application/application.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ func newApplication(appConfig *config.ApplicationConfig) *Application {
1818
backendLoader: config.NewBackendConfigLoader(appConfig.ModelPath),
1919
modelLoader: model.NewModelLoader(appConfig.ModelPath),
2020
applicationConfig: appConfig,
21-
templatesEvaluator: templates.NewEvaluator(templates.NewTemplateCache(appConfig.ModelPath)),
21+
templatesEvaluator: templates.NewEvaluator(appConfig.ModelPath),
2222
}
2323
}
2424

core/http/endpoints/openai/chat.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
303303
predInput = evaluator.TemplateMessages(input.Messages, config, funcs, shouldUseFn)
304304

305305
log.Debug().Msgf("Prompt (after templating): %s", predInput)
306-
if shouldUseFn && config.Grammar != "" {
306+
if config.Grammar != "" {
307307
log.Debug().Msgf("Grammar: %+v", config.Grammar)
308308
}
309309
}

pkg/templates/cache.go

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,33 +20,33 @@ import (
2020
// Technically, order doesn't _really_ matter, but the count must stay in sync, see tests/integration/reflect_test.go
2121
type TemplateType int
2222

23-
type TemplateCache struct {
23+
type templateCache struct {
2424
mu sync.Mutex
2525
templatesPath string
2626
templates map[TemplateType]map[string]*template.Template
2727
jinjaTemplates map[TemplateType]map[string]*exec.Template
2828
}
2929

30-
func NewTemplateCache(templatesPath string) *TemplateCache {
31-
tc := &TemplateCache{
30+
func newTemplateCache(templatesPath string) *templateCache {
31+
tc := &templateCache{
3232
templatesPath: templatesPath,
3333
templates: make(map[TemplateType]map[string]*template.Template),
3434
jinjaTemplates: make(map[TemplateType]map[string]*exec.Template),
3535
}
3636
return tc
3737
}
3838

39-
func (tc *TemplateCache) initializeTemplateMapKey(tt TemplateType) {
39+
func (tc *templateCache) initializeTemplateMapKey(tt TemplateType) {
4040
if _, ok := tc.templates[tt]; !ok {
4141
tc.templates[tt] = make(map[string]*template.Template)
4242
}
4343
}
4444

45-
func (tc *TemplateCache) ExistsInModelPath(s string) bool {
45+
func (tc *templateCache) existsInModelPath(s string) bool {
4646
return utils.ExistsInPath(tc.templatesPath, s)
4747
}
4848

49-
func (tc *TemplateCache) EvaluateTemplate(templateType TemplateType, templateNameOrContent string, in interface{}) (string, error) {
49+
func (tc *templateCache) EvaluateTemplate(templateType TemplateType, templateNameOrContent string, in interface{}) (string, error) {
5050
tc.mu.Lock()
5151
defer tc.mu.Unlock()
5252

@@ -72,7 +72,7 @@ func (tc *TemplateCache) EvaluateTemplate(templateType TemplateType, templateNam
7272
return buf.String(), nil
7373
}
7474

75-
func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templateName string) error {
75+
func (tc *templateCache) loadTemplateIfExists(templateType TemplateType, templateName string) error {
7676

7777
// Check if the template was already loaded
7878
if _, ok := tc.templates[templateType][templateName]; ok {
@@ -92,7 +92,7 @@ func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templat
9292
}
9393

9494
// can either be a file in the system or a string with the template
95-
if tc.ExistsInModelPath(modelTemplateFile) {
95+
if tc.existsInModelPath(modelTemplateFile) {
9696
d, err := os.ReadFile(file)
9797
if err != nil {
9898
return err
@@ -112,13 +112,13 @@ func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templat
112112
return nil
113113
}
114114

115-
func (tc *TemplateCache) initializeJinjaTemplateMapKey(tt TemplateType) {
115+
func (tc *templateCache) initializeJinjaTemplateMapKey(tt TemplateType) {
116116
if _, ok := tc.jinjaTemplates[tt]; !ok {
117117
tc.jinjaTemplates[tt] = make(map[string]*exec.Template)
118118
}
119119
}
120120

121-
func (tc *TemplateCache) EvaluateJinjaTemplate(templateType TemplateType, templateNameOrContent string, in map[string]interface{}) (string, error) {
121+
func (tc *templateCache) EvaluateJinjaTemplate(templateType TemplateType, templateNameOrContent string, in map[string]interface{}) (string, error) {
122122
tc.mu.Lock()
123123
defer tc.mu.Unlock()
124124

@@ -146,7 +146,7 @@ func (tc *TemplateCache) EvaluateJinjaTemplate(templateType TemplateType, templa
146146
return buf.String(), nil
147147
}
148148

149-
func (tc *TemplateCache) loadJinjaTemplateIfExists(templateType TemplateType, templateName string) error {
149+
func (tc *templateCache) loadJinjaTemplateIfExists(templateType TemplateType, templateName string) error {
150150
// Check if the template was already loaded
151151
if _, ok := tc.jinjaTemplates[templateType][templateName]; ok {
152152
return nil

pkg/templates/cache_test.go

Lines changed: 0 additions & 89 deletions
This file was deleted.

pkg/templates/evaluator.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,20 @@ const (
4444
)
4545

4646
type Evaluator struct {
47-
cache *TemplateCache
47+
cache *templateCache
4848
}
4949

50-
func NewEvaluator(cache *TemplateCache) *Evaluator {
50+
func NewEvaluator(modelPath string) *Evaluator {
5151
return &Evaluator{
52-
cache: cache,
52+
cache: newTemplateCache(modelPath),
5353
}
5454
}
5555

5656
func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config config.BackendConfig, in PromptTemplateData) (string, error) {
5757
template := ""
5858

5959
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
60-
if e.cache.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
60+
if e.cache.existsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
6161
template = config.Model
6262
}
6363

pkg/templates/evaluator_test.go

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@ import (
1010
. "github.com/onsi/gomega"
1111
)
1212

13+
const toolCallJinja = `{{ '<|begin_of_text|>' }}{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ '<|start_header_id|>system<|end_header_id|>
14+
15+
' + system_message + '<|eot_id|>' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|start_header_id|>user<|end_header_id|>
16+
17+
' + content + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>
18+
19+
' }}{% elif message['role'] == 'assistant' %}{{ content + '<|eot_id|>' }}{% endif %}{% endfor %}`
20+
1321
const chatML = `<|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}
1422
{{- if .FunctionCall }}
1523
<tool_call>
@@ -183,11 +191,30 @@ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]in
183191
},
184192
}
185193

194+
var jinjaTest map[string]map[string]interface{} = map[string]map[string]interface{}{
195+
"user": {
196+
"expected": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
197+
"config": &config.BackendConfig{
198+
TemplateConfig: config.TemplateConfig{
199+
ChatMessage: toolCallJinja,
200+
JinjaTemplate: true,
201+
},
202+
},
203+
"functions": []functions.Function{},
204+
"shouldUseFn": false,
205+
"messages": []schema.Message{
206+
{
207+
Role: "user",
208+
StringContent: "A long time ago in a galaxy far, far away...",
209+
},
210+
},
211+
},
212+
}
186213
var _ = Describe("Templates", func() {
187214
Context("chat message ChatML", func() {
188215
var evaluator *Evaluator
189216
BeforeEach(func() {
190-
evaluator = NewEvaluator(NewTemplateCache(""))
217+
evaluator = NewEvaluator("")
191218
})
192219
for key := range chatMLTestMatch {
193220
foo := chatMLTestMatch[key]
@@ -200,7 +227,7 @@ var _ = Describe("Templates", func() {
200227
Context("chat message llama3", func() {
201228
var evaluator *Evaluator
202229
BeforeEach(func() {
203-
evaluator = NewEvaluator(NewTemplateCache(""))
230+
evaluator = NewEvaluator("")
204231
})
205232
for key := range llama3TestMatch {
206233
foo := llama3TestMatch[key]
@@ -210,4 +237,17 @@ var _ = Describe("Templates", func() {
210237
})
211238
}
212239
})
240+
Context("chat message jinja", func() {
241+
var evaluator *Evaluator
242+
BeforeEach(func() {
243+
evaluator = NewEvaluator("")
244+
})
245+
for key := range jinjaTest {
246+
foo := jinjaTest[key]
247+
It("renders correctly `"+key+"`", func() {
248+
templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
249+
Expect(templated).To(Equal(foo["expected"]), templated)
250+
})
251+
}
252+
})
213253
})

0 commit comments

Comments
 (0)