Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions pkg/connection/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ type CSIConnection interface {
// DeleteSnapshot deletes a snapshot from a volume
DeleteSnapshot(ctx context.Context, snapshotID string, snapshotterCredentials map[string]string) (err error)

// GetSnapshotStatus lists snapshot from a volume
GetSnapshotStatus(ctx context.Context, snapshotID string) (*csi.SnapshotStatus, int64, error)
// GetSnapshotStatus returns a snapshot's status, creation time, and restore size.
GetSnapshotStatus(ctx context.Context, snapshotID string) (*csi.SnapshotStatus, int64, int64, error)

// Probe checks that the CSI driver is ready to process requests
Probe(ctx context.Context) error
Expand Down Expand Up @@ -232,7 +232,7 @@ func (c *csiConnection) DeleteSnapshot(ctx context.Context, snapshotID string, s
return nil
}

func (c *csiConnection) GetSnapshotStatus(ctx context.Context, snapshotID string) (*csi.SnapshotStatus, int64, error) {
func (c *csiConnection) GetSnapshotStatus(ctx context.Context, snapshotID string) (*csi.SnapshotStatus, int64, int64, error) {
client := csi.NewControllerClient(c.conn)

req := csi.ListSnapshotsRequest{
Expand All @@ -241,14 +241,14 @@ func (c *csiConnection) GetSnapshotStatus(ctx context.Context, snapshotID string

rsp, err := client.ListSnapshots(ctx, &req)
if err != nil {
return nil, 0, err
return nil, 0, 0, err
}

if rsp.Entries == nil || len(rsp.Entries) == 0 {
return nil, 0, fmt.Errorf("can not find snapshot for snapshotID %s", snapshotID)
return nil, 0, 0, fmt.Errorf("can not find snapshot for snapshotID %s", snapshotID)
}

return rsp.Entries[0].Snapshot.Status, rsp.Entries[0].Snapshot.CreatedAt, nil
return rsp.Entries[0].Snapshot.Status, rsp.Entries[0].Snapshot.CreatedAt, rsp.Entries[0].Snapshot.SizeBytes, nil
}

func (c *csiConnection) Close() error {
Expand Down
10 changes: 8 additions & 2 deletions pkg/connection/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,7 @@ func TestDeleteSnapshot(t *testing.T) {
func TestGetSnapshotStatus(t *testing.T) {
defaultID := "testid"
createdAt := time.Now().UnixNano()
size := int64(1000)

defaultRequest := &csi.ListSnapshotsRequest{
SnapshotId: defaultID,
Expand All @@ -668,7 +669,7 @@ func TestGetSnapshotStatus(t *testing.T) {
{
Snapshot: &csi.Snapshot{
Id: defaultID,
SizeBytes: 1000,
SizeBytes: size,
SourceVolumeId: "volumeid",
CreatedAt: createdAt,
Status: &csi.SnapshotStatus{
Expand All @@ -689,6 +690,7 @@ func TestGetSnapshotStatus(t *testing.T) {
expectError bool
expectStatus *csi.SnapshotStatus
expectCreateAt int64
expectSize int64
}{
{
name: "success",
Expand All @@ -701,6 +703,7 @@ func TestGetSnapshotStatus(t *testing.T) {
Details: "success",
},
expectCreateAt: createdAt,
expectSize: size,
},
{
name: "gRPC transient error",
Expand Down Expand Up @@ -741,7 +744,7 @@ func TestGetSnapshotStatus(t *testing.T) {
controllerServer.EXPECT().ListSnapshots(gomock.Any(), in).Return(out, injectedErr).Times(1)
}

status, createTime, err := csiConn.GetSnapshotStatus(context.Background(), test.snapshotID)
status, createTime, size, err := csiConn.GetSnapshotStatus(context.Background(), test.snapshotID)
if test.expectError && err == nil {
t.Errorf("test %q: Expected error, got none", test.name)
}
Expand All @@ -754,6 +757,9 @@ func TestGetSnapshotStatus(t *testing.T) {
if test.expectCreateAt != createTime {
t.Errorf("test %q: expected createTime: %v, got: %v", test.name, test.expectCreateAt, createTime)
}
if test.expectSize != size {
t.Errorf("test %q: expected size: %v, got: %v", test.name, test.expectSize, createTime)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

createTime --> size

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed by PR #24

}
}
}

Expand Down
13 changes: 7 additions & 6 deletions pkg/controller/csi_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import (
type Handler interface {
CreateSnapshot(snapshot *crdv1.VolumeSnapshot, volume *v1.PersistentVolume, parameters map[string]string, snapshotterCredentials map[string]string) (string, string, int64, int64, *csi.SnapshotStatus, error)
DeleteSnapshot(content *crdv1.VolumeSnapshotContent, snapshotterCredentials map[string]string) error
GetSnapshotStatus(content *crdv1.VolumeSnapshotContent) (*csi.SnapshotStatus, int64, error)
GetSnapshotStatus(content *crdv1.VolumeSnapshotContent) (*csi.SnapshotStatus, int64, int64, error)
}

// csiHandler is a handler that calls CSI to create/delete volume snapshot.
Expand Down Expand Up @@ -84,18 +84,19 @@ func (handler *csiHandler) DeleteSnapshot(content *crdv1.VolumeSnapshotContent,
return nil
}

func (handler *csiHandler) GetSnapshotStatus(content *crdv1.VolumeSnapshotContent) (*csi.SnapshotStatus, int64, error) {
func (handler *csiHandler) GetSnapshotStatus(content *crdv1.VolumeSnapshotContent) (*csi.SnapshotStatus, int64, int64, error) {
if content.Spec.CSI == nil {
return nil, 0, fmt.Errorf("CSISnapshot not defined in spec")
return nil, 0, 0, fmt.Errorf("CSISnapshot not defined in spec")
}
ctx, cancel := context.WithTimeout(context.Background(), handler.timeout)
defer cancel()

csiSnapshotStatus, timestamp, err := handler.csiConnection.GetSnapshotStatus(ctx, content.Spec.CSI.SnapshotHandle)
csiSnapshotStatus, timestamp, size, err := handler.csiConnection.GetSnapshotStatus(ctx, content.Spec.CSI.SnapshotHandle)
if err != nil {
return nil, 0, fmt.Errorf("failed to list snapshot data %s: %q", content.Name, err)
return nil, 0, 0, fmt.Errorf("failed to list snapshot data %s: %q", content.Name, err)
}
return csiSnapshotStatus, timestamp, nil
return csiSnapshotStatus, timestamp, size, nil

}

func makeSnapshotName(prefix, snapshotUID string, snapshotNameUUIDLength int) (string, error) {
Expand Down
9 changes: 5 additions & 4 deletions pkg/controller/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,7 @@ type listCall struct {
// information to return
status *csi.SnapshotStatus
createTime int64
size int64
err error
}

Expand Down Expand Up @@ -1203,10 +1204,10 @@ func (f *fakeCSIConnection) DeleteSnapshot(ctx context.Context, snapshotID strin
return call.err
}

func (f *fakeCSIConnection) GetSnapshotStatus(ctx context.Context, snapshotID string) (*csi.SnapshotStatus, int64, error) {
func (f *fakeCSIConnection) GetSnapshotStatus(ctx context.Context, snapshotID string) (*csi.SnapshotStatus, int64, int64, error) {
if f.listCallCounter >= len(f.listCalls) {
f.t.Errorf("Unexpected CSI list Snapshot call: snapshotID=%s, index: %d, calls: %+v", snapshotID, f.createCallCounter, f.createCalls)
return nil, 0, fmt.Errorf("unexpected call")
return nil, 0, 0, fmt.Errorf("unexpected call")
}
call := f.listCalls[f.listCallCounter]
f.listCallCounter++
Expand All @@ -1218,10 +1219,10 @@ func (f *fakeCSIConnection) GetSnapshotStatus(ctx context.Context, snapshotID st
}

if err != nil {
return nil, 0, fmt.Errorf("unexpected call")
return nil, 0, 0, fmt.Errorf("unexpected call")
}

return call.status, call.createTime, call.err
return call.status, call.createTime, call.size, call.err
}

func (f *fakeCSIConnection) Close() error {
Expand Down
18 changes: 9 additions & 9 deletions pkg/controller/snapshot_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,12 +425,12 @@ func (ctrl *csiSnapshotController) checkandBindSnapshotContent(snapshot *crdv1.V
}

func (ctrl *csiSnapshotController) checkandUpdateSnapshotStatusOperation(snapshot *crdv1.VolumeSnapshot, content *crdv1.VolumeSnapshotContent) (*crdv1.VolumeSnapshot, error) {
status, _, err := ctrl.handler.GetSnapshotStatus(content)
status, _, size, err := ctrl.handler.GetSnapshotStatus(content)
if err != nil {
return nil, fmt.Errorf("failed to check snapshot status %s with error %v", snapshot.Name, err)
}

newSnapshot, err := ctrl.updateSnapshotStatus(snapshot, status, time.Now(), nil, IsSnapshotBound(snapshot, content))
timestamp := time.Now().UnixNano()
newSnapshot, err := ctrl.updateSnapshotStatus(snapshot, status, timestamp, size, IsSnapshotBound(snapshot, content))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -494,7 +494,7 @@ func (ctrl *csiSnapshotController) createSnapshotOperation(snapshot *crdv1.Volum
// Update snapshot status with timestamp
for i := 0; i < ctrl.createSnapshotContentRetryCount; i++ {
glog.V(5).Infof("createSnapshot [%s]: trying to update snapshot creation timestamp", snapshotKey(snapshot))
newSnapshot, err = ctrl.updateSnapshotStatus(snapshot, csiSnapshotStatus, time.Unix(0, timestamp), resource.NewQuantity(size, resource.BinarySI), false)
newSnapshot, err = ctrl.updateSnapshotStatus(snapshot, csiSnapshotStatus, timestamp, size, false)
if err == nil {
break
}
Expand Down Expand Up @@ -642,12 +642,12 @@ func (ctrl *csiSnapshotController) bindandUpdateVolumeSnapshot(snapshotContent *
}

// UpdateSnapshotStatus converts snapshot status to crdv1.VolumeSnapshotCondition
func (ctrl *csiSnapshotController) updateSnapshotStatus(snapshot *crdv1.VolumeSnapshot, csistatus *csi.SnapshotStatus, timestamp time.Time, size *resource.Quantity, bound bool) (*crdv1.VolumeSnapshot, error) {
glog.V(5).Infof("updating VolumeSnapshot[]%s, set status %v, timestamp %v", snapshotKey(snapshot), csistatus, timestamp)
func (ctrl *csiSnapshotController) updateSnapshotStatus(snapshot *crdv1.VolumeSnapshot, csistatus *csi.SnapshotStatus, createdAt, size int64, bound bool) (*crdv1.VolumeSnapshot, error) {
glog.V(5).Infof("updating VolumeSnapshot[]%s, set status %v, timestamp %v", snapshotKey(snapshot), csistatus, createdAt)
status := snapshot.Status
change := false
timeAt := &metav1.Time{
Time: timestamp,
Time: time.Unix(0, createdAt),
}

snapshotClone := snapshot.DeepCopy()
Expand Down Expand Up @@ -680,8 +680,8 @@ func (ctrl *csiSnapshotController) updateSnapshotStatus(snapshot *crdv1.VolumeSn
}
}
if change {
if size != nil {
status.RestoreSize = size
if size > 0 {
status.RestoreSize = resource.NewQuantity(size, resource.BinarySI)
}
snapshotClone.Status = status
newSnapshotObj, err := ctrl.clientset.VolumesnapshotV1alpha1().VolumeSnapshots(snapshotClone.Namespace).Update(snapshotClone)
Expand Down