Skip to content

Commit

Permalink
Fix a panic establishing sessions using NewSession
Browse files Browse the repository at this point in the history
Previously, the only way to establish the *first* connection was to use ClusterConfig.CreateSession(). This is due to the global prepared statement cache only being initialised here.
  • Loading branch information
obeattie committed Feb 6, 2015
1 parent f9e6755 commit 241b2b9
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 22 deletions.
8 changes: 4 additions & 4 deletions cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -967,9 +967,9 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string
stmt := "INSERT INTO " + table + " (foo, bar) VALUES (?, 7)"
conn := session.Pool.Pick(nil)
flight := new(inflightPrepare)
stmtsLRU.mu.Lock()
stmtsLRU.Lock()
stmtsLRU.lru.Add(conn.addr+stmt, flight)
stmtsLRU.mu.Unlock()
stmtsLRU.Unlock()
flight.info = &QueryInfo{
Id: []byte{'f', 'o', 'o', 'b', 'a', 'r'},
Args: []ColumnInfo{ColumnInfo{
Expand Down Expand Up @@ -1057,9 +1057,9 @@ func TestQueryInfo(t *testing.T) {
func TestPreparedCacheEviction(t *testing.T) {
session := createSession(t)
defer session.Close()
stmtsLRU.mu.Lock()
stmtsLRU.Lock()
stmtsLRU.Max(4)
stmtsLRU.mu.Unlock()
stmtsLRU.Unlock()

if err := createTable(session, "CREATE TABLE prepcachetest (id int,mod int,PRIMARY KEY (id))"); err != nil {
t.Fatalf("failed to create table with error '%v'", err)
Expand Down
20 changes: 12 additions & 8 deletions cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ var stmtsLRU preparedLRU

//preparedLRU is the prepared statement cache
type preparedLRU struct {
sync.Mutex
lru *lru.Cache
mu sync.Mutex
}

//Max adjusts the maximum size of the cache and cleans up the oldest records if
Expand All @@ -30,6 +30,14 @@ func (p *preparedLRU) Max(max int) {
p.lru.MaxEntries = max
}

func initStmtsLRU(max int) {
if stmtsLRU.lru != nil {
stmtsLRU.Max(max)
} else {
stmtsLRU.lru = lru.New(max)
}
}

// To enable periodic node discovery enable DiscoverHosts in ClusterConfig
type DiscoveryConfig struct {
// If not empty will filter all discoverred hosts to a single Data Centre (default: "")
Expand Down Expand Up @@ -94,13 +102,9 @@ func (cfg *ClusterConfig) CreateSession() (*Session, error) {
pool := cfg.ConnPoolType(cfg)

//Adjust the size of the prepared statements cache to match the latest configuration
stmtsLRU.mu.Lock()
if stmtsLRU.lru != nil {
stmtsLRU.Max(cfg.MaxPreparedStmts)
} else {
stmtsLRU.lru = lru.New(cfg.MaxPreparedStmts)
}
stmtsLRU.mu.Unlock()
stmtsLRU.Lock()
initStmtsLRU(cfg.MaxPreparedStmts)
stmtsLRU.Unlock()

//See if there are any connections in the pool
if pool.Size() > 0 {
Expand Down
23 changes: 13 additions & 10 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,21 +364,24 @@ func (c *Conn) ping() error {
}

func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) {
stmtsLRU.mu.Lock()
stmtsLRU.Lock()
if stmtsLRU.lru == nil {
initStmtsLRU(1000)
}

stmtCacheKey := c.addr + c.currentKeyspace + stmt

if val, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
flight := val.(*inflightPrepare)
stmtsLRU.mu.Unlock()
stmtsLRU.Unlock()
flight.wg.Wait()
return flight.info, flight.err
}

flight := new(inflightPrepare)
flight.wg.Add(1)
stmtsLRU.lru.Add(stmtCacheKey, flight)
stmtsLRU.mu.Unlock()
stmtsLRU.Unlock()

resp, err := c.exec(&prepareFrame{Stmt: stmt}, trace)
if err != nil {
Expand All @@ -402,9 +405,9 @@ func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) {
flight.wg.Done()

if err != nil {
stmtsLRU.mu.Lock()
stmtsLRU.Lock()
stmtsLRU.lru.Remove(stmtCacheKey)
stmtsLRU.mu.Unlock()
stmtsLRU.Unlock()
}

return flight.info, flight.err
Expand Down Expand Up @@ -471,14 +474,14 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
case resultKeyspaceFrame:
return &Iter{}
case RequestErrUnprepared:
stmtsLRU.mu.Lock()
stmtsLRU.Lock()
stmtCacheKey := c.addr + c.currentKeyspace + qry.stmt
if _, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
stmtsLRU.lru.Remove(stmtCacheKey)
stmtsLRU.mu.Unlock()
stmtsLRU.Unlock()
return c.executeQuery(qry)
}
stmtsLRU.mu.Unlock()
stmtsLRU.Unlock()
return &Iter{err: x}
case error:
return &Iter{err: x}
Expand Down Expand Up @@ -602,9 +605,9 @@ func (c *Conn) executeBatch(batch *Batch) error {
case RequestErrUnprepared:
stmt, found := stmts[string(x.StatementId)]
if found {
stmtsLRU.mu.Lock()
stmtsLRU.Lock()
stmtsLRU.lru.Remove(c.addr + c.currentKeyspace + stmt)
stmtsLRU.mu.Unlock()
stmtsLRU.Unlock()
}
if found {
return c.executeBatch(batch)
Expand Down

0 comments on commit 241b2b9

Please sign in to comment.