Skip to content

Commit

Permalink
Tag volumes on Create/Attach (#130)
Browse files Browse the repository at this point in the history
* tag DO volumes on create/attach
  • Loading branch information
jcodybaker committed Apr 25, 2019
1 parent afdbad8 commit ed8c3e7
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 2 deletions.
3 changes: 2 additions & 1 deletion cmd/do-csi-plugin/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func main() {
endpoint = flag.String("endpoint", "unix:///var/lib/kubelet/plugins/"+driver.DriverName+"/csi.sock", "CSI endpoint")
token = flag.String("token", "", "DigitalOcean access token")
url = flag.String("url", "https://api.digitalocean.com/", "DigitalOcean API URL")
doTag = flag.String("do-tag", "", "Tag DigitalOcean volumes on Create/Attach")
version = flag.Bool("version", false, "Print the version and exit.")
)
flag.Parse()
Expand All @@ -39,7 +40,7 @@ func main() {
os.Exit(0)
}

drv, err := driver.NewDriver(*endpoint, *token, *url)
drv, err := driver.NewDriver(*endpoint, *token, *url, *doTag)
if err != nil {
log.Fatalln(err)
}
Expand Down
58 changes: 58 additions & 0 deletions driver/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ const (

// createdByDO is used to tag volumes that are created by this CSI plugin
createdByDO = "Created by DigitalOcean CSI driver"

// doAPITimeout sets the timeout we will use when communicating with the
// Digital Ocean API. NOTE: some queries inherit the context timeout
doAPITimeout = 10 * time.Second
)

var (
Expand Down Expand Up @@ -148,6 +152,10 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest)
SizeGigaBytes: size / GB,
}

if d.doTag != "" {
volumeReq.Tags = append(volumeReq.Tags, d.doTag)
}

ll.Info("checking volume limit")
if err := d.checkLimit(ctx); err != nil {
return nil, err
Expand Down Expand Up @@ -252,6 +260,14 @@ func (d *Driver) ControllerPublishVolume(ctx context.Context, req *csi.Controlle
return nil, err
}

if d.doTag != "" {
err = d.tagVolume(ctx, vol)
if err != nil {
ll.Errorf("error tagging volume: %s", err)
return nil, status.Errorf(codes.Internal, "failed to tag volume")
}
}

// check if droplet exist before trying to attach the volume to the droplet
_, resp, err = d.droplets.Get(ctx, dropletID)
if err != nil {
Expand Down Expand Up @@ -970,3 +986,45 @@ func validateCapabilities(caps []*csi.VolumeCapability) bool {

return supported
}

func (d *Driver) tagVolume(parentCtx context.Context, vol *godo.Volume) error {
for _, tag := range vol.Tags {
if tag == d.doTag {
return nil
}
}

tagReq := &godo.TagResourcesRequest{
Resources: []godo.Resource{
godo.Resource{
ID: vol.ID,
Type: godo.VolumeResourceType,
},
},
}

ctx, cancel := context.WithTimeout(parentCtx, doAPITimeout)
defer cancel()
resp, err := d.tags.TagResources(ctx, d.doTag, tagReq)
if resp == nil || resp.StatusCode != http.StatusNotFound {
// either success or irrecoverable failure
return err
}

// godo.TagsService returns 404 if the tag has not yet been
// created, if that happens we need to create the tag
// and then retry tagging the volume resource.
ctx, cancel = context.WithTimeout(parentCtx, doAPITimeout)
defer cancel()
_, _, err = d.tags.Create(parentCtx, &godo.TagCreateRequest{
Name: d.doTag,
})
if err != nil {
return err
}

ctx, cancel = context.WithTimeout(parentCtx, doAPITimeout)
defer cancel()
_, err = d.tags.TagResources(ctx, d.doTag, tagReq)
return err
}
153 changes: 153 additions & 0 deletions driver/controller_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
package driver

import (
"context"
"errors"
"net/http"
"testing"

"github.com/digitalocean/godo"
)

func TestTagger(t *testing.T) {
tag := "k8s:my-cluster-id"
tcs := []struct {
name string
vol *godo.Volume
createTagFunc func(ctx context.Context, req *godo.TagCreateRequest) (*godo.Tag, *godo.Response, error)
tagResourcesFunc func(context.Context, string, *godo.TagResourcesRequest) (*godo.Response, error)
tagExists bool
expectCreates int
expectTagResources int
expectError bool
expectTags int
}{
{
name: "success existing tag",
vol: &godo.Volume{ID: "hello-world"},
expectTagResources: 1,
tagExists: true,
expectTags: 1,
},
{
name: "success with new tag",
vol: &godo.Volume{ID: "hello-world"},
expectCreates: 1,
expectTagResources: 2,
expectTags: 1,
},
{
name: "success already tagged",
vol: &godo.Volume{
ID: "hello-world",
Tags: []string{tag},
},
expectCreates: 0,
expectTagResources: 0,
},
{
name: "failed first tag",
vol: &godo.Volume{ID: "hello-world"},
expectCreates: 0,
expectTagResources: 1,
expectError: true,
tagResourcesFunc: func(context.Context, string, *godo.TagResourcesRequest) (*godo.Response, error) {
return nil, errors.New("an error")
},
},
{
name: "failed create tag",
vol: &godo.Volume{ID: "hello-world"},
expectCreates: 1,
expectTagResources: 1,
expectError: true,
createTagFunc: func(ctx context.Context, req *godo.TagCreateRequest) (*godo.Tag, *godo.Response, error) {
return nil, nil, errors.New("an error")
},
},
}

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {

tagService := &fakeTagsDriver{
createFunc: tc.createTagFunc,
tagResourcesFunc: tc.tagResourcesFunc,
exists: tc.tagExists,
}
driver := &Driver{
doTag: tag,
tags: tagService,
}

err := driver.tagVolume(context.Background(), tc.vol)

if err != nil && !tc.expectError {
t.Errorf("expected success but got error %v", err)
} else if tc.expectError && err == nil {
t.Error("expected error but got success")
}

if tagService.createCount != tc.expectCreates {
t.Errorf("createCount was %d, expected %d", tagService.createCount, tc.expectCreates)
}
if tagService.tagResourcesCount != tc.expectTagResources {
t.Errorf("tagResourcesCount was %d, expected %d", tagService.tagResourcesCount, tc.expectTagResources)
}
if tc.expectTags != len(tagService.resources) {
t.Errorf("expected %d tagged volume, %d found", tc.expectTags, len(tagService.resources))
}
})
}
}

