diff --git a/cmd/do-csi-plugin/main.go b/cmd/do-csi-plugin/main.go index 6667be88..323b001c 100644 --- a/cmd/do-csi-plugin/main.go +++ b/cmd/do-csi-plugin/main.go @@ -30,6 +30,7 @@ func main() { endpoint = flag.String("endpoint", "unix:///var/lib/kubelet/plugins/"+driver.DriverName+"/csi.sock", "CSI endpoint") token = flag.String("token", "", "DigitalOcean access token") url = flag.String("url", "https://api.digitalocean.com/", "DigitalOcean API URL") + doTag = flag.String("do-tag", "", "Tag DigitalOcean volumes on Create/Attach") version = flag.Bool("version", false, "Print the version and exit.") ) flag.Parse() @@ -39,7 +40,7 @@ func main() { os.Exit(0) } - drv, err := driver.NewDriver(*endpoint, *token, *url) + drv, err := driver.NewDriver(*endpoint, *token, *url, *doTag) if err != nil { log.Fatalln(err) } diff --git a/driver/controller.go b/driver/controller.go index 923e4a35..e6a6487b 100644 --- a/driver/controller.go +++ b/driver/controller.go @@ -58,6 +58,10 @@ const ( // createdByDO is used to tag volumes that are created by this CSI plugin createdByDO = "Created by DigitalOcean CSI driver" + + // doAPITimeout sets the timeout we will use when communicating with the + // Digital Ocean API. NOTE: some queries inherit the context timeout + doAPITimeout = 10 * time.Second ) var ( @@ -148,6 +152,10 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) SizeGigaBytes: size / GB, } + if d.doTag != "" { + volumeReq.Tags = append(volumeReq.Tags, d.doTag) + } + ll.Info("checking volume limit") if err := d.checkLimit(ctx); err != nil { return nil, err @@ -252,6 +260,14 @@ func (d *Driver) ControllerPublishVolume(ctx context.Context, req *csi.Controlle return nil, err } + if d.doTag != "" { + err = d.tagVolume(ctx, vol) + if err != nil { + ll.Errorf("error tagging volume: %s", err) + return nil, status.Errorf(codes.Internal, "failed to tag volume") + } + } + // check if droplet exist before trying to attach the volume to the droplet _, resp, err = d.droplets.Get(ctx, dropletID) if err != nil { @@ -970,3 +986,45 @@ func validateCapabilities(caps []*csi.VolumeCapability) bool { return supported } + +func (d *Driver) tagVolume(parentCtx context.Context, vol *godo.Volume) error { + for _, tag := range vol.Tags { + if tag == d.doTag { + return nil + } + } + + tagReq := &godo.TagResourcesRequest{ + Resources: []godo.Resource{ + godo.Resource{ + ID: vol.ID, + Type: godo.VolumeResourceType, + }, + }, + } + + ctx, cancel := context.WithTimeout(parentCtx, doAPITimeout) + defer cancel() + resp, err := d.tags.TagResources(ctx, d.doTag, tagReq) + if resp == nil || resp.StatusCode != http.StatusNotFound { + // either success or irrecoverable failure + return err + } + + // godo.TagsService returns 404 if the tag has not yet been + // created, if that happens we need to create the tag + // and then retry tagging the volume resource. + ctx, cancel = context.WithTimeout(parentCtx, doAPITimeout) + defer cancel() + _, _, err = d.tags.Create(parentCtx, &godo.TagCreateRequest{ + Name: d.doTag, + }) + if err != nil { + return err + } + + ctx, cancel = context.WithTimeout(parentCtx, doAPITimeout) + defer cancel() + _, err = d.tags.TagResources(ctx, d.doTag, tagReq) + return err +} diff --git a/driver/controller_test.go b/driver/controller_test.go new file mode 100644 index 00000000..858a0814 --- /dev/null +++ b/driver/controller_test.go @@ -0,0 +1,153 @@ +package driver + +import ( + "context" + "errors" + "net/http" + "testing" + + "github.com/digitalocean/godo" +) + +func TestTagger(t *testing.T) { + tag := "k8s:my-cluster-id" + tcs := []struct { + name string + vol *godo.Volume + createTagFunc func(ctx context.Context, req *godo.TagCreateRequest) (*godo.Tag, *godo.Response, error) + tagResourcesFunc func(context.Context, string, *godo.TagResourcesRequest) (*godo.Response, error) + tagExists bool + expectCreates int + expectTagResources int + expectError bool + expectTags int + }{ + { + name: "success existing tag", + vol: &godo.Volume{ID: "hello-world"}, + expectTagResources: 1, + tagExists: true, + expectTags: 1, + }, + { + name: "success with new tag", + vol: &godo.Volume{ID: "hello-world"}, + expectCreates: 1, + expectTagResources: 2, + expectTags: 1, + }, + { + name: "success already tagged", + vol: &godo.Volume{ + ID: "hello-world", + Tags: []string{tag}, + }, + expectCreates: 0, + expectTagResources: 0, + }, + { + name: "failed first tag", + vol: &godo.Volume{ID: "hello-world"}, + expectCreates: 0, + expectTagResources: 1, + expectError: true, + tagResourcesFunc: func(context.Context, string, *godo.TagResourcesRequest) (*godo.Response, error) { + return nil, errors.New("an error") + }, + }, + { + name: "failed create tag", + vol: &godo.Volume{ID: "hello-world"}, + expectCreates: 1, + expectTagResources: 1, + expectError: true, + createTagFunc: func(ctx context.Context, req *godo.TagCreateRequest) (*godo.Tag, *godo.Response, error) { + return nil, nil, errors.New("an error") + }, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + + tagService := &fakeTagsDriver{ + createFunc: tc.createTagFunc, + tagResourcesFunc: tc.tagResourcesFunc, + exists: tc.tagExists, + } + driver := &Driver{ + doTag: tag, + tags: tagService, + } + + err := driver.tagVolume(context.Background(), tc.vol) + + if err != nil && !tc.expectError { + t.Errorf("expected success but got error %v", err) + } else if tc.expectError && err == nil { + t.Error("expected error but got success") + } + + if tagService.createCount != tc.expectCreates { + t.Errorf("createCount was %d, expected %d", tagService.createCount, tc.expectCreates) + } + if tagService.tagResourcesCount != tc.expectTagResources { + t.Errorf("tagResourcesCount was %d, expected %d", tagService.tagResourcesCount, tc.expectTagResources) + } + if tc.expectTags != len(tagService.resources) { + t.Errorf("expected %d tagged volume, %d found", tc.expectTags, len(tagService.resources)) + } + }) + } +} + +type fakeTagsDriver struct { + createFunc func(ctx context.Context, req *godo.TagCreateRequest) (*godo.Tag, *godo.Response, error) + tagResourcesFunc func(context.Context, string, *godo.TagResourcesRequest) (*godo.Response, error) + exists bool + resources []godo.Resource + createCount int + tagResourcesCount int +} + +func (*fakeTagsDriver) List(context.Context, *godo.ListOptions) ([]godo.Tag, *godo.Response, error) { + panic("not implemented") +} + +func (*fakeTagsDriver) Get(context.Context, string) (*godo.Tag, *godo.Response, error) { + panic("not implemented") +} + +func (f *fakeTagsDriver) Create(ctx context.Context, req *godo.TagCreateRequest) (*godo.Tag, *godo.Response, error) { + f.createCount++ + if f.createFunc != nil { + return f.createFunc(ctx, req) + } + f.exists = true + return &godo.Tag{ + Name: req.Name, + }, godoResponse(), nil +} + +func (*fakeTagsDriver) Delete(context.Context, string) (*godo.Response, error) { + panic("not implemented") +} + +func (f *fakeTagsDriver) TagResources(ctx context.Context, tag string, req *godo.TagResourcesRequest) (*godo.Response, error) { + f.tagResourcesCount++ + if f.tagResourcesFunc != nil { + return f.tagResourcesFunc(ctx, tag, req) + } + if !f.exists { + return &godo.Response{ + Response: &http.Response{StatusCode: 404}, + Rate: godo.Rate{Limit: 10, Remaining: 10}, + }, errors.New("An error occured") + } + f.resources = append(f.resources, req.Resources...) + return godoResponse(), nil +} + +func (*fakeTagsDriver) UntagResources(context.Context, string, *godo.UntagResourcesRequest) (*godo.Response, error) { + panic("not implemented") +} diff --git a/driver/driver.go b/driver/driver.go index 8f407d53..75996b31 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -57,6 +57,7 @@ type Driver struct { endpoint string nodeId string region string + doTag string srv *grpc.Server log *logrus.Entry @@ -67,6 +68,7 @@ type Driver struct { droplets godo.DropletsService snapshots godo.SnapshotsService account godo.AccountService + tags godo.TagsService // ready defines whether the driver is ready to function. This value will // be used by the `Identity` service via the `Probe()` method. @@ -77,7 +79,7 @@ type Driver struct { // NewDriver returns a CSI plugin that contains the necessary gRPC // interfaces to interact with Kubernetes over unix domain sockets for // managaing DigitalOcean Block Storage -func NewDriver(ep, token, url string) (*Driver, error) { +func NewDriver(ep, token, url, doTag string) (*Driver, error) { tokenSource := oauth2.StaticTokenSource(&oauth2.Token{ AccessToken: token, }) @@ -106,6 +108,7 @@ func NewDriver(ep, token, url string) (*Driver, error) { }) return &Driver{ + doTag: doTag, endpoint: ep, nodeId: nodeId, region: region, @@ -117,6 +120,7 @@ func NewDriver(ep, token, url string) (*Driver, error) { droplets: doClient.Droplets, snapshots: doClient.Snapshots, account: doClient.Account, + tags: doClient.Tags, }, nil } diff --git a/driver/driver_test.go b/driver/driver_test.go index a2e53345..787f0bf5 100644 --- a/driver/driver_test.go +++ b/driver/driver_test.go @@ -59,6 +59,7 @@ func TestDriverSuite(t *testing.T) { } nodeID := 987654 + doTag := "k8s:cluster-id" volumes := make(map[string]*godo.Volume, 0) snapshots := make(map[string]*godo.Snapshot, 0) droplets := map[int]*godo.Droplet{ @@ -70,6 +71,7 @@ func TestDriverSuite(t *testing.T) { driver := &Driver{ endpoint: endpoint, nodeId: strconv.Itoa(nodeID), + doTag: doTag, region: "nyc3", mounter: &fakeMounter{}, log: logrus.New().WithField("test_enabed", true), @@ -89,6 +91,7 @@ func TestDriverSuite(t *testing.T) { snapshots: snapshots, }, account: &fakeAccountDriver{}, + tags: &fakeTagsDriver{}, } defer driver.Stop()