-
Notifications
You must be signed in to change notification settings - Fork 71
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c891518
commit b462410
Showing
5 changed files
with
302 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
package rest | ||
|
||
import ( | ||
"github.com/googleapis/gax-go/v2" | ||
|
||
v1 "cloud.google.com/go/confidentialcomputing/apiv1" | ||
confidentialcomputingpb "cloud.google.com/go/confidentialcomputing/apiv1/confidentialcomputingpb" | ||
"context" | ||
"fmt" | ||
"github.com/cenkalti/backoff/v4" | ||
"google.golang.org/api/option" | ||
locationpb "google.golang.org/genproto/googleapis/cloud/location" | ||
) | ||
|
||
// retryableClient is a thin wrapper around confidential computing APIs with backoff policies. | ||
type retryableClient struct { | ||
getLocation func(context.Context, *locationpb.GetLocationRequest, ...gax.CallOption) (*locationpb.Location, error) | ||
listLocations func(context.Context, *locationpb.ListLocationsRequest, ...gax.CallOption) *v1.LocationIterator | ||
createChallenge func(context.Context, *confidentialcomputingpb.CreateChallengeRequest, ...gax.CallOption) (*confidentialcomputingpb.Challenge, error) | ||
verifyAttestation func(context.Context, *confidentialcomputingpb.VerifyAttestationRequest, ...gax.CallOption) (*confidentialcomputingpb.VerifyAttestationResponse, error) | ||
backoffPolicy func() backoff.BackOff | ||
} | ||
|
||
// NewRetryableCleint creates a new retryable client. | ||
func NewRetryableClient(ctx context.Context, retry func() backoff.BackOff, opts ...option.ClientOption) (*retryableClient, error) { | ||
client, err := v1.NewRESTClient(ctx, opts...) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return &retryableClient{ | ||
getLocation: client.GetLocation, | ||
listLocations: client.ListLocations, | ||
verifyAttestation: client.VerifyAttestation, | ||
createChallenge: client.CreateChallenge, | ||
backoffPolicy: retry, | ||
}, nil | ||
} | ||
|
||
// GetLocation is a thin wrapper of calling confidential computing GetLocation API with retries. | ||
func (c *retryableClient) GetLocation(ctx context.Context, req *locationpb.GetLocationRequest, opts ...gax.CallOption) (*locationpb.Location, error) { | ||
var location *locationpb.Location | ||
err := backoff.Retry( | ||
func() error { | ||
var err error | ||
location, err = c.getLocation(ctx, req, opts...) | ||
return err | ||
}, | ||
c.backoffPolicy(), | ||
) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to GetLocation with retries, the last error is: %w", err) | ||
} | ||
return location, nil | ||
} | ||
|
||
// ListLocations calls the underlying confidential computing ListLocations API with no-op. | ||
func (c *retryableClient) ListLocations(ctx context.Context, req *locationpb.ListLocationsRequest, opts ...gax.CallOption) *v1.LocationIterator { | ||
return c.listLocations(ctx, req, opts...) | ||
} | ||
|
||
// CreateChallenge calls the underlying confidential computing CreateChallenge API with no-op. | ||
func (c *retryableClient) CreateChallenge(ctx context.Context, req *confidentialcomputingpb.CreateChallengeRequest, opts ...gax.CallOption) (*confidentialcomputingpb.Challenge, error) { | ||
return c.createChallenge(ctx, req, opts...) | ||
} | ||
|
||
// VerifyAttestation is a thin wrapper of calling confidential computing VerifyAttestation API with retries. | ||
func (c *retryableClient) VerifyAttestation(ctx context.Context, req *confidentialcomputingpb.VerifyAttestationRequest, opts ...gax.CallOption) (*confidentialcomputingpb.VerifyAttestationResponse, error) { | ||
var response *confidentialcomputingpb.VerifyAttestationResponse | ||
err := backoff.Retry( | ||
func() error { | ||
var err error | ||
response, err = c.verifyAttestation(ctx, req, opts...) | ||
return err | ||
}, | ||
c.backoffPolicy(), | ||
) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to VerifyAttestation with retries, the last error is: %w", err) | ||
} | ||
return response, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
package rest | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"testing" | ||
|
||
"cloud.google.com/go/confidentialcomputing/apiv1/confidentialcomputingpb" | ||
"github.com/cenkalti/backoff/v4" | ||
"github.com/googleapis/gax-go/v2" | ||
|
||
locationpb "google.golang.org/genproto/googleapis/cloud/location" | ||
) | ||
|
||
func TestGetLocationWithRetry(t *testing.T) { | ||
testCases := []struct { | ||
name string | ||
getLocation func(int) (*locationpb.Location, error) | ||
wantPass bool | ||
}{ | ||
{ | ||
name: "success with single attempt", | ||
getLocation: func(int) (*locationpb.Location, error) { | ||
return &locationpb.Location{}, nil | ||
}, | ||
wantPass: true, | ||
}, | ||
{ | ||
name: "failure then success", | ||
getLocation: func(attempts int) (*locationpb.Location, error) { | ||
if attempts%2 == 1 { | ||
return nil, errors.New("fake error") | ||
} | ||
return &locationpb.Location{}, nil | ||
}, | ||
wantPass: true, | ||
}, | ||
{ | ||
name: "failure with attempts exceeded", | ||
getLocation: func(attempts int) (*locationpb.Location, error) { | ||
return nil, errors.New("fake error") | ||
}, | ||
wantPass: false, | ||
}, | ||
} | ||
|
||
for _, tc := range testCases { | ||
t.Run(tc.name, func(t *testing.T) { | ||
backoffPolicy := func() backoff.BackOff { | ||
b := backoff.NewExponentialBackOff() | ||
return backoff.WithMaxRetries(b, 2) | ||
} | ||
attempts := 0 | ||
fakeClient := &retryableClient{ | ||
getLocation: func(context.Context, *locationpb.GetLocationRequest, ...gax.CallOption) (*locationpb.Location, error) { | ||
attempts++ | ||
return tc.getLocation(attempts) | ||
}, | ||
backoffPolicy: backoffPolicy, | ||
} | ||
|
||
_, err := fakeClient.GetLocation(context.Background(), &locationpb.GetLocationRequest{}) | ||
|
||
if gotPass := (err == nil); gotPass != tc.wantPass { | ||
t.Errorf("GetLocation retry failed, got %v, but want %v", gotPass, tc.wantPass) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestVerifyAttestationWithRetry(t *testing.T) { | ||
testCases := []struct { | ||
name string | ||
verifyAttestation func(int) (*confidentialcomputingpb.VerifyAttestationResponse, error) | ||
wantPass bool | ||
}{ | ||
{ | ||
name: "success with single attempt", | ||
verifyAttestation: func(int) (*confidentialcomputingpb.VerifyAttestationResponse, error) { | ||
return &confidentialcomputingpb.VerifyAttestationResponse{}, nil | ||
}, | ||
wantPass: true, | ||
}, | ||
{ | ||
name: "failure then success", | ||
verifyAttestation: func(attempts int) (*confidentialcomputingpb.VerifyAttestationResponse, error) { | ||
if attempts%2 == 1 { | ||
return nil, errors.New("fake error") | ||
} | ||
return &confidentialcomputingpb.VerifyAttestationResponse{}, nil | ||
}, | ||
wantPass: true, | ||
}, | ||
{ | ||
name: "failure with attempts exceeded", | ||
verifyAttestation: func(attempts int) (*confidentialcomputingpb.VerifyAttestationResponse, error) { | ||
return nil, errors.New("fake error") | ||
}, | ||
wantPass: false, | ||
}, | ||
} | ||
|
||
for _, tc := range testCases { | ||
t.Run(tc.name, func(t *testing.T) { | ||
backoffPolicy := func() backoff.BackOff { | ||
b := backoff.NewExponentialBackOff() | ||
return backoff.WithMaxRetries(b, 2) | ||
} | ||
attempts := 0 | ||
fakeClient := &retryableClient{ | ||
verifyAttestation: func(context.Context, *confidentialcomputingpb.VerifyAttestationRequest, ...gax.CallOption) (*confidentialcomputingpb.VerifyAttestationResponse, error) { | ||
attempts++ | ||
return tc.verifyAttestation(attempts) | ||
}, | ||
backoffPolicy: backoffPolicy, | ||
} | ||
|
||
_, err := fakeClient.VerifyAttestation(context.Background(), &confidentialcomputingpb.VerifyAttestationRequest{}) | ||
|
||
if gotPass := (err == nil); gotPass != tc.wantPass { | ||
t.Errorf("VerifyAttestation retry failed, got %v, but want %v", gotPass, tc.wantPass) | ||
} | ||
}) | ||
} | ||
} |