Skip to content

Commit

Permalink
feat: checkpoint download through master (determined-ai#4989)
Browse files Browse the repository at this point in the history
Users do not always have access to checkpoint storage, and hence downloading
checkpoints with the CLI does not always work. This feature overcomes the
limitation by supporting checkpoint downloads through the master since t is
often easier to grant access to the master than to individual users.  This
feature only supports AWS S3 for now and relies on implicit IAM permissions on
the master.  In the future, it can be extended to support other checkpoint
storage types and take explicitly provided credentials.

A single endpoint is added:
- `/checkpoints/:checkpoint_uuid/`

The implementation downloads the checkpoint by making API calls to the S3
endpoint, packages it as a tgz or a zip file, and then sends the file to the
client in the HTTP response. The whole process operates as streaming. No
temporary files are stored or cached.
  • Loading branch information
hanyucui authored Sep 28, 2022
1 parent 2a16f81 commit 1d0963c
Show file tree
Hide file tree
Showing 13 changed files with 977 additions and 68 deletions.
3 changes: 2 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2182,7 +2182,8 @@ workflows:
test-intg:
jobs:
- test-intg-downstream
- test-intg-master
- test-intg-master:
context: storage-unit-tests
- test-intg-agent
- go-coverage:
requires:
Expand Down
3 changes: 3 additions & 0 deletions master/internal/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,9 @@ func (m *Master) Run(ctx context.Context) error {
experimentsGroup.PATCH("/:experiment_id", api.Route(m.patchExperiment))
experimentsGroup.POST("", api.Route(m.postExperiment))

checkpointsGroup := m.echo.Group("/checkpoints", authFuncs...)
checkpointsGroup.GET("/:checkpoint_uuid", m.getCheckpoint)

searcherGroup := m.echo.Group("/searcher", authFuncs...)
searcherGroup.POST("/preview", api.Route(m.getSearcherPreview))

Expand Down
157 changes: 157 additions & 0 deletions master/internal/core_checkpoint.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
package internal

import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"net/http"

"github.com/google/uuid"
"github.com/labstack/echo/v4"

"github.com/determined-ai/determined/master/pkg/checkpoints"
"github.com/determined-ai/determined/master/pkg/checkpoints/archive"

"github.com/determined-ai/determined/master/internal/api"
"github.com/determined-ai/determined/master/pkg/ptrs"
"github.com/determined-ai/determined/master/pkg/schemas/expconf"
)

const (
// MIMEApplicationGZip is GZip's MIME type.
MIMEApplicationGZip = "application/gzip"
// MIMEApplicationZip is Zip's MIME type.
MIMEApplicationZip = "application/zip"
)

func mimeToArchiveType(mimeType string) archive.ArchiveType {
switch mimeType {
case MIMEApplicationGZip:
return archive.ArchiveTgz
case MIMEApplicationZip:
return archive.ArchiveZip
default:
return archive.ArchiveUnknown
}
}

// Since Echo does not send an http status code until the first write to the ResponseWriter,
// we use delayWriter to buffer our writes, which effectively delays sending the status code
// until we are more confident the download will succeed. delayWriter wraps bufio.Writer
// and adds Close().
type delayWriter struct {
next *bufio.Writer
}

func (w *delayWriter) Write(p []byte) (int, error) {
return w.next.Write(p)
}

// Close flushes the buffer if it is nonempty.
func (w *delayWriter) Close() error {
return w.next.Flush()
}

func newDelayWriter(w io.Writer, delayBytes int) *delayWriter {
return &delayWriter{
next: bufio.NewWriterSize(w, delayBytes),
}
}

func (m *Master) getCheckpointStorageConfig(id uuid.UUID) (
*expconf.CheckpointStorageConfig, error) {
checkpoint, err := m.db.CheckpointByUUID(id)
if err != nil || checkpoint == nil {
return nil, err
}

bytes, err := json.Marshal(checkpoint.CheckpointTrainingMetadata.ExperimentConfig)
if err != nil {
return nil, err
}

legacyConfig, err := expconf.ParseLegacyConfigJSON(bytes)
if err != nil {
return nil, err
}

return ptrs.Ptr(legacyConfig.CheckpointStorage()), nil
}

func (m *Master) getCheckpointImpl(
ctx context.Context, id uuid.UUID, mimeType string, content io.Writer) error {
// Assume a checkpoint always has experiment configs
storageConfig, err := m.getCheckpointStorageConfig(id)
switch {
case err != nil:
return echo.NewHTTPError(http.StatusInternalServerError,
fmt.Sprintf("unable to retrieve experiment config for checkpoint %s: %s",
id.String(), err.Error()))
case storageConfig == nil:
return echo.NewHTTPError(http.StatusNotFound,
fmt.Sprintf("checkpoint %s does not exist", id.String()))
}

// DelayWriter delays the first write until we have successfully downloaded
// some bytes and are more confident that the download will succeed.
dw := newDelayWriter(content, 16*1024)
downloader, err := checkpoints.NewDownloader(
dw, id.String(), storageConfig, mimeToArchiveType(mimeType))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}

err = downloader.Download(ctx)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError,
fmt.Sprintf("unable to download checkpoint %s: %s", id.String(), err.Error()))
}

// Closing the writers will cause Echo to send a 200 response to the client. Hence we
// cannot use defer, and we close the writers only when there has been no error.
for _, v := range []io.Closer{downloader, dw} {
if err := v.Close(); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError,
fmt.Sprintf("failed to complete checkpoint download: %s", err.Error()))
}
}

return nil
}

// @Summary Get a checkpoint's contents in a tgz or zip file.
// @Tags Checkpoints
// @ID get-checkpoint
// @Accept json
// @Produce application/gzip,application/zip
// @Param checkpoint_uuid path string true "Checkpoint UUID"
// @Success 200 {} string ""
//nolint:godot
// @Router /checkpoints/{checkpoint_uuid} [get]
func (m *Master) getCheckpoint(c echo.Context) error {
// Get the MIME type. Only a single type is accepted.
mimeType := c.Request().Header.Get("Accept")
if mimeType != MIMEApplicationGZip &&
mimeType != MIMEApplicationZip {
return echo.NewHTTPError(http.StatusUnsupportedMediaType,
fmt.Sprintf("unsupported media type to download a checkpoint: '%s'", mimeType))
}

args := struct {
CheckpointUUID string `path:"checkpoint_uuid"`
}{}
if err := api.BindArgs(&args, c); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "invalid checkpoint_uuid: "+err.Error())
}
id, err := uuid.Parse(args.CheckpointUUID)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest,
fmt.Sprintf("unable to parse checkpoint UUID %s: %s",
args.CheckpointUUID, err))
}

c.Response().Header().Set(echo.HeaderContentType, mimeType)
return m.getCheckpointImpl(c.Request().Context(), id, mimeType, c.Response())
}
Loading

0 comments on commit 1d0963c

Please sign in to comment.