Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 35 additions & 44 deletions cmd/mock-driver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@ limitations under the License.
package main

import (
"context"
"flag"
"fmt"
"io/ioutil"
"net"
"os"
"os/signal"
"strings"
"syscall"

"github.com/kubernetes-csi/csi-test/v4/driver"
"github.com/kubernetes-csi/csi-test/v4/internal/endpoint"
"github.com/kubernetes-csi/csi-test/v4/internal/proxy"
"github.com/kubernetes-csi/csi-test/v4/mock/service"
"gopkg.in/yaml.v2"
"k8s.io/klog/v2"
Expand All @@ -50,13 +50,37 @@ func main() {
flag.BoolVar(&config.DisableOnlineExpansion, "disable-online-expansion", false, "Disables online volume expansion capability.")
flag.BoolVar(&config.PermissiveTargetPath, "permissive-target-path", false, "Allows the CO to create PublishVolumeRequest.TargetPath, which violates the CSI spec.")
flag.StringVar(&hooksFile, "hooks-file", "", "YAML file with hook scripts.")
proxyEndpoint := flag.String("proxy-endpoint", "", "Instead of running the CSI driver code, just proxy connections from $CSI_ENDPOINT to the given listening socket.")
flag.Parse()

endpoint := os.Getenv("CSI_ENDPOINT")
csiEndpoint := os.Getenv("CSI_ENDPOINT")
controllerEndpoint := os.Getenv("CSI_CONTROLLER_ENDPOINT")
if len(controllerEndpoint) == 0 {
// If empty, set to the common endpoint.
controllerEndpoint = endpoint
controllerEndpoint = csiEndpoint
}

if *proxyEndpoint != "" {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
closer, err := proxy.Run(ctx, csiEndpoint, *proxyEndpoint)
if err != nil {
klog.Fatalf("failed to run proxy: %v", err)
}
defer closer.Close()

// Wait for signal
sigc := make(chan os.Signal, 1)
sigs := []os.Signal{
syscall.SIGTERM,
syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGQUIT,
}
signal.Notify(sigc, sigs...)

<-sigc
return
}

if hooksFile != "" {
Expand All @@ -71,7 +95,7 @@ func main() {
// Create mock driver
s := service.New(config)

if endpoint == controllerEndpoint {
if csiEndpoint == controllerEndpoint {
servers := &driver.CSIDriverServers{
Controller: s,
Identity: s,
Expand All @@ -86,10 +110,10 @@ func main() {
}

// Listen
l, cleanup, err := listen(endpoint)
l, cleanup, err := endpoint.Listen(csiEndpoint)
if err != nil {
klog.Exitf("Error: Unable to listen on %s socket: %v\n",
endpoint,
csiEndpoint,
err)
}
defer cleanup()
Expand Down Expand Up @@ -134,7 +158,7 @@ func main() {
}

// Listen controller.
l, cleanupController, err := listen(controllerEndpoint)
l, cleanupController, err := endpoint.Listen(controllerEndpoint)
if err != nil {
klog.Exitf("Error: Unable to listen on %s socket: %v\n",
controllerEndpoint,
Expand All @@ -150,10 +174,10 @@ func main() {
klog.Infof("mock controller driver started")

// Listen node.
l, cleanupNode, err := listen(endpoint)
l, cleanupNode, err := endpoint.Listen(csiEndpoint)
if err != nil {
klog.Exitf("Error: Unable to listen on %s socket: %v\n",
endpoint,
csiEndpoint,
err)
}
defer cleanupNode()
Expand Down Expand Up @@ -182,39 +206,6 @@ func main() {
}
}

func parseEndpoint(ep string) (string, string, error) {
if strings.HasPrefix(strings.ToLower(ep), "unix://") || strings.HasPrefix(strings.ToLower(ep), "tcp://") {
s := strings.SplitN(ep, "://", 2)
if s[1] != "" {
return s[0], s[1], nil
}
return "", "", fmt.Errorf("Invalid endpoint: %v", ep)
}
// Assume everything else is a file path for a Unix Domain Socket.
return "unix", ep, nil
}

func listen(endpoint string) (net.Listener, func(), error) {
proto, addr, err := parseEndpoint(endpoint)
if err != nil {
return nil, nil, err
}

cleanup := func() {}
if proto == "unix" {
addr = "/" + addr
if err := os.Remove(addr); err != nil && !os.IsNotExist(err) { //nolint: vetshadow
return nil, nil, fmt.Errorf("%s: %q", addr, err)
}
cleanup = func() {
os.Remove(addr)
}
}

l, err := net.Listen(proto, addr)
return l, cleanup, err
}

func parseHooksFile(file string) (*service.Hooks, error) {
var hooks service.Hooks

Expand Down
57 changes: 57 additions & 0 deletions internal/endpoint/endpoint.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
Copyright 2020 Kubernetes Authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package endpoint

import (
"fmt"
"net"
"os"
"strings"
)

func Parse(ep string) (string, string, error) {
if strings.HasPrefix(strings.ToLower(ep), "unix://") || strings.HasPrefix(strings.ToLower(ep), "tcp://") {
s := strings.SplitN(ep, "://", 2)
if s[1] != "" {
return s[0], s[1], nil
}
return "", "", fmt.Errorf("Invalid endpoint: %v", ep)
}
// Assume everything else is a file path for a Unix Domain Socket.
return "unix", ep, nil
}

func Listen(endpoint string) (net.Listener, func(), error) {
proto, addr, err := Parse(endpoint)
if err != nil {
return nil, nil, err
}

cleanup := func() {}
if proto == "unix" {
addr = "/" + addr
if err := os.Remove(addr); err != nil && !os.IsNotExist(err) { //nolint: vetshadow
return nil, nil, fmt.Errorf("%s: %q", addr, err)
}
cleanup = func() {
os.Remove(addr)
}
}

l, err := net.Listen(proto, addr)
return l, cleanup, err
}
146 changes: 146 additions & 0 deletions internal/proxy/proxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
Copyright 2020 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

// Package proxy makes it possible to forward a listening socket in
// situations where the proxy cannot connect to some other address.
// Instead, it creates two listening sockets, pairs two incoming
// connections and then moves data back and forth. This matches
// the behavior of the following socat command:
// socat -d -d -d UNIX-LISTEN:/tmp/socat,fork TCP-LISTEN:9000,reuseport
//
// The advantage over that command is that both listening
// sockets are always open, in contrast to the socat solution
// where the TCP port is only open when there actually is a connection
// available.
//
// To establish a connection, someone has to poll the proxy with a dialer.
package proxy

import (
"context"
"fmt"
"io"
"net"

"k8s.io/klog/v2"

"github.com/kubernetes-csi/csi-test/v4/internal/endpoint"
)

// New listens on both endpoints and starts accepting connections
// until closed or the context is done.
func Run(ctx context.Context, endpoint1, endpoint2 string) (io.Closer, error) {
proxy := &proxy{}
failedProxy := proxy
defer func() {
if failedProxy != nil {
failedProxy.Close()
}
}()

proxy.ctx, proxy.cancel = context.WithCancel(ctx)

var err error
proxy.s1, proxy.cleanup1, err = endpoint.Listen(endpoint1)
if err != nil {
return nil, fmt.Errorf("listen %s: %v", endpoint1, err)
}
proxy.s2, proxy.cleanup2, err = endpoint.Listen(endpoint2)
if err != nil {
return nil, fmt.Errorf("listen %s: %v", endpoint2, err)
}

klog.V(3).Infof("proxy listening on %s and %s", endpoint1, endpoint2)

go func() {
for {
// We block on the first listening socket.
// The Linux kernel proactively accepts connections
// on the second one which we will take over below.
conn1 := accept(proxy.ctx, proxy.s1, endpoint1)
if conn1 == nil {
// Done, shut down.
klog.V(5).Infof("proxy endpoint %s closed, shutting down", endpoint1)
return
}
conn2 := accept(proxy.ctx, proxy.s2, endpoint2)
if conn2 == nil {
// Done, shut down. The already accepted
// connection gets closed.
klog.V(5).Infof("proxy endpoint %s closed, shutting down and close established connection", endpoint2)
conn1.Close()
return
}

klog.V(3).Infof("proxy established a new connection between %s and %s", endpoint1, endpoint2)
go copy(conn1, conn2, endpoint1, endpoint2)
go copy(conn2, conn1, endpoint2, endpoint1)
}
}()

failedProxy = nil
return proxy, nil
}

type proxy struct {
ctx context.Context
cancel func()
s1, s2 net.Listener
cleanup1, cleanup2 func()
}

func (p *proxy) Close() error {
if p.cancel != nil {
p.cancel()
}
if p.s1 != nil {
p.s1.Close()
}
if p.s2 != nil {
p.s2.Close()
}
if p.cleanup1 != nil {
p.cleanup1()
}
if p.cleanup2 != nil {
p.cleanup2()
}
return nil
}

func copy(from, to net.Conn, fromEndpoint, toEndpoint string) {
klog.V(5).Infof("starting to copy %s -> %s", fromEndpoint, toEndpoint)
// Signal recipient that no more data is going to come.
// This also stops reading from it.
defer to.Close()
// Copy data until EOF.
cnt, err := io.Copy(to, from)
klog.V(5).Infof("done copying %s -> %s: %d bytes, %v", fromEndpoint, toEndpoint, cnt, err)
}

func accept(ctx context.Context, s net.Listener, endpoint string) net.Conn {
for {
c, err := s.Accept()
if err == nil {
return c
}
// Ignore error if we are shutting down.
if ctx.Err() != nil {
return nil
}
klog.V(3).Infof("accept on %s failed: %v", endpoint, err)
}
}
Loading