Skip to content

Commit

Permalink
Don't reset connection on non-write errors
Browse files Browse the repository at this point in the history
If there is an error like too big frame that we are sending,
we don't want to close the connection as we abort before trying
to write there. We don't clobber the data stream in this case.

As framer is basically a buffer of frame data,
I moved the I/O operations to separate methods with
explicit reader/writer arguments so that we can
distinguish when IO fails.
  • Loading branch information
martin-sucha committed Mar 18, 2022
1 parent 73e4d19 commit 4ee4f62
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 63 deletions.
2 changes: 1 addition & 1 deletion cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2731,7 +2731,7 @@ func TestNegativeStream(t *testing.T) {
const stream = -50
writer := frameWriterFunc(func(f *framer, streamID int) error {
f.writeHeader(0, opOptions, stream)
return f.finishWrite()
return f.finish()
})

frame, err := conn.exec(context.Background(), writer, nil)
Expand Down
28 changes: 18 additions & 10 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ func (s *startupCoordinator) setupConn(ctx context.Context) error {
return nil
}

func (s *startupCoordinator) write(ctx context.Context, frame frameWriter) (frame, error) {
func (s *startupCoordinator) write(ctx context.Context, frame frameBuilder) (frame, error) {
select {
case s.frameTicker <- struct{}{}:
case <-ctx.Done():
Expand Down Expand Up @@ -681,17 +681,17 @@ func (c *Conn) recv(ctx context.Context) error {
return fmt.Errorf("gocql: frame header stream is beyond call expected bounds: %d", head.stream)
} else if head.stream == -1 {
// TODO: handle cassandra event frames, we shouldnt get any currently
framer := newFramer(c, c, c.compressor, c.version)
if err := framer.readFrame(&head); err != nil {
framer := newFramer(c.compressor, c.version)
if err := framer.readFrame(c, &head); err != nil {
return err
}
go c.session.handleEvent(framer)
return nil
} else if head.stream <= 0 {
// reserved stream that we dont use, probably due to a protocol error
// or a bug in Cassandra, this should be an error, parse it and return.
framer := newFramer(c, c, c.compressor, c.version)
if err := framer.readFrame(&head); err != nil {
framer := newFramer(c.compressor, c.version)
if err := framer.readFrame(c, &head); err != nil {
return err
}

Expand All @@ -716,7 +716,7 @@ func (c *Conn) recv(ctx context.Context) error {
panic(fmt.Sprintf("call has incorrect streamID: got %d expected %d", call.streamID, head.stream))
}

err = call.framer.readFrame(&head)
err = call.framer.readFrame(c, &head)
if err != nil {
// only net errors should cause the connection to be closed. Though
// cassandra returning corrupt frames will be returned here as well.
Expand Down Expand Up @@ -912,7 +912,7 @@ func (w *writeCoalescer) writeFlusher(interval time.Duration) {
}
}

func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*framer, error) {
func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*framer, error) {
if ctxErr := ctx.Err(); ctxErr != nil {
return nil, ctxErr
}
Expand All @@ -924,7 +924,7 @@ func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*frame
}

// resp is basically a waiting semaphore protecting the framer
framer := newFramer(c, c, c.compressor, c.version)
framer := newFramer(c.compressor, c.version)

call := &callReq{
framer: framer,
Expand Down Expand Up @@ -958,7 +958,15 @@ func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*frame
})
}

err := req.writeFrame(framer, stream)
err := req.buildFrame(framer, stream)
if err != nil {
// We failed to serialize the frame into a buffer.
// This should not affect the connection, we just free the current call.
c.releaseStream(call)
return nil, err
}

err = framer.writeTo(c)
if err != nil {
// closeWithError will block waiting for this stream to either receive a response
// or for us to timeout, close the timeout chan here. Im not entirely sure
Expand Down Expand Up @@ -1210,7 +1218,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter {
}

var (
frame frameWriter
frame frameBuilder
info *preparedStatment
)

Expand Down
24 changes: 16 additions & 8 deletions conn_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) 2012 The gocql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build all || unit
// +build all unit

package gocql
Expand Down Expand Up @@ -668,11 +669,14 @@ func TestStream0(t *testing.T) {
const expErr = "gocql: received unexpected frame on stream 0"

var buf bytes.Buffer
f := newFramer(nil, &buf, nil, protoVersion4)
f := newFramer(nil, protoVersion4)
f.writeHeader(0, opResult, 0)
f.writeInt(resultKindVoid)
f.wbuf[0] |= 0x80
if err := f.finishWrite(); err != nil {
if err := f.finish(); err != nil {
t.Fatal(err)
}
if err := f.writeTo(&buf); err != nil {
t.Fatal(err)
}

Expand Down Expand Up @@ -1115,7 +1119,7 @@ func (srv *TestServer) serve() {
srv.onRecv(framer)
}

go srv.process(framer)
go srv.process(conn, framer)
}
}(conn)
}
Expand Down Expand Up @@ -1153,7 +1157,7 @@ func (srv *TestServer) errorLocked(err interface{}) {
srv.t.Error(err)
}

func (srv *TestServer) process(f *framer) {
func (srv *TestServer) process(conn net.Conn, f *framer) {
head := f.header
if head == nil {
srv.errorLocked("process frame with a nil header")
Expand Down Expand Up @@ -1204,7 +1208,7 @@ func (srv *TestServer) process(f *framer) {
case <-srv.ctx.Done():
return
case <-time.After(50 * time.Millisecond):
f.finishWrite()
f.finish()
}
}()
return
Expand Down Expand Up @@ -1236,7 +1240,11 @@ func (srv *TestServer) process(f *framer) {

f.wbuf[0] = srv.protocol | 0x80

if err := f.finishWrite(); err != nil {
if err := f.finish(); err != nil {
srv.errorLocked(err)
}

if err := f.writeTo(conn); err != nil {
srv.errorLocked(err)
}
}
Expand All @@ -1247,9 +1255,9 @@ func (srv *TestServer) readFrame(conn net.Conn) (*framer, error) {
if err != nil {
return nil, err
}
framer := newFramer(conn, conn, nil, srv.protocol)
framer := newFramer(nil, srv.protocol)

err = framer.readFrame(&head)
err = framer.readFrame(conn, &head)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion control.go
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ func (c *controlConn) getConn() *connHost {
return c.conn.Load().(*connHost)
}

func (c *controlConn) writeFrame(w frameWriter) (frame, error) {
func (c *controlConn) writeFrame(w frameBuilder) (frame, error) {
ch := c.getConn()
if ch == nil {
return nil, errNoControl
Expand Down
Loading

0 comments on commit 4ee4f62

Please sign in to comment.