type fakeTagsDriver struct {
createFunc func(ctx context.Context, req *godo.TagCreateRequest) (*godo.Tag, *godo.Response, error)
tagResourcesFunc func(context.Context, string, *godo.TagResourcesRequest) (*godo.Response, error)
exists bool
resources []godo.Resource
createCount int
tagResourcesCount int
}

func (*fakeTagsDriver) List(context.Context, *godo.ListOptions) ([]godo.Tag, *godo.Response, error) {
panic("not implemented")
}

func (*fakeTagsDriver) Get(context.Context, string) (*godo.Tag, *godo.Response, error) {
panic("not implemented")
}

func (f *fakeTagsDriver) Create(ctx context.Context, req *godo.TagCreateRequest) (*godo.Tag, *godo.Response, error) {
f.createCount++
if f.createFunc != nil {
return f.createFunc(ctx, req)
}
f.exists = true
return &godo.Tag{
Name: req.Name,
}, godoResponse(), nil
}

func (*fakeTagsDriver) Delete(context.Context, string) (*godo.Response, error) {
panic("not implemented")
}

func (f *fakeTagsDriver) TagResources(ctx context.Context, tag string, req *godo.TagResourcesRequest) (*godo.Response, error) {
f.tagResourcesCount++
if f.tagResourcesFunc != nil {
return f.tagResourcesFunc(ctx, tag, req)
}
if !f.exists {
return &godo.Response{
Response: &http.Response{StatusCode: 404},
Rate: godo.Rate{Limit: 10, Remaining: 10},
}, errors.New("An error occured")
}
f.resources = append(f.resources, req.Resources...)
return godoResponse(), nil
}

func (*fakeTagsDriver) UntagResources(context.Context, string, *godo.UntagResourcesRequest) (*godo.Response, error) {
panic("not implemented")
}
6 changes: 5 additions & 1 deletion driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ type Driver struct {
endpoint string
nodeId string
region string
doTag string

srv *grpc.Server
log *logrus.Entry
Expand All @@ -67,6 +68,7 @@ type Driver struct {
droplets godo.DropletsService
snapshots godo.SnapshotsService
account godo.AccountService
tags godo.TagsService

// ready defines whether the driver is ready to function. This value will
// be used by the `Identity` service via the `Probe()` method.
Expand All @@ -77,7 +79,7 @@ type Driver struct {
// NewDriver returns a CSI plugin that contains the necessary gRPC
// interfaces to interact with Kubernetes over unix domain sockets for
// managaing DigitalOcean Block Storage
func NewDriver(ep, token, url string) (*Driver, error) {
func NewDriver(ep, token, url, doTag string) (*Driver, error) {
tokenSource := oauth2.StaticTokenSource(&oauth2.Token{
AccessToken: token,
})
Expand Down Expand Up @@ -106,6 +108,7 @@ func NewDriver(ep, token, url string) (*Driver, error) {
})

return &Driver{
doTag: doTag,
endpoint: ep,
nodeId: nodeId,
region: region,
Expand All @@ -117,6 +120,7 @@ func NewDriver(ep, token, url string) (*Driver, error) {
droplets: doClient.Droplets,
snapshots: doClient.Snapshots,
account: doClient.Account,
tags: doClient.Tags,
}, nil
}

Expand Down
3 changes: 3 additions & 0 deletions driver/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ func TestDriverSuite(t *testing.T) {
}

nodeID := 987654
doTag := "k8s:cluster-id"
volumes := make(map[string]*godo.Volume, 0)
snapshots := make(map[string]*godo.Snapshot, 0)
droplets := map[int]*godo.Droplet{
Expand All @@ -70,6 +71,7 @@ func TestDriverSuite(t *testing.T) {
driver := &Driver{
endpoint: endpoint,
nodeId: strconv.Itoa(nodeID),
doTag: doTag,
region: "nyc3",
mounter: &fakeMounter{},
log: logrus.New().WithField("test_enabed", true),
Expand All @@ -89,6 +91,7 @@ func TestDriverSuite(t *testing.T) {
snapshots: snapshots,
},
account: &fakeAccountDriver{},
tags: &fakeTagsDriver{},
}
defer driver.Stop()

Expand Down

0 comments on commit ed8c3e7

Please sign in to comment.