Skip to content

Commit 282bfaa

Browse files
committed
ollamarunner: Use a separate context per multimodal input
Currently there is a single context per sequence, shared all by all multimodal inputs. Since we build a vision encoder graph per image, with a large number of inputs we can eventually hit the maximum number of graph nodes per context. This changes to use a separate context for each image, ensuring that available resource limits are consistent.
1 parent 9679f40 commit 282bfaa

File tree

4 files changed

+33
-19
lines changed

4 files changed

+33
-19
lines changed

model/model.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ type MultimodalProcessor interface {
6060
// This function is also responsible for updating MultimodalHash for any Multimodal
6161
// that is modified to ensure that there is a unique hash value that accurately
6262
// represents the contents.
63-
PostTokenize(ml.Context, []input.Input) ([]input.Input, error)
63+
PostTokenize([]input.Input) ([]input.Input, error)
6464
}
6565

6666
// Base implements the common fields and methods for all models

model/models/gemma3/model.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
111111
return visionOutputs, nil
112112
}
113113

114-
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
114+
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
115115
var result []input.Input
116116

117117
for _, inp := range inputs {

model/models/mllama/model.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,17 +106,17 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
106106
return m.Projector.Forward(ctx, crossAttentionStates), nil
107107
}
108108

109-
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
109+
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
110110
var images []input.Input
111111
fnvHash := fnv.New64a()
112112

113113
for i := range inputs {
114114
if inputs[i].Multimodal == nil {
115115
if len(images) > 0 {
116-
inputs[i].Multimodal = images[0].Multimodal
116+
inputs[i].Multimodal = []ml.Tensor{images[0].Multimodal.(ml.Tensor)}
117117
inputs[i].MultimodalHash = images[0].MultimodalHash
118118
for j := 1; j < len(images); j++ {
119-
inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
119+
inputs[i].Multimodal = append(inputs[i].Multimodal.([]ml.Tensor), images[0].Multimodal.(ml.Tensor))
120120
fnvHash.Reset()
121121
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
122122
binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
@@ -138,7 +138,10 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Inpu
138138
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
139139
var crossAttentionStates ml.Tensor
140140
if len(opts.Multimodal) > 0 {
141-
crossAttentionStates = opts.Multimodal[len(opts.Multimodal)-1].Multimodal.(ml.Tensor)
141+
images := opts.Multimodal[len(opts.Multimodal)-1].Multimodal.([]ml.Tensor)
142+
if len(images) > 0 {
143+
crossAttentionStates = images[len(images)-1]
144+
}
142145
}
143146

144147
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))

runner/ollamarunner/runner.go

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,14 @@ import (
3434
_ "github.com/ollama/ollama/model/models"
3535
)
3636

37+
type contextList struct {
38+
list []ml.Context
39+
}
40+
3741
type Sequence struct {
38-
// ctx for allocating tensors that last the lifetime of the sequence, such as
42+
// ctxs are used for allocating tensors that last the lifetime of the sequence, such as
3943
// multimodal embeddings
40-
ctx ml.Context
44+
ctxs *contextList
4145

4246
// batch index
4347
iBatch int
@@ -99,9 +103,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
99103
s.ready.Wait()
100104

101105
startTime := time.Now()
102-
ctx := s.model.Backend().NewContext()
103106

104-
inputs, err := s.inputs(ctx, prompt, images)
107+
inputs, ctxs, err := s.inputs(prompt, images)
105108
if err != nil {
106109
return nil, fmt.Errorf("failed to process inputs: %w", err)
107110
} else if len(inputs) == 0 {
@@ -127,7 +130,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
127130
// TODO(jessegross): Ingest cached history for grammar
128131

129132
return &Sequence{
130-
ctx: ctx,
133+
ctxs: ctxs,
131134
inputs: inputs,
132135
numPromptInputs: len(inputs),
133136
startProcessingTime: startTime,
@@ -146,7 +149,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
146149
// inputs processes the prompt and images into a list of inputs
147150
// by splitting the prompt on [img-<n>] tags, tokenizing text and
148151
// decoding images
149-
func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) ([]input.Input, error) {
152+
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *contextList, error) {
150153
var inputs []input.Input
151154
var parts []string
152155
var matches [][]string
@@ -161,12 +164,19 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) (
161164
parts = []string{prompt}
162165
}
163166

167+
var contexts contextList
168+
runtime.AddCleanup(&contexts, func(ctxs []ml.Context) {
169+
for _, ctx := range ctxs {
170+
ctx.Close()
171+
}
172+
}, contexts.list)
173+
164174
postTokenize := false
165175
for i, part := range parts {
166176
// text - tokenize
167177
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
168178
if err != nil {
169-
return nil, err
179+
return nil, nil, err
170180
}
171181

172182
for _, t := range tokens {
@@ -186,12 +196,14 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) (
186196
}
187197

188198
if imageIndex < 0 {
189-
return nil, fmt.Errorf("invalid image index: %d", n)
199+
return nil, nil, fmt.Errorf("invalid image index: %d", n)
190200
}
191201

202+
ctx := s.model.Backend().NewContext()
203+
contexts.list = append(contexts.list, ctx)
192204
imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
193205
if err != nil {
194-
return nil, err
206+
return nil, nil, err
195207
}
196208

197209
s.multimodalHash.Reset()
@@ -205,13 +217,13 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) (
205217

206218
if visionModel && postTokenize {
207219
var err error
208-
inputs, err = multimodalProcessor.PostTokenize(ctx, inputs)
220+
inputs, err = multimodalProcessor.PostTokenize(inputs)
209221
if err != nil {
210-
return nil, err
222+
return nil, nil, err
211223
}
212224
}
213225

214-
return inputs, nil
226+
return inputs, &contexts, nil
215227
}
216228

217229
type Server struct {
@@ -306,7 +318,6 @@ func (s *Server) removeSequence(seqIndex int, reason string) {
306318
close(seq.responses)
307319
close(seq.embedding)
308320
seq.cache.InUse = false
309-
seq.ctx.Close()
310321
s.seqs[seqIndex] = nil
311322
s.seqsSem.Release(1)
312323
}

0 commit comments

Comments
 (0)