Skip to content

Commit 3892c3a

Browse files
authored
llm: remove internal subprocess req and resp types (ollama#9324)
This commit refactors the LLM subsystem by removing internal subprocess request and response types. It consolidates duplicate type definitions across the codebase, moving them to centralized locations. The change also standardizes interfaces between components, simplifies the ServerStatusResp struct, and moves the ParseDurationMs function to a common package. This cleanup reduces code duplication between different runner implementations (llamarunner and ollamarunner).
1 parent 4e320b8 commit 3892c3a

File tree

4 files changed

+125
-354
lines changed

4 files changed

+125
-354
lines changed

llm/server.go

Lines changed: 39 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
402402
s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal)
403403
}
404404

405-
slog.Info("starting llama server", "cmd", s.cmd.String())
405+
slog.Info("starting llama server", "cmd", s.cmd)
406406
if envconfig.Debug() {
407407
filteredEnv := []string{}
408408
for _, ev := range s.cmd.Env {
@@ -470,7 +470,7 @@ const ( // iota is reset to 0
470470
ServerStatusError
471471
)
472472

473-
func (s ServerStatus) ToString() string {
473+
func (s ServerStatus) String() string {
474474
switch s {
475475
case ServerStatusReady:
476476
return "llm server ready"
@@ -485,12 +485,9 @@ func (s ServerStatus) ToString() string {
485485
}
486486
}
487487

488-
type ServerStatusResp struct {
489-
Status string `json:"status"`
490-
SlotsIdle int `json:"slots_idle"`
491-
SlotsProcessing int `json:"slots_processing"`
492-
Error string `json:"error"`
493-
Progress float32 `json:"progress"`
488+
type ServerStatusResponse struct {
489+
Status ServerStatus `json:"status"`
490+
Progress float32 `json:"progress"`
494491
}
495492

496493
func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
@@ -502,7 +499,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
502499
}
503500
if s.cmd.ProcessState.ExitCode() == -1 {
504501
// Most likely a signal killed it, log some more details to try to help troubleshoot
505-
slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState.String())
502+
slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState)
506503
}
507504
return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
508505
}
@@ -527,21 +524,19 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
527524
return ServerStatusError, fmt.Errorf("read health request: %w", err)
528525
}
529526

530-
var status ServerStatusResp
531-
if err := json.Unmarshal(body, &status); err != nil {
527+
var ssr ServerStatusResponse
528+
if err := json.Unmarshal(body, &ssr); err != nil {
532529
return ServerStatusError, fmt.Errorf("health unmarshal encode response: %w", err)
533530
}
534531

535-
switch status.Status {
536-
case "ok":
537-
return ServerStatusReady, nil
538-
case "no slot available":
539-
return ServerStatusNoSlotsAvailable, nil
540-
case "loading model":
541-
s.loadProgress = status.Progress
542-
return ServerStatusLoadingModel, nil
532+
switch ssr.Status {
533+
case ServerStatusLoadingModel:
534+
s.loadProgress = ssr.Progress
535+
return ssr.Status, nil
536+
case ServerStatusReady, ServerStatusNoSlotsAvailable:
537+
return ssr.Status, nil
543538
default:
544-
return ServerStatusError, fmt.Errorf("server error: %+v", status)
539+
return ssr.Status, fmt.Errorf("server error: %+v", ssr)
545540
}
546541
}
547542

@@ -616,7 +611,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
616611
status, _ := s.getServerStatus(ctx)
617612
if lastStatus != status && status != ServerStatusReady {
618613
// Only log on status changes
619-
slog.Info("waiting for server to become available", "status", status.ToString())
614+
slog.Info("waiting for server to become available", "status", status)
620615
}
621616
switch status {
622617
case ServerStatusReady:
@@ -630,7 +625,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
630625
slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress))
631626
stallTimer = time.Now().Add(stallDuration)
632627
} else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 {
633-
slog.Debug("model load completed, waiting for server to become available", "status", status.ToString())
628+
slog.Debug("model load completed, waiting for server to become available", "status", status)
634629
stallTimer = time.Now().Add(stallDuration)
635630
fullyLoaded = true
636631
}
@@ -671,71 +666,34 @@ type ImageData struct {
671666
AspectRatioID int `json:"aspect_ratio_id"`
672667
}
673668

674-
type completion struct {
675-
Content string `json:"content"`
676-
Model string `json:"model"`
677-
Prompt string `json:"prompt"`
678-
Stop bool `json:"stop"`
679-
StoppedLimit bool `json:"stopped_limit"`
680-
681-
Timings struct {
682-
PredictedN int `json:"predicted_n"`
683-
PredictedMS float64 `json:"predicted_ms"`
684-
PromptN int `json:"prompt_n"`
685-
PromptMS float64 `json:"prompt_ms"`
686-
}
687-
}
688-
689669
type CompletionRequest struct {
690670
Prompt string
691671
Format json.RawMessage
692672
Images []ImageData
693673
Options *api.Options
674+
675+
Grammar string // set before sending the request to the subprocess
694676
}
695677

