Skip to content

Commit

Permalink
Reduce memory allocations in NextReader, NextWriter
Browse files Browse the repository at this point in the history
Redo 8b209f6 with support for old
versions of Go.
  • Loading branch information
garyburd committed May 31, 2016
1 parent 50d660d commit be01041
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 86 deletions.
163 changes: 77 additions & 86 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,23 +238,23 @@ type Conn struct {
writeBuf []byte // frame is constructed in this buffer.
writePos int // end of data in writeBuf.
writeFrameType int // type of the current frame.
writeSeq int // incremented to invalidate message writers.
writeDeadline time.Time
isWriting bool // for best-effort concurrent write detection
isWriting bool // for best-effort concurrent write detection
messageWriter *messageWriter // the current writer

// Read fields
readErr error
br *bufio.Reader
readRemaining int64 // bytes remaining in current frame.
readFinal bool // true the current message has more frames.
readSeq int // incremented to invalidate message readers.
readLength int64 // Message size.
readLimit int64 // Maximum message size.
readMaskPos int
readMaskKey [4]byte
handlePong func(string) error
handlePing func(string) error
readErrCount int
messageReader *messageReader // the current reader
}

func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
Expand All @@ -264,6 +264,9 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int)
if readBufferSize == 0 {
readBufferSize = defaultReadBufferSize
}
if readBufferSize < maxControlFramePayloadSize {
readBufferSize = maxControlFramePayloadSize
}
if writeBufferSize == 0 {
writeBufferSize = defaultWriteBufferSize
}
Expand Down Expand Up @@ -390,8 +393,8 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
return hideTempErr(err)
}

// NextWriter returns a writer for the next message to send. The writer's
// Close method flushes the complete message to the network.
// NextWriter returns a writer for the next message to send. The writer's Close
// method flushes the complete message to the network.
//
// There can be at most one open writer on a connection. NextWriter closes the
// previous writer if the application has not already done so.
Expand All @@ -411,7 +414,9 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
}

c.writeFrameType = messageType
return messageWriter{c, c.writeSeq}, nil
w := &messageWriter{c}
c.messageWriter = w
return w, nil
}

