Skip to content

Commit

Permalink
Add GetDriverName function
Browse files Browse the repository at this point in the history
  • Loading branch information
jsafrane committed Feb 11, 2019
1 parent 763b95d commit e7886c7
Show file tree
Hide file tree
Showing 6 changed files with 5,608 additions and 10 deletions.
9 changes: 9 additions & 0 deletions Gopkg.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions Gopkg.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Refer to https://golang.github.io/dep/docs/Gopkg.toml.html
# for detailed Gopkg.toml documentation.

[[constraint]]
name = "github.com/container-storage-interface/spec"
version = "1.0.0"

[prune]
go-tests = true
non-go = true
Expand Down
17 changes: 17 additions & 0 deletions connection/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ package connection
import (
"context"
"errors"
"fmt"
"net"
"strings"
"time"

"github.com/container-storage-interface/spec/lib/go/csi"
"github.com/kubernetes-csi/csi-lib-utils/protosanitizer"
"google.golang.org/grpc"
"k8s.io/klog"
Expand Down Expand Up @@ -160,3 +162,18 @@ func LogGRPC(ctx context.Context, method string, req, reply interface{}, cc *grp
klog.V(5).Infof("GRPC error: %v", err)
return err
}

func GetDriverName(ctx context.Context, conn *grpc.ClientConn) (string, error) {
client := csi.NewIdentityClient(conn)

req := csi.GetPluginInfoRequest{}
rsp, err := client.GetPluginInfo(ctx, &req)
if err != nil {
return "", err
}
name := rsp.GetName()
if name == "" {
return "", fmt.Errorf("driver name is empty")
}
return name, nil
}
110 changes: 100 additions & 10 deletions connection/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package connection

import (
"context"
"fmt"
"io/ioutil"
"net"
"os"
Expand All @@ -33,6 +34,8 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/container-storage-interface/spec/lib/go/csi"
)

func tmpDir(t *testing.T) string {
Expand All @@ -48,11 +51,14 @@ const (
// startServer creates a gRPC server without any registered services.
// The returned address can be used to connect to it. The cleanup
// function stops it. It can be called multiple times.
func startServer(t *testing.T, tmp string) (string, func()) {
func startServer(t *testing.T, tmp string, identity csi.IdentityServer) (string, func()) {
addr := path.Join(tmp, serverSock)
listener, err := net.Listen("unix", addr)
require.NoError(t, err, "listening on %s", addr)
server := grpc.NewServer()
if identity != nil {
csi.RegisterIdentityServer(server, identity)
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
Expand All @@ -73,7 +79,7 @@ func startServer(t *testing.T, tmp string) (string, func()) {
func TestConnect(t *testing.T) {
tmp := tmpDir(t)
defer os.RemoveAll(tmp)
addr, stopServer := startServer(t, tmp)
addr, stopServer := startServer(t, tmp, nil)
defer stopServer()

conn, err := Connect(addr)
Expand All @@ -88,7 +94,7 @@ func TestConnect(t *testing.T) {
func TestConnectUnix(t *testing.T) {
tmp := tmpDir(t)
defer os.RemoveAll(tmp)
addr, stopServer := startServer(t, tmp)
addr, stopServer := startServer(t, tmp, nil)
defer stopServer()

conn, err := Connect("unix:///" + addr)
Expand Down Expand Up @@ -129,7 +135,7 @@ func TestWaitForServer(t *testing.T) {
t.Logf("sleeping %s before starting server", delay)
time.Sleep(delay)
startTimeServer = time.Now()
_, stopServer = startServer(t, tmp)
_, stopServer = startServer(t, tmp, nil)
}()
conn, err := Connect(path.Join(tmp, serverSock))
if assert.NoError(t, err, "connect via absolute path") {
Expand Down Expand Up @@ -163,7 +169,7 @@ func TestTimout(t *testing.T) {
func TestReconnect(t *testing.T) {
tmp := tmpDir(t)
defer os.RemoveAll(tmp)
addr, stopServer := startServer(t, tmp)
addr, stopServer := startServer(t, tmp, nil)
defer func() {
stopServer()
}()
Expand All @@ -190,7 +196,7 @@ func TestReconnect(t *testing.T) {
}

// No reconnection either when the server comes back.
_, stopServer = startServer(t, tmp)
_, stopServer = startServer(t, tmp, nil)
// We need to give gRPC some time. It does not attempt to reconnect
// immediately. If we send the method call too soon, the test passes
// even though a later method call will go through again.
Expand All @@ -208,7 +214,7 @@ func TestReconnect(t *testing.T) {
func TestDisconnect(t *testing.T) {
tmp := tmpDir(t)
defer os.RemoveAll(tmp)
addr, stopServer := startServer(t, tmp)
addr, stopServer := startServer(t, tmp, nil)
defer func() {
stopServer()
}()
Expand Down Expand Up @@ -239,7 +245,7 @@ func TestDisconnect(t *testing.T) {
}

// No reconnection either when the server comes back.
_, stopServer = startServer(t, tmp)
_, stopServer = startServer(t, tmp, nil)
// We need to give gRPC some time. It does not attempt to reconnect
// immediately. If we send the method call too soon, the test passes
// even though a later method call will go through again.
Expand All @@ -259,7 +265,7 @@ func TestDisconnect(t *testing.T) {
func TestExplicitReconnect(t *testing.T) {
tmp := tmpDir(t)
defer os.RemoveAll(tmp)
addr, stopServer := startServer(t, tmp)
addr, stopServer := startServer(t, tmp, nil)
defer func() {
stopServer()
}()
Expand Down Expand Up @@ -290,7 +296,7 @@ func TestExplicitReconnect(t *testing.T) {
}

// No reconnection either when the server comes back.
_, stopServer = startServer(t, tmp)
_, stopServer = startServer(t, tmp, nil)
// We need to give gRPC some time. It does not attempt to reconnect
// immediately. If we send the method call too soon, the test passes
// even though a later method call will go through again.
Expand All @@ -306,3 +312,87 @@ func TestExplicitReconnect(t *testing.T) {
assert.Equal(t, 1, reconnectCount, "connection loss callback should be called once")
}
}

func TestGetDriverName(t *testing.T) {
tests := []struct {
name string
output *csi.GetPluginInfoResponse
injectError bool
expectError bool
}{
{
name: "success",
output: &csi.GetPluginInfoResponse{
Name: "csi/example",
VendorVersion: "0.2.0",
Manifest: map[string]string{
"hello": "world",
},
},
expectError: false,
},
{
name: "gRPC error",
output: nil,
injectError: true,
expectError: true,
},
{
name: "empty name",
output: &csi.GetPluginInfoResponse{
Name: "",
},
expectError: true,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
out := test.output
var injectedErr error
if test.injectError {
injectedErr = fmt.Errorf("mock error")
}

tmp := tmpDir(t)
defer os.RemoveAll(tmp)
identity := &identityServer{out, injectedErr}
addr, stopServer := startServer(t, tmp, identity)
defer func() {
stopServer()
}()

conn, err := Connect(addr)

name, err := GetDriverName(context.Background(), conn)
if test.expectError && err == nil {
t.Errorf("test %q: Expected error, got none", test.name)
}
if !test.expectError && err != nil {
t.Errorf("test %q: got error: %v", test.name, err)
}
if err == nil && name != "csi/example" {
t.Errorf("got unexpected name: %q", name)
}
})
}
}

type identityServer struct {
response *csi.GetPluginInfoResponse
err error
}

var _ csi.IdentityServer = &identityServer{}

func (i *identityServer) GetPluginCapabilities(context.Context, *csi.GetPluginCapabilitiesRequest) (*csi.GetPluginCapabilitiesResponse, error) {
return nil, fmt.Errorf("Not implemented")
}

func (i *identityServer) GetPluginInfo(context.Context, *csi.GetPluginInfoRequest) (*csi.GetPluginInfoResponse, error) {
return i.response, i.err
}

func (i *identityServer) Probe(context.Context, *csi.ProbeRequest) (*csi.ProbeResponse, error) {
return nil, fmt.Errorf("Not implemented")
}
Loading

0 comments on commit e7886c7

Please sign in to comment.