Skip to content

Commit

Permalink
Merge Add HostDialer interface (apache#1629)
Browse files Browse the repository at this point in the history
  • Loading branch information
martin-sucha committed Jul 1, 2022
2 parents 4c7ec9a + 66ec152 commit 7a6cf00
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 71 deletions.
9 changes: 9 additions & 0 deletions cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ type ClusterConfig struct {
ProtoVersion int

// Connection timeout (default: 600ms)
// ConnectTimeout is used to set up the default dialer and is ignored if Dialer or HostDialer is provided.
Timeout time.Duration

// Initial connection timeout, used during initial dial to server (default: 600ms)
Expand Down Expand Up @@ -98,6 +99,7 @@ type ClusterConfig struct {
ReconnectionPolicy ReconnectionPolicy

// The keepalive period to use, enabled if > 0 (default: 0)
// SocketKeepalive is used to set up the default dialer and is ignored if Dialer or HostDialer is provided.
SocketKeepalive time.Duration

// Maximum cache size for prepared statements globally for gocql.
Expand All @@ -116,6 +118,8 @@ type ClusterConfig struct {
// Default: unset
SerialConsistency SerialConsistency

// SslOpts configures TLS use when HostDialer is not set.
// SslOpts is ignored if HostDialer is set.
SslOpts *SslOptions

// Sends a client side timestamp for all requests which overrides the timestamp at which it arrives at the server.
Expand Down Expand Up @@ -208,8 +212,13 @@ type ClusterConfig struct {

// Dialer will be used to establish all connections created for this Cluster.
// If not provided, a default dialer configured with ConnectTimeout will be used.
// Dialer is ignored if HostDialer is provided.
Dialer Dialer

// HostDialer will be used to establish all connections for this Cluster.
// If not provided, Dialer will be used instead.
HostDialer HostDialer

// Logger for this ClusterConfig.
// If not specified, defaults to the global gocql.Logger.
Logger StdLogger
Expand Down
60 changes: 9 additions & 51 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ type ConnConfig struct {
WriteTimeout time.Duration
ConnectTimeout time.Duration
Dialer Dialer
HostDialer HostDialer
Compressor Compressor
Authenticator Authenticator
AuthProvider func(h *HostInfo) (Authenticator, error)
Expand Down Expand Up @@ -231,53 +232,10 @@ func (s *Session) dial(ctx context.Context, host *HostInfo, connConfig *ConnConf
//
// dialWithoutObserver does not notify the connection observer, so you most probably want to call dial() instead.
func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
ip := host.ConnectAddress()
port := host.port

// TODO(zariel): remove these
if !validIpAddr(ip) {
panic(fmt.Sprintf("host missing connect ip address: %v", ip))
} else if port == 0 {
panic(fmt.Sprintf("host missing port: %v", port))
}

dialer := cfg.Dialer
if dialer == nil {
d := &net.Dialer{
Timeout: cfg.ConnectTimeout,
}
if cfg.Keepalive > 0 {
d.KeepAlive = cfg.Keepalive
}
dialer = d
}

addr := host.HostnameAndPort()
conn, err := dialer.DialContext(ctx, "tcp", addr)
dialedHost, err := cfg.HostDialer.DialHost(ctx, host)
if err != nil {
return nil, err
}
if cfg.tlsConfig != nil {
// the TLS config is safe to be reused by connections but it must not
// be modified after being used.
tlsConfig := cfg.tlsConfig
if !tlsConfig.InsecureSkipVerify && tlsConfig.ServerName == "" {
colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 {
colonPos = len(addr)
}
hostname := addr[:colonPos]
// clone config to avoid modifying the shared one.
tlsConfig = tlsConfig.Clone()
tlsConfig.ServerName = hostname
}
tconn := tls.Client(conn, tlsConfig)
if err := tconn.Handshake(); err != nil {
conn.Close()
return nil, err
}
conn = tconn
}

writeTimeout := cfg.Timeout
if cfg.WriteTimeout > 0 {
Expand All @@ -286,20 +244,20 @@ func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg *

ctx, cancel := context.WithCancel(ctx)
c := &Conn{
conn: conn,
r: bufio.NewReader(conn),
conn: dialedHost.Conn,
r: bufio.NewReader(dialedHost.Conn),
cfg: cfg,
calls: make(map[int]*callReq),
version: uint8(cfg.ProtoVersion),
addr: conn.RemoteAddr().String(),
addr: dialedHost.Conn.RemoteAddr().String(),
errorHandler: errorHandler,
compressor: cfg.Compressor,
session: s,
streams: streams.New(cfg.ProtoVersion),
host: host,
frameObserver: s.frameObserver,
w: &deadlineContextWriter{
w: conn,
w: dialedHost.Conn,
timeout: writeTimeout,
semaphore: make(chan struct{}, 1),
quit: make(chan struct{}),
Expand All @@ -311,7 +269,7 @@ func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg *
writeTimeout: writeTimeout,
}

if err := c.init(ctx); err != nil {
if err := c.init(ctx, dialedHost); err != nil {
cancel()
c.Close()
return nil, err
Expand All @@ -320,7 +278,7 @@ func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg *
return c, nil
}

func (c *Conn) init(ctx context.Context) error {
func (c *Conn) init(ctx context.Context, dialedHost *DialedHost) error {
if c.session.cfg.AuthProvider != nil {
var err error
c.auth, err = c.cfg.AuthProvider(c.host)
Expand All @@ -344,7 +302,7 @@ func (c *Conn) init(ctx context.Context) error {
c.timeout = c.cfg.Timeout

// dont coalesce startup frames
if c.session.cfg.WriteCoalesceWaitTime > 0 && !c.cfg.disableCoalesce {
if c.session.cfg.WriteCoalesceWaitTime > 0 && !c.cfg.disableCoalesce && !dialedHost.DisableCoalesce {
c.w = newWriteCoalescer(c.conn, c.writeTimeout, c.session.cfg.WriteCoalesceWaitTime, ctx.Done())
}

Expand Down
60 changes: 40 additions & 20 deletions connectionpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,32 +89,52 @@ type policyConnPool struct {

func connConfig(cfg *ClusterConfig) (*ConnConfig, error) {
var (
err error
tlsConfig *tls.Config
err error
hostDialer HostDialer
)

// TODO(zariel): move tls config setup into session init.
if cfg.SslOpts != nil {
tlsConfig, err = setupTLSConfig(cfg.SslOpts)
if err != nil {
return nil, err
hostDialer = cfg.HostDialer
if hostDialer == nil {
var tlsConfig *tls.Config

// TODO(zariel): move tls config setup into session init.
if cfg.SslOpts != nil {
tlsConfig, err = setupTLSConfig(cfg.SslOpts)
if err != nil {
return nil, err
}
}

dialer := cfg.Dialer
if dialer == nil {
d := &net.Dialer{
Timeout: cfg.ConnectTimeout,
}
if cfg.SocketKeepalive > 0 {
d.KeepAlive = cfg.SocketKeepalive
}
dialer = d
}

hostDialer = &defaultHostDialer{
dialer: dialer,
tlsConfig: tlsConfig,
}
}

return &ConnConfig{
ProtoVersion: cfg.ProtoVersion,
CQLVersion: cfg.CQLVersion,
Timeout: cfg.Timeout,
WriteTimeout: cfg.WriteTimeout,
ConnectTimeout: cfg.ConnectTimeout,
Dialer: cfg.Dialer,
Compressor: cfg.Compressor,
Authenticator: cfg.Authenticator,
AuthProvider: cfg.AuthProvider,
Keepalive: cfg.SocketKeepalive,
Logger: cfg.logger(),
tlsConfig: tlsConfig,
disableCoalesce: tlsConfig != nil, // write coalescing doesn't work with framing on top of TCP like in TLS.
ProtoVersion: cfg.ProtoVersion,
CQLVersion: cfg.CQLVersion,
Timeout: cfg.Timeout,
WriteTimeout: cfg.WriteTimeout,
ConnectTimeout: cfg.ConnectTimeout,
Dialer: cfg.Dialer,
HostDialer: hostDialer,
Compressor: cfg.Compressor,
Authenticator: cfg.Authenticator,
AuthProvider: cfg.AuthProvider,
Keepalive: cfg.SocketKeepalive,
Logger: cfg.logger(),
}, nil
}

Expand Down
90 changes: 90 additions & 0 deletions dial.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package gocql

import (
"context"
"crypto/tls"
"fmt"
"net"
"strings"
)

// HostDialer allows customizing connection to cluster nodes.
type HostDialer interface {
// DialHost establishes a connection to the host.
// The returned connection must be directly usable for CQL protocol,
// specifically DialHost is responsible also for setting up the TLS session if needed.
// DialHost should disable write coalescing if the returned net.Conn does not support writev.
// As of Go 1.18, only plain TCP connections support writev, TLS sessions should disable coalescing.
// You can use WrapTLS helper function if you don't need to override the TLS setup.
DialHost(ctx context.Context, host *HostInfo) (*DialedHost, error)
}

// DialedHost contains information about established connection to a host.
type DialedHost struct {
// Conn used to communicate with the server.
Conn net.Conn

// DisableCoalesce disables write coalescing for the Conn.
// If true, the effect is the same as if WriteCoalesceWaitTime was configured to 0.
DisableCoalesce bool
}

// defaultHostDialer dials host in a default way.
type defaultHostDialer struct {
dialer Dialer
tlsConfig *tls.Config
}

func (hd *defaultHostDialer) DialHost(ctx context.Context, host *HostInfo) (*DialedHost, error) {
ip := host.ConnectAddress()
port := host.Port()

if !validIpAddr(ip) {
return nil, fmt.Errorf("host missing connect ip address: %v", ip)
} else if port == 0 {
return nil, fmt.Errorf("host missing port: %v", port)
}

addr := host.HostnameAndPort()
conn, err := hd.dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, err
}
return WrapTLS(ctx, conn, addr, hd.tlsConfig)
}

func tlsConfigForAddr(tlsConfig *tls.Config, addr string) *tls.Config {
// the TLS config is safe to be reused by connections but it must not
// be modified after being used.
if !tlsConfig.InsecureSkipVerify && tlsConfig.ServerName == "" {
colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 {
colonPos = len(addr)
}
hostname := addr[:colonPos]
// clone config to avoid modifying the shared one.
tlsConfig = tlsConfig.Clone()
tlsConfig.ServerName = hostname
}
return tlsConfig
}

// WrapTLS optionally wraps a net.Conn connected to addr with the given tlsConfig.
// If the tlsConfig is nil, conn is not wrapped into a TLS session, so is insecure.
// If the tlsConfig does not have server name set, it is updated based on the default gocql rules.
func WrapTLS(ctx context.Context, conn net.Conn, addr string, tlsConfig *tls.Config) (*DialedHost, error) {
if tlsConfig != nil {
tlsConfig := tlsConfigForAddr(tlsConfig, addr)
tconn := tls.Client(conn, tlsConfig)
if err := tconn.HandshakeContext(ctx); err != nil {
conn.Close()
return nil, err
}
conn = tconn
}

return &DialedHost{
Conn: conn,
DisableCoalesce: tlsConfig != nil, // write coalescing can't use writev when the connection is wrapped.
}, nil
}

0 comments on commit 7a6cf00

Please sign in to comment.