Skip to content

Commit

Permalink
batch: return an iter containing an error like queries
Browse files Browse the repository at this point in the history
To be consistent with queries batches should return an error inside and
iter instead of directly. To access to the error callers should use
iter.Close().
  • Loading branch information
Zariel committed Jan 15, 2016
1 parent 2fc8e5b commit d16cdd2
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 59 deletions.
2 changes: 1 addition & 1 deletion cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,7 @@ func TestReprepareBatch(t *testing.T) {
stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement_batch")
batch := session.NewBatch(UnloggedBatch)
batch.Query(stmt, "bar")
if _, err := conn.executeBatch(batch); err != nil {
if err := conn.executeBatch(batch).Close(); err != nil {
t.Fatalf("Failed to execute query for reprepare statement: %v", err)
}

Expand Down
31 changes: 15 additions & 16 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -817,9 +817,9 @@ func (c *Conn) UseKeyspace(keyspace string) error {
return nil
}

func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
func (c *Conn) executeBatch(batch *Batch) *Iter {
if c.version == protoVersion1 {
return nil, ErrUnsupported
return &Iter{err: ErrUnsupported}
}

n := len(batch.Entries)
Expand All @@ -831,15 +831,15 @@ func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
defaultTimestamp: batch.defaultTimestamp,
}

stmts := make(map[string]string)
stmts := make(map[string]string, len(batch.Entries))

for i := 0; i < n; i++ {
entry := &batch.Entries[i]
b := &req.statements[i]
if len(entry.Args) > 0 || entry.binding != nil {
info, err := c.prepareStatement(entry.Stmt, nil)
if err != nil {
return nil, err
return &Iter{err: err}
}

var args []interface{}
Expand All @@ -848,12 +848,12 @@ func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
} else {
args, err = entry.binding(info)
if err != nil {
return nil, err
return &Iter{err: err}
}
}

if len(args) != len(info.Args) {
return nil, ErrQueryArgLength
return &Iter{err: ErrQueryArgLength}
}

b.preparedID = info.Id
Expand All @@ -864,7 +864,7 @@ func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
for j := 0; j < len(info.Args); j++ {
val, err := Marshal(info.Args[j].TypeInfo, args[j])
if err != nil {
return nil, err
return &Iter{err: err}
}

b.values[j].value = val
Expand All @@ -878,18 +878,18 @@ func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
// TODO: should batch support tracing?
framer, err := c.exec(req, nil)
if err != nil {
return nil, err
return &Iter{err: err}
}

resp, err := framer.parseFrame()
if err != nil {
return nil, err
return &Iter{err: err, framer: framer}
}

switch x := resp.(type) {
case *resultVoidFrame:
framerPool.Put(framer)
return nil, nil
return &Iter{}
case *RequestErrUnprepared:
stmt, found := stmts[string(x.StatementId)]
if found {
Expand All @@ -903,7 +903,7 @@ func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
if found {
return c.executeBatch(batch)
} else {
return nil, x
return &Iter{err: err, framer: framer}
}
case *resultRowsFrame:
iter := &Iter{
Expand All @@ -912,13 +912,12 @@ func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
framer: framer,
}

return iter, nil
return iter
case error:
framerPool.Put(framer)
return nil, x

return &Iter{err: err, framer: framer}
default:
framerPool.Put(framer)
return nil, NewErrProtocol("Unknown type in response to batch statement: %s", x)
return &Iter{err: NewErrProtocol("Unknown type in response to batch statement: %s", x), framer: framer}
}
}

Expand Down
93 changes: 51 additions & 42 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -452,100 +452,109 @@ func (s *Session) routingKeyInfo(stmt string) (*routingKeyInfo, error) {
return routingKeyInfo, nil
}

func (s *Session) executeBatch(batch *Batch) (*Iter, error) {
func (s *Session) executeBatch(batch *Batch) *Iter {
// fail fast
if s.Closed() {
return nil, ErrSessionClosed
return &Iter{err: ErrSessionClosed}
}

// Prevent the execution of the batch if greater than the limit
// Currently batches have a limit of 65536 queries.
// https://datastax-oss.atlassian.net/browse/JAVA-229
if batch.Size() > BatchSizeMaximum {
return nil, ErrTooManyStmts
return &Iter{err: ErrTooManyStmts}
}

var err error
var iter *Iter
batch.attempts = 0
batch.totalLatency = 0
for {
host, conn := s.pool.Pick(nil)

//Assign the error unavailable and break loop
batch.attempts++
if conn == nil {
err = ErrNoConnections
if batch.rt == nil || !batch.rt.Attempt(batch) {
// Assign the error unavailable and break loop
iter = &Iter{err: ErrNoConnections}
break
}

continue
}

if conn == nil {
iter = &Iter{err: ErrNoConnections}
break
}

t := time.Now()
iter, err = conn.executeBatch(batch)
batch.totalLatency += time.Now().Sub(t).Nanoseconds()
batch.attempts++

// Update host
host.Mark(err)
iter = conn.executeBatch(batch)

batch.totalLatency += time.Since(t).Nanoseconds()
// Exit loop if operation executed correctly
if err == nil {
if iter.err == nil {
host.Mark(nil)
break
}

// Mark host with error if returned from Close
host.Mark(iter.Close())

if batch.rt == nil || !batch.rt.Attempt(batch) {
break
}
}

return iter, err
return iter
}

// ExecuteBatch executes a batch operation and returns nil if successful
// otherwise an error is returned describing the failure.
func (s *Session) ExecuteBatch(batch *Batch) error {
_, err := s.executeBatch(batch)
return err
iter := s.executeBatch(batch)
return iter.Close()
}

// ExecuteBatchCAS executes a batch operation and returns nil if successful and
// ExecuteBatchCAS executes a batch operation and returns true if successful and
// an iterator (to scan aditional rows if more than one conditional statement)
// was sent, otherwise an error is returned describing the failure.
// was sent.
// Further scans on the interator must also remember to include
// the applied boolean as the first argument to *Iter.Scan
func (s *Session) ExecuteBatchCAS(batch *Batch, dest ...interface{}) (applied bool, iter *Iter, err error) {
if iter, err := s.executeBatch(batch); err == nil {
if err := iter.checkErrAndNotFound(); err != nil {
return false, nil, err
}
if len(iter.Columns()) > 1 {
dest = append([]interface{}{&applied}, dest...)
iter.Scan(dest...)
} else {
iter.Scan(&applied)
}
return applied, iter, nil
} else {
iter = s.executeBatch(batch)
if err := iter.checkErrAndNotFound(); err != nil {
iter.Close()
return false, nil, err
}

if len(iter.Columns()) > 1 {
dest = append([]interface{}{&applied}, dest...)
iter.Scan(dest...)
} else {
iter.Scan(&applied)
}

return applied, iter, nil
}

// MapExecuteBatchCAS executes a batch operation much like ExecuteBatchCAS,
// however it accepts a map rather than a list of arguments for the initial
// scan.
func (s *Session) MapExecuteBatchCAS(batch *Batch, dest map[string]interface{}) (applied bool, iter *Iter, err error) {
if iter, err := s.executeBatch(batch); err == nil {
if err := iter.checkErrAndNotFound(); err != nil {
return false, nil, err
}
iter.MapScan(dest)
applied = dest["[applied]"].(bool)
delete(dest, "[applied]")

// we usually close here, but instead of closing, just returin an error
// if MapScan failed. Although Close just returns err, using Close
// here might be confusing as we are not actually closing the iter
return applied, iter, iter.err
} else {
iter = s.executeBatch(batch)
if err := iter.checkErrAndNotFound(); err != nil {
iter.Close()
return false, nil, err
}
iter.MapScan(dest)
applied = dest["[applied]"].(bool)
delete(dest, "[applied]")

// we usually close here, but instead of closing, just returin an error
// if MapScan failed. Although Close just returns err, using Close
// here might be confusing as we are not actually closing the iter
return applied, iter, iter.err
}

func (s *Session) connect(addr string, errorHandler ConnErrorHandler) (*Conn, error) {
Expand Down

0 comments on commit d16cdd2

Please sign in to comment.