Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply retry logics in confidential computing API + workload image puller #511

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 187 additions & 0 deletions go.work.sum

Large diffs are not rendered by default.

48 changes: 46 additions & 2 deletions launcher/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"bytes"
"context"
"crypto"
"errors"
"fmt"
"io"
"net/http"
Expand All @@ -24,11 +25,15 @@ import (
pb "github.com/google/go-tpm-tools/proto/attest"
"github.com/google/go-tpm-tools/verifier"
"github.com/google/go-tpm-tools/verifier/oci"
"github.com/google/go-tpm-tools/verifier/rest"
"github.com/google/go-tpm-tools/verifier/util"
)

var defaultCELHashAlgo = []crypto.Hash{crypto.SHA256, crypto.SHA1}

// attestFunc is used for doAttest indirectly so that unit tests can stub it.
var attestFunc = doAttest

type principalIDTokenFetcher func(audience string) ([][]byte, error)

// AttestationAgent is an agent that interacts with GCE's Attestation Service
Expand Down Expand Up @@ -101,10 +106,49 @@ func (a *agent) MeasureEvent(event cel.Content) error {
return a.cosCel.AppendEvent(a.tpm, cel.CosEventPCR, defaultCELHashAlgo, event)
}

// Attest fetches the nonce and connection ID from the Attestation Service,
// Attest is a thin wrapper of AttestWithRetries with defaultRetryPolicy.
func (a *agent) Attest(ctx context.Context, opts AttestAgentOpts) ([]byte, error) {
return a.AttestWithRetries(ctx, opts, defaultRetryPolicy)
}

// Attest executes doAttest with retries when 500 errors originate from VerifyAttestation API.
func (a *agent) AttestWithRetries(ctx context.Context, opts AttestAgentOpts, retry func() backoff.BackOff) ([]byte, error) {
var token []byte
var err error

retryErr := backoff.Retry(
func() error {
var doErr error
token, doErr = attestFunc(ctx, a, opts)
var verifyErr *rest.VerifyAttestationError
// Retry for VerifyAttestation 500 errors.
if errors.As(doErr, &verifyErr) && verifyErr.StatusCode() == http.StatusInternalServerError {
return verifyErr
}

// Otherwise, save the error and exit the retry.
err = doErr
return nil
},
retry(),
)

// If retryErr is set, it means we retried maxAttempts for VerifyAttestation without success.
if retryErr != nil {
err = retryErr
}

if err != nil {
return nil, err
}

return token, nil
}

