Skip to content

Commit

Permalink
simplify the prepared cache access
Browse files Browse the repository at this point in the history
  • Loading branch information
Zariel committed Jan 31, 2016
1 parent ddf78cc commit 45c7cec
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 96 deletions.
79 changes: 39 additions & 40 deletions cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -991,12 +991,14 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string
)`); err != nil {
t.Fatal("create:", err)
}

stmt := "INSERT INTO " + table + " (foo, bar) VALUES (?, 7)"
_, conn := session.pool.Pick(nil)

flight := new(inflightPrepare)
session.stmtsLRU.Lock()
session.stmtsLRU.lru.Add(conn.addr+stmt, flight)
session.stmtsLRU.Unlock()
key := session.stmtsLRU.keyFor(conn.addr, "", stmt)
session.stmtsLRU.add(key, flight)

flight.preparedStatment = &preparedStatment{
id: []byte{'f', 'o', 'o', 'b', 'a', 'r'},
request: preparedMetadata{
Expand All @@ -1016,10 +1018,11 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string
},
},
}

return stmt, conn
}

func TestMissingSchemaPrepare(t *testing.T) {
func TestPrepare_MissingSchemaPrepare(t *testing.T) {
s := createSession(t)
_, conn := s.pool.Pick(nil)
defer s.Close()
Expand All @@ -1041,7 +1044,7 @@ func TestMissingSchemaPrepare(t *testing.T) {
}
}

func TestReprepareStatement(t *testing.T) {
func TestPrepare_ReprepareStatement(t *testing.T) {
session := createSession(t)
defer session.Close()
stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement")
Expand All @@ -1051,7 +1054,7 @@ func TestReprepareStatement(t *testing.T) {
}
}

func TestReprepareBatch(t *testing.T) {
func TestPrepare_ReprepareBatch(t *testing.T) {
if *flagProto == 1 {
t.Skip("atomic batches not supported. Please use Cassandra >= 2.0")
}
Expand Down Expand Up @@ -1089,8 +1092,9 @@ func TestQueryInfo(t *testing.T) {
}

//TestPreparedCacheEviction will make sure that the cache size is maintained
func TestPreparedCacheEviction(t *testing.T) {
func TestPrepare_PreparedCacheEviction(t *testing.T) {
const maxPrepared = 4

cluster := createCluster()
cluster.MaxPreparedStmts = maxPrepared
cluster.Events.DisableSchemaEvents = true
Expand All @@ -1101,6 +1105,9 @@ func TestPreparedCacheEviction(t *testing.T) {
if err := createTable(session, "CREATE TABLE gocql_test.prepcachetest (id int,mod int,PRIMARY KEY (id))"); err != nil {
t.Fatalf("failed to create table with error '%v'", err)
}
// clear the cache
session.stmtsLRU.clear()

//Fill the table
for i := 0; i < 2; i++ {
if err := session.Query("INSERT INTO prepcachetest (id,mod) VALUES (?, ?)", i, 10000%(i+1)).Exec(); err != nil {
Expand Down Expand Up @@ -1134,52 +1141,44 @@ func TestPreparedCacheEviction(t *testing.T) {
t.Fatalf("insert into prepcachetest failed, error '%v'", err)
}

session.stmtsLRU.Lock()
session.stmtsLRU.mu.Lock()
defer session.stmtsLRU.mu.Unlock()

//Make sure the cache size is maintained
if session.stmtsLRU.lru.Len() != session.stmtsLRU.lru.MaxEntries {
t.Fatalf("expected cache size of %v, got %v", session.stmtsLRU.lru.MaxEntries, session.stmtsLRU.lru.Len())
}

//Walk through all the configured hosts and test cache retention and eviction
var selFound, insFound, updFound, delFound, selEvict bool
for i := range session.cfg.Hosts {
_, ok := session.stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testSELECT id,mod FROM prepcachetest WHERE id = 1")
selFound = selFound || ok

_, ok = session.stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testINSERT INTO prepcachetest (id,mod) VALUES (?, ?)")
insFound = insFound || ok

_, ok = session.stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testUPDATE prepcachetest SET mod = ? WHERE id = ?")
updFound = updFound || ok
// Walk through all the configured hosts and test cache retention and eviction
for _, host := range session.cfg.Hosts {
_, ok := session.stmtsLRU.lru.Get(session.stmtsLRU.keyFor(host+":9042", session.cfg.Keyspace, "SELECT id,mod FROM prepcachetest WHERE id = 0"))
if ok {
t.Errorf("expected first select to be purged but was in cache for host=%q", host)
}

_, ok = session.stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testDELETE FROM prepcachetest WHERE id = ?")
delFound = delFound || ok
_, ok = session.stmtsLRU.lru.Get(session.stmtsLRU.keyFor(host+":9042", session.cfg.Keyspace, "SELECT id,mod FROM prepcachetest WHERE id = 1"))
if !ok {
t.Errorf("exepected second select to be in cache for host=%q", host)
}

_, ok = session.stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testSELECT id,mod FROM prepcachetest WHERE id = 0")
selEvict = selEvict || !ok
}
_, ok = session.stmtsLRU.lru.Get(session.stmtsLRU.keyFor(host+":9042", session.cfg.Keyspace, "INSERT INTO prepcachetest (id,mod) VALUES (?, ?)"))
if !ok {
t.Errorf("expected insert to be in cache for host=%q", host)
}

session.stmtsLRU.Unlock()
_, ok = session.stmtsLRU.lru.Get(session.stmtsLRU.keyFor(host+":9042", session.cfg.Keyspace, "UPDATE prepcachetest SET mod = ? WHERE id = ?"))
if !ok {
t.Errorf("expected update to be in cached for host=%q", host)
}

if !selEvict {
t.Fatalf("expected first select statement to be purged, but statement was found in the cache.")
}
if !selFound {
t.Fatalf("expected second select statement to be cached, but statement was purged or not prepared.")
}
if !insFound {
t.Fatalf("expected insert statement to be cached, but statement was purged or not prepared.")
}
if !updFound {
t.Fatalf("expected update statement to be cached, but statement was purged or not prepared.")
}
if !delFound {
t.Error("expected delete statement to be cached, but statement was purged or not prepared.")
_, ok = session.stmtsLRU.lru.Get(session.stmtsLRU.keyFor(host+":9042", session.cfg.Keyspace, "DELETE FROM prepcachetest WHERE id = ?"))
if !ok {
t.Errorf("expected delete to be cached for host=%q", host)
}
}
}

func TestPreparedCacheKey(t *testing.T) {
func TestPrepare_PreparedCacheKey(t *testing.T) {
session := createSession(t)
defer session.Close()

Expand Down
32 changes: 0 additions & 32 deletions cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,9 @@ package gocql

import (
"errors"
"sync"
"time"

"github.com/gocql/gocql/internal/lru"
)

const defaultMaxPreparedStmts = 1000

//preparedLRU is the prepared statement cache
type preparedLRU struct {
sync.Mutex
lru *lru.Cache
}

//Max adjusts the maximum size of the cache and cleans up the oldest records if
//the new max is lower than the previous value. Not concurrency safe.
func (p *preparedLRU) max(max int) {
p.Lock()
defer p.Unlock()

for p.lru.Len() > max {
p.lru.RemoveOldest()
}
p.lru.MaxEntries = max
}

func (p *preparedLRU) clear() {
p.Lock()
defer p.Unlock()

for p.lru.Len() > 0 {
p.lru.RemoveOldest()
}
}

// PoolConfig configures the connection pool used by the driver, it defaults to
// using a round robbin host selection policy and a round robbin connection selection
// policy for each host.
Expand Down
38 changes: 16 additions & 22 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"crypto/tls"
"errors"
"fmt"
"github.com/gocql/gocql/internal/lru"
"io"
"io/ioutil"
"log"
Expand Down Expand Up @@ -593,20 +594,19 @@ type inflightPrepare struct {
}

func (c *Conn) prepareStatement(stmt string, tracer Tracer) (*preparedStatment, error) {
c.session.stmtsLRU.Lock()
stmtCacheKey := c.addr + c.currentKeyspace + stmt
if val, ok := c.session.stmtsLRU.lru.Get(stmtCacheKey); ok {
c.session.stmtsLRU.Unlock()
flight := val.(*inflightPrepare)
stmtCacheKey := c.session.stmtsLRU.keyFor(c.addr, c.currentKeyspace, stmt)
flight, ok := c.session.stmtsLRU.execIfMissing(stmtCacheKey, func(lru *lru.Cache) *inflightPrepare {
flight := new(inflightPrepare)
flight.wg.Add(1)
lru.Add(stmtCacheKey, flight)
return flight
})

if ok {
flight.wg.Wait()
return flight.preparedStatment, flight.err
}

flight := new(inflightPrepare)
flight.wg.Add(1)
c.session.stmtsLRU.lru.Add(stmtCacheKey, flight)
c.session.stmtsLRU.Unlock()

prep := &writePrepareFrame{
statement: stmt,
}
Expand Down Expand Up @@ -650,9 +650,7 @@ func (c *Conn) prepareStatement(stmt string, tracer Tracer) (*preparedStatment,
flight.wg.Done()

if flight.err != nil {
c.session.stmtsLRU.Lock()
c.session.stmtsLRU.lru.Remove(stmtCacheKey)
c.session.stmtsLRU.Unlock()
c.session.stmtsLRU.remove(stmtCacheKey)
}

framerPool.Put(framer)
Expand Down Expand Up @@ -799,14 +797,11 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
// is not consistent with regards to its schema.
return iter
case *RequestErrUnprepared:
c.session.stmtsLRU.Lock()
stmtCacheKey := c.addr + c.currentKeyspace + qry.stmt
if _, ok := c.session.stmtsLRU.lru.Get(stmtCacheKey); ok {
c.session.stmtsLRU.lru.Remove(stmtCacheKey)
c.session.stmtsLRU.Unlock()
stmtCacheKey := c.session.stmtsLRU.keyFor(c.addr, c.currentKeyspace, qry.stmt)
if c.session.stmtsLRU.remove(stmtCacheKey) {
return c.executeQuery(qry)
}
c.session.stmtsLRU.Unlock()

return &Iter{err: x, framer: framer}
case error:
return &Iter{err: x, framer: framer}
Expand Down Expand Up @@ -945,9 +940,8 @@ func (c *Conn) executeBatch(batch *Batch) *Iter {
case *RequestErrUnprepared:
stmt, found := stmts[string(x.StatementId)]
if found {
c.session.stmtsLRU.Lock()
c.session.stmtsLRU.lru.Remove(c.addr + c.currentKeyspace + stmt)
c.session.stmtsLRU.Unlock()
key := c.session.stmtsLRU.keyFor(c.addr, c.currentKeyspace, stmt)
c.session.stmtsLRU.remove(key)
}

framerPool.Put(framer)
Expand Down
8 changes: 6 additions & 2 deletions internal/lru/lru.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,17 @@ func (c *Cache) Get(key string) (value interface{}, ok bool) {
}

// Remove removes the provided key from the cache.
func (c *Cache) Remove(key string) {
func (c *Cache) Remove(key string) bool {
if c.cache == nil {
return
return false
}

if ele, hit := c.cache[key]; hit {
c.removeElement(ele)
return true
}

return false
}

// RemoveOldest removes the oldest item from the cache.
Expand Down
64 changes: 64 additions & 0 deletions prepared_cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package gocql

import (
"github.com/gocql/gocql/internal/lru"
"sync"
)

const defaultMaxPreparedStmts = 1000

// preparedLRU is the prepared statement cache
type preparedLRU struct {
mu sync.Mutex
lru *lru.Cache
}

// Max adjusts the maximum size of the cache and cleans up the oldest records if
// the new max is lower than the previous value. Not concurrency safe.
func (p *preparedLRU) max(max int) {
p.mu.Lock()
defer p.mu.Unlock()

for p.lru.Len() > max {
p.lru.RemoveOldest()
}
p.lru.MaxEntries = max
}

func (p *preparedLRU) clear() {
p.mu.Lock()
defer p.mu.Unlock()

for p.lru.Len() > 0 {
p.lru.RemoveOldest()
}
}

func (p *preparedLRU) add(key string, val *inflightPrepare) {
p.mu.Lock()
defer p.mu.Unlock()
p.lru.Add(key, val)
}

func (p *preparedLRU) remove(key string) bool {
p.mu.Lock()
defer p.mu.Unlock()
return p.lru.Remove(key)
}

func (p *preparedLRU) execIfMissing(key string, fn func(lru *lru.Cache) *inflightPrepare) (*inflightPrepare, bool) {
p.mu.Lock()
defer p.mu.Unlock()

val, ok := p.lru.Get(key)
if ok {
return val.(*inflightPrepare), true
}

return fn(p.lru), false
}

func (p *preparedLRU) keyFor(addr, keyspace, statement string) string {
// TODO: maybe use []byte for keys?
return addr + keyspace + statement
}

0 comments on commit 45c7cec

Please sign in to comment.