diff --git a/deploy/example/cloning/README.md b/deploy/example/cloning/README.md index c81e11482..797026748 100644 --- a/deploy/example/cloning/README.md +++ b/deploy/example/cloning/README.md @@ -21,7 +21,7 @@ outfile ## Create a PVC from an existing PVC > Make sure application is not writing data to source blob container ```console -kubectl apply -f https://raw.githubusercontent.com/kubernetes-sigs/blob-csi-driver/master/deploy/example/cloning/pvc-blob-cloning.yaml +kubectl apply -f https://raw.githubusercontent.com/kubernetes-sigs/blob-csi-driver/master/deploy/example/cloning/pvc-blob-csi-cloning.yaml ``` ### Check the Creation Status diff --git a/hack/update-mock.sh b/hack/update-mock.sh old mode 100644 new mode 100755 diff --git a/pkg/blob/controllerserver.go b/pkg/blob/controllerserver.go index e10044c0d..b1c6c13ae 100644 --- a/pkg/blob/controllerserver.go +++ b/pkg/blob/controllerserver.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "net/url" + "os" "os/exec" "strconv" "strings" @@ -50,6 +51,15 @@ import ( const ( privateEndpoint = "privateendpoint" + azcopyAutoLoginType = "AZCOPY_AUTO_LOGIN_TYPE" + azcopySPAApplicationID = "AZCOPY_SPA_APPLICATION_ID" + azcopySPAClientSecret = "AZCOPY_SPA_CLIENT_SECRET" + azcopyTenantID = "AZCOPY_TENANT_ID" + azcopyMSIClientID = "AZCOPY_MSI_CLIENT_ID" + MSI = "MSI" + SPN = "SPN" + authorizationPermissionMismatch = "AuthorizationPermissionMismatch" + waitForCopyInterval = 5 * time.Second waitForCopyTimeout = 3 * time.Minute ) @@ -73,7 +83,7 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) if acquired := d.volumeLocks.TryAcquire(volName); !acquired { // logging the job status if it's volume cloning if req.GetVolumeContentSource() != nil { - jobState, percent, err := d.azcopy.GetAzcopyJob(volName) + jobState, percent, err := d.azcopy.GetAzcopyJob(volName, []string{}) klog.V(2).Infof("azcopy job status: %s, copy percent: %s%%, error: %v", jobState, percent, err) } return nil, status.Errorf(codes.Aborted, volumeOperationAlreadyExistsFmt, volName) @@ -412,12 +422,11 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) } if req.GetVolumeContentSource() != nil { - if accountKey == "" { - if _, accountKey, err = d.GetStorageAccesskey(ctx, accountOptions, secrets, secretName, secretNamespace); err != nil { - return nil, status.Errorf(codes.Internal, "failed to GetStorageAccesskey on account(%s) rg(%s), error: %v", accountOptions.Name, accountOptions.ResourceGroup, err) - } + var accountSASToken string + if accountSASToken, err = d.getSASToken(ctx, accountName, accountKey, storageEndpointSuffix, accountOptions, secrets, secretName, secretNamespace); err != nil { + return nil, status.Errorf(codes.Internal, "failed to getSASToken on account(%s) rg(%s), error: %v", accountOptions.Name, accountOptions.ResourceGroup, err) } - if err := d.copyVolume(ctx, req, accountKey, validContainerName, storageEndpointSuffix); err != nil { + if err := d.copyVolume(ctx, req, accountSASToken, validContainerName, storageEndpointSuffix); err != nil { return nil, err } } else { @@ -712,7 +721,7 @@ func (d *Driver) DeleteBlobContainer(ctx context.Context, subsID, resourceGroupN } // CopyBlobContainer copies a blob container in the same storage account -func (d *Driver) copyBlobContainer(_ context.Context, req *csi.CreateVolumeRequest, accountKey, dstContainerName, storageEndpointSuffix string) error { +func (d *Driver) copyBlobContainer(_ context.Context, req *csi.CreateVolumeRequest, accountSasToken, dstContainerName, storageEndpointSuffix string) error { var sourceVolumeID string if req.GetVolumeContentSource() != nil && req.GetVolumeContentSource().GetVolume() != nil { sourceVolumeID = req.GetVolumeContentSource().GetVolume().GetVolumeId() @@ -726,10 +735,11 @@ func (d *Driver) copyBlobContainer(_ context.Context, req *csi.CreateVolumeReque return fmt.Errorf("srcContainerName(%s) or dstContainerName(%s) is empty", srcContainerName, dstContainerName) } - klog.V(2).Infof("generate sas token for account(%s)", accountName) - accountSasToken, genErr := generateSASToken(accountName, accountKey, storageEndpointSuffix, d.sasTokenExpirationMinutes) - if genErr != nil { - return genErr + var authAzcopyEnv []string + if accountSasToken == "" { + if authAzcopyEnv, err = d.authorizeAzcopyWithIdentity(); err != nil { + return err + } } timeAfter := time.After(waitForCopyTimeout) @@ -737,7 +747,7 @@ func (d *Driver) copyBlobContainer(_ context.Context, req *csi.CreateVolumeReque srcPath := fmt.Sprintf("https://%s.blob.%s/%s%s", accountName, storageEndpointSuffix, srcContainerName, accountSasToken) dstPath := fmt.Sprintf("https://%s.blob.%s/%s%s", accountName, storageEndpointSuffix, dstContainerName, accountSasToken) - jobState, percent, err := d.azcopy.GetAzcopyJob(dstContainerName) + jobState, percent, err := d.azcopy.GetAzcopyJob(dstContainerName, authAzcopyEnv) klog.V(2).Infof("azcopy job status: %s, copy percent: %s%%, error: %v", jobState, percent, err) if jobState == util.AzcopyJobError || jobState == util.AzcopyJobCompleted { return err @@ -746,14 +756,18 @@ func (d *Driver) copyBlobContainer(_ context.Context, req *csi.CreateVolumeReque for { select { case <-timeTick: - jobState, percent, err := d.azcopy.GetAzcopyJob(dstContainerName) + jobState, percent, err := d.azcopy.GetAzcopyJob(dstContainerName, authAzcopyEnv) klog.V(2).Infof("azcopy job status: %s, copy percent: %s%%, error: %v", jobState, percent, err) switch jobState { case util.AzcopyJobError, util.AzcopyJobCompleted: return err case util.AzcopyJobNotFound: klog.V(2).Infof("copy blob container %s to %s", srcContainerName, dstContainerName) - out, copyErr := exec.Command("azcopy", "copy", srcPath, dstPath, "--recursive", "--check-length=false").CombinedOutput() + cmd := exec.Command("azcopy", "copy", srcPath, dstPath, "--recursive", "--check-length=false") + if len(authAzcopyEnv) > 0 { + cmd.Env = append(os.Environ(), authAzcopyEnv...) + } + out, copyErr := cmd.CombinedOutput() if copyErr != nil { klog.Warningf("CopyBlobContainer(%s, %s, %s) failed with error(%v): %v", resourceGroupName, accountName, dstPath, copyErr, string(out)) } else { @@ -768,18 +782,77 @@ func (d *Driver) copyBlobContainer(_ context.Context, req *csi.CreateVolumeReque } // copyVolume copies a volume form volume or snapshot, snapshot is not supported now -func (d *Driver) copyVolume(ctx context.Context, req *csi.CreateVolumeRequest, accountKey, dstContainerName, storageEndpointSuffix string) error { +func (d *Driver) copyVolume(ctx context.Context, req *csi.CreateVolumeRequest, accountSASToken, dstContainerName, storageEndpointSuffix string) error { vs := req.VolumeContentSource switch vs.Type.(type) { case *csi.VolumeContentSource_Snapshot: return status.Errorf(codes.InvalidArgument, "copy volume from volumeSnapshot is not supported") case *csi.VolumeContentSource_Volume: - return d.copyBlobContainer(ctx, req, accountKey, dstContainerName, storageEndpointSuffix) + return d.copyBlobContainer(ctx, req, accountSASToken, dstContainerName, storageEndpointSuffix) default: return status.Errorf(codes.InvalidArgument, "%v is not a proper volume source", vs) } } +func (d *Driver) authorizeAzcopyWithIdentity() ([]string, error) { + azureAuthConfig := d.cloud.Config.AzureAuthConfig + var authAzcopyEnv []string + if azureAuthConfig.UseManagedIdentityExtension { + authAzcopyEnv = append(authAzcopyEnv, fmt.Sprintf("%s=%s", azcopyAutoLoginType, MSI)) + if len(azureAuthConfig.UserAssignedIdentityID) > 0 { + klog.V(2).Infof("use user assigned managed identity to authorize azcopy") + authAzcopyEnv = append(authAzcopyEnv, fmt.Sprintf("%s=%s", azcopyMSIClientID, azureAuthConfig.UserAssignedIdentityID)) + } else { + klog.V(2).Infof("use system-assigned managed identity to authorize azcopy") + } + return authAzcopyEnv, nil + } + if len(azureAuthConfig.AADClientSecret) > 0 { + klog.V(2).Infof("use service principal to authorize azcopy") + authAzcopyEnv = append(authAzcopyEnv, fmt.Sprintf("%s=%s", azcopyAutoLoginType, SPN)) + if azureAuthConfig.AADClientID == "" || azureAuthConfig.TenantID == "" { + return []string{}, fmt.Errorf("AADClientID and TenantID must be set when use service principal") + } + authAzcopyEnv = append(authAzcopyEnv, fmt.Sprintf("%s=%s", azcopySPAApplicationID, azureAuthConfig.AADClientID)) + authAzcopyEnv = append(authAzcopyEnv, fmt.Sprintf("%s=%s", azcopySPAClientSecret, azureAuthConfig.AADClientSecret)) + authAzcopyEnv = append(authAzcopyEnv, fmt.Sprintf("%s=%s", azcopyTenantID, azureAuthConfig.TenantID)) + klog.V(2).Infof(fmt.Sprintf("set AZCOPY_SPA_APPLICATION_ID=%s, AZCOPY_TENANT_ID=%s successfully", azureAuthConfig.AADClientID, azureAuthConfig.TenantID)) + + return authAzcopyEnv, nil + } + return []string{}, fmt.Errorf("service principle or managed identity are both not set") +} + +// getSASToken will only generate sas token for azcopy in following conditions: +// 1. secrets is not empty +// 2. driver is not using managed identity and service principal +// 3. azcopy returns AuthorizationPermissionMismatch error when using service principal or managed identity +func (d *Driver) getSASToken(ctx context.Context, accountName, accountKey, storageEndpointSuffix string, accountOptions *azure.AccountOptions, secrets map[string]string, secretName, secretNamespace string) (string, error) { + authAzcopyEnv, _ := d.authorizeAzcopyWithIdentity() + useSasTokenFallBack := false + if len(authAzcopyEnv) > 0 { + out, testErr := d.azcopy.TestListJobs(accountName, storageEndpointSuffix, authAzcopyEnv) + if testErr != nil { + return "", fmt.Errorf("azcopy list command failed with error(%v): %v", testErr, out) + } + if strings.Contains(out, authorizationPermissionMismatch) { + klog.Warningf("azcopy list failed with AuthorizationPermissionMismatch error, should assign \"Storage Blob Data Contributor\" role to controller identity, fall back to use sas token, original output: %v", out) + useSasTokenFallBack = true + } + } + if len(secrets) > 0 || len(authAzcopyEnv) == 0 || useSasTokenFallBack { + var err error + if accountKey == "" { + if _, accountKey, err = d.GetStorageAccesskey(ctx, accountOptions, secrets, secretName, secretNamespace); err != nil { + return "", err + } + } + klog.V(2).Infof("generate sas token for account(%s)", accountName) + return generateSASToken(accountName, accountKey, storageEndpointSuffix, d.sasTokenExpirationMinutes) + } + return "", nil +} + // isValidVolumeCapabilities validates the given VolumeCapability array is valid func isValidVolumeCapabilities(volCaps []*csi.VolumeCapability) error { if len(volCaps) == 0 { diff --git a/pkg/blob/controllerserver_test.go b/pkg/blob/controllerserver_test.go index 7446c2012..25eb5d707 100644 --- a/pkg/blob/controllerserver_test.go +++ b/pkg/blob/controllerserver_test.go @@ -33,9 +33,11 @@ import ( "k8s.io/apimachinery/pkg/util/wait" "k8s.io/utils/pointer" "sigs.k8s.io/blob-csi-driver/pkg/util" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient" "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/blobclient" "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/storageaccountclient/mockstorageaccountclient" azure "sigs.k8s.io/cloud-provider-azure/pkg/provider" + "sigs.k8s.io/cloud-provider-azure/pkg/provider/config" "sigs.k8s.io/cloud-provider-azure/pkg/retry" ) @@ -756,6 +758,7 @@ func TestCreateVolume(t *testing.T) { Value: &fakeValue, }) d.cloud.StorageAccountClient = NewMockSAClient(context.Background(), gomock.NewController(t), "subID", "unit-test", "unit-test", &keyList) + d.cloud.Config.AzureAuthConfig.UseManagedIdentityExtension = true errorType := NULL d.cloud.BlobClient = &mockBlobClient{errorType: &errorType} @@ -789,6 +792,15 @@ func TestCreateVolume(t *testing.T) { controllerServiceCapability, } + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + m := util.NewMockEXEC(ctrl) + listStr := "no error" + m.EXPECT().RunCommand(gomock.Any(), gomock.Any()).Return(listStr, nil) + + d.azcopy.ExecCmd = m + expectedErr := status.Errorf(codes.InvalidArgument, "copy volume from volumeSnapshot is not supported") _, err := d.CreateVolume(context.Background(), req) if !reflect.DeepEqual(err, expectedErr) { @@ -812,6 +824,7 @@ func TestCreateVolume(t *testing.T) { Value: &fakeValue, }) d.cloud.StorageAccountClient = NewMockSAClient(context.Background(), gomock.NewController(t), "subID", "unit-test", "unit-test", &keyList) + d.cloud.Config.AzureAuthConfig.UseManagedIdentityExtension = true errorType := NULL d.cloud.BlobClient = &mockBlobClient{errorType: &errorType} @@ -845,6 +858,15 @@ func TestCreateVolume(t *testing.T) { controllerServiceCapability, } + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + m := util.NewMockEXEC(ctrl) + listStr := "no error" + m.EXPECT().RunCommand(gomock.Any(), gomock.Any()).Return(listStr, nil) + + d.azcopy.ExecCmd = m + expectedErr := status.Errorf(codes.NotFound, "error parsing volume id: \"unit-test\", should at least contain two #") _, err := d.CreateVolume(context.Background(), req) if !reflect.DeepEqual(err, expectedErr) { @@ -1598,6 +1620,38 @@ func TestCopyVolume(t *testing.T) { } }, }, + { + name: "AADClientSecret shouldn't be nil or useManagedIdentityExtension must be set to true when accountSASToken is empty", + testFunc: func(t *testing.T) { + d := NewFakeDriver() + mp := map[string]string{} + + volumeSource := &csi.VolumeContentSource_VolumeSource{ + VolumeId: "vol_1#f5713de20cde511e8ba4900#fileshare#", + } + volumeContentSourceVolumeSource := &csi.VolumeContentSource_Volume{ + Volume: volumeSource, + } + volumecontensource := csi.VolumeContentSource{ + Type: volumeContentSourceVolumeSource, + } + + req := &csi.CreateVolumeRequest{ + Name: "unit-test", + VolumeCapabilities: stdVolumeCapabilities, + Parameters: mp, + VolumeContentSource: &volumecontensource, + } + + ctx := context.Background() + + expectedErr := fmt.Errorf("service principle or managed identity are both not set") + err := d.copyVolume(ctx, req, "", "dstContainer", "core.windows.net") + if !reflect.DeepEqual(err, expectedErr) { + t.Errorf("Unexpected error: %v", err) + } + }, + }, { name: "azcopy job is already completed", testFunc: func(t *testing.T) { @@ -1626,7 +1680,7 @@ func TestCopyVolume(t *testing.T) { m := util.NewMockEXEC(ctrl) listStr := "JobId: ed1c3833-eaff-fe42-71d7-513fb065a9d9\nStart Time: Monday, 07-Aug-23 03:29:54 UTC\nStatus: Completed\nCommand: copy https://{accountName}.file.core.windows.net/{srcFileshare}{SAStoken} https://{accountName}.file.core.windows.net/{dstFileshare}{SAStoken} --recursive --check-length=false" - m.EXPECT().RunCommand(gomock.Eq("azcopy jobs list | grep dstContainer -B 3")).Return(listStr, nil) + m.EXPECT().RunCommand(gomock.Eq("azcopy jobs list | grep dstContainer -B 3"), gomock.Any()).Return(listStr, nil) // if test.enableShow { // m.EXPECT().RunCommand(gomock.Not("azcopy jobs list | grep dstContainer -B 3")).Return(test.showStr, test.showErr) // } @@ -1636,7 +1690,7 @@ func TestCopyVolume(t *testing.T) { ctx := context.Background() var expectedErr error - err := d.copyVolume(ctx, req, "", "dstContainer", "core.windows.net") + err := d.copyVolume(ctx, req, "sastoken", "dstContainer", "core.windows.net") if !reflect.DeepEqual(err, expectedErr) { t.Errorf("Unexpected error: %v", err) } @@ -1671,9 +1725,9 @@ func TestCopyVolume(t *testing.T) { m := util.NewMockEXEC(ctrl) listStr1 := "JobId: ed1c3833-eaff-fe42-71d7-513fb065a9d9\nStart Time: Monday, 07-Aug-23 03:29:54 UTC\nStatus: InProgress\nCommand: copy https://{accountName}.file.core.windows.net/{srcFileshare}{SAStoken} https://{accountName}.file.core.windows.net/{dstFileshare}{SAStoken} --recursive --check-length=false" listStr2 := "JobId: ed1c3833-eaff-fe42-71d7-513fb065a9d9\nStart Time: Monday, 07-Aug-23 03:29:54 UTC\nStatus: Completed\nCommand: copy https://{accountName}.file.core.windows.net/{srcFileshare}{SAStoken} https://{accountName}.file.core.windows.net/{dstFileshare}{SAStoken} --recursive --check-length=false" - o1 := m.EXPECT().RunCommand(gomock.Eq("azcopy jobs list | grep dstContainer -B 3")).Return(listStr1, nil).Times(1) - m.EXPECT().RunCommand(gomock.Not("azcopy jobs list | grep dstBlobContainer -B 3")).Return("Percent Complete (approx): 50.0", nil) - o2 := m.EXPECT().RunCommand(gomock.Eq("azcopy jobs list | grep dstContainer -B 3")).Return(listStr2, nil) + o1 := m.EXPECT().RunCommand(gomock.Eq("azcopy jobs list | grep dstContainer -B 3"), gomock.Any()).Return(listStr1, nil).Times(1) + m.EXPECT().RunCommand(gomock.Not("azcopy jobs list | grep dstBlobContainer -B 3"), gomock.Any()).Return("Percent Complete (approx): 50.0", nil) + o2 := m.EXPECT().RunCommand(gomock.Eq("azcopy jobs list | grep dstContainer -B 3"), gomock.Any()).Return(listStr2, nil) gomock.InOrder(o1, o2) d.azcopy.ExecCmd = m @@ -1681,7 +1735,7 @@ func TestCopyVolume(t *testing.T) { ctx := context.Background() var expectedErr error - err := d.copyVolume(ctx, req, "", "dstContainer", "core.windows.net") + err := d.copyVolume(ctx, req, "sastoken", "dstContainer", "core.windows.net") if !reflect.DeepEqual(err, expectedErr) { t.Errorf("Unexpected error: %v", err) } @@ -1780,3 +1834,256 @@ func Test_generateSASToken(t *testing.T) { }) } } + +func Test_authorizeAzcopyWithIdentity(t *testing.T) { + testCases := []struct { + name string + testFunc func(t *testing.T) + }{ + { + name: "use service principal to authorize azcopy", + testFunc: func(t *testing.T) { + d := NewFakeDriver() + d.cloud = &azure.Cloud{ + Config: azure.Config{ + AzureAuthConfig: config.AzureAuthConfig{ + ARMClientConfig: azclient.ARMClientConfig{ + TenantID: "TenantID", + }, + AzureAuthConfig: azclient.AzureAuthConfig{ + AADClientID: "AADClientID", + AADClientSecret: "AADClientSecret", + }, + }, + }, + } + expectedAuthAzcopyEnv := []string{ + fmt.Sprintf(azcopyAutoLoginType + "=SPN"), + fmt.Sprintf(azcopySPAApplicationID + "=AADClientID"), + fmt.Sprintf(azcopySPAClientSecret + "=AADClientSecret"), + fmt.Sprintf(azcopyTenantID + "=TenantID"), + } + var expectedErr error + authAzcopyEnv, err := d.authorizeAzcopyWithIdentity() + if !reflect.DeepEqual(authAzcopyEnv, expectedAuthAzcopyEnv) || !reflect.DeepEqual(err, expectedErr) { + t.Errorf("Unexpected authAzcopyEnv: %v, Unexpected error: %v", authAzcopyEnv, err) + } + }, + }, + { + name: "use service principal to authorize azcopy but client id is empty", + testFunc: func(t *testing.T) { + d := NewFakeDriver() + d.cloud = &azure.Cloud{ + Config: azure.Config{ + AzureAuthConfig: config.AzureAuthConfig{ + ARMClientConfig: azclient.ARMClientConfig{ + TenantID: "TenantID", + }, + AzureAuthConfig: azclient.AzureAuthConfig{ + AADClientSecret: "AADClientSecret", + }, + }, + }, + } + expectedAuthAzcopyEnv := []string{} + expectedErr := fmt.Errorf("AADClientID and TenantID must be set when use service principal") + authAzcopyEnv, err := d.authorizeAzcopyWithIdentity() + if !reflect.DeepEqual(authAzcopyEnv, expectedAuthAzcopyEnv) || !reflect.DeepEqual(err, expectedErr) { + t.Errorf("Unexpected authAzcopyEnv: %v, Unexpected error: %v", authAzcopyEnv, err) + } + }, + }, + { + name: "use user assigned managed identity to authorize azcopy", + testFunc: func(t *testing.T) { + d := NewFakeDriver() + d.cloud = &azure.Cloud{ + Config: azure.Config{ + AzureAuthConfig: config.AzureAuthConfig{ + AzureAuthConfig: azclient.AzureAuthConfig{ + UseManagedIdentityExtension: true, + UserAssignedIdentityID: "UserAssignedIdentityID", + }, + }, + }, + } + expectedAuthAzcopyEnv := []string{ + fmt.Sprintf(azcopyAutoLoginType + "=MSI"), + fmt.Sprintf(azcopyMSIClientID + "=UserAssignedIdentityID"), + } + var expectedErr error + authAzcopyEnv, err := d.authorizeAzcopyWithIdentity() + if !reflect.DeepEqual(authAzcopyEnv, expectedAuthAzcopyEnv) || !reflect.DeepEqual(err, expectedErr) { + t.Errorf("Unexpected authAzcopyEnv: %v, Unexpected error: %v", authAzcopyEnv, err) + } + }, + }, + { + name: "use system assigned managed identity to authorize azcopy", + testFunc: func(t *testing.T) { + d := NewFakeDriver() + d.cloud = &azure.Cloud{ + Config: azure.Config{ + AzureAuthConfig: config.AzureAuthConfig{ + AzureAuthConfig: azclient.AzureAuthConfig{ + UseManagedIdentityExtension: true, + }, + }, + }, + } + expectedAuthAzcopyEnv := []string{ + fmt.Sprintf(azcopyAutoLoginType + "=MSI"), + } + var expectedErr error + authAzcopyEnv, err := d.authorizeAzcopyWithIdentity() + if !reflect.DeepEqual(authAzcopyEnv, expectedAuthAzcopyEnv) || !reflect.DeepEqual(err, expectedErr) { + t.Errorf("Unexpected authAzcopyEnv: %v, Unexpected error: %v", authAzcopyEnv, err) + } + }, + }, + { + name: "AADClientSecret be nil and useManagedIdentityExtension is false", + testFunc: func(t *testing.T) { + d := NewFakeDriver() + d.cloud = &azure.Cloud{ + Config: azure.Config{ + AzureAuthConfig: config.AzureAuthConfig{}, + }, + } + expectedAuthAzcopyEnv := []string{} + expectedErr := fmt.Errorf("service principle or managed identity are both not set") + authAzcopyEnv, err := d.authorizeAzcopyWithIdentity() + if !reflect.DeepEqual(authAzcopyEnv, expectedAuthAzcopyEnv) || !reflect.DeepEqual(err, expectedErr) { + t.Errorf("Unexpected authAzcopyEnv: %v, Unexpected error: %v", authAzcopyEnv, err) + } + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, tc.testFunc) + } +} + +func Test_getSASToken(t *testing.T) { + testCases := []struct { + name string + testFunc func(t *testing.T) + }{ + { + name: "failed to get accountKey in secrets", + testFunc: func(t *testing.T) { + d := NewFakeDriver() + d.cloud = &azure.Cloud{ + Config: azure.Config{}, + } + secrets := map[string]string{ + defaultSecretAccountName: "accountName", + } + + ctx := context.Background() + expectedAccountSASToken := "" + expectedErr := fmt.Errorf("could not find accountkey or azurestorageaccountkey field in secrets") + accountSASToken, err := d.getSASToken(ctx, "accountName", "", "core.windows.net", &azure.AccountOptions{}, secrets, "secretsName", "secretsNamespace") + if !reflect.DeepEqual(err, expectedErr) || !reflect.DeepEqual(accountSASToken, expectedAccountSASToken) { + t.Errorf("Unexpected accountSASToken: %s, Unexpected error: %v", accountSASToken, err) + } + }, + }, + { + name: "failed to test azcopy list command", + testFunc: func(t *testing.T) { + d := NewFakeDriver() + d.cloud = &azure.Cloud{ + Config: azure.Config{ + AzureAuthConfig: config.AzureAuthConfig{ + AzureAuthConfig: azclient.AzureAuthConfig{ + UseManagedIdentityExtension: true, + }, + }, + }, + } + secrets := map[string]string{ + defaultSecretAccountName: "accountName", + } + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + m := util.NewMockEXEC(ctrl) + listStr := "error" + m.EXPECT().RunCommand(gomock.Any(), gomock.Any()).Return(listStr, fmt.Errorf("error")) + + d.azcopy.ExecCmd = m + + ctx := context.Background() + expectedAccountSASToken := "" + expectedErr := fmt.Errorf("azcopy list command failed with error(%v): %v", fmt.Errorf("error"), "error") + accountSASToken, err := d.getSASToken(ctx, "accountName", "", "core.windows.net", &azure.AccountOptions{}, secrets, "secretsName", "secretsNamespace") + if !reflect.DeepEqual(err, expectedErr) || !reflect.DeepEqual(accountSASToken, expectedAccountSASToken) { + t.Errorf("Unexpected accountSASToken: %s, Unexpected error: %v", accountSASToken, err) + } + }, + }, + { + name: "fall back to generate SAS token failed for illegal account key", + testFunc: func(t *testing.T) { + d := NewFakeDriver() + d.cloud = &azure.Cloud{ + Config: azure.Config{ + AzureAuthConfig: config.AzureAuthConfig{ + AzureAuthConfig: azclient.AzureAuthConfig{ + UseManagedIdentityExtension: true, + }, + }, + }, + } + secrets := map[string]string{ + defaultSecretAccountName: "accountName", + defaultSecretAccountKey: "fakeValue", + } + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + m := util.NewMockEXEC(ctrl) + listStr := "RESPONSE 403: 403 This request is not authorized to perform this operation using this permission.\nERROR CODE: AuthorizationPermissionMismatch" + m.EXPECT().RunCommand(gomock.Any(), gomock.Any()).Return(listStr, nil) + + d.azcopy.ExecCmd = m + + ctx := context.Background() + expectedAccountSASToken := "" + expectedErr := status.Errorf(codes.Internal, fmt.Sprintf("failed to generate sas token in creating new shared key credential, accountName: %s, err: %s", "accountName", "decode account key: illegal base64 data at input byte 8")) + accountSASToken, err := d.getSASToken(ctx, "accountName", "", "core.windows.net", &azure.AccountOptions{}, secrets, "secretsName", "secretsNamespace") + if !reflect.DeepEqual(err, expectedErr) || !reflect.DeepEqual(accountSASToken, expectedAccountSASToken) { + t.Errorf("Unexpected accountSASToken: %s, Unexpected error: %v", accountSASToken, err) + } + }, + }, + { + name: "generate SAS token failed for illegal account key", + testFunc: func(t *testing.T) { + d := NewFakeDriver() + d.cloud = &azure.Cloud{ + Config: azure.Config{}, + } + secrets := map[string]string{ + defaultSecretAccountName: "accountName", + defaultSecretAccountKey: "fakeValue", + } + + ctx := context.Background() + expectedAccountSASToken := "" + expectedErr := status.Errorf(codes.Internal, fmt.Sprintf("failed to generate sas token in creating new shared key credential, accountName: %s, err: %s", "accountName", "decode account key: illegal base64 data at input byte 8")) + accountSASToken, err := d.getSASToken(ctx, "accountName", "", "core.windows.net", &azure.AccountOptions{}, secrets, "secretsName", "secretsNamespace") + if !reflect.DeepEqual(err, expectedErr) || !reflect.DeepEqual(accountSASToken, expectedAccountSASToken) { + t.Errorf("Unexpected accountSASToken: %s, Unexpected error: %v", accountSASToken, err) + } + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, tc.testFunc) + } +} diff --git a/pkg/util/util.go b/pkg/util/util.go index 691984027..c010db385 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -205,14 +205,18 @@ func TrimDuplicatedSpace(s string) string { } type EXEC interface { - RunCommand(string) (string, error) + RunCommand(string, []string) (string, error) } type ExecCommand struct { } -func (ec *ExecCommand) RunCommand(cmd string) (string, error) { - out, err := exec.Command("sh", "-c", cmd).CombinedOutput() +func (ec *ExecCommand) RunCommand(cmdStr string, authEnv []string) (string, error) { + cmd := exec.Command("sh", "-c", cmdStr) + if len(authEnv) > 0 { + cmd.Env = append(os.Environ(), authEnv...) + } + out, err := cmd.CombinedOutput() return string(out), err } @@ -221,7 +225,7 @@ type Azcopy struct { } // GetAzcopyJob get the azcopy job status if job existed -func (ac *Azcopy) GetAzcopyJob(dstBlobContainer string) (AzcopyJobState, string, error) { +func (ac *Azcopy) GetAzcopyJob(dstBlobContainer string, authAzcopyEnv []string) (AzcopyJobState, string, error) { cmdStr := fmt.Sprintf("azcopy jobs list | grep %s -B 3", dstBlobContainer) // cmd output example: // JobId: ed1c3833-eaff-fe42-71d7-513fb065a9d9 @@ -236,7 +240,7 @@ func (ac *Azcopy) GetAzcopyJob(dstBlobContainer string) (AzcopyJobState, string, if ac.ExecCmd == nil { ac.ExecCmd = &ExecCommand{} } - out, err := ac.ExecCmd.RunCommand(cmdStr) + out, err := ac.ExecCmd.RunCommand(cmdStr, authAzcopyEnv) // if grep command returns nothing, the exec will return exit status 1 error, so filter this error if err != nil && err.Error() != "exit status 1" { klog.Warningf("failed to get azcopy job with error: %v, jobState: %v", err, AzcopyJobError) @@ -256,7 +260,7 @@ func (ac *Azcopy) GetAzcopyJob(dstBlobContainer string) (AzcopyJobState, string, cmdPercentStr := fmt.Sprintf("azcopy jobs show %s | grep Percent", jobid) // cmd out example: // Percent Complete (approx): 100.0 - summary, err := ac.ExecCmd.RunCommand(cmdPercentStr) + summary, err := ac.ExecCmd.RunCommand(cmdPercentStr, authAzcopyEnv) if err != nil { klog.Warningf("failed to get azcopy job with error: %v, jobState: %v", err, AzcopyJobError) return AzcopyJobError, "", fmt.Errorf("couldn't show jobs summary in azcopy %v", err) @@ -269,6 +273,15 @@ func (ac *Azcopy) GetAzcopyJob(dstBlobContainer string) (AzcopyJobState, string, return jobState, percent, nil } +// TestListJobs test azcopy jobs list command with authAzcopyEnv +func (ac *Azcopy) TestListJobs(accountName, storageEndpointSuffix string, authAzcopyEnv []string) (string, error) { + cmdStr := fmt.Sprintf("azcopy list %s", fmt.Sprintf("https://%s.blob.%s", accountName, storageEndpointSuffix)) + if ac.ExecCmd == nil { + ac.ExecCmd = &ExecCommand{} + } + return ac.ExecCmd.RunCommand(cmdStr, authAzcopyEnv) +} + // parseAzcopyJobList parse command azcopy jobs list, get jobid and state from joblist func parseAzcopyJobList(joblist string) (string, AzcopyJobState, error) { jobid := "" diff --git a/pkg/util/util_mock.go b/pkg/util/util_mock.go index f381ec968..6764d3fcf 100644 --- a/pkg/util/util_mock.go +++ b/pkg/util/util_mock.go @@ -51,16 +51,16 @@ func (m *MockEXEC) EXPECT() *MockEXECMockRecorder { } // RunCommand mocks base method. -func (m *MockEXEC) RunCommand(arg0 string) (string, error) { +func (m *MockEXEC) RunCommand(arg0 string, arg1 []string) (string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RunCommand", arg0) + ret := m.ctrl.Call(m, "RunCommand", arg0, arg1) ret0, _ := ret[0].(string) ret1, _ := ret[1].(error) return ret0, ret1 } // RunCommand indicates an expected call of RunCommand. -func (mr *MockEXECMockRecorder) RunCommand(arg0 interface{}) *gomock.Call { +func (mr *MockEXECMockRecorder) RunCommand(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RunCommand", reflect.TypeOf((*MockEXEC)(nil).RunCommand), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RunCommand", reflect.TypeOf((*MockEXEC)(nil).RunCommand), arg0, arg1) } diff --git a/pkg/util/util_test.go b/pkg/util/util_test.go index 06ba4a400..4c3832b25 100644 --- a/pkg/util/util_test.go +++ b/pkg/util/util_test.go @@ -419,14 +419,14 @@ func TestGetAzcopyJob(t *testing.T) { defer ctrl.Finish() m := NewMockEXEC(ctrl) - m.EXPECT().RunCommand(gomock.Eq("azcopy jobs list | grep dstBlobContainer -B 3")).Return(test.listStr, test.listErr) + m.EXPECT().RunCommand(gomock.Eq("azcopy jobs list | grep dstBlobContainer -B 3"), []string{}).Return(test.listStr, test.listErr) if test.enableShow { - m.EXPECT().RunCommand(gomock.Not("azcopy jobs list | grep dstBlobContainer -B 3")).Return(test.showStr, test.showErr) + m.EXPECT().RunCommand(gomock.Not("azcopy jobs list | grep dstBlobContainer -B 3"), []string{}).Return(test.showStr, test.showErr) } azcopyFunc := &Azcopy{} azcopyFunc.ExecCmd = m - jobState, percent, err := azcopyFunc.GetAzcopyJob(dstBlobContainer) + jobState, percent, err := azcopyFunc.GetAzcopyJob(dstBlobContainer, []string{}) if jobState != test.expectedJobState || percent != test.expectedPercent || !reflect.DeepEqual(err, test.expectedErr) { t.Errorf("test[%s]: unexpected jobState: %v, percent: %v, err: %v, expected jobState: %v, percent: %v, err: %v", test.desc, jobState, percent, err, test.expectedJobState, test.expectedPercent, test.expectedErr) }