Skip to content

Commit

Permalink
Split request and response framer
Browse files Browse the repository at this point in the history
There is no need to share the write buffer with read goroutine,
so I'm splitting the code to use a separate framer for
request and response. This should help prevent accidental
reuse of framer in the future.
  • Loading branch information
martin-sucha committed Mar 18, 2022
1 parent 4ee4f62 commit 344f583
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 140 deletions.
61 changes: 34 additions & 27 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ func (c *Conn) closeWithError(err error) {
// we need to send the error to all waiting queries, put the state
// of this conn into not active so that it can not execute any queries.
select {
case req.resp <- err:
case req.resp <- callResp{err: err}:
case <-req.timeout:
}
if req.streamObserverContext != nil {
Expand Down Expand Up @@ -709,14 +709,16 @@ func (c *Conn) recv(ctx context.Context) error {
call, ok := c.calls[head.stream]
delete(c.calls, head.stream)
c.mu.Unlock()
if call == nil || call.framer == nil || !ok {
if call == nil || !ok {
c.logger.Printf("gocql: received response for stream which has no handler: header=%v\n", head)
return c.discardFrame(head)
} else if head.stream != call.streamID {
panic(fmt.Sprintf("call has incorrect streamID: got %d expected %d", call.streamID, head.stream))
}

err = call.framer.readFrame(c, &head)
framer := newFramer(c.compressor, c.version)

err = framer.readFrame(c, &head)
if err != nil {
// only net errors should cause the connection to be closed. Though
// cassandra returning corrupt frames will be returned here as well.
Expand All @@ -728,7 +730,7 @@ func (c *Conn) recv(ctx context.Context) error {
// we either, return a response to the caller, the caller timedout, or the
// connection has closed. Either way we should never block indefinatly here
select {
case call.resp <- err:
case call.resp <- callResp{framer: framer, err: err}:
case <-call.timeout:
c.releaseStream(call)
case <-ctx.Done():
Expand Down Expand Up @@ -760,10 +762,9 @@ func (c *Conn) handleTimeout() {
}

type callReq struct {
// could use a waitgroup but this allows us to do timeouts on the read/send
resp chan error
framer *framer
timeout chan struct{} // indicates to recv() that a call has timedout
// resp will receive the frame that was sent as a response to this stream.
resp chan callResp
timeout chan struct{} // indicates to recv() that a call has timed out
streamID int // current stream in use

timer *time.Timer
Expand All @@ -776,6 +777,14 @@ type callReq struct {
streamObserverEndOnce sync.Once
}

type callResp struct {
// framer is the response frame.
// May be nil if err is not nil.
framer *framer
// err is error encountered, if any.
err error
}

type deadlineWriter struct {
w interface {
SetWriteDeadline(time.Time) error
Expand Down Expand Up @@ -927,10 +936,9 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram
framer := newFramer(c.compressor, c.version)

call := &callReq{
framer: framer,
timeout: make(chan struct{}),
streamID: stream,
resp: make(chan error),
resp: make(chan callResp),
}

if c.streamObserver != nil {
Expand Down Expand Up @@ -1005,18 +1013,31 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram
}

select {
case err := <-call.resp:
case resp := <-call.resp:
close(call.timeout)
if err != nil {
if resp.err != nil {
if !c.Closed() {
// if the connection is closed then we cant release the stream,
// this is because the request is still outstanding and we have
// been handed another error from another stream which caused the
// connection to close.
c.releaseStream(call)
}
return nil, err
return nil, resp.err
}
// dont release the stream if detect a timeout as another request can reuse
// that stream and get a response for the old request, which we have no
// easy way of detecting.
//
// Ensure that the stream is not released if there are potentially outstanding
// requests on the stream to prevent nil pointer dereferences in recv().
defer c.releaseStream(call)

if v := resp.framer.header.version.version(); v != c.version {
return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version)
}

return resp.framer, nil
case <-timeoutCh:
close(call.timeout)
c.handleTimeout()
Expand All @@ -1027,20 +1048,6 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram
case <-c.ctx.Done():
return nil, ErrConnectionClosed
}

// dont release the stream if detect a timeout as another request can reuse
// that stream and get a response for the old request, which we have no
// easy way of detecting.
//
// Ensure that the stream is not released if there are potentially outstanding
// requests on the stream to prevent nil pointer dereferences in recv().
defer c.releaseStream(call)

if v := framer.header.version.version(); v != c.version {
return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version)
}

return framer, nil
}

// ObservedStream observes a single request/response stream.
Expand Down
70 changes: 36 additions & 34 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ func TestStream0(t *testing.T) {
f := newFramer(nil, protoVersion4)
f.writeHeader(0, opResult, 0)
f.writeInt(resultKindVoid)
f.wbuf[0] |= 0x80
f.buf[0] |= 0x80
if err := f.finish(); err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -1157,12 +1157,13 @@ func (srv *TestServer) errorLocked(err interface{}) {
srv.t.Error(err)
}

func (srv *TestServer) process(conn net.Conn, f *framer) {
head := f.header
func (srv *TestServer) process(conn net.Conn, reqFrame *framer) {
head := reqFrame.header
if head == nil {
srv.errorLocked("process frame with a nil header")
return
}
respFrame := newFramer(nil, reqFrame.proto)

switch head.op {
case opStartup:
Expand All @@ -1174,77 +1175,78 @@ func (srv *TestServer) process(conn net.Conn, f *framer) {
return
}
}
f.writeHeader(0, opReady, head.stream)
respFrame.writeHeader(0, opReady, head.stream)
case opOptions:
f.writeHeader(0, opSupported, head.stream)
f.writeShort(0)
respFrame.writeHeader(0, opSupported, head.stream)
respFrame.writeShort(0)
case opQuery:
query := f.readLongString()
query := reqFrame.readLongString()
first := query
if n := strings.Index(query, " "); n > 0 {
first = first[:n]
}
switch strings.ToLower(first) {
case "kill":
atomic.AddInt64(&srv.nKillReq, 1)
f.writeHeader(0, opError, head.stream)
f.writeInt(0x1001)
f.writeString("query killed")
respFrame.writeHeader(0, opError, head.stream)
respFrame.writeInt(0x1001)
respFrame.writeString("query killed")
case "use":
f.writeInt(resultKindKeyspace)
f.writeString(strings.TrimSpace(query[3:]))
respFrame.writeInt(resultKindKeyspace)
respFrame.writeString(strings.TrimSpace(query[3:]))
case "void":
f.writeHeader(0, opResult, head.stream)
f.writeInt(resultKindVoid)
respFrame.writeHeader(0, opResult, head.stream)
respFrame.writeInt(resultKindVoid)
case "timeout":
<-srv.ctx.Done()
return
case "slow":
go func() {
f.writeHeader(0, opResult, head.stream)
f.writeInt(resultKindVoid)
f.wbuf[0] = srv.protocol | 0x80
respFrame.writeHeader(0, opResult, head.stream)
respFrame.writeInt(resultKindVoid)
respFrame.buf[0] = srv.protocol | 0x80
select {
case <-srv.ctx.Done():
return
case <-time.After(50 * time.Millisecond):
f.finish()
respFrame.finish()
respFrame.writeTo(conn)
}
}()
return
case "speculative":
atomic.AddInt64(&srv.nKillReq, 1)
if atomic.LoadInt64(&srv.nKillReq) > 3 {
f.writeHeader(0, opResult, head.stream)
f.writeInt(resultKindVoid)
f.writeString("speculative query success on the node " + srv.Address)
respFrame.writeHeader(0, opResult, head.stream)
respFrame.writeInt(resultKindVoid)
respFrame.writeString("speculative query success on the node " + srv.Address)
} else {
f.writeHeader(0, opError, head.stream)
f.writeInt(0x1001)
f.writeString("speculative error")
respFrame.writeHeader(0, opError, head.stream)
respFrame.writeInt(0x1001)
respFrame.writeString("speculative error")
rand.Seed(time.Now().UnixNano())
<-time.After(time.Millisecond * 120)
}
default:
f.writeHeader(0, opResult, head.stream)
f.writeInt(resultKindVoid)
respFrame.writeHeader(0, opResult, head.stream)
respFrame.writeInt(resultKindVoid)
}
case opError:
f.writeHeader(0, opError, head.stream)
f.wbuf = append(f.wbuf, f.rbuf...)
respFrame.writeHeader(0, opError, head.stream)
respFrame.buf = append(respFrame.buf, reqFrame.buf...)
default:
f.writeHeader(0, opError, head.stream)
f.writeInt(0)
f.writeString("not supported")
respFrame.writeHeader(0, opError, head.stream)
respFrame.writeInt(0)
respFrame.writeString("not supported")
}

f.wbuf[0] = srv.protocol | 0x80
respFrame.buf[0] = srv.protocol | 0x80

if err := f.finish(); err != nil {
if err := respFrame.finish(); err != nil {
srv.errorLocked(err)
}

if err := f.writeTo(conn); err != nil {
if err := respFrame.writeTo(conn); err != nil {
srv.errorLocked(err)
}
}
Expand Down
Loading

0 comments on commit 344f583

Please sign in to comment.