Skip to content

Commit

Permalink
Merge pull request apache#417 from Zariel/timeout-queries
Browse files Browse the repository at this point in the history
apply a timeout for every query at the connection
  • Loading branch information
0x6e6562 committed Jun 16, 2015
2 parents 677750e + b21c5b0 commit 5792e6b
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 30 deletions.
42 changes: 39 additions & 3 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)

Expand Down Expand Up @@ -84,6 +85,12 @@ type ConnErrorHandler interface {
HandleError(conn *Conn, err error, closed bool)
}

// How many timeouts we will allow to occur before the connection is closed
// and restarted. This is to prevent a single query timeout from killing a connection
// which may be serving more queries just fine.
// Default is 10, should not be changed concurrently with queries.
var TimeoutLimit int64 = 10

// Conn is a single connection to a Cassandra node. It can be used to execute
// queries, but users are usually advised to use a more reliable, higher
// level API.
Expand All @@ -107,6 +114,8 @@ type Conn struct {

closedMu sync.RWMutex
isClosed bool

timeouts int64
}

// Connect establishes a connection to a Cassandra node.
Expand Down Expand Up @@ -298,7 +307,16 @@ func (c *Conn) serve() {
}
}

c.closeWithError(err)
}

func (c *Conn) closeWithError(err error) {
if c.Closed() {
return
}

c.Close()

for id := 0; id < len(c.calls); id++ {
req := &c.calls[id]
// we need to send the error to all waiting queries, put the state
Expand Down Expand Up @@ -339,7 +357,11 @@ func (c *Conn) recv() error {

// once we get to here we know that the caller must be waiting and that there
// is no error.
call.resp <- nil
select {
case call.resp <- nil:
default:
// in case the caller timedout
}

return nil
}
Expand All @@ -357,6 +379,12 @@ func (c *Conn) releaseStream(stream int) {
}
}

func (c *Conn) handleTimeout() {
if atomic.AddInt64(&c.timeouts, 1) > TimeoutLimit {
c.closeWithError(ErrTooManyTimeouts)
}
}

func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
// TODO: move tracer onto conn
stream := <-c.uniq
Expand All @@ -376,7 +404,13 @@ func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
return nil, err
}

err = <-call.resp
select {
case err = <-call.resp:
case <-time.After(c.timeout):
c.handleTimeout()
return nil, ErrTimeoutNoResponse
}

if err != nil {
return nil, err
}
Expand Down Expand Up @@ -730,5 +764,7 @@ type inflightPrepare struct {
}

var (
ErrQueryArgLength = errors.New("query argument length mismatch")
ErrQueryArgLength = errors.New("query argument length mismatch")
ErrTimeoutNoResponse = errors.New("gocql: no response recieved from cassandra within timeout period")
ErrTooManyTimeouts = errors.New("gocql: too many query timeouts on the connection")
)
72 changes: 45 additions & 27 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,21 +182,6 @@ func TestQueryRetry(t *testing.T) {
}
}

func TestSlowQuery(t *testing.T) {
srv := NewTestServer(t, defaultProto)
defer srv.Stop()

db, err := newTestSession(srv.Address, defaultProto)
if err != nil {
t.Errorf("NewCluster: %v", err)
return
}

if err := db.Query("slow").Exec(); err != nil {
t.Fatal(err)
}
}

func TestSimplePoolRoundRobin(t *testing.T) {
servers := make([]*TestServer, 5)
addrs := make([]string, len(servers))
Expand Down Expand Up @@ -486,6 +471,43 @@ func TestPolicyConnPoolSSL(t *testing.T) {
}
}

func TestQueryTimeout(t *testing.T) {
srv := NewTestServer(t, protoVersion2)
defer srv.Stop()

cluster := NewCluster(srv.Address)
// Set the timeout arbitrarily low so that the query hits the timeout in a
// timely manner.
cluster.Timeout = 1 * time.Millisecond

db, err := cluster.CreateSession()
if err != nil {
t.Errorf("NewCluster: %v", err)
}
defer db.Close()

ch := make(chan error, 1)

go func() {
err := db.Query("timeout").Exec()
if err != nil {
ch <- err
return
}
t.Errorf("err was nil, expected to get a timeout after %v", db.cfg.Timeout)
}()

select {
case err := <-ch:
if err != ErrTimeoutNoResponse {
t.Fatalf("expected to get %v for timeout got %v", ErrTimeoutNoResponse, err)
}
case <-time.After(10*time.Millisecond + db.cfg.Timeout):
// ensure that the query goroutines have been scheduled
t.Fatalf("query did not timeout after %v", db.cfg.Timeout)
}
}

func NewTestServer(t testing.TB, protocol uint8) *TestServer {
laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
if err != nil {
Expand All @@ -508,6 +530,7 @@ func NewTestServer(t testing.TB, protocol uint8) *TestServer {
t: t,
protocol: protocol,
headerSize: headerSize,
quit: make(chan struct{}),
}

go srv.serve()
Expand Down Expand Up @@ -545,6 +568,7 @@ func NewSSLTestServer(t testing.TB, protocol uint8) *TestServer {
t: t,
protocol: protocol,
headerSize: headerSize,
quit: make(chan struct{}),
}
go srv.serve()
return srv
Expand All @@ -560,6 +584,8 @@ type TestServer struct {

protocol byte
headerSize int

quit chan struct{}
}

func (srv *TestServer) serve() {
Expand Down Expand Up @@ -592,6 +618,7 @@ func (srv *TestServer) serve() {

func (srv *TestServer) Stop() {
srv.listen.Close()
close(srv.quit)
}

func (srv *TestServer) process(f *framer) {
Expand Down Expand Up @@ -619,24 +646,15 @@ func (srv *TestServer) process(f *framer) {
f.writeHeader(0, opError, head.stream)
f.writeInt(0x1001)
f.writeString("query killed")
case "slow":
go func() {
<-time.After(1 * time.Second)
f.writeHeader(0, opResult, head.stream)
f.wbuf[0] = srv.protocol | 0x80
f.writeInt(resultKindVoid)
if err := f.finishWrite(); err != nil {
srv.t.Error(err)
}
}()

return
case "use":
f.writeInt(resultKindKeyspace)
f.writeString(strings.TrimSpace(query[3:]))
case "void":
f.writeHeader(0, opResult, head.stream)
f.writeInt(resultKindVoid)
case "timeout":
<-srv.quit
return
default:
f.writeHeader(0, opResult, head.stream)
f.writeInt(resultKindVoid)
Expand Down

0 comments on commit 5792e6b

Please sign in to comment.