func (c *Conn) flushFrame(final bool, extra []byte) error {
Expand All @@ -420,7 +425,7 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
// Check for invalid control frames.
if isControl(c.writeFrameType) &&
(!final || length > maxControlFramePayloadSize) {
c.writeSeq++
c.messageWriter = nil
c.writeFrameType = noFrame
c.writePos = maxFrameHeaderSize
return errInvalidControlFrame
Expand Down Expand Up @@ -488,20 +493,17 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
c.writePos = maxFrameHeaderSize
c.writeFrameType = continuationFrame
if final {
c.writeSeq++
c.messageWriter = nil
c.writeFrameType = noFrame
}
return c.writeErr
}

type messageWriter struct {
c *Conn
seq int
}
type messageWriter struct{ c *Conn }

func (w messageWriter) err() error {
func (w *messageWriter) err() error {
c := w.c
if c.writeSeq != w.seq {
if c.messageWriter != w {
return errWriteClosed
}
if c.writeErr != nil {
Expand All @@ -510,7 +512,7 @@ func (w messageWriter) err() error {
return nil
}

func (w messageWriter) ncopy(max int) (int, error) {
func (w *messageWriter) ncopy(max int) (int, error) {
n := len(w.c.writeBuf) - w.c.writePos
if n <= 0 {
if err := w.c.flushFrame(false, nil); err != nil {
Expand All @@ -524,7 +526,7 @@ func (w messageWriter) ncopy(max int) (int, error) {
return n, nil
}

func (w messageWriter) write(final bool, p []byte) (int, error) {
func (w *messageWriter) write(final bool, p []byte) (int, error) {
if err := w.err(); err != nil {
return 0, err
}
Expand All @@ -551,11 +553,11 @@ func (w messageWriter) write(final bool, p []byte) (int, error) {
return nn, nil
}

func (w messageWriter) Write(p []byte) (int, error) {
func (w *messageWriter) Write(p []byte) (int, error) {
return w.write(false, p)
}

func (w messageWriter) WriteString(p string) (int, error) {
func (w *messageWriter) WriteString(p string) (int, error) {
if err := w.err(); err != nil {
return 0, err
}
Expand All @@ -573,7 +575,7 @@ func (w messageWriter) WriteString(p string) (int, error) {
return nn, nil
}

func (w messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
if err := w.err(); err != nil {
return 0, err
}
Expand All @@ -598,7 +600,7 @@ func (w messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
return nn, err
}

func (w messageWriter) Close() error {
func (w *messageWriter) Close() error {
if err := w.err(); err != nil {
return err
}
Expand All @@ -608,20 +610,22 @@ func (w messageWriter) Close() error {
// WriteMessage is a helper method for getting a writer using NextWriter,
// writing the message and closing the writer.
func (c *Conn) WriteMessage(messageType int, data []byte) error {
wr, err := c.NextWriter(messageType)
w, err := c.NextWriter(messageType)
if err != nil {
return err
}
w := wr.(messageWriter)
if _, err := w.write(true, data); err != nil {
if _, ok := w.(*messageWriter); ok && c.isServer {
// Optimize write as a single frame.
n := copy(c.writeBuf[c.writePos:], data)
c.writePos += n
data = data[n:]
err = c.flushFrame(true, data)
return err
}
if c.writeSeq == w.seq {
if err := c.flushFrame(true, nil); err != nil {
return err
}
if _, err = w.Write(data); err != nil {
return err
}
return nil
return w.Close()
}

// SetWriteDeadline sets the write deadline on the underlying network
Expand All @@ -635,22 +639,6 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {

// Read methods

// readFull is like io.ReadFull except that io.EOF is never returned.
func (c *Conn) readFull(p []byte) (err error) {
var n int
for n < len(p) && err == nil {
var nn int
nn, err = c.br.Read(p[n:])
n += nn
}
if n == len(p) {
err = nil
} else if err == io.EOF {
err = errUnexpectedEOF
}
return
}

func (c *Conn) advanceFrame() (int, error) {

// 1. Skip remainder of previous frame.
Expand All @@ -663,16 +651,16 @@ func (c *Conn) advanceFrame() (int, error) {

// 2. Read and parse first two bytes of frame header.

var b [8]byte
if err := c.readFull(b[:2]); err != nil {
p, err := c.read(2)
if err != nil {
return noFrame, err
}

final := b[0]&finalBit != 0
frameType := int(b[0] & 0xf)
reserved := int((b[0] >> 4) & 0x7)
mask := b[1]&maskBit != 0
c.readRemaining = int64(b[1] & 0x7f)
final := p[0]&finalBit != 0
frameType := int(p[0] & 0xf)
reserved := int((p[0] >> 4) & 0x7)
mask := p[1]&maskBit != 0
c.readRemaining = int64(p[1] & 0x7f)

if reserved != 0 {
return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved))
Expand Down Expand Up @@ -704,15 +692,17 @@ func (c *Conn) advanceFrame() (int, error) {

switch c.readRemaining {
case 126:
if err := c.readFull(b[:2]); err != nil {
p, err := c.read(2)
if err != nil {
return noFrame, err
}
c.readRemaining = int64(binary.BigEndian.Uint16(b[:2]))
c.readRemaining = int64(binary.BigEndian.Uint16(p))
case 127:
if err := c.readFull(b[:8]); err != nil {
p, err := c.read(8)
if err != nil {
return noFrame, err
}
c.readRemaining = int64(binary.BigEndian.Uint64(b[:8]))
c.readRemaining = int64(binary.BigEndian.Uint64(p))
}

// 4. Handle frame masking.
Expand All @@ -723,9 +713,11 @@ func (c *Conn) advanceFrame() (int, error) {

if mask {
c.readMaskPos = 0
if err := c.readFull(c.readMaskKey[:]); err != nil {
p, err := c.read(len(c.readMaskKey))
if err != nil {
return noFrame, err
}
copy(c.readMaskKey[:], p)
}

// 5. For text and binary messages, enforce read limit and return.
Expand All @@ -745,9 +737,9 @@ func (c *Conn) advanceFrame() (int, error) {

var payload []byte
if c.readRemaining > 0 {
payload = make([]byte, c.readRemaining)
payload, err = c.read(int(c.readRemaining))
c.readRemaining = 0
if err := c.readFull(payload); err != nil {
if err != nil {
return noFrame, err
}
if c.isServer {
Expand Down Expand Up @@ -805,7 +797,7 @@ func (c *Conn) handleProtocolError(message string) error {
// this method return the same error.
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {

c.readSeq++
c.messageReader = nil
c.readLength = 0

for c.readErr == nil {
Expand All @@ -815,7 +807,9 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
break
}
if frameType == TextMessage || frameType == BinaryMessage {
return frameType, messageReader{c, c.readSeq}, nil
r := &messageReader{c}
c.messageReader = r
return frameType, r, nil
}
}

Expand All @@ -830,51 +824,48 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
return noFrame, nil, c.readErr
}

type messageReader struct {
c *Conn
seq int
}

func (r messageReader) Read(b []byte) (int, error) {
type messageReader struct{ c *Conn }

if r.seq != r.c.readSeq {
func (r *messageReader) Read(b []byte) (int, error) {
c := r.c
if c.messageReader != r {
return 0, io.EOF
}

for r.c.readErr == nil {
for c.readErr == nil {

if r.c.readRemaining > 0 {
if int64(len(b)) > r.c.readRemaining {
b = b[:r.c.readRemaining]
if c.readRemaining > 0 {
if int64(len(b)) > c.readRemaining {
b = b[:c.readRemaining]
}
n, err := r.c.br.Read(b)
r.c.readErr = hideTempErr(err)
if r.c.isServer {
r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b[:n])
n, err := c.br.Read(b)
c.readErr = hideTempErr(err)
if c.isServer {
c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
}
r.c.readRemaining -= int64(n)
if r.c.readRemaining > 0 && r.c.readErr == io.EOF {
r.c.readErr = errUnexpectedEOF
c.readRemaining -= int64(n)
if c.readRemaining > 0 && c.readErr == io.EOF {
c.readErr = errUnexpectedEOF
}
return n, r.c.readErr
return n, c.readErr
}

if r.c.readFinal {
r.c.readSeq++
if c.readFinal {
c.messageReader = nil
return 0, io.EOF
}

frameType, err := r.c.advanceFrame()
frameType, err := c.advanceFrame()
switch {
case err != nil:
r.c.readErr = hideTempErr(err)
c.readErr = hideTempErr(err)
case frameType == TextMessage || frameType == BinaryMessage:
r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
}
}

err := r.c.readErr
if err == io.EOF && r.seq == r.c.readSeq {
err := c.readErr
if err == io.EOF && c.messageReader == r {
err = errUnexpectedEOF
}
return 0, err
Expand Down
18 changes: 18 additions & 0 deletions conn_read.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// +build go1.5

package websocket

import "io"

func (c *Conn) read(n int) ([]byte, error) {
p, err := c.br.Peek(n)
if err == io.EOF {
err = errUnexpectedEOF
}
c.br.Discard(len(p))
return p, err
}
Loading

0 comments on commit be01041

Please sign in to comment.