Skip to content

Commit

Permalink
Apply retry logics in launcher
Browse files Browse the repository at this point in the history
  • Loading branch information
yawangwang committed Nov 4, 2024
1 parent c891518 commit b462410
Show file tree
Hide file tree
Showing 5 changed files with 302 additions and 7 deletions.
40 changes: 36 additions & 4 deletions launcher/container_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,12 @@ func getSignatureDiscoveryClient(cdClient *containerd.Client, mdsClient *metadat
return registryauth.RefreshResolver(ctx, mdsClient)
}
imageFetcher := func(ctx context.Context, imageRef string, opts ...containerd.RemoteOpt) (containerd.Image, error) {
image, err := cdClient.Pull(ctx, imageRef, opts...)
image, err := pullImageWithRetries(
func() (containerd.Image, error) {
return cdClient.Pull(ctx, imageRef, opts...)
},
pullImageBackoffPolicy,
)
if err != nil {
return nil, fmt.Errorf("cannot pull signature objects from the signature image [%s]: %w", imageRef, err)
}
Expand Down Expand Up @@ -529,6 +534,11 @@ func defaultRetryPolicy() *backoff.ExponentialBackOff {
return expBack
}

func pullImageBackoffPolicy() backoff.BackOff {
b := backoff.NewConstantBackOff(time.Millisecond * 500)
return backoff.WithMaxRetries(b, 3)
}

// Run the container
// Container output will always be redirected to logger writer for now
func (r *ContainerRunner) Run(ctx context.Context) error {
Expand Down Expand Up @@ -621,17 +631,39 @@ func (r *ContainerRunner) Run(ctx context.Context) error {
return nil
}

func pullImageWithRetries(f func() (containerd.Image, error), retry func() backoff.BackOff) (containerd.Image, error) {
var err error
var image containerd.Image
err = backoff.Retry(func() error {
image, err = f()
return err
}, retry())
if err != nil {
return nil, fmt.Errorf("failed to pull image with retries, the last error is: %w", err)
}
return image, nil
}

func initImage(ctx context.Context, cdClient *containerd.Client, launchSpec spec.LaunchSpec, token oauth2.Token) (containerd.Image, error) {
if token.Valid() {
remoteOpt := containerd.WithResolver(registryauth.Resolver(token.AccessToken))

image, err := cdClient.Pull(ctx, launchSpec.ImageRef, containerd.WithPullUnpack, remoteOpt)
image, err := pullImageWithRetries(
func() (containerd.Image, error) {
return cdClient.Pull(ctx, launchSpec.ImageRef, containerd.WithPullUnpack, remoteOpt)
},
pullImageBackoffPolicy,
)
if err != nil {
return nil, fmt.Errorf("cannot pull the image: %w", err)
}
return image, nil
}
image, err := cdClient.Pull(ctx, launchSpec.ImageRef, containerd.WithPullUnpack)
image, err := pullImageWithRetries(
func() (containerd.Image, error) {
return cdClient.Pull(ctx, launchSpec.ImageRef, containerd.WithPullUnpack)
},
pullImageBackoffPolicy,
)
if err != nil {
return nil, fmt.Errorf("cannot pull the image (no token, only works for a public image): %w", err)
}
Expand Down
51 changes: 51 additions & 0 deletions launcher/container_runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,57 @@ func TestMeasureCELEvents(t *testing.T) {
}
}

func TestPullImageWithRetries(t *testing.T) {
testCases := []struct {
name string
imagePuller func(int) (containerd.Image, error)
wantPass bool
}{
{
name: "success with single attempt",
imagePuller: func(int) (containerd.Image, error) { return &fakeImage{}, nil },
wantPass: true,
},
{
name: "failure then success",
imagePuller: func(attempts int) (containerd.Image, error) {
if attempts%2 == 1 {
return nil, errors.New("fake error")
}
return &fakeImage{}, nil
},
wantPass: true,
},
{
name: "failure with attempts exceeded",
imagePuller: func(attempts int) (containerd.Image, error) {

Check warning on line 664 in launcher/container_runner_test.go

View workflow job for this annotation

GitHub Actions / Lint ./launcher (ubuntu-latest, Go 1.21.x)

unused-parameter: parameter 'attempts' seems to be unused, consider removing or renaming it as _ (revive)
return nil, errors.New("fake error")
},
wantPass: false,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
retryPolicy := func() backoff.BackOff {
b := backoff.NewExponentialBackOff()
return backoff.WithMaxRetries(b, 2)
}

attempts := 0
_, err := pullImageWithRetries(
func() (containerd.Image, error) {
attempts++
return tc.imagePuller(attempts)
},
retryPolicy)
if gotPass := (err == nil); gotPass != tc.wantPass {
t.Errorf("pullImageWithRetries failed, got %v, but want %v", gotPass, tc.wantPass)
}
})
}
}

// This ensures fakeContainer implements containerd.Container interface.
var _ containerd.Container = &fakeContainer{}

Expand Down
12 changes: 9 additions & 3 deletions verifier/rest/rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@ import (
"fmt"
"log"
"strings"
"time"

"github.com/cenkalti/backoff/v4"
sabi "github.com/google/go-sev-guest/abi"
"github.com/google/go-sev-guest/proto/sevsnp"
tabi "github.com/google/go-tdx-guest/abi"
"github.com/google/go-tdx-guest/proto/tdx"
"github.com/google/go-tpm-tools/verifier"
"github.com/google/go-tpm-tools/verifier/oci"

v1 "cloud.google.com/go/confidentialcomputing/apiv1"
confidentialcomputingpb "cloud.google.com/go/confidentialcomputing/apiv1/confidentialcomputingpb"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
Expand All @@ -42,11 +43,16 @@ func (e *BadRegionError) Unwrap() error {
return e.err
}

func confComputeBackoffPolicy() backoff.BackOff {
b := backoff.NewConstantBackOff(time.Millisecond * 500)
return backoff.WithMaxRetries(b, 3)
}

// NewClient creates a new REST client which is configured to perform
// attestations in a particular project and region. Returns a *BadRegionError
// if the requested project is valid, but the region is invalid.
func NewClient(ctx context.Context, projectID string, region string, opts ...option.ClientOption) (verifier.Client, error) {
client, err := v1.NewRESTClient(ctx, opts...)
client, err := NewRetryableClient(ctx, confComputeBackoffPolicy, opts...)
if err != nil {
return nil, fmt.Errorf("can't create ConfidentialComputing v1 API client: %w", err)
}
Expand Down Expand Up @@ -89,7 +95,7 @@ func NewClient(ctx context.Context, projectID string, region string, opts ...opt
}

type restClient struct {
v1Client *v1.Client
v1Client *retryableClient
location *locationpb.Location
}

Expand Down
81 changes: 81 additions & 0 deletions verifier/rest/retry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package rest

import (
"github.com/googleapis/gax-go/v2"

v1 "cloud.google.com/go/confidentialcomputing/apiv1"
confidentialcomputingpb "cloud.google.com/go/confidentialcomputing/apiv1/confidentialcomputingpb"
"context"
"fmt"
"github.com/cenkalti/backoff/v4"
"google.golang.org/api/option"
locationpb "google.golang.org/genproto/googleapis/cloud/location"
)

// retryableClient is a thin wrapper around confidential computing APIs with backoff policies.
type retryableClient struct {
getLocation func(context.Context, *locationpb.GetLocationRequest, ...gax.CallOption) (*locationpb.Location, error)
listLocations func(context.Context, *locationpb.ListLocationsRequest, ...gax.CallOption) *v1.LocationIterator
createChallenge func(context.Context, *confidentialcomputingpb.CreateChallengeRequest, ...gax.CallOption) (*confidentialcomputingpb.Challenge, error)
verifyAttestation func(context.Context, *confidentialcomputingpb.VerifyAttestationRequest, ...gax.CallOption) (*confidentialcomputingpb.VerifyAttestationResponse, error)
backoffPolicy func() backoff.BackOff
}

// NewRetryableCleint creates a new retryable client.
func NewRetryableClient(ctx context.Context, retry func() backoff.BackOff, opts ...option.ClientOption) (*retryableClient, error) {
client, err := v1.NewRESTClient(ctx, opts...)
if err != nil {
return nil, err
}
return &retryableClient{
getLocation: client.GetLocation,
listLocations: client.ListLocations,
verifyAttestation: client.VerifyAttestation,
createChallenge: client.CreateChallenge,
backoffPolicy: retry,
}, nil
}

// GetLocation is a thin wrapper of calling confidential computing GetLocation API with retries.
func (c *retryableClient) GetLocation(ctx context.Context, req *locationpb.GetLocationRequest, opts ...gax.CallOption) (*locationpb.Location, error) {
var location *locationpb.Location
err := backoff.Retry(
func() error {
var err error
location, err = c.getLocation(ctx, req, opts...)
return err
},
c.backoffPolicy(),
)
if err != nil {
return nil, fmt.Errorf("failed to GetLocation with retries, the last error is: %w", err)
}
return location, nil
}

// ListLocations calls the underlying confidential computing ListLocations API with no-op.
func (c *retryableClient) ListLocations(ctx context.Context, req *locationpb.ListLocationsRequest, opts ...gax.CallOption) *v1.LocationIterator {
return c.listLocations(ctx, req, opts...)
}

// CreateChallenge calls the underlying confidential computing CreateChallenge API with no-op.
func (c *retryableClient) CreateChallenge(ctx context.Context, req *confidentialcomputingpb.CreateChallengeRequest, opts ...gax.CallOption) (*confidentialcomputingpb.Challenge, error) {
return c.createChallenge(ctx, req, opts...)
}

// VerifyAttestation is a thin wrapper of calling confidential computing VerifyAttestation API with retries.
func (c *retryableClient) VerifyAttestation(ctx context.Context, req *confidentialcomputingpb.VerifyAttestationRequest, opts ...gax.CallOption) (*confidentialcomputingpb.VerifyAttestationResponse, error) {
var response *confidentialcomputingpb.VerifyAttestationResponse
err := backoff.Retry(
func() error {
var err error
response, err = c.verifyAttestation(ctx, req, opts...)
return err
},
c.backoffPolicy(),
)
if err != nil {
return nil, fmt.Errorf("failed to VerifyAttestation with retries, the last error is: %w", err)
}
return response, nil
}
125 changes: 125 additions & 0 deletions verifier/rest/retry_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package rest

import (
"context"
"errors"
"testing"

"cloud.google.com/go/confidentialcomputing/apiv1/confidentialcomputingpb"
"github.com/cenkalti/backoff/v4"
"github.com/googleapis/gax-go/v2"

locationpb "google.golang.org/genproto/googleapis/cloud/location"
)

func TestGetLocationWithRetry(t *testing.T) {
testCases := []struct {
name string
getLocation func(int) (*locationpb.Location, error)
wantPass bool
}{
{
name: "success with single attempt",
getLocation: func(int) (*locationpb.Location, error) {
return &locationpb.Location{}, nil
},
wantPass: true,
},
{
name: "failure then success",
getLocation: func(attempts int) (*locationpb.Location, error) {
if attempts%2 == 1 {
return nil, errors.New("fake error")
}
return &locationpb.Location{}, nil
},
wantPass: true,
},
{
name: "failure with attempts exceeded",
getLocation: func(attempts int) (*locationpb.Location, error) {
return nil, errors.New("fake error")
},
wantPass: false,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
backoffPolicy := func() backoff.BackOff {
b := backoff.NewExponentialBackOff()
return backoff.WithMaxRetries(b, 2)
}
attempts := 0
fakeClient := &retryableClient{
getLocation: func(context.Context, *locationpb.GetLocationRequest, ...gax.CallOption) (*locationpb.Location, error) {
attempts++
return tc.getLocation(attempts)
},
backoffPolicy: backoffPolicy,
}

_, err := fakeClient.GetLocation(context.Background(), &locationpb.GetLocationRequest{})

if gotPass := (err == nil); gotPass != tc.wantPass {
t.Errorf("GetLocation retry failed, got %v, but want %v", gotPass, tc.wantPass)
}
})
}
}

func TestVerifyAttestationWithRetry(t *testing.T) {
testCases := []struct {
name string
verifyAttestation func(int) (*confidentialcomputingpb.VerifyAttestationResponse, error)
wantPass bool
}{
{
name: "success with single attempt",
verifyAttestation: func(int) (*confidentialcomputingpb.VerifyAttestationResponse, error) {
return &confidentialcomputingpb.VerifyAttestationResponse{}, nil
},
wantPass: true,
},
{
name: "failure then success",
verifyAttestation: func(attempts int) (*confidentialcomputingpb.VerifyAttestationResponse, error) {
if attempts%2 == 1 {
return nil, errors.New("fake error")
}
return &confidentialcomputingpb.VerifyAttestationResponse{}, nil
},
wantPass: true,
},
{
name: "failure with attempts exceeded",
verifyAttestation: func(attempts int) (*confidentialcomputingpb.VerifyAttestationResponse, error) {
return nil, errors.New("fake error")
},
wantPass: false,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
backoffPolicy := func() backoff.BackOff {
b := backoff.NewExponentialBackOff()
return backoff.WithMaxRetries(b, 2)
}
attempts := 0
fakeClient := &retryableClient{
verifyAttestation: func(context.Context, *confidentialcomputingpb.VerifyAttestationRequest, ...gax.CallOption) (*confidentialcomputingpb.VerifyAttestationResponse, error) {
attempts++
return tc.verifyAttestation(attempts)
},
backoffPolicy: backoffPolicy,
}

_, err := fakeClient.VerifyAttestation(context.Background(), &confidentialcomputingpb.VerifyAttestationRequest{})

if gotPass := (err == nil); gotPass != tc.wantPass {
t.Errorf("VerifyAttestation retry failed, got %v, but want %v", gotPass, tc.wantPass)
}
})
}
}

0 comments on commit b462410

Please sign in to comment.