diff --git a/br/pkg/storage/BUILD.bazel b/br/pkg/storage/BUILD.bazel index c67a17713b2ca..8c98a13e59500 100644 --- a/br/pkg/storage/BUILD.bazel +++ b/br/pkg/storage/BUILD.bazel @@ -35,6 +35,7 @@ go_library( "@com_github_aws_aws_sdk_go//service/s3", "@com_github_aws_aws_sdk_go//service/s3/s3iface", "@com_github_aws_aws_sdk_go//service/s3/s3manager", + "@com_github_azure_azure_sdk_for_go_sdk_azcore//policy", "@com_github_azure_azure_sdk_for_go_sdk_azidentity//:azidentity", "@com_github_azure_azure_sdk_for_go_sdk_storage_azblob//:azblob", "@com_github_golang_snappy//:snappy", diff --git a/br/pkg/storage/azblob.go b/br/pkg/storage/azblob.go index c557a79e3ac8f..41d8fa88f559f 100644 --- a/br/pkg/storage/azblob.go +++ b/br/pkg/storage/azblob.go @@ -12,6 +12,7 @@ import ( "path" "strings" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" "github.com/google/uuid" @@ -30,6 +31,16 @@ const ( azblobAccountKey = "azblob.account-key" ) +const azblobRetryTimes int32 = 5 + +func getDefaultClientOptions() *azblob.ClientOptions { + return &azblob.ClientOptions{ + Retry: policy.RetryOptions{ + MaxRetries: azblobRetryTimes, + }, + } +} + // AzblobBackendOptions is the options for Azure Blob storage. type AzblobBackendOptions struct { Endpoint string `json:"endpoint" toml:"endpoint"` @@ -99,7 +110,7 @@ type sharedKeyClientBuilder struct { } func (b *sharedKeyClientBuilder) GetServiceClient() (azblob.ServiceClient, error) { - return azblob.NewServiceClientWithSharedKey(b.serviceURL, b.cred, nil) + return azblob.NewServiceClientWithSharedKey(b.serviceURL, b.cred, getDefaultClientOptions()) } func (b *sharedKeyClientBuilder) GetAccountName() string { @@ -114,7 +125,7 @@ type tokenClientBuilder struct { } func (b *tokenClientBuilder) GetServiceClient() (azblob.ServiceClient, error) { - return azblob.NewServiceClient(b.serviceURL, b.cred, nil) + return azblob.NewServiceClient(b.serviceURL, b.cred, getDefaultClientOptions()) } func (b *tokenClientBuilder) GetAccountName() string { @@ -285,7 +296,9 @@ func (s *AzureBlobStorage) ReadFile(ctx context.Context, name string) ([]byte, e return nil, errors.Annotatef(err, "Failed to download azure blob file, file info: bucket(container)='%s', key='%s'", s.options.Bucket, s.withPrefix(name)) } defer resp.RawResponse.Body.Close() - data, err := io.ReadAll(resp.Body(azblob.RetryReaderOptions{})) + data, err := io.ReadAll(resp.Body(azblob.RetryReaderOptions{ + MaxRetryRequests: int(azblobRetryTimes), + })) if err != nil { return nil, errors.Annotatef(err, "Failed to read azure blob file, file info: bucket(container)='%s', key='%s'", s.options.Bucket, s.withPrefix(name)) } diff --git a/br/pkg/storage/azblob_test.go b/br/pkg/storage/azblob_test.go index c099037ea51b2..74ddfa7125699 100644 --- a/br/pkg/storage/azblob_test.go +++ b/br/pkg/storage/azblob_test.go @@ -4,9 +4,13 @@ package storage import ( "context" + "fmt" "io" + "net/http" + "net/http/httptest" "os" "strings" + "sync" "testing" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" @@ -298,3 +302,52 @@ func TestNewAzblobStorage(t *testing.T) { require.Equal(t, "http://127.0.0.1:1000", b.serviceURL) } } + +type fakeClientBuilder struct { + Endpoint string +} + +func (b *fakeClientBuilder) GetServiceClient() (azblob.ServiceClient, error) { + connStr := fmt.Sprintf("DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=%s/devstoreaccount1;", b.Endpoint) + return azblob.NewServiceClientFromConnectionString(connStr, getDefaultClientOptions()) +} + +func (b *fakeClientBuilder) GetAccountName() string { + return "devstoreaccount1" +} + +func TestDownloadRetry(t *testing.T) { + var count int32 = 0 + var lock sync.Mutex + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Log(r.URL) + if strings.Contains(r.URL.String(), "restype=container") { + w.WriteHeader(201) + return + } + lock.Lock() + count += 1 + lock.Unlock() + header := w.Header() + header.Add("Etag", "0x1") + header.Add("Content-Length", "5") + w.WriteHeader(200) + w.Write([]byte("1234567")) + })) + + defer server.Close() + t.Log(server.URL) + + options := &backuppb.AzureBlobStorage{ + Bucket: "test", + Prefix: "a/b/", + } + + ctx := context.Background() + builder := &fakeClientBuilder{Endpoint: server.URL} + s, err := newAzureBlobStorageWithClientBuilder(ctx, options, builder) + require.NoError(t, err) + _, err = s.ReadFile(ctx, "c") + require.Error(t, err) + require.Less(t, azblobRetryTimes, count) +}