Skip to content

Commit

Permalink
Apply extra volume tags to EBS snapshots
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishenzie committed Oct 8, 2020
1 parent ee64872 commit 7c17cb5
Show file tree
Hide file tree
Showing 11 changed files with 264 additions and 43 deletions.
1 change: 1 addition & 0 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func main() {

drv, err := driver.NewDriver(
driver.WithEndpoint(options.ServerOptions.Endpoint),
driver.WithExtraTags(options.ControllerOptions.ExtraTags),
driver.WithExtraVolumeTags(options.ControllerOptions.ExtraVolumeTags),
driver.WithMode(options.DriverMode),
driver.WithVolumeAttachLimit(options.NodeOptions.VolumeAttachLimit),
Expand Down
10 changes: 7 additions & 3 deletions cmd/options/controller_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,19 @@ import (

// ControllerOptions contains options and configuration settings for the controller service.
type ControllerOptions struct {
// ExtraTags is a map of tags that will be attached to each dynamically provisioned
// resource.
ExtraTags map[string]string
// ExtraVolumeTags is a map of tags that will be attached to each dynamically provisioned
// volume.
// DEPRECATED: Use ExtraTags instead.
ExtraVolumeTags map[string]string
// ID of the kubernetes cluster. This is used only to create the same tags on volumes that
// in-tree volume volume plugin does.
// ID of the kubernetes cluster.
KubernetesClusterID string
}

func (s *ControllerOptions) AddFlags(fs *flag.FlagSet) {
fs.Var(cliflag.NewMapStringString(&s.ExtraVolumeTags), "extra-volume-tags", "Extra volume tags to attach to each dynamically provisioned volume. It is a comma separated list of key value pairs like '<key1>=<value1>,<key2>=<value2>'")
fs.Var(cliflag.NewMapStringString(&s.ExtraTags), "extra-tags", "Extra tags to attach to each dynamically provisioned resource. It is a comma separated list of key value pairs like '<key1>=<value1>,<key2>=<value2>'")
fs.Var(cliflag.NewMapStringString(&s.ExtraVolumeTags), "extra-volume-tags", "DEPRECATED: Please use --extra-tags instead. Extra volume tags to attach to each dynamically provisioned volume. It is a comma separated list of key value pairs like '<key1>=<value1>,<key2>=<value2>'")
fs.StringVar(&s.KubernetesClusterID, "k8s-tag-cluster-id", "", "ID of the Kubernetes cluster used for tagging provisioned EBS volumes (optional).")
}
22 changes: 11 additions & 11 deletions cmd/options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ func TestGetOptions(t *testing.T) {
endpointFlagName := "endpoint"
endpoint := "foo"

extraVolumeTagsFlagName := "extra-volume-tags"
extraVolumeTagKey := "bar"
extraVolumeTagValue := "baz"
extraVolumeTags := map[string]string{
extraVolumeTagKey: extraVolumeTagValue,
extraTagsFlagName := "extra-tags"
extraTagKey := "bar"
extraTagValue := "baz"
extraTags := map[string]string{
extraTagKey: extraTagValue,
}

VolumeAttachLimitFlagName := "volume-attach-limit"
Expand All @@ -57,7 +57,7 @@ func TestGetOptions(t *testing.T) {
args = append(args, "-"+endpointFlagName+"="+endpoint)
}
if withControllerOptions {
args = append(args, "-"+extraVolumeTagsFlagName+"="+extraVolumeTagKey+"="+extraVolumeTagValue)
args = append(args, "-"+extraTagsFlagName+"="+extraTagKey+"="+extraTagValue)
}
if withNodeOptions {
args = append(args, "-"+VolumeAttachLimitFlagName+"="+strconv.FormatInt(VolumeAttachLimit, 10))
Expand All @@ -80,12 +80,12 @@ func TestGetOptions(t *testing.T) {
}

if withControllerOptions {
extraVolumeTagsFlag := flagSet.Lookup(extraVolumeTagsFlagName)
if extraVolumeTagsFlag == nil {
t.Fatalf("expected %q flag to be added but it is not", extraVolumeTagsFlagName)
extraTagsFlag := flagSet.Lookup(extraTagsFlagName)
if extraTagsFlag == nil {
t.Fatalf("expected %q flag to be added but it is not", extraTagsFlagName)
}
if !reflect.DeepEqual(options.ControllerOptions.ExtraVolumeTags, extraVolumeTags) {
t.Fatalf("expected extra volume tags to be %q but it is %q", extraVolumeTags, options.ControllerOptions.ExtraVolumeTags)
if !reflect.DeepEqual(options.ControllerOptions.ExtraTags, extraTags) {
t.Fatalf("expected extra tags to be %q but it is %q", extraTags, options.ControllerOptions.ExtraTags)
}
}

Expand Down
4 changes: 3 additions & 1 deletion pkg/cloud/cloud.go
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,9 @@ func (c *cloud) CreateSnapshot(ctx context.Context, volumeID string, snapshotOpt

var tags []*ec2.Tag
for key, value := range snapshotOptions.Tags {
tags = append(tags, &ec2.Tag{Key: &key, Value: &value})
copiedKey := key
copiedValue := value
tags = append(tags, &ec2.Tag{Key: &copiedKey, Value: &copiedValue})
}
tagSpec := ec2.TagSpecification{
ResourceType: aws.String("snapshot"),
Expand Down
59 changes: 58 additions & 1 deletion pkg/cloud/cloud_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"context"
"errors"
"fmt"
"reflect"
"sort"
"strings"
"testing"

Expand Down Expand Up @@ -562,6 +564,7 @@ func TestCreateSnapshot(t *testing.T) {
name string
snapshotName string
snapshotOptions *SnapshotOptions
expInput *ec2.CreateSnapshotInput
expSnapshot *Snapshot
expErr error
}{
Expand All @@ -571,8 +574,29 @@ func TestCreateSnapshot(t *testing.T) {
snapshotOptions: &SnapshotOptions{
Tags: map[string]string{
SnapshotNameTagKey: "snap-test-name",
"extra-tag-key": "extra-tag-value",
},
},
expInput: &ec2.CreateSnapshotInput{
VolumeId: aws.String("snap-test-volume"),
DryRun: aws.Bool(false),
TagSpecifications: []*ec2.TagSpecification{
{
ResourceType: aws.String("snapshot"),
Tags: []*ec2.Tag{
{
Key: aws.String(SnapshotNameTagKey),
Value: aws.String("snap-test-name"),
},
{
Key: aws.String("extra-tag-key"),
Value: aws.String("extra-tag-value"),
},
},
},
},
Description: aws.String("Created by AWS EBS CSI driver for volume snap-test-volume"),
},
expSnapshot: &Snapshot{
SourceVolumeID: "snap-test-volume",
},
Expand All @@ -593,7 +617,7 @@ func TestCreateSnapshot(t *testing.T) {
}

ctx := context.Background()
mockEC2.EXPECT().CreateSnapshotWithContext(gomock.Eq(ctx), gomock.Any()).Return(ec2snapshot, tc.expErr)
mockEC2.EXPECT().CreateSnapshotWithContext(gomock.Eq(ctx), eqCreateSnapshotInput(tc.expInput)).Return(ec2snapshot, tc.expErr)
mockEC2.EXPECT().DescribeSnapshotsWithContext(gomock.Eq(ctx), gomock.Any()).Return(&ec2.DescribeSnapshotsOutput{Snapshots: []*ec2.Snapshot{ec2snapshot}}, nil).AnyTimes()

snapshot, err := c.CreateSnapshot(ctx, tc.expSnapshot.SourceVolumeID, tc.snapshotOptions)
Expand Down Expand Up @@ -845,6 +869,7 @@ func TestGetSnapshotByName(t *testing.T) {
snapshotOptions: &SnapshotOptions{
Tags: map[string]string{
SnapshotNameTagKey: "snap-test-name",
"extra-tag-key": "extra-tag-value",
},
},
expSnapshot: &Snapshot{
Expand Down Expand Up @@ -899,6 +924,7 @@ func TestGetSnapshotByID(t *testing.T) {
snapshotOptions: &SnapshotOptions{
Tags: map[string]string{
SnapshotNameTagKey: "snap-test-name",
"extra-tag-key": "extra-tag-value",
},
},
expSnapshot: &Snapshot{
Expand Down Expand Up @@ -1166,3 +1192,34 @@ func newDescribeInstancesOutput(nodeID string) *ec2.DescribeInstancesOutput {
}},
}
}

type eqCreateSnapshotInputMatcher struct {
expected *ec2.CreateSnapshotInput
}

func eqCreateSnapshotInput(expected *ec2.CreateSnapshotInput) gomock.Matcher {
return &eqCreateSnapshotInputMatcher{expected}
}

func (m *eqCreateSnapshotInputMatcher) Matches(x interface{}) bool {
input, ok := x.(*ec2.CreateSnapshotInput)
if !ok {
return false
}

if input != nil {
for _, ts := range input.TagSpecifications {
// Because these tags are generated from a map
// which has a random order.
sort.SliceStable(ts.Tags, func(i, j int) bool {
return *ts.Tags[i].Key < *ts.Tags[j].Key
})
}
}

return reflect.DeepEqual(m.expected, input)
}

func (m *eqCreateSnapshotInputMatcher) String() string {
return m.expected.String()
}
17 changes: 15 additions & 2 deletions pkg/driver/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ func (d *controllerService) CreateVolume(ctx context.Context, req *csi.CreateVol
volumeTags[resourceLifecycleTag] = ResourceLifecycleOwned
volumeTags[NameTag] = d.driverOptions.kubernetesClusterID + "-dynamic-" + volName
}
for k, v := range d.driverOptions.extraVolumeTags {
for k, v := range d.driverOptions.extraTags {
volumeTags[k] = v
}

Expand Down Expand Up @@ -441,9 +441,22 @@ func (d *controllerService) CreateSnapshot(ctx context.Context, req *csi.CreateS
klog.V(4).Infof("Snapshot %s of volume %s already exists; nothing to do", snapshotName, volumeID)
return newCreateSnapshotResponse(snapshot)
}

snapshotTags := map[string]string{
cloud.SnapshotNameTagKey: snapshotName,
}
if d.driverOptions.kubernetesClusterID != "" {
resourceLifecycleTag := ResourceLifecycleTagPrefix + d.driverOptions.kubernetesClusterID
snapshotTags[resourceLifecycleTag] = ResourceLifecycleOwned
snapshotTags[NameTag] = d.driverOptions.kubernetesClusterID + "-dynamic-" + snapshotName
}
for k, v := range d.driverOptions.extraTags {
snapshotTags[k] = v
}
opts := &cloud.SnapshotOptions{
Tags: map[string]string{cloud.SnapshotNameTagKey: snapshotName},
Tags: snapshotTags,
}

snapshot, err = d.cloud.CreateSnapshot(ctx, volumeID, opts)

if err != nil {
Expand Down
114 changes: 113 additions & 1 deletion pkg/driver/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1185,7 +1185,7 @@ func TestCreateVolume(t *testing.T) {
awsDriver := controllerService{
cloud: mockCloud,
driverOptions: &DriverOptions{
extraVolumeTags: map[string]string{
extraTags: map[string]string{
extraVolumeTagKey: extraVolumeTagValue,
},
},
Expand Down Expand Up @@ -1723,6 +1723,118 @@ func TestCreateSnapshot(t *testing.T) {
}
},
},
{
name: "success with cluster-id",
testFunc: func(t *testing.T) {
const (
snapshotName = "test-snapshot"
clusterID = "test-cluster-id"
expectedOwnerTag = "kubernetes.io/cluster/test-cluster-id"
expectedOwnerTagValue = "owned"
expectedNameTag = "Name"
expectedNameTagValue = "test-cluster-id-dynamic-test-snapshot"
)
req := &csi.CreateSnapshotRequest{
Name: snapshotName,
Parameters: nil,
SourceVolumeId: "vol-test",
}
expSnapshot := &csi.Snapshot{
ReadyToUse: true,
}

ctx := context.Background()
mockSnapshot := &cloud.Snapshot{
SnapshotID: fmt.Sprintf("snapshot-%d", rand.New(rand.NewSource(time.Now().UnixNano())).Uint64()),
SourceVolumeID: req.SourceVolumeId,
Size: 1,
CreationTime: time.Now(),
}
snapshotOptions := &cloud.SnapshotOptions{
Tags: map[string]string{
cloud.SnapshotNameTagKey: snapshotName,
expectedOwnerTag: expectedOwnerTagValue,
expectedNameTag: expectedNameTagValue,
},
}
mockCtl := gomock.NewController(t)
defer mockCtl.Finish()

mockCloud := mocks.NewMockCloud(mockCtl)
mockCloud.EXPECT().CreateSnapshot(gomock.Eq(ctx), gomock.Eq(req.SourceVolumeId), gomock.Eq(snapshotOptions)).Return(mockSnapshot, nil)
mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq(req.GetName())).Return(nil, cloud.ErrNotFound)

awsDriver := controllerService{
cloud: mockCloud,
driverOptions: &DriverOptions{
kubernetesClusterID: clusterID,
},
}
resp, err := awsDriver.CreateSnapshot(context.Background(), req)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

if snap := resp.GetSnapshot(); snap == nil {
t.Fatalf("Expected snapshot %v, got nil", expSnapshot)
}
},
},
{
name: "success with extra tags",
testFunc: func(t *testing.T) {
const (
snapshotName = "test-snapshot"
extraVolumeTagKey = "extra-tag-key"
extraVolumeTagValue = "extra-tag-value"
)
req := &csi.CreateSnapshotRequest{
Name: snapshotName,
Parameters: nil,
SourceVolumeId: "vol-test",
}
expSnapshot := &csi.Snapshot{
ReadyToUse: true,
}

ctx := context.Background()
mockSnapshot := &cloud.Snapshot{
SnapshotID: fmt.Sprintf("snapshot-%d", rand.New(rand.NewSource(time.Now().UnixNano())).Uint64()),
SourceVolumeID: req.SourceVolumeId,
Size: 1,
CreationTime: time.Now(),
}
snapshotOptions := &cloud.SnapshotOptions{
Tags: map[string]string{
cloud.SnapshotNameTagKey: snapshotName,
extraVolumeTagKey: extraVolumeTagValue,
},
}
mockCtl := gomock.NewController(t)
defer mockCtl.Finish()

mockCloud := mocks.NewMockCloud(mockCtl)
mockCloud.EXPECT().CreateSnapshot(gomock.Eq(ctx), gomock.Eq(req.SourceVolumeId), gomock.Eq(snapshotOptions)).Return(mockSnapshot, nil)
mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq(req.GetName())).Return(nil, cloud.ErrNotFound)

awsDriver := controllerService{
cloud: mockCloud,
driverOptions: &DriverOptions{
extraTags: map[string]string{
extraVolumeTagKey: extraVolumeTagValue,
},
},
}
resp, err := awsDriver.CreateSnapshot(context.Background(), req)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

if snap := resp.GetSnapshot(); snap == nil {
t.Fatalf("Expected snapshot %v, got nil", expSnapshot)
}
},
},
{
name: "fail no name",
testFunc: func(t *testing.T) {
Expand Down
13 changes: 11 additions & 2 deletions pkg/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ type Driver struct {

type DriverOptions struct {
endpoint string
extraVolumeTags map[string]string
extraTags map[string]string
mode Mode
volumeAttachLimit int64
kubernetesClusterID string
Expand Down Expand Up @@ -150,9 +150,18 @@ func WithEndpoint(endpoint string) func(*DriverOptions) {
}
}

func WithExtraTags(extraTags map[string]string) func(*DriverOptions) {
return func(o *DriverOptions) {
o.extraTags = extraTags
}
}

func WithExtraVolumeTags(extraVolumeTags map[string]string) func(*DriverOptions) {
return func(o *DriverOptions) {
o.extraVolumeTags = extraVolumeTags
if o.extraTags == nil && extraVolumeTags != nil {
klog.Warning("DEPRECATION WARNING: --extra-volume-tags is deprecated, please use --extra-tags instead")
o.extraTags = extraVolumeTags
}
}
}

Expand Down
Loading

0 comments on commit 7c17cb5

Please sign in to comment.