// doAttest fetches the nonce and connection ID from the Attestation Service,
// creates an attestation message, and returns the resultant
// principalIDTokens and Metadata Server-generated ID tokens for the instance.
func (a *agent) Attest(ctx context.Context, opts AttestAgentOpts) ([]byte, error) {
func doAttest(ctx context.Context, a *agent, opts AttestAgentOpts) ([]byte, error) {
challenge, err := a.client.CreateChallenge(ctx)
if err != nil {
return nil, err
Expand Down
75 changes: 75 additions & 0 deletions launcher/agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/base64"
"fmt"
"math"
"net/http"
"runtime"
"sync"
"testing"
Expand All @@ -28,6 +29,7 @@ import (
"github.com/google/go-tpm-tools/verifier/oci/cosign"
"github.com/google/go-tpm-tools/verifier/rest"
"golang.org/x/oauth2/google"
"google.golang.org/api/googleapi"
"google.golang.org/api/option"
"google.golang.org/protobuf/encoding/protojson"
)
Expand Down Expand Up @@ -78,6 +80,79 @@ func TestAttestRacing(t *testing.T) {
agent.Close()
}

func TestAttestWithRetries(t *testing.T) {
testCases := []struct {
name string
fn func(int) ([]byte, error)
wantPass bool
wantAttempts int
}{
{
name: "success",
fn: func(int) ([]byte, error) {
return []byte("test token"), nil
},
wantPass: true,
wantAttempts: 1,
},
{
name: "failed with 500, then success",
fn: func(attempts int) ([]byte, error) {
if attempts == 1 {
return nil, rest.NewVerifyAttestationError(nil, &googleapi.Error{Code: http.StatusInternalServerError})
}
return []byte("test token"), nil
},
wantPass: true,
wantAttempts: 2,
},
{
name: "failed with 500 after attempts exceed",
fn: func(int) ([]byte, error) {
return nil, rest.NewVerifyAttestationError(nil, &googleapi.Error{Code: http.StatusInternalServerError})
},
wantPass: false,
wantAttempts: 4,
},
{
name: "failed with non-500 error",
fn: func(int) ([]byte, error) {
return nil, fmt.Errorf("other error")
},
wantPass: false,
wantAttempts: 1,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Reset stub after test case is done.
af := attestFunc
t.Cleanup(func() { attestFunc = af })

attempts := 0
// Stub attestFunc.
attestFunc = func(context.Context, *agent, AttestAgentOpts) ([]byte, error) {
attempts++
return tc.fn(attempts)
}

a := &agent{}
testRetryPolicy := func() backoff.BackOff {
return backoff.WithMaxRetries(backoff.NewConstantBackOff(time.Millisecond), 3)
}
_, err := a.AttestWithRetries(context.Background(), AttestAgentOpts{}, testRetryPolicy)
if gotPass := (err == nil); gotPass != tc.wantPass {
t.Errorf("AttestWithRetries failed, gotPass %v, but wantPass %v", gotPass, tc.wantPass)
}

if gotAttempts := attempts; gotAttempts != tc.wantAttempts {
t.Errorf("AttestWithRetries failed, gotAttempts %v, but wantAttempts %v", gotAttempts, tc.wantAttempts)
}
})
}
}

func TestAttest(t *testing.T) {
ctx := context.Background()
testCases := []struct {
Expand Down
40 changes: 36 additions & 4 deletions launcher/container_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,12 @@ func getSignatureDiscoveryClient(cdClient *containerd.Client, mdsClient *metadat
return registryauth.RefreshResolver(ctx, mdsClient)
}
imageFetcher := func(ctx context.Context, imageRef string, opts ...containerd.RemoteOpt) (containerd.Image, error) {
image, err := cdClient.Pull(ctx, imageRef, opts...)
image, err := pullImageWithRetries(
func() (containerd.Image, error) {
return cdClient.Pull(ctx, imageRef, opts...)
},
pullImageBackoffPolicy,
)
if err != nil {
return nil, fmt.Errorf("cannot pull signature objects from the signature image [%s]: %w", imageRef, err)
}
Expand Down Expand Up @@ -529,6 +534,11 @@ func defaultRetryPolicy() *backoff.ExponentialBackOff {
return expBack
}

func pullImageBackoffPolicy() backoff.BackOff {
b := backoff.NewConstantBackOff(time.Millisecond * 500)
return backoff.WithMaxRetries(b, 3)
}

// Run the container
// Container output will always be redirected to logger writer for now
func (r *ContainerRunner) Run(ctx context.Context) error {
Expand Down Expand Up @@ -621,17 +631,39 @@ func (r *ContainerRunner) Run(ctx context.Context) error {
return nil
}

func pullImageWithRetries(f func() (containerd.Image, error), retry func() backoff.BackOff) (containerd.Image, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This retry method seems very generic, with only the return in the func being passed being specific to pulling images.
Is it possible to make this more generic? Or less generic?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intention is to make specific to pulling images. The reason I abstract this method is to make unit tests easier so that I can mock containerd API. Since cotainerd will make a network call to an actual docker registry, I want to avoid it in unit tests.

var err error
var image containerd.Image
err = backoff.Retry(func() error {
image, err = f()
return err
}, retry())
if err != nil {
return nil, fmt.Errorf("failed to pull image with retries, the last error is: %w", err)
}
return image, nil
}

func initImage(ctx context.Context, cdClient *containerd.Client, launchSpec spec.LaunchSpec, token oauth2.Token) (containerd.Image, error) {
if token.Valid() {
remoteOpt := containerd.WithResolver(registryauth.Resolver(token.AccessToken))

image, err := cdClient.Pull(ctx, launchSpec.ImageRef, containerd.WithPullUnpack, remoteOpt)
image, err := pullImageWithRetries(
func() (containerd.Image, error) {
return cdClient.Pull(ctx, launchSpec.ImageRef, containerd.WithPullUnpack, remoteOpt)
},
pullImageBackoffPolicy,
)
if err != nil {
return nil, fmt.Errorf("cannot pull the image: %w", err)
}
return image, nil
}
image, err := cdClient.Pull(ctx, launchSpec.ImageRef, containerd.WithPullUnpack)
image, err := pullImageWithRetries(
func() (containerd.Image, error) {
return cdClient.Pull(ctx, launchSpec.ImageRef, containerd.WithPullUnpack)
},
pullImageBackoffPolicy,
)
if err != nil {
return nil, fmt.Errorf("cannot pull the image (no token, only works for a public image): %w", err)
}
Expand Down
51 changes: 51 additions & 0 deletions launcher/container_runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,57 @@ func TestMeasureCELEvents(t *testing.T) {
}
}

func TestPullImageWithRetries(t *testing.T) {
testCases := []struct {
name string
imagePuller func(int) (containerd.Image, error)
wantPass bool
}{
{
name: "success with single attempt",
imagePuller: func(int) (containerd.Image, error) { return &fakeImage{}, nil },
wantPass: true,
},
{
name: "failure then success",
imagePuller: func(attempts int) (containerd.Image, error) {
if attempts%2 == 1 {
return nil, errors.New("fake error")
}
return &fakeImage{}, nil
},
wantPass: true,
},
{
name: "failure with attempts exceeded",
imagePuller: func(int) (containerd.Image, error) {
return nil, errors.New("fake error")
},
wantPass: false,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
retryPolicy := func() backoff.BackOff {
b := backoff.NewExponentialBackOff()
return backoff.WithMaxRetries(b, 2)
}

attempts := 0
_, err := pullImageWithRetries(
func() (containerd.Image, error) {
attempts++
return tc.imagePuller(attempts)
},
retryPolicy)
if gotPass := (err == nil); gotPass != tc.wantPass {
t.Errorf("pullImageWithRetries failed, got %v, but want %v", gotPass, tc.wantPass)
}
})
}
}

// This ensures fakeContainer implements containerd.Container interface.
var _ containerd.Container = &fakeContainer{}

Expand Down
52 changes: 26 additions & 26 deletions launcher/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ module github.com/google/go-tpm-tools/launcher
go 1.21

require (
cloud.google.com/go/compute/metadata v0.5.0
cloud.google.com/go/logging v1.10.0
cloud.google.com/go/compute/metadata v0.5.2
cloud.google.com/go/logging v1.12.0
github.com/cenkalti/backoff/v4 v4.2.1
github.com/containerd/containerd v1.7.16
github.com/coreos/go-systemd/v22 v22.5.0
Expand All @@ -16,18 +16,18 @@ require (
github.com/opencontainers/go-digest v1.0.0
github.com/opencontainers/image-spec v1.1.0
github.com/opencontainers/runtime-spec v1.1.0
golang.org/x/oauth2 v0.21.0
google.golang.org/api v0.189.0
google.golang.org/genproto/googleapis/api v0.0.0-20240722135656-d784300faade
google.golang.org/protobuf v1.34.2
golang.org/x/oauth2 v0.23.0
google.golang.org/api v0.205.0
google.golang.org/genproto/googleapis/api v0.0.0-20241015192408-796eee8c2d53
google.golang.org/protobuf v1.35.1
)

require (
cloud.google.com/go v0.115.0 // indirect
cloud.google.com/go/auth v0.7.2 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.3 // indirect
cloud.google.com/go v0.116.0 // indirect
cloud.google.com/go/auth v0.10.1 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.5 // indirect
cloud.google.com/go/confidentialcomputing v1.6.0 // indirect
cloud.google.com/go/longrunning v0.5.9 // indirect
cloud.google.com/go/longrunning v0.6.1 // indirect
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 // indirect
github.com/AdamKorcz/go-118-fuzz-build v0.0.0-20230306123547-8075edf89bb0 // indirect
github.com/Microsoft/go-winio v0.6.1 // indirect
Expand Down Expand Up @@ -55,9 +55,9 @@ require (
github.com/google/go-tdx-guest v0.3.2-0.20240902060211-1f7f7b9b42b9 // indirect
github.com/google/go-tspi v0.3.0 // indirect
github.com/google/logger v1.1.1 // indirect
github.com/google/s2a-go v0.1.7 // indirect
github.com/google/s2a-go v0.1.8 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
github.com/googleapis/gax-go/v2 v2.13.0 // indirect
github.com/klauspost/compress v1.16.7 // indirect
github.com/moby/locker v1.0.1 // indirect
Expand All @@ -70,24 +70,24 @@ require (
github.com/pkg/errors v0.9.1 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
go.opentelemetry.io/otel v1.24.0 // indirect
go.opentelemetry.io/otel/metric v1.24.0 // indirect
go.opentelemetry.io/otel/trace v1.24.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect
go.opentelemetry.io/otel v1.29.0 // indirect
go.opentelemetry.io/otel/metric v1.29.0 // indirect
go.opentelemetry.io/otel/trace v1.29.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/crypto v0.25.0 // indirect
golang.org/x/crypto v0.28.0 // indirect
golang.org/x/exp v0.0.0-20240531132922-fd00a4e0eefc // indirect
golang.org/x/mod v0.17.0 // indirect
golang.org/x/net v0.27.0 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.22.0 // indirect
golang.org/x/text v0.16.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.org/x/net v0.30.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.26.0 // indirect
golang.org/x/text v0.19.0 // indirect
golang.org/x/time v0.7.0 // indirect
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
google.golang.org/genproto v0.0.0-20240722135656-d784300faade // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240722135656-d784300faade // indirect
google.golang.org/grpc v1.64.1 // indirect
google.golang.org/genproto v0.0.0-20241021214115-324edc3d5d38 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38 // indirect
google.golang.org/grpc v1.67.1 // indirect
)

replace (
Expand Down
Loading
Loading