diff --git a/cluster.go b/cluster.go index fa9eaf85d..cf403dd6c 100644 --- a/cluster.go +++ b/cluster.go @@ -52,6 +52,7 @@ type ClusterConfig struct { ProtoVersion int // Connection timeout (default: 600ms) + // ConnectTimeout is used to set up the default dialer and is ignored if Dialer or HostDialer is provided. Timeout time.Duration // Initial connection timeout, used during initial dial to server (default: 600ms) @@ -98,6 +99,7 @@ type ClusterConfig struct { ReconnectionPolicy ReconnectionPolicy // The keepalive period to use, enabled if > 0 (default: 0) + // SocketKeepalive is used to set up the default dialer and is ignored if Dialer or HostDialer is provided. SocketKeepalive time.Duration // Maximum cache size for prepared statements globally for gocql. @@ -116,6 +118,8 @@ type ClusterConfig struct { // Default: unset SerialConsistency SerialConsistency + // SslOpts configures TLS use when HostDialer is not set. + // SslOpts is ignored if HostDialer is set. SslOpts *SslOptions // Sends a client side timestamp for all requests which overrides the timestamp at which it arrives at the server. @@ -208,8 +212,13 @@ type ClusterConfig struct { // Dialer will be used to establish all connections created for this Cluster. // If not provided, a default dialer configured with ConnectTimeout will be used. + // Dialer is ignored if HostDialer is provided. Dialer Dialer + // HostDialer will be used to establish all connections for this Cluster. + // If not provided, Dialer will be used instead. + HostDialer HostDialer + // Logger for this ClusterConfig. // If not specified, defaults to the global gocql.Logger. Logger StdLogger diff --git a/conn.go b/conn.go index 1c80f3e0c..a7ca2d9b8 100644 --- a/conn.go +++ b/conn.go @@ -123,6 +123,7 @@ type ConnConfig struct { WriteTimeout time.Duration ConnectTimeout time.Duration Dialer Dialer + HostDialer HostDialer Compressor Compressor Authenticator Authenticator AuthProvider func(h *HostInfo) (Authenticator, error) @@ -231,53 +232,10 @@ func (s *Session) dial(ctx context.Context, host *HostInfo, connConfig *ConnConf // // dialWithoutObserver does not notify the connection observer, so you most probably want to call dial() instead. func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) { - ip := host.ConnectAddress() - port := host.port - - // TODO(zariel): remove these - if !validIpAddr(ip) { - panic(fmt.Sprintf("host missing connect ip address: %v", ip)) - } else if port == 0 { - panic(fmt.Sprintf("host missing port: %v", port)) - } - - dialer := cfg.Dialer - if dialer == nil { - d := &net.Dialer{ - Timeout: cfg.ConnectTimeout, - } - if cfg.Keepalive > 0 { - d.KeepAlive = cfg.Keepalive - } - dialer = d - } - - addr := host.HostnameAndPort() - conn, err := dialer.DialContext(ctx, "tcp", addr) + dialedHost, err := cfg.HostDialer.DialHost(ctx, host) if err != nil { return nil, err } - if cfg.tlsConfig != nil { - // the TLS config is safe to be reused by connections but it must not - // be modified after being used. - tlsConfig := cfg.tlsConfig - if !tlsConfig.InsecureSkipVerify && tlsConfig.ServerName == "" { - colonPos := strings.LastIndex(addr, ":") - if colonPos == -1 { - colonPos = len(addr) - } - hostname := addr[:colonPos] - // clone config to avoid modifying the shared one. - tlsConfig = tlsConfig.Clone() - tlsConfig.ServerName = hostname - } - tconn := tls.Client(conn, tlsConfig) - if err := tconn.Handshake(); err != nil { - conn.Close() - return nil, err - } - conn = tconn - } writeTimeout := cfg.Timeout if cfg.WriteTimeout > 0 { @@ -286,12 +244,12 @@ func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg * ctx, cancel := context.WithCancel(ctx) c := &Conn{ - conn: conn, - r: bufio.NewReader(conn), + conn: dialedHost.Conn, + r: bufio.NewReader(dialedHost.Conn), cfg: cfg, calls: make(map[int]*callReq), version: uint8(cfg.ProtoVersion), - addr: conn.RemoteAddr().String(), + addr: dialedHost.Conn.RemoteAddr().String(), errorHandler: errorHandler, compressor: cfg.Compressor, session: s, @@ -299,7 +257,7 @@ func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg * host: host, frameObserver: s.frameObserver, w: &deadlineContextWriter{ - w: conn, + w: dialedHost.Conn, timeout: writeTimeout, semaphore: make(chan struct{}, 1), quit: make(chan struct{}), @@ -311,7 +269,7 @@ func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg * writeTimeout: writeTimeout, } - if err := c.init(ctx); err != nil { + if err := c.init(ctx, dialedHost); err != nil { cancel() c.Close() return nil, err @@ -320,7 +278,7 @@ func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg * return c, nil } -func (c *Conn) init(ctx context.Context) error { +func (c *Conn) init(ctx context.Context, dialedHost *DialedHost) error { if c.session.cfg.AuthProvider != nil { var err error c.auth, err = c.cfg.AuthProvider(c.host) @@ -344,7 +302,7 @@ func (c *Conn) init(ctx context.Context) error { c.timeout = c.cfg.Timeout // dont coalesce startup frames - if c.session.cfg.WriteCoalesceWaitTime > 0 && !c.cfg.disableCoalesce { + if c.session.cfg.WriteCoalesceWaitTime > 0 && !c.cfg.disableCoalesce && !dialedHost.DisableCoalesce { c.w = newWriteCoalescer(c.conn, c.writeTimeout, c.session.cfg.WriteCoalesceWaitTime, ctx.Done()) } diff --git a/connectionpool.go b/connectionpool.go index 24f1c7cdc..ca52c2956 100644 --- a/connectionpool.go +++ b/connectionpool.go @@ -89,32 +89,52 @@ type policyConnPool struct { func connConfig(cfg *ClusterConfig) (*ConnConfig, error) { var ( - err error - tlsConfig *tls.Config + err error + hostDialer HostDialer ) - // TODO(zariel): move tls config setup into session init. - if cfg.SslOpts != nil { - tlsConfig, err = setupTLSConfig(cfg.SslOpts) - if err != nil { - return nil, err + hostDialer = cfg.HostDialer + if hostDialer == nil { + var tlsConfig *tls.Config + + // TODO(zariel): move tls config setup into session init. + if cfg.SslOpts != nil { + tlsConfig, err = setupTLSConfig(cfg.SslOpts) + if err != nil { + return nil, err + } + } + + dialer := cfg.Dialer + if dialer == nil { + d := &net.Dialer{ + Timeout: cfg.ConnectTimeout, + } + if cfg.SocketKeepalive > 0 { + d.KeepAlive = cfg.SocketKeepalive + } + dialer = d + } + + hostDialer = &defaultHostDialer{ + dialer: dialer, + tlsConfig: tlsConfig, } } return &ConnConfig{ - ProtoVersion: cfg.ProtoVersion, - CQLVersion: cfg.CQLVersion, - Timeout: cfg.Timeout, - WriteTimeout: cfg.WriteTimeout, - ConnectTimeout: cfg.ConnectTimeout, - Dialer: cfg.Dialer, - Compressor: cfg.Compressor, - Authenticator: cfg.Authenticator, - AuthProvider: cfg.AuthProvider, - Keepalive: cfg.SocketKeepalive, - Logger: cfg.logger(), - tlsConfig: tlsConfig, - disableCoalesce: tlsConfig != nil, // write coalescing doesn't work with framing on top of TCP like in TLS. + ProtoVersion: cfg.ProtoVersion, + CQLVersion: cfg.CQLVersion, + Timeout: cfg.Timeout, + WriteTimeout: cfg.WriteTimeout, + ConnectTimeout: cfg.ConnectTimeout, + Dialer: cfg.Dialer, + HostDialer: hostDialer, + Compressor: cfg.Compressor, + Authenticator: cfg.Authenticator, + AuthProvider: cfg.AuthProvider, + Keepalive: cfg.SocketKeepalive, + Logger: cfg.logger(), }, nil } diff --git a/dial.go b/dial.go new file mode 100644 index 000000000..71c0611bc --- /dev/null +++ b/dial.go @@ -0,0 +1,90 @@ +package gocql + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "strings" +) + +// HostDialer allows customizing connection to cluster nodes. +type HostDialer interface { + // DialHost establishes a connection to the host. + // The returned connection must be directly usable for CQL protocol, + // specifically DialHost is responsible also for setting up the TLS session if needed. + // DialHost should disable write coalescing if the returned net.Conn does not support writev. + // As of Go 1.18, only plain TCP connections support writev, TLS sessions should disable coalescing. + // You can use WrapTLS helper function if you don't need to override the TLS setup. + DialHost(ctx context.Context, host *HostInfo) (*DialedHost, error) +} + +// DialedHost contains information about established connection to a host. +type DialedHost struct { + // Conn used to communicate with the server. + Conn net.Conn + + // DisableCoalesce disables write coalescing for the Conn. + // If true, the effect is the same as if WriteCoalesceWaitTime was configured to 0. + DisableCoalesce bool +} + +// defaultHostDialer dials host in a default way. +type defaultHostDialer struct { + dialer Dialer + tlsConfig *tls.Config +} + +func (hd *defaultHostDialer) DialHost(ctx context.Context, host *HostInfo) (*DialedHost, error) { + ip := host.ConnectAddress() + port := host.Port() + + if !validIpAddr(ip) { + return nil, fmt.Errorf("host missing connect ip address: %v", ip) + } else if port == 0 { + return nil, fmt.Errorf("host missing port: %v", port) + } + + addr := host.HostnameAndPort() + conn, err := hd.dialer.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, err + } + return WrapTLS(ctx, conn, addr, hd.tlsConfig) +} + +func tlsConfigForAddr(tlsConfig *tls.Config, addr string) *tls.Config { + // the TLS config is safe to be reused by connections but it must not + // be modified after being used. + if !tlsConfig.InsecureSkipVerify && tlsConfig.ServerName == "" { + colonPos := strings.LastIndex(addr, ":") + if colonPos == -1 { + colonPos = len(addr) + } + hostname := addr[:colonPos] + // clone config to avoid modifying the shared one. + tlsConfig = tlsConfig.Clone() + tlsConfig.ServerName = hostname + } + return tlsConfig +} + +// WrapTLS optionally wraps a net.Conn connected to addr with the given tlsConfig. +// If the tlsConfig is nil, conn is not wrapped into a TLS session, so is insecure. +// If the tlsConfig does not have server name set, it is updated based on the default gocql rules. +func WrapTLS(ctx context.Context, conn net.Conn, addr string, tlsConfig *tls.Config) (*DialedHost, error) { + if tlsConfig != nil { + tlsConfig := tlsConfigForAddr(tlsConfig, addr) + tconn := tls.Client(conn, tlsConfig) + if err := tconn.HandshakeContext(ctx); err != nil { + conn.Close() + return nil, err + } + conn = tconn + } + + return &DialedHost{ + Conn: conn, + DisableCoalesce: tlsConfig != nil, // write coalescing can't use writev when the connection is wrapped. + }, nil +}