Skip to content

Commit

Permalink
Add HostDialer interface
Browse files Browse the repository at this point in the history
There are multiple use cases where the Dialer interface is not
sufficient.

When using TLS, users need to control also TLS context per host.
For example, host certificates might be either UUID in common name,
some hostname, IP address per host, etc.

Hosted instances or instances deployed in Kubernetes cluster tend
to be behind proxies. A proxy might use TLS server name indication
to identify which database host to connect to.
  • Loading branch information
martin-sucha committed Jun 28, 2022
1 parent ae365fa commit 66ec152
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 71 deletions.
9 changes: 9 additions & 0 deletions cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ type ClusterConfig struct {
ProtoVersion int

// Connection timeout (default: 600ms)
// ConnectTimeout is used to set up the default dialer and is ignored if Dialer or HostDialer is provided.
Timeout time.Duration

// Initial connection timeout, used during initial dial to server (default: 600ms)
Expand Down Expand Up @@ -98,6 +99,7 @@ type ClusterConfig struct {
ReconnectionPolicy ReconnectionPolicy

// The keepalive period to use, enabled if > 0 (default: 0)
// SocketKeepalive is used to set up the default dialer and is ignored if Dialer or HostDialer is provided.
SocketKeepalive time.Duration

// Maximum cache size for prepared statements globally for gocql.
Expand All @@ -116,6 +118,8 @@ type ClusterConfig struct {
// Default: unset
SerialConsistency SerialConsistency

// SslOpts configures TLS use when HostDialer is not set.
// SslOpts is ignored if HostDialer is set.
SslOpts *SslOptions

// Sends a client side timestamp for all requests which overrides the timestamp at which it arrives at the server.
Expand Down Expand Up @@ -208,8 +212,13 @@ type ClusterConfig struct {

// Dialer will be used to establish all connections created for this Cluster.
// If not provided, a default dialer configured with ConnectTimeout will be used.
// Dialer is ignored if HostDialer is provided.
Dialer Dialer

// HostDialer will be used to establish all connections for this Cluster.
// If not provided, Dialer will be used instead.
HostDialer HostDialer

// Logger for this ClusterConfig.
// If not specified, defaults to the global gocql.Logger.
Logger StdLogger
Expand Down
60 changes: 9 additions & 51 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ type ConnConfig struct {
WriteTimeout time.Duration
ConnectTimeout time.Duration
Dialer Dialer
HostDialer HostDialer
Compressor Compressor
Authenticator Authenticator
AuthProvider func(h *HostInfo) (Authenticator, error)
Expand Down Expand Up @@ -231,53 +232,10 @@ func (s *Session) dial(ctx context.Context, host *HostInfo, connConfig *ConnConf
//
// dialWithoutObserver does not notify the connection observer, so you most probably want to call dial() instead.
func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
ip := host.ConnectAddress()
port := host.port

// TODO(zariel): remove these
if !validIpAddr(ip) {
panic(fmt.Sprintf("host missing connect ip address: %v", ip))
} else if port == 0 {
panic(fmt.Sprintf("host missing port: %v", port))
}

dialer := cfg.Dialer
if dialer == nil {
d := &net.Dialer{
Timeout: cfg.ConnectTimeout,
}
if cfg.Keepalive > 0 {
d.KeepAlive = cfg.Keepalive
}
dialer = d
}

addr := host.HostnameAndPort()
conn, err := dialer.DialContext(ctx, "tcp", addr)
dialedHost, err := cfg.HostDialer.DialHost(ctx, host)
if err != nil {
return nil, err
}
if cfg.tlsConfig != nil {
// the TLS config is safe to be reused by connections but it must not
// be modified after being used.
tlsConfig := cfg.tlsConfig
if !tlsConfig.InsecureSkipVerify && tlsConfig.ServerName == "" {
colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 {
colonPos = len(addr)
}
hostname := addr[:colonPos]
// clone config to avoid modifying the shared one.
tlsConfig = tlsConfig.Clone()
tlsConfig.ServerName = hostname
}
tconn := tls.Client(conn, tlsConfig)
if err := tconn.Handshake(); err != nil {
conn.Close()
return nil, err
}
conn = tconn
}

writeTimeout := cfg.Timeout
if cfg.WriteTimeout > 0 {
Expand All @@ -286,20 +244,20 @@ func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg *

ctx, cancel := context.WithCancel(ctx)
c := &Conn{
conn: conn,
r: bufio.NewReader(conn),
conn: dialedHost.Conn,
r: bufio.NewReader(dialedHost.Conn),
cfg: cfg,
calls: make(map[int]*callReq),
version: uint8(cfg.ProtoVersion),
addr: conn.RemoteAddr().String(),
addr: dialedHost.Conn.RemoteAddr().String(),
errorHandler: errorHandler,
compressor: cfg.Compressor,
session: s,
streams: streams.New(cfg.ProtoVersion),
host: host,
frameObserver: s.frameObserver,
w: &deadlineContextWriter{
w: conn,
w: dialedHost.Conn,
timeout: writeTimeout,
semaphore: make(chan struct{}, 1),
quit: make(chan struct{}),
Expand All @@ -311,7 +269,7 @@ func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg *
writeTimeout: writeTimeout,
}

if err := c.init(ctx); err != nil {
if err := c.init(ctx, dialedHost); err != nil {
cancel()
c.Close()
return nil, err
Expand All @@ -320,7 +278,7 @@ func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg *
return c, nil
}

func (c *Conn) init(ctx context.Context) error {
func (c *Conn) init(ctx context.Context, dialedHost *DialedHost) error {
if c.session.cfg.AuthProvider != nil {
var err error
c.auth, err = c.cfg.AuthProvider(c.host)
Expand All @@ -344,7 +302,7 @@ func (c *Conn) init(ctx context.Context) error {
c.timeout = c.cfg.Timeout

// dont coalesce startup frames
if c.session.cfg.WriteCoalesceWaitTime > 0 && !c.cfg.disableCoalesce {
if c.session.cfg.WriteCoalesceWaitTime > 0 && !c.cfg.disableCoalesce && !dialedHost.DisableCoalesce {
c.w = newWriteCoalescer(c.conn, c.writeTimeout, c.session.cfg.WriteCoalesceWaitTime, ctx.Done())
}

Expand Down
60 changes: 40 additions & 20 deletions connectionpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,32 +91,52 @@ type policyConnPool struct {

func connConfig(cfg *ClusterConfig) (*ConnConfig, error) {
var (
err error
tlsConfig *tls.Config
err error
hostDialer HostDialer
)

// TODO(zariel): move tls config setup into session init.
if cfg.SslOpts != nil {
tlsConfig, err = setupTLSConfig(cfg.SslOpts)
if err != nil {
return nil, err
hostDialer = cfg.HostDialer
if hostDialer == nil {
var tlsConfig *tls.Config

// TODO(zariel): move tls config setup into session init.
if cfg.SslOpts != nil {
tlsConfig, err = setupTLSConfig(cfg.SslOpts)
if err != nil {
return nil, err
}
}

dialer := cfg.Dialer
if dialer == nil {
d := &net.Dialer{
Timeout: cfg.ConnectTimeout,
}
if cfg.SocketKeepalive > 0 {
d.KeepAlive = cfg.SocketKeepalive
}
dialer = d
}

hostDialer = &defaultHostDialer{
dialer: dialer,
tlsConfig: tlsConfig,
}
}

return &ConnConfig{
ProtoVersion: cfg.ProtoVersion,
CQLVersion: cfg.CQLVersion,
Timeout: cfg.Timeout,
WriteTimeout: cfg.WriteTimeout,
ConnectTimeout: cfg.ConnectTimeout,
Dialer: cfg.Dialer,
Compressor: cfg.Compressor,
Authenticator: cfg.Authenticator,
AuthProvider: cfg.AuthProvider,
Keepalive: cfg.SocketKeepalive,
Logger: cfg.logger(),
tlsConfig: tlsConfig,
disableCoalesce: tlsConfig != nil, // write coalescing doesn't work with framing on top of TCP like in TLS.
ProtoVersion: cfg.ProtoVersion,
CQLVersion: cfg.CQLVersion,
Timeout: cfg.Timeout,
WriteTimeout: cfg.WriteTimeout,
ConnectTimeout: cfg.ConnectTimeout,
Dialer: cfg.Dialer,
HostDialer: hostDialer,
Compressor: cfg.Compressor,
Authenticator: cfg.Authenticator,
AuthProvider: cfg.AuthProvider,
Keepalive: cfg.SocketKeepalive,
Logger: cfg.logger(),
}, nil
}

Expand Down
90 changes: 90 additions & 0 deletions dial.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package gocql

import (
"context"
"crypto/tls"
"fmt"
"net"
"strings"
)

// HostDialer allows customizing connection to cluster nodes.
type HostDialer interface {
// DialHost establishes a connection to the host.
// The returned connection must be directly usable for CQL protocol,
// specifically DialHost is responsible also for setting up the TLS session if needed.
// DialHost should disable write coalescing if the returned net.Conn does not support writev.
// As of Go 1.18, only plain TCP connections support writev, TLS sessions should disable coalescing.
// You can use WrapTLS helper function if you don't need to override the TLS setup.
DialHost(ctx context.Context, host *HostInfo) (*DialedHost, error)
}

// DialedHost contains information about established connection to a host.
type DialedHost struct {
// Conn used to communicate with the server.
Conn net.Conn

// DisableCoalesce disables write coalescing for the Conn.
// If true, the effect is the same as if WriteCoalesceWaitTime was configured to 0.
DisableCoalesce bool
}

// defaultHostDialer dials host in a default way.
type defaultHostDialer struct {
dialer Dialer
tlsConfig *tls.Config
}

func (hd *defaultHostDialer) DialHost(ctx context.Context, host *HostInfo) (*DialedHost, error) {
ip := host.ConnectAddress()
port := host.Port()

if !validIpAddr(ip) {
return nil, fmt.Errorf("host missing connect ip address: %v", ip)
} else if port == 0 {
return nil, fmt.Errorf("host missing port: %v", port)
}

addr := host.HostnameAndPort()
conn, err := hd.dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, err
}
return WrapTLS(ctx, conn, addr, hd.tlsConfig)
}

func tlsConfigForAddr(tlsConfig *tls.Config, addr string) *tls.Config {
// the TLS config is safe to be reused by connections but it must not
// be modified after being used.
if !tlsConfig.InsecureSkipVerify && tlsConfig.ServerName == "" {
colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 {
colonPos = len(addr)
}
hostname := addr[:colonPos]
// clone config to avoid modifying the shared one.
tlsConfig = tlsConfig.Clone()
tlsConfig.ServerName = hostname
}
return tlsConfig
}

// WrapTLS optionally wraps a net.Conn connected to addr with the given tlsConfig.
// If the tlsConfig is nil, conn is not wrapped into a TLS session, so is insecure.
// If the tlsConfig does not have server name set, it is updated based on the default gocql rules.
func WrapTLS(ctx context.Context, conn net.Conn, addr string, tlsConfig *tls.Config) (*DialedHost, error) {
if tlsConfig != nil {
tlsConfig := tlsConfigForAddr(tlsConfig, addr)
tconn := tls.Client(conn, tlsConfig)
if err := tconn.HandshakeContext(ctx); err != nil {
conn.Close()
return nil, err
}
conn = tconn
}

return &DialedHost{
Conn: conn,
DisableCoalesce: tlsConfig != nil, // write coalescing can't use writev when the connection is wrapped.
}, nil
}

0 comments on commit 66ec152

Please sign in to comment.