diff --git a/connection/connection_test.go b/connection/connection_test.go index dbaf70ee..8c5bfb6b 100644 --- a/connection/connection_test.go +++ b/connection/connection_test.go @@ -18,12 +18,10 @@ package connection import ( "context" - "fmt" "io/ioutil" "net" "os" "path" - "reflect" "sync" "testing" "time" @@ -33,7 +31,6 @@ import ( "google.golang.org/grpc/connectivity" "google.golang.org/grpc/status" - "github.com/golang/protobuf/ptypes/wrappers" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -317,534 +314,3 @@ func TestExplicitReconnect(t *testing.T) { assert.Equal(t, 1, reconnectCount, "connection loss callback should be called once") } } - -func TestGetDriverName(t *testing.T) { - tests := []struct { - name string - output *csi.GetPluginInfoResponse - injectError bool - expectError bool - }{ - { - name: "success", - output: &csi.GetPluginInfoResponse{ - Name: "csi/example", - VendorVersion: "0.2.0", - Manifest: map[string]string{ - "hello": "world", - }, - }, - expectError: false, - }, - { - name: "gRPC error", - output: nil, - injectError: true, - expectError: true, - }, - { - name: "empty name", - output: &csi.GetPluginInfoResponse{ - Name: "", - }, - expectError: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - out := test.output - var injectedErr error - if test.injectError { - injectedErr = fmt.Errorf("mock error") - } - - tmp := tmpDir(t) - defer os.RemoveAll(tmp) - identity := &identityServer{ - pluginInfoResponse: out, - err: injectedErr, - } - addr, stopServer := startServer(t, tmp, identity, nil) - defer func() { - stopServer() - }() - - conn, err := Connect(addr) - if err != nil { - t.Fatalf("Failed to connect to CSI driver: %s", err) - } - - name, err := GetDriverName(context.Background(), conn) - if test.expectError && err == nil { - t.Errorf("Expected error, got none") - } - if !test.expectError && err != nil { - t.Errorf("Got error: %v", err) - } - if err == nil && name != "csi/example" { - t.Errorf("Got unexpected name: %q", name) - } - }) - } -} - -func TestGetPluginCapabilities(t *testing.T) { - tests := []struct { - name string - output *csi.GetPluginCapabilitiesResponse - injectError bool - expectCapabilities PluginCapabilitySet - expectError bool - }{ - { - name: "success", - output: &csi.GetPluginCapabilitiesResponse{ - Capabilities: []*csi.PluginCapability{ - { - Type: &csi.PluginCapability_Service_{ - Service: &csi.PluginCapability_Service{ - Type: csi.PluginCapability_Service_CONTROLLER_SERVICE, - }, - }, - }, - { - Type: &csi.PluginCapability_Service_{ - Service: &csi.PluginCapability_Service{ - Type: csi.PluginCapability_Service_UNKNOWN, - }, - }, - }, - }, - }, - expectCapabilities: PluginCapabilitySet{ - csi.PluginCapability_Service_CONTROLLER_SERVICE: true, - csi.PluginCapability_Service_UNKNOWN: true, - }, - expectError: false, - }, - { - name: "gRPC error", - output: nil, - injectError: true, - expectError: true, - }, - { - name: "no controller service", - output: &csi.GetPluginCapabilitiesResponse{ - Capabilities: []*csi.PluginCapability{ - { - Type: &csi.PluginCapability_Service_{ - Service: &csi.PluginCapability_Service{ - Type: csi.PluginCapability_Service_UNKNOWN, - }, - }, - }, - }, - }, - expectCapabilities: PluginCapabilitySet{ - csi.PluginCapability_Service_UNKNOWN: true, - }, - expectError: false, - }, - { - name: "empty capability", - output: &csi.GetPluginCapabilitiesResponse{ - Capabilities: []*csi.PluginCapability{ - { - Type: nil, - }, - }, - }, - expectCapabilities: PluginCapabilitySet{}, - expectError: false, - }, - { - name: "no capabilities", - output: &csi.GetPluginCapabilitiesResponse{ - Capabilities: []*csi.PluginCapability{}, - }, - expectCapabilities: PluginCapabilitySet{}, - expectError: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - var injectedErr error - if test.injectError { - injectedErr = fmt.Errorf("mock error") - } - - tmp := tmpDir(t) - defer os.RemoveAll(tmp) - identity := &identityServer{ - getPluginCapabilitiesResponse: test.output, - err: injectedErr, - } - addr, stopServer := startServer(t, tmp, identity, nil) - defer func() { - stopServer() - }() - - conn, err := Connect(addr) - if err != nil { - t.Fatalf("Failed to connect to CSI driver: %s", err) - } - - caps, err := GetPluginCapabilities(context.Background(), conn) - if test.expectError && err == nil { - t.Errorf("Expected error, got none") - } - if !test.expectError && err != nil { - t.Errorf("Got error: %v", err) - } - if !reflect.DeepEqual(test.expectCapabilities, caps) { - t.Errorf("expected capabilities %+v, got %+v", test.expectCapabilities, caps) - } - }) - } -} - -func TestGetControllerCapabilities(t *testing.T) { - tests := []struct { - name string - output *csi.ControllerGetCapabilitiesResponse - injectError bool - expectCapabilities ControllerCapabilitySet - expectError bool - }{ - { - name: "success", - output: &csi.ControllerGetCapabilitiesResponse{ - Capabilities: []*csi.ControllerServiceCapability{ - { - Type: &csi.ControllerServiceCapability_Rpc{ - Rpc: &csi.ControllerServiceCapability_RPC{ - Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME, - }, - }, - }, - { - Type: &csi.ControllerServiceCapability_Rpc{ - Rpc: &csi.ControllerServiceCapability_RPC{ - Type: csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME, - }, - }, - }, - }, - }, - expectCapabilities: ControllerCapabilitySet{ - csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME: true, - csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME: true, - }, - expectError: false, - }, - { - name: "supports read only", - output: &csi.ControllerGetCapabilitiesResponse{ - Capabilities: []*csi.ControllerServiceCapability{ - { - Type: &csi.ControllerServiceCapability_Rpc{ - Rpc: &csi.ControllerServiceCapability_RPC{ - Type: csi.ControllerServiceCapability_RPC_PUBLISH_READONLY, - }, - }, - }, - { - Type: &csi.ControllerServiceCapability_Rpc{ - Rpc: &csi.ControllerServiceCapability_RPC{ - Type: csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME, - }, - }, - }, - }, - }, - expectCapabilities: ControllerCapabilitySet{ - csi.ControllerServiceCapability_RPC_PUBLISH_READONLY: true, - csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME: true, - }, - expectError: false, - }, - { - name: "gRPC error", - output: nil, - injectError: true, - expectError: true, - }, - { - name: "empty capability", - output: &csi.ControllerGetCapabilitiesResponse{ - Capabilities: []*csi.ControllerServiceCapability{ - { - Type: nil, - }, - }, - }, - expectCapabilities: ControllerCapabilitySet{}, - expectError: false, - }, - { - name: "no capabilities", - output: &csi.ControllerGetCapabilitiesResponse{ - Capabilities: []*csi.ControllerServiceCapability{}, - }, - expectCapabilities: ControllerCapabilitySet{}, - expectError: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - var injectedErr error - if test.injectError { - injectedErr = fmt.Errorf("mock error") - } - - tmp := tmpDir(t) - defer os.RemoveAll(tmp) - controller := &controllerServer{ - controllerGetCapabilitiesResponse: test.output, - err: injectedErr, - } - addr, stopServer := startServer(t, tmp, nil, controller) - defer func() { - stopServer() - }() - - conn, err := Connect(addr) - if err != nil { - t.Fatalf("Failed to connect to CSI driver: %s", err) - } - - caps, err := GetControllerCapabilities(context.Background(), conn) - if test.expectError && err == nil { - t.Errorf("Expected error, got none") - } - if !test.expectError && err != nil { - t.Errorf("Got error: %v", err) - } - if !reflect.DeepEqual(test.expectCapabilities, caps) { - t.Errorf("expected capabilities %+v, got %+v", test.expectCapabilities, caps) - } - }) - } -} - -func TestProbeForever(t *testing.T) { - tests := []struct { - name string - probeCalls []probeCall - expectError bool - }{ - { - name: "success", - probeCalls: []probeCall{ - { - response: &csi.ProbeResponse{ - Ready: &wrappers.BoolValue{Value: true}, - }, - }, - }, - expectError: false, - }, - { - name: "success with empty Ready field (true is assumed)", - probeCalls: []probeCall{ - { - response: &csi.ProbeResponse{ - Ready: nil, - }, - }, - }, - expectError: false, - }, - { - name: "error", - probeCalls: []probeCall{ - { - err: fmt.Errorf("mock error"), - }, - }, - expectError: true, - }, - { - name: "timeout + failure", - probeCalls: []probeCall{ - { - err: status.Error(codes.DeadlineExceeded, "timeout"), - }, - { - err: fmt.Errorf("mock error"), - }, - }, - expectError: true, - }, - { - name: "timeout + success", - probeCalls: []probeCall{ - { - err: status.Error(codes.DeadlineExceeded, "timeout"), - }, - { - err: status.Error(codes.DeadlineExceeded, "timeout"), - }, - { - response: &csi.ProbeResponse{ - Ready: &wrappers.BoolValue{Value: true}, - }, - }, - }, - expectError: false, - }, - { - name: "unready + failure", - probeCalls: []probeCall{ - { - response: &csi.ProbeResponse{ - Ready: &wrappers.BoolValue{Value: false}, - }, - }, - { - err: fmt.Errorf("mock error"), - }, - }, - expectError: true, - }, - { - name: "unready + success", - probeCalls: []probeCall{ - { - response: &csi.ProbeResponse{ - Ready: &wrappers.BoolValue{Value: false}, - }, - }, - { - response: &csi.ProbeResponse{ - Ready: &wrappers.BoolValue{Value: false}, - }, - }, - { - response: &csi.ProbeResponse{ - Ready: &wrappers.BoolValue{Value: true}, - }, - }, - }, - expectError: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - tmp := tmpDir(t) - defer os.RemoveAll(tmp) - identity := &identityServer{ - probeCalls: test.probeCalls, - } - addr, stopServer := startServer(t, tmp, identity, nil) - defer func() { - stopServer() - }() - - conn, err := Connect(addr) - if err != nil { - t.Fatalf("Failed to connect to CSI driver: %s", err) - } - - err = ProbeForever(conn, time.Second) - if test.expectError && err == nil { - t.Errorf("Expected error, got none") - } - if !test.expectError && err != nil { - t.Errorf("Got error: %v", err) - } - if len(identity.probeCalls) != identity.probeCallCount { - t.Errorf("Expected %d probe calls, got %d", len(identity.probeCalls), identity.probeCallCount) - } - }) - } -} - -type identityServer struct { - pluginInfoResponse *csi.GetPluginInfoResponse - getPluginCapabilitiesResponse *csi.GetPluginCapabilitiesResponse - err error - - probeCalls []probeCall - probeCallCount int -} - -type probeCall struct { - response *csi.ProbeResponse - err error -} - -var _ csi.IdentityServer = &identityServer{} - -func (i *identityServer) GetPluginCapabilities(context.Context, *csi.GetPluginCapabilitiesRequest) (*csi.GetPluginCapabilitiesResponse, error) { - return i.getPluginCapabilitiesResponse, i.err -} - -func (i *identityServer) GetPluginInfo(context.Context, *csi.GetPluginInfoRequest) (*csi.GetPluginInfoResponse, error) { - return i.pluginInfoResponse, i.err -} - -func (i *identityServer) Probe(context.Context, *csi.ProbeRequest) (*csi.ProbeResponse, error) { - if i.probeCallCount >= len(i.probeCalls) { - return nil, fmt.Errorf("Unexpected Probe() call") - } - call := i.probeCalls[i.probeCallCount] - i.probeCallCount++ - return call.response, call.err -} - -type controllerServer struct { - controllerGetCapabilitiesResponse *csi.ControllerGetCapabilitiesResponse - err error -} - -var _ csi.ControllerServer = &controllerServer{} - -func (c *controllerServer) CreateVolume(context.Context, *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) { - return nil, fmt.Errorf("unimplemented") -} - -func (c *controllerServer) DeleteVolume(context.Context, *csi.DeleteVolumeRequest) (*csi.DeleteVolumeResponse, error) { - return nil, fmt.Errorf("unimplemented") -} - -func (c *controllerServer) ControllerPublishVolume(context.Context, *csi.ControllerPublishVolumeRequest) (*csi.ControllerPublishVolumeResponse, error) { - return nil, fmt.Errorf("unimplemented") -} - -func (c *controllerServer) ControllerUnpublishVolume(context.Context, *csi.ControllerUnpublishVolumeRequest) (*csi.ControllerUnpublishVolumeResponse, error) { - return nil, fmt.Errorf("unimplemented") -} - -func (c *controllerServer) ValidateVolumeCapabilities(context.Context, *csi.ValidateVolumeCapabilitiesRequest) (*csi.ValidateVolumeCapabilitiesResponse, error) { - return nil, fmt.Errorf("unimplemented") -} - -func (c *controllerServer) ListVolumes(context.Context, *csi.ListVolumesRequest) (*csi.ListVolumesResponse, error) { - return nil, fmt.Errorf("unimplemented") -} - -func (c *controllerServer) GetCapacity(context.Context, *csi.GetCapacityRequest) (*csi.GetCapacityResponse, error) { - return nil, fmt.Errorf("unimplemented") -} - -func (c *controllerServer) ControllerGetCapabilities(context.Context, *csi.ControllerGetCapabilitiesRequest) (*csi.ControllerGetCapabilitiesResponse, error) { - return c.controllerGetCapabilitiesResponse, c.err -} - -func (c *controllerServer) CreateSnapshot(context.Context, *csi.CreateSnapshotRequest) (*csi.CreateSnapshotResponse, error) { - return nil, fmt.Errorf("unimplemented") -} - -func (c *controllerServer) DeleteSnapshot(context.Context, *csi.DeleteSnapshotRequest) (*csi.DeleteSnapshotResponse, error) { - return nil, fmt.Errorf("unimplemented") -} - -func (c *controllerServer) ListSnapshots(context.Context, *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) { - return nil, fmt.Errorf("unimplemented") -} diff --git a/rpc/common.go b/rpc/common.go new file mode 100644 index 00000000..bb4a5c44 --- /dev/null +++ b/rpc/common.go @@ -0,0 +1,160 @@ +/* +Copyright 2019 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package rpc + +import ( + "context" + "fmt" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/container-storage-interface/spec/lib/go/csi" + + "k8s.io/klog" +) + +const ( + // Interval of trying to call Probe() until it succeeds + probeInterval = 1 * time.Second +) + +// GetDriverName returns name of CSI driver. +func GetDriverName(ctx context.Context, conn *grpc.ClientConn) (string, error) { + client := csi.NewIdentityClient(conn) + + req := csi.GetPluginInfoRequest{} + rsp, err := client.GetPluginInfo(ctx, &req) + if err != nil { + return "", err + } + name := rsp.GetName() + if name == "" { + return "", fmt.Errorf("driver name is empty") + } + return name, nil +} + +// PluginCapabilitySet is set of CSI plugin capabilities. Only supported capabilities are in the map. +type PluginCapabilitySet map[csi.PluginCapability_Service_Type]bool + +// GetPluginCapabilities returns set of supported capabilities of CSI driver. +func GetPluginCapabilities(ctx context.Context, conn *grpc.ClientConn) (PluginCapabilitySet, error) { + client := csi.NewIdentityClient(conn) + req := csi.GetPluginCapabilitiesRequest{} + rsp, err := client.GetPluginCapabilities(ctx, &req) + if err != nil { + return nil, err + } + caps := PluginCapabilitySet{} + for _, cap := range rsp.GetCapabilities() { + if cap == nil { + continue + } + srv := cap.GetService() + if srv == nil { + continue + } + t := srv.GetType() + caps[t] = true + } + return caps, nil +} + +// ControllerCapabilitySet is set of CSI controller capabilities. Only supported capabilities are in the map. +type ControllerCapabilitySet map[csi.ControllerServiceCapability_RPC_Type]bool + +// GetControllerCapabilities returns set of supported controller capabilities of CSI driver. +func GetControllerCapabilities(ctx context.Context, conn *grpc.ClientConn) (ControllerCapabilitySet, error) { + client := csi.NewControllerClient(conn) + req := csi.ControllerGetCapabilitiesRequest{} + rsp, err := client.ControllerGetCapabilities(ctx, &req) + if err != nil { + return nil, err + } + + caps := ControllerCapabilitySet{} + for _, cap := range rsp.GetCapabilities() { + if cap == nil { + continue + } + rpc := cap.GetRpc() + if rpc == nil { + continue + } + t := rpc.GetType() + caps[t] = true + } + return caps, nil +} + +// ProbeForever calls Probe() of a CSI driver and waits until the driver becomes ready. +// Any error other than timeout is returned. +func ProbeForever(conn *grpc.ClientConn, singleProbeTimeout time.Duration) error { + for { + klog.Info("Probing CSI driver for readiness") + ready, err := probeOnce(conn, singleProbeTimeout) + if err != nil { + st, ok := status.FromError(err) + if !ok { + // This is not gRPC error. The probe must have failed before gRPC + // method was called, otherwise we would get gRPC error. + return fmt.Errorf("CSI driver probe failed: %s", err) + } + if st.Code() != codes.DeadlineExceeded { + return fmt.Errorf("CSI driver probe failed: %s", err) + } + // Timeout -> driver is not ready. Fall through to sleep() below. + klog.Warning("CSI driver probe timed out") + } else { + if ready { + return nil + } + klog.Warning("CSI driver is not ready") + } + // Timeout was returned or driver is not ready. + time.Sleep(probeInterval) + } +} + +// probeOnce is a helper to simplify defer cancel() +func probeOnce(conn *grpc.ClientConn, timeout time.Duration) (bool, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return Probe(ctx, conn) +} + +// Probe calls driver Probe() just once and returns its result without any processing. +func Probe(ctx context.Context, conn *grpc.ClientConn) (ready bool, err error) { + client := csi.NewIdentityClient(conn) + + req := csi.ProbeRequest{} + rsp, err := client.Probe(ctx, &req) + + if err != nil { + return false, err + } + + r := rsp.GetReady() + if r == nil { + // "If not present, the caller SHALL assume that the plugin is in a ready state" + return true, nil + } + return r.GetValue(), nil +} diff --git a/rpc/common_test.go b/rpc/common_test.go new file mode 100644 index 00000000..77ffbc1b --- /dev/null +++ b/rpc/common_test.go @@ -0,0 +1,611 @@ +/* +Copyright 2019 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package rpc + +import ( + "context" + "fmt" + "io/ioutil" + "net" + "os" + "path" + "reflect" + "sync" + "testing" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/container-storage-interface/spec/lib/go/csi" + "github.com/golang/protobuf/ptypes/wrappers" + "github.com/kubernetes-csi/csi-lib-utils/connection" + "github.com/stretchr/testify/require" +) + +func tmpDir(t *testing.T) string { + dir, err := ioutil.TempDir("", "connect") + require.NoError(t, err, "creating temp directory") + return dir +} + +const ( + serverSock = "server.sock" +) + +// startServer creates a gRPC server without any registered services. +// The returned address can be used to connect to it. The cleanup +// function stops it. It can be called multiple times. +func startServer(t *testing.T, tmp string, identity csi.IdentityServer, controller csi.ControllerServer) (string, func()) { + addr := path.Join(tmp, serverSock) + listener, err := net.Listen("unix", addr) + require.NoError(t, err, "listening on %s", addr) + server := grpc.NewServer() + if identity != nil { + csi.RegisterIdentityServer(server, identity) + } + if controller != nil { + csi.RegisterControllerServer(server, controller) + } + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + if err := server.Serve(listener); err != nil { + t.Logf("starting server failed: %s", err) + } + }() + return addr, func() { + server.Stop() + wg.Wait() + if err := os.Remove(addr); err != nil && !os.IsNotExist(err) { + t.Logf("remove Unix socket: %s", err) + } + } +} + +func TestGetDriverName(t *testing.T) { + tests := []struct { + name string + output *csi.GetPluginInfoResponse + injectError bool + expectError bool + }{ + { + name: "success", + output: &csi.GetPluginInfoResponse{ + Name: "csi/example", + VendorVersion: "0.2.0", + Manifest: map[string]string{ + "hello": "world", + }, + }, + expectError: false, + }, + { + name: "gRPC error", + output: nil, + injectError: true, + expectError: true, + }, + { + name: "empty name", + output: &csi.GetPluginInfoResponse{ + Name: "", + }, + expectError: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + out := test.output + var injectedErr error + if test.injectError { + injectedErr = fmt.Errorf("mock error") + } + + tmp := tmpDir(t) + defer os.RemoveAll(tmp) + identity := &fakeIdentityServer{ + pluginInfoResponse: out, + err: injectedErr, + } + addr, stopServer := startServer(t, tmp, identity, nil) + defer func() { + stopServer() + }() + + conn, err := connection.Connect(addr) + if err != nil { + t.Fatalf("Failed to connect to CSI driver: %s", err) + } + + name, err := GetDriverName(context.Background(), conn) + if test.expectError && err == nil { + t.Errorf("Expected error, got none") + } + if !test.expectError && err != nil { + t.Errorf("Got error: %v", err) + } + if err == nil && name != "csi/example" { + t.Errorf("Got unexpected name: %q", name) + } + }) + } +} + +func TestGetPluginCapabilities(t *testing.T) { + tests := []struct { + name string + output *csi.GetPluginCapabilitiesResponse + injectError bool + expectCapabilities PluginCapabilitySet + expectError bool + }{ + { + name: "success", + output: &csi.GetPluginCapabilitiesResponse{ + Capabilities: []*csi.PluginCapability{ + { + Type: &csi.PluginCapability_Service_{ + Service: &csi.PluginCapability_Service{ + Type: csi.PluginCapability_Service_CONTROLLER_SERVICE, + }, + }, + }, + { + Type: &csi.PluginCapability_Service_{ + Service: &csi.PluginCapability_Service{ + Type: csi.PluginCapability_Service_UNKNOWN, + }, + }, + }, + }, + }, + expectCapabilities: PluginCapabilitySet{ + csi.PluginCapability_Service_CONTROLLER_SERVICE: true, + csi.PluginCapability_Service_UNKNOWN: true, + }, + expectError: false, + }, + { + name: "gRPC error", + output: nil, + injectError: true, + expectError: true, + }, + { + name: "no controller service", + output: &csi.GetPluginCapabilitiesResponse{ + Capabilities: []*csi.PluginCapability{ + { + Type: &csi.PluginCapability_Service_{ + Service: &csi.PluginCapability_Service{ + Type: csi.PluginCapability_Service_UNKNOWN, + }, + }, + }, + }, + }, + expectCapabilities: PluginCapabilitySet{ + csi.PluginCapability_Service_UNKNOWN: true, + }, + expectError: false, + }, + { + name: "empty capability", + output: &csi.GetPluginCapabilitiesResponse{ + Capabilities: []*csi.PluginCapability{ + { + Type: nil, + }, + }, + }, + expectCapabilities: PluginCapabilitySet{}, + expectError: false, + }, + { + name: "no capabilities", + output: &csi.GetPluginCapabilitiesResponse{ + Capabilities: []*csi.PluginCapability{}, + }, + expectCapabilities: PluginCapabilitySet{}, + expectError: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var injectedErr error + if test.injectError { + injectedErr = fmt.Errorf("mock error") + } + + tmp := tmpDir(t) + defer os.RemoveAll(tmp) + identity := &fakeIdentityServer{ + getPluginCapabilitiesResponse: test.output, + err: injectedErr, + } + addr, stopServer := startServer(t, tmp, identity, nil) + defer func() { + stopServer() + }() + + conn, err := connection.Connect(addr) + if err != nil { + t.Fatalf("Failed to connect to CSI driver: %s", err) + } + + caps, err := GetPluginCapabilities(context.Background(), conn) + if test.expectError && err == nil { + t.Errorf("Expected error, got none") + } + if !test.expectError && err != nil { + t.Errorf("Got error: %v", err) + } + if !reflect.DeepEqual(test.expectCapabilities, caps) { + t.Errorf("expected capabilities %+v, got %+v", test.expectCapabilities, caps) + } + }) + } +} + +func TestGetControllerCapabilities(t *testing.T) { + tests := []struct { + name string + output *csi.ControllerGetCapabilitiesResponse + injectError bool + expectCapabilities ControllerCapabilitySet + expectError bool + }{ + { + name: "success", + output: &csi.ControllerGetCapabilitiesResponse{ + Capabilities: []*csi.ControllerServiceCapability{ + { + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME, + }, + }, + }, + { + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME, + }, + }, + }, + }, + }, + expectCapabilities: ControllerCapabilitySet{ + csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME: true, + csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME: true, + }, + expectError: false, + }, + { + name: "supports read only", + output: &csi.ControllerGetCapabilitiesResponse{ + Capabilities: []*csi.ControllerServiceCapability{ + { + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_PUBLISH_READONLY, + }, + }, + }, + { + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME, + }, + }, + }, + }, + }, + expectCapabilities: ControllerCapabilitySet{ + csi.ControllerServiceCapability_RPC_PUBLISH_READONLY: true, + csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME: true, + }, + expectError: false, + }, + { + name: "gRPC error", + output: nil, + injectError: true, + expectError: true, + }, + { + name: "empty capability", + output: &csi.ControllerGetCapabilitiesResponse{ + Capabilities: []*csi.ControllerServiceCapability{ + { + Type: nil, + }, + }, + }, + expectCapabilities: ControllerCapabilitySet{}, + expectError: false, + }, + { + name: "no capabilities", + output: &csi.ControllerGetCapabilitiesResponse{ + Capabilities: []*csi.ControllerServiceCapability{}, + }, + expectCapabilities: ControllerCapabilitySet{}, + expectError: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var injectedErr error + if test.injectError { + injectedErr = fmt.Errorf("mock error") + } + + tmp := tmpDir(t) + defer os.RemoveAll(tmp) + controller := &fakeControllerServer{ + controllerGetCapabilitiesResponse: test.output, + err: injectedErr, + } + addr, stopServer := startServer(t, tmp, nil, controller) + defer func() { + stopServer() + }() + + conn, err := connection.Connect(addr) + if err != nil { + t.Fatalf("Failed to connect to CSI driver: %s", err) + } + + caps, err := GetControllerCapabilities(context.Background(), conn) + if test.expectError && err == nil { + t.Errorf("Expected error, got none") + } + if !test.expectError && err != nil { + t.Errorf("Got error: %v", err) + } + if !reflect.DeepEqual(test.expectCapabilities, caps) { + t.Errorf("expected capabilities %+v, got %+v", test.expectCapabilities, caps) + } + }) + } +} + +func TestProbeForever(t *testing.T) { + tests := []struct { + name string + probeCalls []probeCall + expectError bool + }{ + { + name: "success", + probeCalls: []probeCall{ + { + response: &csi.ProbeResponse{ + Ready: &wrappers.BoolValue{Value: true}, + }, + }, + }, + expectError: false, + }, + { + name: "success with empty Ready field (true is assumed)", + probeCalls: []probeCall{ + { + response: &csi.ProbeResponse{ + Ready: nil, + }, + }, + }, + expectError: false, + }, + { + name: "error", + probeCalls: []probeCall{ + { + err: fmt.Errorf("mock error"), + }, + }, + expectError: true, + }, + { + name: "timeout + failure", + probeCalls: []probeCall{ + { + err: status.Error(codes.DeadlineExceeded, "timeout"), + }, + { + err: fmt.Errorf("mock error"), + }, + }, + expectError: true, + }, + { + name: "timeout + success", + probeCalls: []probeCall{ + { + err: status.Error(codes.DeadlineExceeded, "timeout"), + }, + { + err: status.Error(codes.DeadlineExceeded, "timeout"), + }, + { + response: &csi.ProbeResponse{ + Ready: &wrappers.BoolValue{Value: true}, + }, + }, + }, + expectError: false, + }, + { + name: "unready + failure", + probeCalls: []probeCall{ + { + response: &csi.ProbeResponse{ + Ready: &wrappers.BoolValue{Value: false}, + }, + }, + { + err: fmt.Errorf("mock error"), + }, + }, + expectError: true, + }, + { + name: "unready + success", + probeCalls: []probeCall{ + { + response: &csi.ProbeResponse{ + Ready: &wrappers.BoolValue{Value: false}, + }, + }, + { + response: &csi.ProbeResponse{ + Ready: &wrappers.BoolValue{Value: false}, + }, + }, + { + response: &csi.ProbeResponse{ + Ready: &wrappers.BoolValue{Value: true}, + }, + }, + }, + expectError: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tmp := tmpDir(t) + defer os.RemoveAll(tmp) + identity := &fakeIdentityServer{ + probeCalls: test.probeCalls, + } + addr, stopServer := startServer(t, tmp, identity, nil) + defer func() { + stopServer() + }() + + conn, err := connection.Connect(addr) + if err != nil { + t.Fatalf("Failed to connect to CSI driver: %s", err) + } + + err = ProbeForever(conn, time.Second) + if test.expectError && err == nil { + t.Errorf("Expected error, got none") + } + if !test.expectError && err != nil { + t.Errorf("Got error: %v", err) + } + if len(identity.probeCalls) != identity.probeCallCount { + t.Errorf("Expected %d probe calls, got %d", len(identity.probeCalls), identity.probeCallCount) + } + }) + } +} + +type fakeIdentityServer struct { + pluginInfoResponse *csi.GetPluginInfoResponse + getPluginCapabilitiesResponse *csi.GetPluginCapabilitiesResponse + err error + + probeCalls []probeCall + probeCallCount int +} + +type probeCall struct { + response *csi.ProbeResponse + err error +} + +var _ csi.IdentityServer = &fakeIdentityServer{} + +func (i *fakeIdentityServer) GetPluginCapabilities(context.Context, *csi.GetPluginCapabilitiesRequest) (*csi.GetPluginCapabilitiesResponse, error) { + return i.getPluginCapabilitiesResponse, i.err +} + +func (i *fakeIdentityServer) GetPluginInfo(context.Context, *csi.GetPluginInfoRequest) (*csi.GetPluginInfoResponse, error) { + return i.pluginInfoResponse, i.err +} + +func (i *fakeIdentityServer) Probe(context.Context, *csi.ProbeRequest) (*csi.ProbeResponse, error) { + if i.probeCallCount >= len(i.probeCalls) { + return nil, fmt.Errorf("Unexpected Probe() call") + } + call := i.probeCalls[i.probeCallCount] + i.probeCallCount++ + return call.response, call.err +} + +type fakeControllerServer struct { + controllerGetCapabilitiesResponse *csi.ControllerGetCapabilitiesResponse + err error +} + +var _ csi.ControllerServer = &fakeControllerServer{} + +func (c *fakeControllerServer) CreateVolume(context.Context, *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) { + return nil, fmt.Errorf("unimplemented") +} + +func (c *fakeControllerServer) DeleteVolume(context.Context, *csi.DeleteVolumeRequest) (*csi.DeleteVolumeResponse, error) { + return nil, fmt.Errorf("unimplemented") +} + +func (c *fakeControllerServer) ControllerPublishVolume(context.Context, *csi.ControllerPublishVolumeRequest) (*csi.ControllerPublishVolumeResponse, error) { + return nil, fmt.Errorf("unimplemented") +} + +func (c *fakeControllerServer) ControllerUnpublishVolume(context.Context, *csi.ControllerUnpublishVolumeRequest) (*csi.ControllerUnpublishVolumeResponse, error) { + return nil, fmt.Errorf("unimplemented") +} + +func (c *fakeControllerServer) ValidateVolumeCapabilities(context.Context, *csi.ValidateVolumeCapabilitiesRequest) (*csi.ValidateVolumeCapabilitiesResponse, error) { + return nil, fmt.Errorf("unimplemented") +} + +func (c *fakeControllerServer) ListVolumes(context.Context, *csi.ListVolumesRequest) (*csi.ListVolumesResponse, error) { + return nil, fmt.Errorf("unimplemented") +} + +func (c *fakeControllerServer) GetCapacity(context.Context, *csi.GetCapacityRequest) (*csi.GetCapacityResponse, error) { + return nil, fmt.Errorf("unimplemented") +} + +func (c *fakeControllerServer) ControllerGetCapabilities(context.Context, *csi.ControllerGetCapabilitiesRequest) (*csi.ControllerGetCapabilitiesResponse, error) { + return c.controllerGetCapabilitiesResponse, c.err +} + +func (c *fakeControllerServer) CreateSnapshot(context.Context, *csi.CreateSnapshotRequest) (*csi.CreateSnapshotResponse, error) { + return nil, fmt.Errorf("unimplemented") +} + +func (c *fakeControllerServer) DeleteSnapshot(context.Context, *csi.DeleteSnapshotRequest) (*csi.DeleteSnapshotResponse, error) { + return nil, fmt.Errorf("unimplemented") +} + +func (c *fakeControllerServer) ListSnapshots(context.Context, *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) { + return nil, fmt.Errorf("unimplemented") +}