Skip to content

Commit

Permalink
query: add context support (apache#710)
Browse files Browse the repository at this point in the history
Add WithContext on Batch and Query to supply a context which will cause
queries to abort when the context is cancelled/timeouted or done.
  • Loading branch information
Zariel committed Apr 16, 2016
1 parent 516d6d2 commit 437c5ce
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 24 deletions.
12 changes: 6 additions & 6 deletions cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package gocql

import (
"bytes"
"golang.org/x/net/context"
"io"
"math"
"math/big"
Expand Down Expand Up @@ -1118,7 +1119,7 @@ func TestQueryInfo(t *testing.T) {
defer session.Close()

conn := getRandomConn(t, session)
info, err := conn.prepareStatement("SELECT release_version, host_id FROM system.local WHERE key = ?", nil)
info, err := conn.prepareStatement(context.Background(), "SELECT release_version, host_id FROM system.local WHERE key = ?", nil)

if err != nil {
t.Fatalf("Failed to execute query for preparing statement: %v", err)
Expand Down Expand Up @@ -1833,7 +1834,7 @@ func TestRoutingKey(t *testing.T) {
t.Fatalf("failed to create table with error '%v'", err)
}

routingKeyInfo, err := session.routingKeyInfo("SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?")
routingKeyInfo, err := session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?")
if err != nil {
t.Fatalf("failed to get routing key info due to error: %v", err)
}
Expand All @@ -1857,7 +1858,7 @@ func TestRoutingKey(t *testing.T) {
}

// verify the cache is working
routingKeyInfo, err = session.routingKeyInfo("SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?")
routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?")
if err != nil {
t.Fatalf("failed to get routing key info due to error: %v", err)
}
Expand Down Expand Up @@ -1891,7 +1892,7 @@ func TestRoutingKey(t *testing.T) {
t.Errorf("Expected routing key %v but was %v", expectedRoutingKey, routingKey)
}

routingKeyInfo, err = session.routingKeyInfo("SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?")
routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?")
if err != nil {
t.Fatalf("failed to get routing key info due to error: %v", err)
}
Expand Down Expand Up @@ -1999,7 +2000,7 @@ func TestNegativeStream(t *testing.T) {
return f.finishWrite()
})

frame, err := conn.exec(writer, nil)
frame, err := conn.exec(context.Background(), writer, nil)
if err == nil {
t.Fatalf("expected to get an error on stream %d", stream)
} else if frame != nil {
Expand Down Expand Up @@ -2411,5 +2412,4 @@ func TestSchemaReset(t *testing.T) {
if val != expVal {
t.Errorf("expected to get val=%q got=%q", expVal, val)
}

}
29 changes: 19 additions & 10 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"crypto/tls"
"errors"
"fmt"
"golang.org/x/net/context"
"io"
"io/ioutil"
"log"
Expand Down Expand Up @@ -284,7 +285,7 @@ func (c *Conn) startup(frameTicker chan struct{}) error {
}

frameTicker <- struct{}{}
framer, err := c.exec(&writeStartupFrame{opts: m}, nil)
framer, err := c.exec(context.Background(), &writeStartupFrame{opts: m}, nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -320,7 +321,7 @@ func (c *Conn) authenticateHandshake(authFrame *authenticateFrame, frameTicker c

for {
frameTicker <- struct{}{}
framer, err := c.exec(req, nil)
framer, err := c.exec(context.Background(), req, nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -531,7 +532,7 @@ type callReq struct {
timer *time.Timer
}

func (c *Conn) exec(req frameWriter, tracer Tracer) (*framer, error) {
func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*framer, error) {
// TODO: move tracer onto conn
stream, ok := c.streams.GetStream()
if !ok {
Expand Down Expand Up @@ -593,6 +594,11 @@ func (c *Conn) exec(req frameWriter, tracer Tracer) (*framer, error) {
timeoutCh = call.timer.C
}

var ctxDone <-chan struct{}
if ctx != nil {
ctxDone = ctx.Done()
}

select {
case err := <-call.resp:
close(call.timeout)
Expand All @@ -610,6 +616,9 @@ func (c *Conn) exec(req frameWriter, tracer Tracer) (*framer, error) {
close(call.timeout)
c.handleTimeout()
return nil, ErrTimeoutNoResponse
case <-ctxDone:
close(call.timeout)
return nil, ctx.Err()
case <-c.quit:
return nil, ErrConnectionClosed
}
Expand Down Expand Up @@ -642,7 +651,7 @@ type inflightPrepare struct {
preparedStatment *preparedStatment
}

func (c *Conn) prepareStatement(stmt string, tracer Tracer) (*preparedStatment, error) {
func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) (*preparedStatment, error) {
stmtCacheKey := c.session.stmtsLRU.keyFor(c.addr, c.currentKeyspace, stmt)
flight, ok := c.session.stmtsLRU.execIfMissing(stmtCacheKey, func(lru *lru.Cache) *inflightPrepare {
flight := new(inflightPrepare)
Expand All @@ -660,7 +669,7 @@ func (c *Conn) prepareStatement(stmt string, tracer Tracer) (*preparedStatment,
statement: stmt,
}

framer, err := c.exec(prep, tracer)
framer, err := c.exec(ctx, prep, tracer)
if err != nil {
flight.err = err
flight.wg.Done()
Expand Down Expand Up @@ -732,7 +741,7 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
if qry.shouldPrepare() {
// Prepare all DML queries. Other queries can not be prepared.
var err error
info, err = c.prepareStatement(qry.stmt, qry.trace)
info, err = c.prepareStatement(qry.context, qry.stmt, qry.trace)
if err != nil {
return &Iter{err: err}
}
Expand Down Expand Up @@ -783,7 +792,7 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
}
}

framer, err := c.exec(frame, qry.trace)
framer, err := c.exec(qry.context, frame, qry.trace)
if err != nil {
return &Iter{err: err}
}
Expand Down Expand Up @@ -883,7 +892,7 @@ func (c *Conn) UseKeyspace(keyspace string) error {
q := &writeQueryFrame{statement: `USE "` + keyspace + `"`}
q.params.consistency = Any

framer, err := c.exec(q, nil)
framer, err := c.exec(context.Background(), q, nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -926,7 +935,7 @@ func (c *Conn) executeBatch(batch *Batch) *Iter {
entry := &batch.Entries[i]
b := &req.statements[i]
if len(entry.Args) > 0 || entry.binding != nil {
info, err := c.prepareStatement(entry.Stmt, nil)
info, err := c.prepareStatement(batch.context, entry.Stmt, nil)
if err != nil {
return &Iter{err: err}
}
Expand Down Expand Up @@ -970,7 +979,7 @@ func (c *Conn) executeBatch(batch *Batch) *Iter {
}

// TODO: should batch support tracing?
framer, err := c.exec(req, nil)
framer, err := c.exec(batch.context, req, nil)
if err != nil {
return &Iter{err: err}
}
Expand Down
23 changes: 22 additions & 1 deletion conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"crypto/tls"
"crypto/x509"
"fmt"
"golang.org/x/net/context"
"io"
"io/ioutil"
"net"
Expand Down Expand Up @@ -482,7 +483,7 @@ func TestStream0(t *testing.T) {
})

// need to write out an invalid frame, which we need a connection to do
framer, err := conn.exec(writer, nil)
framer, err := conn.exec(context.Background(), writer, nil)
if err == nil {
t.Fatal("expected to get an error on stream 0")
} else if !strings.HasPrefix(err.Error(), expErr) {
Expand Down Expand Up @@ -523,6 +524,26 @@ func TestConnClosedBlocked(t *testing.T) {
}
}

func TestContext_Timeout(t *testing.T) {
srv := NewTestServer(t, defaultProto)
defer srv.Stop()

cluster := testCluster(srv.Address, defaultProto)
cluster.Timeout = 5 * time.Second
db, err := cluster.CreateSession()
if err != nil {
t.Fatal(err)
}
defer db.Close()

ctx, cancel := context.WithCancel(context.Background())
cancel()
err = db.Query("timeout").WithContext(ctx).Exec()
if err != context.Canceled {
t.Fatalf("expected to get context cancel error: %v got %v", context.Canceled, err)
}
}

func NewTestServer(t testing.TB, protocol uint8) *TestServer {
laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
if err != nil {
Expand Down
10 changes: 6 additions & 4 deletions control.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
crand "crypto/rand"
"errors"
"fmt"
"golang.org/x/net/context"
"log"
"math/rand"
"net"
Expand Down Expand Up @@ -193,9 +194,10 @@ func (c *controlConn) registerEvents(conn *Conn) error {
return nil
}

framer, err := conn.exec(&writeRegisterFrame{
events: events,
}, nil)
framer, err := conn.exec(context.Background(),
&writeRegisterFrame{
events: events,
}, nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -282,7 +284,7 @@ func (c *controlConn) writeFrame(w frameWriter) (frame, error) {
return nil, errNoControl
}

framer, err := conn.exec(w, nil)
framer, err := conn.exec(context.Background(), w, nil)
if err != nil {
return nil, err
}
Expand Down
23 changes: 20 additions & 3 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"encoding/binary"
"errors"
"fmt"
"golang.org/x/net/context"
"io"
"net"
"strconv"
Expand Down Expand Up @@ -359,7 +360,7 @@ func (s *Session) getConn() *Conn {
}

// returns routing key indexes and type info
func (s *Session) routingKeyInfo(stmt string) (*routingKeyInfo, error) {
func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyInfo, error) {
s.routingKeyInfoCache.mu.Lock()

entry, cached := s.routingKeyInfoCache.lru.Get(stmt)
Expand Down Expand Up @@ -402,7 +403,7 @@ func (s *Session) routingKeyInfo(stmt string) (*routingKeyInfo, error) {
}

// get the query info for the statement
info, inflight.err = conn.prepareStatement(stmt, nil)
info, inflight.err = conn.prepareStatement(ctx, stmt, nil)
if inflight.err != nil {
// don't cache this error
s.routingKeyInfoCache.Remove(stmt)
Expand Down Expand Up @@ -587,6 +588,7 @@ type Query struct {
defaultTimestamp bool
defaultTimestampValue int64
disableSkipMetadata bool
context context.Context

disableAutoPage bool
}
Expand Down Expand Up @@ -669,6 +671,13 @@ func (q *Query) RoutingKey(routingKey []byte) *Query {
return q
}

// WithContext will set the context to use during a query, it will be used to
// timeout when waiting for responses from Cassandra.
func (q *Query) WithContext(ctx context.Context) *Query {
q.context = ctx
return q
}

func (q *Query) execute(conn *Conn) *Iter {
return conn.executeQuery(q)
}
Expand Down Expand Up @@ -700,7 +709,7 @@ func (q *Query) GetRoutingKey() ([]byte, error) {
}

// try to determine the routing key
routingKeyInfo, err := q.session.routingKeyInfo(q.stmt)
routingKeyInfo, err := q.session.routingKeyInfo(q.context, q.stmt)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1078,6 +1087,7 @@ type Batch struct {
totalLatency int64
serialCons SerialConsistency
defaultTimestamp bool
context context.Context
}

// NewBatch creates a new batch operation without defaults from the cluster
Expand Down Expand Up @@ -1135,6 +1145,13 @@ func (b *Batch) RetryPolicy(r RetryPolicy) *Batch {
return b
}

// WithContext will set the context to use during a query, it will be used to
// timeout when waiting for responses from Cassandra.
func (b *Batch) WithContext(ctx context.Context) *Batch {
b.context = ctx
return b
}

// Size returns the number of batch statements to be executed by the batch operation.
func (b *Batch) Size() int {
return len(b.Entries)
Expand Down

0 comments on commit 437c5ce

Please sign in to comment.