diff --git a/pkg/snapshotter/snapshotter.go b/pkg/snapshotter/snapshotter.go index ff0f64eef..157f08972 100644 --- a/pkg/snapshotter/snapshotter.go +++ b/pkg/snapshotter/snapshotter.go @@ -101,13 +101,36 @@ func (s *snapshot) DeleteSnapshot(ctx context.Context, snapshotID string, snapsh return nil } +func (s *snapshot) isListSnapshotsSupported(ctx context.Context) (bool, error) { + client := csi.NewControllerClient(s.conn) + capRsp, err := client.ControllerGetCapabilities(ctx, &csi.ControllerGetCapabilitiesRequest{}) + if err != nil { + return false, err + } + + for _, cap := range capRsp.Capabilities { + if cap.GetRpc().GetType() == csi.ControllerServiceCapability_RPC_LIST_SNAPSHOTS { + return true, nil + } + } + + return false, nil +} + func (s *snapshot) GetSnapshotStatus(ctx context.Context, snapshotID string) (bool, time.Time, int64, error) { client := csi.NewControllerClient(s.conn) + // If the driver does not support ListSnapshots, assume the snapshot ID is valid. + listSnapshotsSupported, err := s.isListSnapshotsSupported(ctx) + if err != nil { + return false, time.Time{}, 0, fmt.Errorf("failed to check if ListSnapshots is supported: %s", err.Error()) + } + if !listSnapshotsSupported { + return true, time.Time{}, 0, nil + } req := csi.ListSnapshotsRequest{ SnapshotId: snapshotID, } - rsp, err := client.ListSnapshots(ctx, &req) if err != nil { return false, time.Time{}, 0, err diff --git a/pkg/snapshotter/snapshotter_test.go b/pkg/snapshotter/snapshotter_test.go index 588030ae0..b0ff50f10 100644 --- a/pkg/snapshotter/snapshotter_test.go +++ b/pkg/snapshotter/snapshotter_test.go @@ -370,41 +370,56 @@ func TestGetSnapshotStatus(t *testing.T) { } tests := []struct { - name string - snapshotID string - input *csi.ListSnapshotsRequest - output *csi.ListSnapshotsResponse - injectError codes.Code - expectError bool - expectReady bool - expectCreateAt time.Time - expectSize int64 + name string + snapshotID string + listSnapshotsSupported bool + input *csi.ListSnapshotsRequest + output *csi.ListSnapshotsResponse + injectError codes.Code + expectError bool + expectReady bool + expectCreateAt time.Time + expectSize int64 }{ { - name: "success", - snapshotID: defaultID, - input: defaultRequest, - output: defaultResponse, - expectError: false, - expectReady: true, - expectCreateAt: createTime, - expectSize: size, + name: "success", + snapshotID: defaultID, + listSnapshotsSupported: true, + input: defaultRequest, + output: defaultResponse, + expectError: false, + expectReady: true, + expectCreateAt: createTime, + expectSize: size, }, { - name: "gRPC transient error", - snapshotID: defaultID, - input: defaultRequest, - output: nil, - injectError: codes.DeadlineExceeded, - expectError: true, + name: "ListSnapshots not supported", + snapshotID: defaultID, + listSnapshotsSupported: false, + input: defaultRequest, + output: defaultResponse, + expectError: false, + expectReady: true, + expectCreateAt: time.Time{}, + expectSize: 0, }, { - name: "gRPC final error", - snapshotID: defaultID, - input: defaultRequest, - output: nil, - injectError: codes.NotFound, - expectError: true, + name: "gRPC transient error", + snapshotID: defaultID, + listSnapshotsSupported: true, + input: defaultRequest, + output: nil, + injectError: codes.DeadlineExceeded, + expectError: true, + }, + { + name: "gRPC final error", + snapshotID: defaultID, + listSnapshotsSupported: true, + input: defaultRequest, + output: nil, + injectError: codes.NotFound, + expectError: true, }, } @@ -425,8 +440,25 @@ func TestGetSnapshotStatus(t *testing.T) { } // Setup expectation + listSnapshotsCap := &csi.ControllerServiceCapability{ + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_LIST_SNAPSHOTS, + }, + }, + } + + var controllerCapabilities []*csi.ControllerServiceCapability + if test.listSnapshotsSupported { + controllerCapabilities = append(controllerCapabilities, listSnapshotsCap) + } if in != nil { - controllerServer.EXPECT().ListSnapshots(gomock.Any(), in).Return(out, injectedErr).Times(1) + controllerServer.EXPECT().ControllerGetCapabilities(gomock.Any(), gomock.Any()).Return(&csi.ControllerGetCapabilitiesResponse{ + Capabilities: controllerCapabilities, + }, nil).Times(1) + if test.listSnapshotsSupported { + controllerServer.EXPECT().ListSnapshots(gomock.Any(), in).Return(out, injectedErr).Times(1) + } } s := NewSnapshotter(csiConn)