Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cmd/crane/cmd/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func NewCmdPull(options *[]crane.Option) *cobra.Command {
var (
cachePath, format string
annotateRef bool
resumable bool
)

cmd := &cobra.Command{
Expand All @@ -49,6 +50,10 @@ func NewCmdPull(options *[]crane.Option) *cobra.Command {
return fmt.Errorf("parsing reference %q: %w", src, err)
}

if resumable {
o.Remote = append(o.Remote, remote.WithResumable())
}

rmt, err := remote.Get(ref, o.Remote...)
if err != nil {
return err
Expand Down Expand Up @@ -133,6 +138,7 @@ func NewCmdPull(options *[]crane.Option) *cobra.Command {
cmd.Flags().StringVarP(&cachePath, "cache_path", "c", "", "Path to cache image layers")
cmd.Flags().StringVar(&format, "format", "tarball", fmt.Sprintf("Format in which to save images (%q, %q, or %q)", "tarball", "legacy", "oci"))
cmd.Flags().BoolVar(&annotateRef, "annotate-ref", false, "Preserves image reference used to pull as an annotation when used with --format=oci")
cmd.Flags().BoolVar(&resumable, "resumable", false, "Enable resumable transport for pulling images")

return cmd
}
44 changes: 44 additions & 0 deletions pkg/v1/remote/image_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package remote
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -747,3 +749,45 @@ func TestData(t *testing.T) {
t.Fatal(err)
}
}

func TestImageResumable(t *testing.T) {
ref, err := name.ParseReference("ghcr.io/labring/fastgpt:v4.9.0")
if err != nil {
t.Fatal(err)
}

image, err := Image(ref, WithResumable())
if err != nil {
t.Fatal(err)
}

layers, err := image.Layers()
if err != nil {
t.Fatal(err)
}

for _, layer := range layers {
digest, err := layer.Digest()
if err != nil {
t.Fatal(err)
}

rc, err := layer.Compressed()
if err != nil {
t.Fatal(err)
}

hash := sha256.New()
_, err = io.Copy(hash, rc)
rc.Close()
if err != nil {
t.Fatal(err)
}

if digest.Hex == hex.EncodeToString(hash.Sum(nil)) {
t.Logf("digest matches: %s", digest)
} else {
t.Errorf("digest mismatch: %s != %s", digest, hex.EncodeToString(hash.Sum(nil)))
}
}
}
13 changes: 13 additions & 0 deletions pkg/v1/remote/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ type options struct {
retryBackoff Backoff
retryPredicate retry.Predicate
retryStatusCodes []int
resumable bool

// Only these options can overwrite Reuse()d options.
platform v1.Platform
Expand Down Expand Up @@ -170,6 +171,11 @@ func makeOptions(opts ...Option) (*options, error) {

// Wrap the transport in something that can retry network flakes.
o.transport = transport.NewRetry(o.transport, transport.WithRetryBackoff(o.retryBackoff), transport.WithRetryPredicate(predicate), transport.WithRetryStatusCodes(o.retryStatusCodes...))

if o.resumable {
o.transport = transport.NewResumable(o.transport)
}

// Wrap this last to prevent transport.New from double-wrapping.
if o.userAgent != "" {
o.transport = transport.NewUserAgent(o.transport, o.userAgent)
Expand All @@ -192,6 +198,13 @@ func WithTransport(t http.RoundTripper) Option {
}
}

func WithResumable() Option {
return func(o *options) error {
o.resumable = true
return nil
}
}

// WithAuth is a functional option for overriding the default authenticator
// for remote operations.
// It is an error to use both WithAuth and WithAuthFromKeychain in the same Option set.
Expand Down
273 changes: 273 additions & 0 deletions pkg/v1/remote/transport/resumable.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
package transport

import (
"errors"
"fmt"
"io"
"net/http"
"regexp"
"strconv"
"strings"
"sync/atomic"

"github.com/google/go-containerregistry/pkg/logs"
)

// NewResumable creates a http.RoundTripper that resumes http GET from error,
// and the inner should be wrapped with retry transport, otherwise, the
// transport will abort if resume() returns error.
func NewResumable(inner http.RoundTripper) http.RoundTripper {
return &resumableTransport{inner: inner}
}

var (
contentRangeRe = regexp.MustCompile(`^bytes (\d+)-(\d+)/(\d+|\*)$`)
rangeRe = regexp.MustCompile(`bytes=(\d+)-(\d+)?`)
)

type resumableTransport struct {
inner http.RoundTripper
}

func (rt *resumableTransport) RoundTrip(in *http.Request) (resp *http.Response, err error) {
var total, start, end int64
// check initial request, maybe resumable transport is already enabled
if contentRange := in.Header.Get("Range"); contentRange != "" {
if matches := rangeRe.FindStringSubmatch(contentRange); len(matches) == 3 {
if start, err = strconv.ParseInt(matches[1], 10, 64); err != nil {
return nil, fmt.Errorf("invalid content range %q: %w", contentRange, err)
}

if len(matches[2]) == 0 {
// request whole file
end = -1
} else if end, err = strconv.ParseInt(matches[2], 10, 64); err == nil {
if start > end {
return nil, fmt.Errorf("invalid content range %q", contentRange)
}
} else {
return nil, fmt.Errorf("invalid content range %q: %w", contentRange, err)
}
}
}

if resp, err = rt.inner.RoundTrip(in); err != nil {
return resp, err
}

if in.Method != http.MethodGet {
return resp, nil
}

switch resp.StatusCode {
case http.StatusOK:
if end != 0 {
// request range content, but unexpected status code, cant not resume for this request
return resp, nil
}

total = resp.ContentLength
case http.StatusPartialContent:
// keep original response status code, which should be processed by original transport or operation
if start, _, total, err = parseContentRange(resp.Header.Get("Content-Range")); err != nil || total <= 0 {
return resp, nil
} else if end > 0 {
total = end + 1
}
default:
return resp, nil
}

if total > 0 {
resp.Body = &resumableBody{
rc: resp.Body,
inner: rt.inner,
req: in,
total: total,
transferred: start,
}
}

return resp, nil
}

type resumableBody struct {
rc io.ReadCloser

inner http.RoundTripper
req *http.Request

transferred int64
total int64

closed uint32
}

func (rb *resumableBody) Read(p []byte) (n int, err error) {
if atomic.LoadUint32(&rb.closed) == 1 {
// response body already closed
return 0, http.ErrBodyReadAfterClose
} else if rb.total >= 0 && rb.transferred >= rb.total {
return 0, io.EOF
}

for {
if n, err = rb.rc.Read(p); n > 0 {
if rb.transferred+int64(n) >= rb.total {
n = int(rb.total - rb.transferred)
err = io.EOF
}
rb.transferred += int64(n)
}

if err == nil {
return
}

if errors.Is(err, io.EOF) && rb.total >= 0 && rb.transferred >= rb.total {
return
}

if err = rb.resume(err); err == nil {
if n == 0 {
// zero bytes read, try reading again with new response.Body
continue
}

// already read some bytes from previous response.Body, returns and waits for next Read operation
}

return n, err
}
}

func (rb *resumableBody) Close() (err error) {
if !atomic.CompareAndSwapUint32(&rb.closed, 0, 1) {
return nil
}

return rb.rc.Close()
}

func (rb *resumableBody) resume(reason error) error {
if reason != nil {
logs.Debug.Printf("Resume http transporting from error: %v", reason)
}

ctx := rb.req.Context()
select {
case <-ctx.Done():
// context already done, stop resuming from error
return ctx.Err()
default:
}

req := rb.req.Clone(ctx)
req.Header.Set("Range", "bytes="+strconv.FormatInt(rb.transferred, 10)+"-")
resp, err := rb.inner.RoundTrip(req)
if err != nil {
return err
}

if err = rb.validate(resp); err != nil {
resp.Body.Close()
// wraps original error
return fmt.Errorf("%w, %v", reason, err)
}

if atomic.LoadUint32(&rb.closed) == 1 {
resp.Body.Close()
return http.ErrBodyReadAfterClose
}

rb.rc.Close()
rb.rc = resp.Body

return nil
}

const size100m = 100 << 20

func (rb *resumableBody) validate(resp *http.Response) (err error) {
var start, total int64
switch resp.StatusCode {
case http.StatusPartialContent:
// donot using total size from Content-Range header, keep rb.total unchanged
if start, _, _, err = parseContentRange(resp.Header.Get("Content-Range")); err != nil {
return err
}

if start == rb.transferred {
break
} else if start < rb.transferred {
// incoming data is overlapped for somehow, just discard it
if _, err := io.CopyN(io.Discard, resp.Body, rb.transferred-start); err != nil {
return fmt.Errorf("discard overlapped data failed, %v", err)
}
} else {
return fmt.Errorf("unexpected resume start %d, wanted: %d", start, rb.transferred)
}
case http.StatusOK:
if rb.transferred > 0 {
// range is not supported, and transferred data is too large, stop resuming
if rb.transferred > size100m {
return fmt.Errorf("too large data transferred: %d", rb.transferred)
}

// try resume from unsupported range request
if _, err = io.CopyN(io.Discard, resp.Body, rb.transferred); err != nil {
return err
}
}
case http.StatusRequestedRangeNotSatisfiable:
if contentRange := resp.Header.Get("Content-Range"); contentRange != "" && strings.HasPrefix(contentRange, "bytes */") {
if total, err = strconv.ParseInt(strings.TrimPrefix(contentRange, "bytes */"), 10, 64); err == nil && total >= 0 && rb.transferred >= total {
return io.EOF
}
}

fallthrough
default:
return fmt.Errorf("unexpected status code %d", resp.StatusCode)
}

return nil
}

func parseContentRange(contentRange string) (start, end, size int64, err error) {
if contentRange == "" {
return -1, -1, -1, errors.New("unexpected empty content range")
}

matches := contentRangeRe.FindStringSubmatch(contentRange)
if len(matches) != 4 {
return -1, -1, -1, fmt.Errorf("invalid content range: %s", contentRange)
}

if start, err = strconv.ParseInt(matches[1], 10, 64); err != nil {
return -1, -1, -1, fmt.Errorf("unexpected start from content range '%s', %v", contentRange, err)
}

if end, err = strconv.ParseInt(matches[2], 10, 64); err != nil {
return -1, -1, -1, fmt.Errorf("unexpected end from content range '%s', %v", contentRange, err)
}

if start > end {
return -1, -1, -1, fmt.Errorf("invalid content range: %s", contentRange)
}

if matches[3] == "*" {
size = -1
} else {
size, err = strconv.ParseInt(matches[3], 10, 64)
if err != nil {
return -1, -1, -1, fmt.Errorf("unexpected total from content range '%s', %v", contentRange, err)
}

if end >= size {
return -1, -1, -1, fmt.Errorf("invalid content range: %s", contentRange)
}
}

return
}
Loading