696678
type CompletionResponse struct {
697-
Content string
698-
DoneReason string
699-
Done bool
700-
PromptEvalCount int
701-
PromptEvalDuration time.Duration
702-
EvalCount int
703-
EvalDuration time.Duration
679+
Content string `json:"content"`
680+
DoneReason string `json:"done_reason"`
681+
Done bool `json:"done"`
682+
PromptEvalCount int `json:"prompt_eval_count"`
683+
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
684+
EvalCount int `json:"eval_count"`
685+
EvalDuration time.Duration `json:"eval_duration"`
704686
}
705687

706688
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
707-
request := map[string]any{
708-
"prompt": req.Prompt,
709-
"stream": true,
710-
"n_predict": req.Options.NumPredict,
711-
"n_keep": req.Options.NumKeep,
712-
"main_gpu": req.Options.MainGPU,
713-
"temperature": req.Options.Temperature,
714-
"top_k": req.Options.TopK,
715-
"top_p": req.Options.TopP,
716-
"min_p": req.Options.MinP,
717-
"typical_p": req.Options.TypicalP,
718-
"repeat_last_n": req.Options.RepeatLastN,
719-
"repeat_penalty": req.Options.RepeatPenalty,
720-
"presence_penalty": req.Options.PresencePenalty,
721-
"frequency_penalty": req.Options.FrequencyPenalty,
722-
"mirostat": req.Options.Mirostat,
723-
"mirostat_tau": req.Options.MirostatTau,
724-
"mirostat_eta": req.Options.MirostatEta,
725-
"seed": req.Options.Seed,
726-
"stop": req.Options.Stop,
727-
"image_data": req.Images,
728-
"cache_prompt": true,
729-
}
730-
731689
if len(req.Format) > 0 {
732690
switch string(req.Format) {
733691
case `null`, `""`:
734692
// Field was set, but "missing" a value. We accept
735693
// these as "not set".
736694
break
737695
case `"json"`:
738-
request["grammar"] = grammarJSON
696+
req.Grammar = grammarJSON
739697
default:
740698
if req.Format[0] != '{' {
741699
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
@@ -746,10 +704,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
746704
if g == nil {
747705
return fmt.Errorf("invalid JSON schema in format")
748706
}
749-
request["grammar"] = string(g)
707+
req.Grammar = string(g)
750708
}
751709
}
752710

711+
if req.Options == nil {
712+
opts := api.DefaultOptions()
713+
req.Options = &opts
714+
}
715+
753716
if err := s.sem.Acquire(ctx, 1); err != nil {
754717
if errors.Is(err, context.Canceled) {
755718
slog.Info("aborting completion request due to client closing the connection")
@@ -770,15 +733,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
770733
if err != nil {
771734
return err
772735
} else if status != ServerStatusReady {
773-
return fmt.Errorf("unexpected server status: %s", status.ToString())
736+
return fmt.Errorf("unexpected server status: %s", status)
774737
}
775738

776739
// Handling JSON marshaling with special characters unescaped.
777740
buffer := &bytes.Buffer{}
778741
enc := json.NewEncoder(buffer)
779742
enc.SetEscapeHTML(false)
780743

781-
if err := enc.Encode(request); err != nil {
744+
if err := enc.Encode(req); err != nil {
782745
return fmt.Errorf("failed to marshal data: %v", err)
783746
}
784747

@@ -829,7 +792,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
829792
evt = line
830793
}
831794

832-
var c completion
795+
var c CompletionResponse
833796
if err := json.Unmarshal(evt, &c); err != nil {
834797
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
835798
}
@@ -853,20 +816,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
853816
})
854817
}
855818

856-
if c.Stop {
857-
doneReason := "stop"
858-
if c.StoppedLimit {
859-
doneReason = "length"
860-
}
861-
862-
fn(CompletionResponse{
863-
Done: true,
864-
DoneReason: doneReason,
865-
PromptEvalCount: c.Timings.PromptN,
866-
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
867-
EvalCount: c.Timings.PredictedN,
868-
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
869-
})
819+
if c.Done {
820+
fn(c)
870821
return nil
871822
}
872823
}
@@ -914,7 +865,7 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err
914865
if err != nil {
915866
return nil, err
916867
} else if status != ServerStatusReady {
917-
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
868+
return nil, fmt.Errorf("unexpected server status: %s", status)
918869
}
919870

920871
data, err := json.Marshal(EmbeddingRequest{Content: input})
@@ -1059,12 +1010,3 @@ func (s *llmServer) EstimatedVRAMByGPU(gpuID string) uint64 {
10591010
}
10601011
return 0
10611012
}
1062-
1063-
func parseDurationMs(ms float64) time.Duration {
1064-
dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
1065-
if err != nil {
1066-
panic(err)
1067-
}
1068-
1069-
return dur
1070-
}

0 commit comments

Comments
 (0)