Skip to content

Commit

Permalink
conn: only setup the TLS config once
Browse files Browse the repository at this point in the history
we can reuse the TLS Config between connections, so avoid the
expensive read and cert parsing when dialing connections, instead
do this in the connection pool once during startup.
  • Loading branch information
Zariel committed Apr 7, 2015
1 parent a4a7de8 commit ef3e59c
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 36 deletions.
5 changes: 4 additions & 1 deletion cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ func (cfg *ClusterConfig) CreateSession() (*Session, error) {
cfg.NumStreams = maxStreams
}

pool := cfg.ConnPoolType(cfg)
pool, err := cfg.ConnPoolType(cfg)
if err != nil {
return nil, err
}

//Adjust the size of the prepared statements cache to match the latest configuration
stmtsLRU.Lock()
Expand Down
33 changes: 5 additions & 28 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@ package gocql
import (
"bufio"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"strconv"
Expand Down Expand Up @@ -81,7 +79,7 @@ type ConnConfig struct {
Compressor Compressor
Authenticator Authenticator
Keepalive time.Duration
SslOpts *SslOptions
TLSConfig *tls.Config
}

// Conn is a single connection to a Cassandra node. It can be used to execute
Expand Down Expand Up @@ -115,31 +113,10 @@ func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
conn net.Conn
)

if cfg.SslOpts != nil {
certPool := x509.NewCertPool()
//ca cert is optional
if cfg.SslOpts.CaPath != "" {
pem, err := ioutil.ReadFile(cfg.SslOpts.CaPath)
if err != nil {
return nil, err
}
if !certPool.AppendCertsFromPEM(pem) {
return nil, errors.New("Failed parsing or appending certs")
}
}

mycert, err := tls.LoadX509KeyPair(cfg.SslOpts.CertPath, cfg.SslOpts.KeyPath)
if err != nil {
return nil, err
}

config := tls.Config{
Certificates: []tls.Certificate{mycert},
RootCAs: certPool,
}

config.InsecureSkipVerify = !cfg.SslOpts.EnableHostVerification
if conn, err = tls.Dial("tcp", addr, &config); err != nil {
if cfg.TLSConfig != nil {
// the TLS config is safe to be reused by connections but it must not
// be modified after being used.
if conn, err = tls.Dial("tcp", addr, cfg.TLSConfig); err != nil {
return nil, err
}
} else if conn, err = net.DialTimeout("tcp", addr, cfg.Timeout); err != nil {
Expand Down
52 changes: 48 additions & 4 deletions connectionpool.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package gocql

import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"log"
"sync"
"time"
Expand Down Expand Up @@ -91,7 +96,7 @@ type ConnectionPool interface {
}

//NewPoolFunc is the type used by ClusterConfig to create a pool of a specific type.
type NewPoolFunc func(*ClusterConfig) ConnectionPool
type NewPoolFunc func(*ClusterConfig) (ConnectionPool, error)

//SimplePool is the current implementation of the connection pool inside gocql. This
//pool is meant to be a simple default used by gocql so users can get up and running
Expand All @@ -115,11 +120,42 @@ type SimplePool struct {
quit bool
quitWait chan bool
quitOnce sync.Once

tlsConfig *tls.Config
}

func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
certPool := x509.NewCertPool()
// ca cert is optional
if sslOpts.CaPath != "" {
pem, err := ioutil.ReadFile(sslOpts.CaPath)
if err != nil {
return nil, fmt.Errorf("connectionpool: unable to open CA certs: %v", err)
}

if !certPool.AppendCertsFromPEM(pem) {
return nil, errors.New("connectionpool: failed parsing or CA certs")
}
}

mycert, err := tls.LoadX509KeyPair(sslOpts.CertPath, sslOpts.KeyPath)
if err != nil {
return nil, fmt.Errorf("connectionpool: unable to load X509 key pair: %v", err)
}

config := &tls.Config{
Certificates: []tls.Certificate{mycert},
RootCAs: certPool,
}

config.InsecureSkipVerify = !sslOpts.EnableHostVerification

return config, nil
}

//NewSimplePool is the function used by gocql to create the simple connection pool.
//This is the default if no other pool type is specified.
func NewSimplePool(cfg *ClusterConfig) ConnectionPool {
func NewSimplePool(cfg *ClusterConfig) (ConnectionPool, error) {
pool := &SimplePool{
cfg: cfg,
hostPool: NewRoundRobin(),
Expand All @@ -137,6 +173,14 @@ func NewSimplePool(cfg *ClusterConfig) ConnectionPool {
pool.hosts[host] = &HostInfo{Peer: host}
}

if cfg.SslOpts != nil {
config, err := setupTLSConfig(cfg.SslOpts)
if err != nil {
return nil, err
}
pool.tlsConfig = config
}

//Walk through connecting to hosts. As soon as one host connects
//defer the remaining connections to cluster.fillPool()
for i := 0; i < len(cfg.Hosts); i++ {
Expand All @@ -149,7 +193,7 @@ func NewSimplePool(cfg *ClusterConfig) ConnectionPool {
}
}

return pool
return pool, nil
}

func (c *SimplePool) connect(addr string) error {
Expand All @@ -162,7 +206,7 @@ func (c *SimplePool) connect(addr string) error {
Compressor: c.cfg.Compressor,
Authenticator: c.cfg.Authenticator,
Keepalive: c.cfg.SocketKeepalive,
SslOpts: c.cfg.SslOpts,
TLSConfig: c.tlsConfig,
}

conn, err := Connect(addr, cfg, c)
Expand Down
12 changes: 9 additions & 3 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ import (
func TestSessionAPI(t *testing.T) {

cfg := ClusterConfig{}
pool := NewSimplePool(&cfg)
pool, err := NewSimplePool(&cfg)
if err != nil {
t.Fatal(err)
}

s := NewSession(pool, cfg)

Expand Down Expand Up @@ -60,7 +63,7 @@ func TestSessionAPI(t *testing.T) {

testBatch := s.NewBatch(LoggedBatch)
testBatch.Query("test")
err := s.ExecuteBatch(testBatch)
err = s.ExecuteBatch(testBatch)

if err != ErrNoConnections {
t.Fatalf("expected session.ExecuteBatch to return '%v', got '%v'", ErrNoConnections, err)
Expand Down Expand Up @@ -151,7 +154,10 @@ func TestBatchBasicAPI(t *testing.T) {

cfg := ClusterConfig{}
cfg.RetryPolicy = &SimpleRetryPolicy{NumRetries: 2}
pool := NewSimplePool(&cfg)
pool, err := NewSimplePool(&cfg)
if err != nil {
t.Fatal(err)
}

s := NewSession(pool, cfg)
b := s.NewBatch(UnloggedBatch)
Expand Down

0 comments on commit ef3e59c

Please sign in to comment.