-
Notifications
You must be signed in to change notification settings - Fork 92
fix: add daemon-side model repackaging for Linux support #639
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -214,7 +214,7 @@ type builderInitResult struct { | |
| } | ||
|
|
||
| // initializeBuilder creates a package builder from GGUF, Safetensors, DDUF, or existing model | ||
| func initializeBuilder(cmd *cobra.Command, opts packageOptions) (*builderInitResult, error) { | ||
| func initializeBuilder(ctx context.Context, cmd *cobra.Command, client *desktop.Client, opts packageOptions) (*builderInitResult, error) { | ||
| result := &builderInitResult{} | ||
|
|
||
| if opts.fromModel != "" { | ||
|
|
@@ -238,10 +238,14 @@ func initializeBuilder(cmd *cobra.Command, opts packageOptions) (*builderInitRes | |
| // Package from existing model | ||
| cmd.PrintErrf("Reading model from store: %q\n", opts.fromModel) | ||
|
|
||
| // Get the model from the local store | ||
| mdl, err := distClient.GetModel(opts.fromModel) | ||
| if err != nil { | ||
| return nil, fmt.Errorf("get model from store: %w", err) | ||
| cmd.PrintErrf("Model not found in local store, fetching from daemon...\n") | ||
|
|
||
| mdl, result.distClient, result.cleanupFunc, err = fetchModelFromDaemon(ctx, cmd, client, opts.fromModel) | ||
| if err != nil { | ||
| return nil, fmt.Errorf("get model from store: %w", err) | ||
| } | ||
| } | ||
|
|
||
| // Type assert to ModelArtifact - the Model from store implements both interfaces | ||
|
|
@@ -306,7 +310,74 @@ func initializeBuilder(cmd *cobra.Command, opts packageOptions) (*builderInitRes | |
| return result, nil | ||
| } | ||
|
|
||
| func fetchModelFromDaemon(ctx context.Context, cmd *cobra.Command, client *desktop.Client, modelRef string) (types.Model, *distribution.Client, func(), error) { | ||
| exportReader, err := client.ExportModel(ctx, modelRef) | ||
| if err != nil { | ||
| return nil, nil, nil, fmt.Errorf("export model from daemon: %w", err) | ||
| } | ||
| defer exportReader.Close() | ||
|
|
||
| tempDir, err := os.MkdirTemp("", "docker-model-package-*") | ||
| if err != nil { | ||
| return nil, nil, nil, fmt.Errorf("create temp directory: %w", err) | ||
| } | ||
| cleanup := func() { | ||
| os.RemoveAll(tempDir) | ||
| } | ||
|
|
||
| tempClient, err := distribution.NewClient(distribution.WithStoreRootPath(tempDir)) | ||
| if err != nil { | ||
| cleanup() | ||
| return nil, nil, nil, fmt.Errorf("create temp distribution client: %w", err) | ||
| } | ||
|
Comment on lines
+328
to
+332
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The For example: var success bool
defer func() {
if !success {
cleanup()
}
}()
// ... your logic ...
success = true
return mdl, tempClient, cleanup, nil |
||
|
|
||
| cmd.PrintErrf("Loading model from daemon...\n") | ||
| modelID, err := tempClient.LoadModel(exportReader, nil) | ||
| if err != nil { | ||
| cleanup() | ||
| return nil, nil, nil, fmt.Errorf("load model into temp store: %w", err) | ||
| } | ||
|
|
||
| mdl, err := tempClient.GetModel(modelID) | ||
| if err != nil { | ||
| cleanup() | ||
| return nil, nil, nil, fmt.Errorf("get model from temp store: %w", err) | ||
| } | ||
|
|
||
| return mdl, tempClient, cleanup, nil | ||
| } | ||
|
|
||
| func packageModel(ctx context.Context, cmd *cobra.Command, client *desktop.Client, opts packageOptions) error { | ||
| // Use daemon-side repackaging for simple config-only changes (no new layers) | ||
| canUseDaemonRepackage := opts.fromModel != "" && | ||
| !opts.push && | ||
| len(opts.licensePaths) == 0 && | ||
| opts.chatTemplatePath == "" && | ||
| opts.mmprojPath == "" && | ||
| len(opts.dirTarPaths) == 0 && | ||
| cmd.Flags().Changed("context-size") | ||
|
|
||
| if canUseDaemonRepackage { | ||
| cmd.PrintErrf("Reading model from daemon: %q\n", opts.fromModel) | ||
| cmd.PrintErrf("Setting context size %d\n", opts.contextSize) | ||
| cmd.PrintErrln("Creating lightweight model variant...") | ||
|
|
||
| // Ensure standalone runner is available | ||
| if _, err := ensureStandaloneRunnerAvailable(ctx, asPrinter(cmd), false); err != nil { | ||
| return fmt.Errorf("unable to initialize standalone model runner: %w", err) | ||
| } | ||
|
|
||
| repackageOpts := desktop.RepackageOptions{ | ||
| ContextSize: &opts.contextSize, | ||
| } | ||
| if err := client.RepackageModel(ctx, opts.fromModel, opts.tag, repackageOpts); err != nil { | ||
| return fmt.Errorf("failed to create lightweight model: %w", err) | ||
| } | ||
|
|
||
| cmd.PrintErrln("Model variant created successfully") | ||
| return nil | ||
| } | ||
|
|
||
| var ( | ||
| target builder.Target | ||
| err error | ||
|
|
@@ -327,7 +398,7 @@ func packageModel(ctx context.Context, cmd *cobra.Command, client *desktop.Clien | |
| } | ||
|
|
||
| // Initialize the package builder based on model format | ||
| initResult, err := initializeBuilder(cmd, opts) | ||
| initResult, err := initializeBuilder(ctx, cmd, client, opts) | ||
| if err != nil { | ||
| return err | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -938,3 +938,66 @@ func (c *Client) LoadModel(ctx context.Context, r io.Reader) error { | |
| } | ||
| return nil | ||
| } | ||
|
|
||
| func (c *Client) ExportModel(ctx context.Context, model string) (io.ReadCloser, error) { | ||
| exportPath := fmt.Sprintf("%s/%s/export", inference.ModelsPrefix, model) | ||
| req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.modelRunner.URL(exportPath), http.NoBody) | ||
| if err != nil { | ||
| return nil, fmt.Errorf("failed to create request: %w", err) | ||
| } | ||
| req.Header.Set("User-Agent", "docker-model-cli/"+Version) | ||
|
|
||
| resp, err := c.modelRunner.Client().Do(req) | ||
| if err != nil { | ||
|
Comment on lines
942
to
951
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (bug_risk): ExportModel bypasses doRequestWithAuthContext, so it likely misses auth and standard desktop context headers. Since this bypasses Suggested implementation: resp, err := c.doRequestWithAuthContext(ctx, req)This change assumes that |
||
| return nil, c.handleQueryError(err, exportPath) | ||
| } | ||
|
|
||
| if resp.StatusCode == http.StatusNotFound { | ||
| resp.Body.Close() | ||
| return nil, errors.Wrap(ErrNotFound, model) | ||
| } | ||
| if resp.StatusCode != http.StatusOK { | ||
| body, _ := io.ReadAll(resp.Body) | ||
| resp.Body.Close() | ||
| return nil, fmt.Errorf("export failed with status %s: %s", resp.Status, string(body)) | ||
| } | ||
|
|
||
| return resp.Body, nil | ||
| } | ||
|
|
||
| type RepackageOptions struct { | ||
| ContextSize *uint64 `json:"context_size,omitempty"` | ||
| } | ||
|
Comment on lines
+968
to
+970
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
|
||
| func (c *Client) RepackageModel(ctx context.Context, source, target string, opts RepackageOptions) error { | ||
| repackagePath := fmt.Sprintf("%s/%s/repackage", inference.ModelsPrefix, source) | ||
|
|
||
| reqBody := struct { | ||
| Target string `json:"target"` | ||
| ContextSize *uint64 `json:"context_size,omitempty"` | ||
| }{ | ||
| Target: target, | ||
| ContextSize: opts.ContextSize, | ||
| } | ||
|
|
||
| jsonData, err := json.Marshal(reqBody) | ||
| if err != nil { | ||
| return fmt.Errorf("error marshaling request: %w", err) | ||
| } | ||
|
|
||
| resp, err := c.doRequestWithAuthContext(ctx, http.MethodPost, repackagePath, bytes.NewReader(jsonData)) | ||
| if err != nil { | ||
| return c.handleQueryError(err, repackagePath) | ||
| } | ||
| defer resp.Body.Close() | ||
|
|
||
| if resp.StatusCode == http.StatusNotFound { | ||
| return errors.Wrap(ErrNotFound, source) | ||
| } | ||
| if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { | ||
| body, _ := io.ReadAll(resp.Body) | ||
| return fmt.Errorf("repackage failed with status %s: %s", resp.Status, string(body)) | ||
| } | ||
|
|
||
| return nil | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ import ( | |
| "strings" | ||
|
|
||
| "github.com/docker/model-runner/pkg/distribution/huggingface" | ||
| "github.com/docker/model-runner/pkg/distribution/internal/mutate" | ||
| "github.com/docker/model-runner/pkg/distribution/internal/progress" | ||
| "github.com/docker/model-runner/pkg/distribution/internal/store" | ||
| "github.com/docker/model-runner/pkg/distribution/oci" | ||
|
|
@@ -615,6 +616,59 @@ func (c *Client) ResetStore() error { | |
| return nil | ||
| } | ||
|
|
||
| func (c *Client) ExportModel(reference string, w io.Writer) error { | ||
| c.log.Infoln("Exporting model:", utils.SanitizeForLog(reference)) | ||
| normalizedRef := c.normalizeModelName(reference) | ||
| mdl, err := c.store.Read(normalizedRef) | ||
| if err != nil { | ||
| c.log.Errorln("Failed to get model for export:", err, "reference:", utils.SanitizeForLog(reference)) | ||
| return fmt.Errorf("get model '%q': %w", utils.SanitizeForLog(reference), err) | ||
| } | ||
|
|
||
| target, err := tarball.NewTarget(w) | ||
| if err != nil { | ||
| return fmt.Errorf("create tarball target: %w", err) | ||
| } | ||
|
|
||
| if err := target.Write(context.Background(), mdl, nil); err != nil { | ||
| c.log.Errorln("Failed to export model:", err, "reference:", utils.SanitizeForLog(reference)) | ||
| return fmt.Errorf("export model: %w", err) | ||
| } | ||
|
|
||
| c.log.Infoln("Successfully exported model:", utils.SanitizeForLog(reference)) | ||
| return nil | ||
| } | ||
|
|
||
| type RepackageOptions struct { | ||
| ContextSize *uint64 | ||
| } | ||
|
Comment on lines
+642
to
+644
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| func (c *Client) RepackageModel(sourceRef string, targetRef string, opts RepackageOptions) error { | ||
| c.log.Infoln("Repackaging model:", utils.SanitizeForLog(sourceRef), "->", utils.SanitizeForLog(targetRef)) | ||
|
|
||
| normalizedSource := c.normalizeModelName(sourceRef) | ||
| normalizedTarget := c.normalizeModelName(targetRef) | ||
|
|
||
| mdl, err := c.store.Read(normalizedSource) | ||
| if err != nil { | ||
| c.log.Errorln("Failed to get model for repackaging:", err, "reference:", utils.SanitizeForLog(sourceRef)) | ||
| return fmt.Errorf("get model '%q': %w", utils.SanitizeForLog(sourceRef), err) | ||
| } | ||
|
|
||
| var modifiedModel types.ModelArtifact = mdl | ||
| if opts.ContextSize != nil { | ||
| modifiedModel = mutate.ContextSize(modifiedModel, int32(*opts.ContextSize)) | ||
| } | ||
|
|
||
| if err := c.store.WriteLightweight(modifiedModel, []string{normalizedTarget}); err != nil { | ||
| c.log.Errorln("Failed to write repackaged model:", err, "target:", utils.SanitizeForLog(targetRef)) | ||
| return fmt.Errorf("write repackaged model: %w", err) | ||
| } | ||
|
|
||
| c.log.Infoln("Successfully repackaged model:", utils.SanitizeForLog(sourceRef), "->", utils.SanitizeForLog(targetRef)) | ||
| return nil | ||
| } | ||
|
|
||
| // GetBundle returns a types.Bundle containing the model, creating one as necessary | ||
| func (c *Client) GetBundle(ref string) (types.ModelBundle, error) { | ||
| normalizedRef := c.normalizeModelName(ref) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -84,7 +84,7 @@ func (h *HTTPHandler) routeHandlers() map[string]http.HandlerFunc { | |
| "POST " + inference.ModelsPrefix + "/create": h.handleCreateModel, | ||
| "POST " + inference.ModelsPrefix + "/load": h.handleLoadModel, | ||
| "GET " + inference.ModelsPrefix: h.handleGetModels, | ||
| "GET " + inference.ModelsPrefix + "/{name...}": h.handleGetModel, | ||
| "GET " + inference.ModelsPrefix + "/{nameAndAction...}": h.handleModelGetAction, | ||
| "DELETE " + inference.ModelsPrefix + "/{name...}": h.handleDeleteModel, | ||
| "POST " + inference.ModelsPrefix + "/{nameAndAction...}": h.handleModelAction, | ||
| "DELETE " + inference.ModelsPrefix + "/purge": h.handlePurge, | ||
|
|
@@ -142,6 +142,35 @@ func (h *HTTPHandler) handleLoadModel(w http.ResponseWriter, r *http.Request) { | |
| } | ||
| } | ||
|
|
||
| func (h *HTTPHandler) handleModelGetAction(w http.ResponseWriter, r *http.Request) { | ||
| nameAndAction := r.PathValue("nameAndAction") | ||
| model, action := path.Split(nameAndAction) | ||
| model = strings.TrimRight(model, "/") | ||
|
|
||
| if action == "export" { | ||
| h.handleExportModel(w, r, model) | ||
| return | ||
| } | ||
|
|
||
| h.handleGetModelByRef(w, r, nameAndAction) | ||
| } | ||
|
|
||
| func (h *HTTPHandler) handleExportModel(w http.ResponseWriter, r *http.Request, modelRef string) { | ||
| w.Header().Set("Content-Type", "application/x-tar") | ||
| w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", modelRef+".tar")) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion: Using modelRef directly in the Content-Disposition filename can produce awkward or potentially confusing filenames with slashes/colons. Because Suggested implementation: func (h *HTTPHandler) handleExportModel(w http.ResponseWriter, r *http.Request, modelRef string) {
w.Header().Set("Content-Type", "application/x-tar")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", safeExportFilename(modelRef)))
err := h.manager.Export(modelRef, w)
if err != nil {
if errors.Is(err, distribution.ErrModelNotFound) {
http.Error(w, err.Error(), http.StatusNotFound)
return
}
h.log.Warnln("Error while exporting model:", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
// safeExportFilename derives a safe filename for the exported model tarball from a model reference.
// It keeps the full modelRef for export semantics, but ensures the header filename is UI/FS-friendly.
func safeExportFilename(modelRef string) string {
// Use the last path segment as the base name (e.g. "org/repo:tag" -> "repo:tag").
base := path.Base(modelRef)
// Normalize potentially confusing characters for filenames.
// Colons (used for tags) are replaced with '-', and any remaining slashes are also normalized.
base = strings.ReplaceAll(base, ":", "-")
base = strings.ReplaceAll(base, "/", "-")
// Fall back to a generic name if we somehow end up empty.
if base == "" || base == "." || base == "/" {
base = "model"
}
return base + ".tar"
}You’ll also need to update the imports at the top of
Make sure these are merged into the existing import block following the file’s import style (grouping with stdlib imports, etc.). |
||
|
|
||
| err := h.manager.Export(modelRef, w) | ||
| if err != nil { | ||
| if errors.Is(err, distribution.ErrModelNotFound) { | ||
| http.Error(w, err.Error(), http.StatusNotFound) | ||
| return | ||
| } | ||
| h.log.Warnln("Error while exporting model:", err) | ||
| http.Error(w, err.Error(), http.StatusInternalServerError) | ||
| return | ||
| } | ||
| } | ||
|
|
||
| // handleGetModels handles GET <inference-prefix>/models requests. | ||
| func (h *HTTPHandler) handleGetModels(w http.ResponseWriter, r *http.Request) { | ||
| apiModels, err := h.manager.List() | ||
|
|
@@ -160,7 +189,10 @@ func (h *HTTPHandler) handleGetModels(w http.ResponseWriter, r *http.Request) { | |
| // handleGetModel handles GET <inference-prefix>/models/{name} requests. | ||
| func (h *HTTPHandler) handleGetModel(w http.ResponseWriter, r *http.Request) { | ||
| modelRef := r.PathValue("name") | ||
| h.handleGetModelByRef(w, r, modelRef) | ||
| } | ||
|
|
||
| func (h *HTTPHandler) handleGetModelByRef(w http.ResponseWriter, r *http.Request, modelRef string) { | ||
| // Parse remote query parameter | ||
| remote := false | ||
| if r.URL.Query().Has("remote") { | ||
|
|
@@ -355,10 +387,8 @@ func (h *HTTPHandler) handleOpenAIGetModel(w http.ResponseWriter, r *http.Reques | |
| } | ||
| } | ||
|
|
||
| // handleTagModel handles POST <inference-prefix>/models/{nameAndAction} requests. | ||
| // Action is one of: | ||
| // - tag: tag the model with a repository and tag (e.g. POST <inference-prefix>/models/my-org/my-repo:latest/tag}) | ||
| // - push: pushes a tagged model to the registry | ||
| // handleModelAction handles POST <inference-prefix>/models/{nameAndAction} requests. | ||
| // Actions: tag, push, repackage | ||
| func (h *HTTPHandler) handleModelAction(w http.ResponseWriter, r *http.Request) { | ||
| model, action := path.Split(r.PathValue("nameAndAction")) | ||
| model = strings.TrimRight(model, "/") | ||
|
|
@@ -368,6 +398,8 @@ func (h *HTTPHandler) handleModelAction(w http.ResponseWriter, r *http.Request) | |
| h.handleTagModel(w, r, model) | ||
| case "push": | ||
| h.handlePushModel(w, r, model) | ||
| case "repackage": | ||
| h.handleRepackageModel(w, r, model) | ||
| default: | ||
| http.Error(w, fmt.Sprintf("unknown action %q", action), http.StatusNotFound) | ||
| } | ||
|
|
@@ -438,6 +470,49 @@ func (h *HTTPHandler) handlePushModel(w http.ResponseWriter, r *http.Request, mo | |
| } | ||
| } | ||
|
|
||
| type RepackageRequest struct { | ||
| Target string `json:"target"` | ||
| ContextSize *uint64 `json:"context_size,omitempty"` | ||
| } | ||
|
Comment on lines
+473
to
+476
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This |
||
|
|
||
| func (h *HTTPHandler) handleRepackageModel(w http.ResponseWriter, r *http.Request, model string) { | ||
| var req RepackageRequest | ||
| if err := json.NewDecoder(r.Body).Decode(&req); err != nil { | ||
| http.Error(w, "invalid request body: "+err.Error(), http.StatusBadRequest) | ||
| return | ||
| } | ||
|
|
||
| if req.Target == "" { | ||
| http.Error(w, "target is required", http.StatusBadRequest) | ||
| return | ||
| } | ||
|
|
||
| opts := RepackageOptions{ | ||
| ContextSize: req.ContextSize, | ||
| } | ||
|
|
||
| if err := h.manager.Repackage(model, req.Target, opts); err != nil { | ||
| if errors.Is(err, distribution.ErrModelNotFound) { | ||
| http.Error(w, err.Error(), http.StatusNotFound) | ||
| return | ||
| } | ||
| h.log.Warnf("Failed to repackage model %q: %v", utils.SanitizeForLog(model, -1), err) | ||
| http.Error(w, err.Error(), http.StatusInternalServerError) | ||
| return | ||
| } | ||
|
|
||
| w.Header().Set("Content-Type", "application/json") | ||
| w.WriteHeader(http.StatusCreated) | ||
| response := map[string]string{ | ||
| "message": fmt.Sprintf("Model repackaged successfully as %q", req.Target), | ||
| "source": model, | ||
| "target": req.Target, | ||
| } | ||
| if err := json.NewEncoder(w).Encode(response); err != nil { | ||
| h.log.Warnln("Error while encoding repackage response:", err) | ||
| } | ||
| } | ||
|
|
||
| // handlePurge handles DELETE <inference-prefix>/models/purge requests. | ||
| func (h *HTTPHandler) handlePurge(w http.ResponseWriter, _ *http.Request) { | ||
| err := h.manager.Purge() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -416,3 +416,23 @@ func (m *Manager) Purge() error { | |
| } | ||
| return nil | ||
| } | ||
|
|
||
| func (m *Manager) Export(ref string, w io.Writer) error { | ||
| if m.distributionClient == nil { | ||
| return fmt.Errorf("model distribution service unavailable") | ||
| } | ||
| return m.distributionClient.ExportModel(ref, w) | ||
| } | ||
|
|
||
| type RepackageOptions struct { | ||
| ContextSize *uint64 `json:"context_size,omitempty"` | ||
| } | ||
|
Comment on lines
+427
to
+429
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| func (m *Manager) Repackage(sourceRef string, targetRef string, opts RepackageOptions) error { | ||
| if m.distributionClient == nil { | ||
| return fmt.Errorf("model distribution service unavailable") | ||
| } | ||
| return m.distributionClient.RepackageModel(sourceRef, targetRef, distribution.RepackageOptions{ | ||
| ContextSize: opts.ContextSize, | ||
| }) | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (bug_risk): Original distClient cleanup is skipped when falling back to fetching from the daemon, which can leak temporary resources.
In the
opts.fromModelpath, whendistClient.GetModelfails and you switch tofetchModelFromDaemon, the originalcleanupFuncfromconstructDistClientis never called. If that client holds a temp dir or similar resources, they’ll leak in this path. Before overwritingresult.distClient/result.cleanupFunc, store and invoke the existing cleanup so those resources are released when falling back to the daemon.