Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ package driver

import (
"context"
"encoding/json"
"errors"
"fmt"
"net"
"sync"

Expand Down Expand Up @@ -102,7 +104,7 @@ func (c *CSIDriver) Start(l net.Listener) error {

// Create a new grpc server
c.server = grpc.NewServer(
grpc.UnaryInterceptor(c.authInterceptor),
grpc.UnaryInterceptor(c.callInterceptor),
)

// Register Mock servers
Expand Down Expand Up @@ -162,22 +164,49 @@ func (c *CSIDriver) SetDefaultCreds() {
}
}

func (c *CSIDriver) authInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
func (c *CSIDriver) callInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
err := c.authInterceptor(req)
if err != nil {
logGRPC(info.FullMethod, req, nil, err)
return nil, err
}
rsp, err := handler(ctx, req)
logGRPC(info.FullMethod, req, rsp, err)
return rsp, err
}

func (c *CSIDriver) authInterceptor(req interface{}) error {
if c.creds != nil {
authenticated, authErr := isAuthenticated(req, c.creds)
if !authenticated {
if authErr == ErrNoCredentials {
return nil, status.Error(codes.InvalidArgument, authErr.Error())
return status.Error(codes.InvalidArgument, authErr.Error())
}
if authErr == ErrAuthFailed {
return nil, status.Error(codes.Unauthenticated, authErr.Error())
return status.Error(codes.Unauthenticated, authErr.Error())
}
}
}
return nil
}

h, err := handler(ctx, req)

return h, err
func logGRPC(method string, request, reply interface{}, err error) {
// Log JSON with the request and response for easier parsing
logMessage := struct {
Method string
Request interface{}
Response interface{}
Error string
}{
Method: method,
Request: request,
Response: reply,
}
if err != nil {
logMessage.Error = err.Error()
}
msg, _ := json.Marshal(logMessage)
fmt.Printf("gRPCCall: %s\n", msg)
}

func isAuthenticated(req interface{}, creds *CSICreds) (bool, error) {
Expand Down
8 changes: 7 additions & 1 deletion mock/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
package main

import (
"flag"
"fmt"
"net"
"os"
Expand All @@ -28,6 +29,11 @@ import (
)

func main() {
var config service.Config
flag.BoolVar(&config.DisableAttach, "disable-attach", false, "Disables RPC_PUBLISH_UNPUBLISH_VOLUME capability.")
flag.StringVar(&config.DriverName, "name", service.Name, "CSI driver name.")
flag.Parse()

endpoint := os.Getenv("CSI_ENDPOINT")
if len(endpoint) == 0 {
fmt.Println("CSI_ENDPOINT must be defined and must be a path")
Expand All @@ -39,7 +45,7 @@ func main() {
}

// Create mock driver
s := service.New()
s := service.New(config)
servers := &driver.CSIDriverServers{
Controller: s,
Identity: s,
Expand Down
79 changes: 46 additions & 33 deletions mock/service/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ func (s *service) ControllerPublishVolume(
req *csi.ControllerPublishVolumeRequest) (
*csi.ControllerPublishVolumeResponse, error) {

if s.config.DisableAttach {
return nil, status.Error(codes.Unimplemented, "ControllerPublish is not supported")
}

if len(req.VolumeId) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty")
}
Expand Down Expand Up @@ -188,6 +192,10 @@ func (s *service) ControllerUnpublishVolume(
req *csi.ControllerUnpublishVolumeRequest) (
*csi.ControllerUnpublishVolumeResponse, error) {

if s.config.DisableAttach {
return nil, status.Error(codes.Unimplemented, "ControllerPublish is not supported")
}

if len(req.VolumeId) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty")
}
Expand Down Expand Up @@ -338,51 +346,56 @@ func (s *service) ControllerGetCapabilities(
req *csi.ControllerGetCapabilitiesRequest) (
*csi.ControllerGetCapabilitiesResponse, error) {

return &csi.ControllerGetCapabilitiesResponse{
Capabilities: []*csi.ControllerServiceCapability{
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME,
},
caps := []*csi.ControllerServiceCapability{
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME,
},
},
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME,
},
},
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_LIST_VOLUMES,
},
},
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_LIST_VOLUMES,
},
},
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_GET_CAPACITY,
},
},
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_GET_CAPACITY,
},
},
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_LIST_SNAPSHOTS,
},
},
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_LIST_SNAPSHOTS,
},
},
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_SNAPSHOT,
},
},
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_SNAPSHOT,
},
},
}

if !s.config.DisableAttach {
caps = append(caps, &csi.ControllerServiceCapability{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME,
},
},
},
})
}

return &csi.ControllerGetCapabilitiesResponse{
Capabilities: caps,
}, nil
}

Expand Down
2 changes: 1 addition & 1 deletion mock/service/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func (s *service) GetPluginInfo(
*csi.GetPluginInfoResponse, error) {

return &csi.GetPluginInfoResponse{
Name: Name,
Name: s.config.DriverName,
VendorVersion: VendorVersion,
Manifest: Manifest,
}, nil
Expand Down
20 changes: 14 additions & 6 deletions mock/service/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@ func (s *service) NodeStageVolume(

device, ok := req.PublishInfo["device"]
if !ok {
return nil, status.Error(
codes.InvalidArgument,
"stage volume info 'device' key required")
if s.config.DisableAttach {
device = "mock device"
} else {
return nil, status.Error(
codes.InvalidArgument,
"stage volume info 'device' key required")
}
}

if len(req.GetVolumeId()) == 0 {
Expand Down Expand Up @@ -105,9 +109,13 @@ func (s *service) NodePublishVolume(

device, ok := req.PublishInfo["device"]
if !ok {
return nil, status.Error(
codes.InvalidArgument,
"publish volume info 'device' key required")
if s.config.DisableAttach {
device = "mock device"
} else {
return nil, status.Error(
codes.InvalidArgument,
"stage volume info 'device' key required")
}
}

if len(req.GetVolumeId()) == 0 {
Expand Down
13 changes: 11 additions & 2 deletions mock/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ var Manifest = map[string]string{
"url": "https://github.com/kubernetes-csi/csi-test/mock",
}

type Config struct {
DisableAttach bool
DriverName string
}

// Service is the CSI Mock service provider.
type Service interface {
csi.ControllerServer
Expand All @@ -40,6 +45,7 @@ type service struct {
volsNID uint64
snapshots cache.SnapshotCache
snapshotsNID uint64
config Config
}

type Volume struct {
Expand All @@ -55,8 +61,11 @@ type Volume struct {
var MockVolumes map[string]Volume

// New returns a new Service.
func New() Service {
s := &service{nodeID: Name}
func New(config Config) Service {
s := &service{
nodeID: config.DriverName,
config: config,
}
s.snapshots = cache.NewSnapshotCache()
s.vols = []csi.Volume{
s.newVolume("Mock Volume 1", gib100),
Expand Down