Skip to content

Commit

Permalink
runner.go: Enforce NUM_PARALLEL directly in the runner
Browse files Browse the repository at this point in the history
NUM_PARALEL is currently enforced by the Ollama server process - it
will only issue requests to the runner if the maximum number of
concurrent requests has not been exceeded. Although this should
be sufficient, it is good for the runner to protect its own data
structures. Currently, if too many requests get through to the
runner, they will just get stuck and never return.

This may help with reports of Ollama hanging, though it is unclear
how it would actually occur.

Bug ollama#7573
  • Loading branch information
jessegross committed Nov 14, 2024
1 parent 549c2bd commit 17b386a
Showing 1 changed file with 47 additions and 20 deletions.
67 changes: 47 additions & 20 deletions llama/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"time"
"unicode/utf8"

"golang.org/x/sync/semaphore"

"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llama"
)
Expand Down Expand Up @@ -203,38 +205,51 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
}

type Server struct {
// is the server ready to process requests?
// protects access to model and image
ready sync.WaitGroup

// loaded model
model *llama.Model
lc *llama.Context

// required for image embeddings
// image model context for multi-modal models
image *ImageContext

// status for external health reporting - loading, ready to serve, etc.
status ServerStatus

// current progress on loading the model
progress float32

// number of simultaneous requests to handle
parallel int

// maximum number of elements in a batch (per sequence)
// TODO (jmorganca): make this n_batch
batchSize int

// parallel is the number of parallel requests to handle
parallel int
// protects access to everything below this line
// this is context state needed for decoding
mu sync.Mutex

// indicates that data is ready for processing
cond *sync.Cond

// decoding state
lc *llama.Context

// seqs is the list of parallel sequences being evaluated
// TODO (jmorganca): this can probably be moved into run()
// the list of simultaneous sequences being evaluated
seqs []*Sequence

// seqs can have a maximum of parallel entries, which
// is enfoced by seqSem
seqsSem *semaphore.Weighted

// KV cache
cache *InputCache

// next sequence for prompt processing to avoid starvation
nextSeq int

// is the server ready to process requests?
ready sync.WaitGroup

mu sync.Mutex

cond *sync.Cond

progress float32

status ServerStatus
}

func (s *Server) allNil() bool {
Expand Down Expand Up @@ -616,8 +631,13 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}

// TODO (jmorganca): add to sequence queue instead of
// failing if a slot isn't available
// Ensure that a place to put the sequence is available
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return
}
defer s.seqsSem.Release(1)

s.mu.Lock()
for i, sq := range s.seqs {
if sq == nil {
Expand Down Expand Up @@ -700,7 +720,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
return
}

// TODO (jessegross): Wait for a free slot instead of failing and blocking forever
// Ensure that a place to put the sequence is available
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return
}
defer s.seqsSem.Release(1)

s.mu.Lock()
for i, sq := range s.seqs {
if sq == nil {
Expand Down Expand Up @@ -855,6 +881,7 @@ func main() {
batchSize: *batchSize,
parallel: *parallel,
seqs: make([]*Sequence, *parallel),
seqsSem: semaphore.NewWeighted(int64(*parallel)),
status: ServerStatusLoadingModel,
}

Expand Down

0 comments on commit 17b386a

Please sign in to comment.