Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ _testmain.go
*.exe

.idea/
*.iml
*.iml
.vscode/
31 changes: 28 additions & 3 deletions compression.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ var (
}}
)

func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
func decompressNoContextTakeover(r io.Reader, dict []byte) io.ReadCloser {
const tail =
// Add four bytes as specified in RFC
"\x00\x00\xff\xff" +
Expand All @@ -37,11 +37,23 @@ func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
return &flateReadWrapper{fr}
}

func decompressContextTakeover(r io.Reader, dict []byte) io.ReadCloser {
const tail =
// Add four bytes as specified in RFC
"\x00\x00\xff\xff" +
// Add final block to squelch unexpected EOF error from flate reader.
"\x01\x00\x00\xff\xff"

fr, _ := flateReaderPool.Get().(io.ReadCloser)
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), dict)
return &flateReadWrapper{fr}
}

func isValidCompressionLevel(level int) bool {
return minCompressionLevel <= level && level <= maxCompressionLevel
}

func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser {
func compressNoContextTakeover(w io.WriteCloser, level int, dict []byte) io.WriteCloser {
p := &flateWriterPools[level-minCompressionLevel]
tw := &truncWriter{w: w}
fw, _ := p.Get().(*flate.Writer)
Expand All @@ -53,6 +65,17 @@ func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser {
return &flateWriteWrapper{fw: fw, tw: tw, p: p}
}

func compressContextTakeover(w io.WriteCloser, level int, dict []byte) io.WriteCloser {
p := &flateWriterPools[level-minCompressionLevel]
tw := &truncWriter{w: w}

// WriterDict's Reset just restores the dictionary.
// Initialization is done with New. (If possible get struct from sync.Pool)
fw, _ := flate.NewWriterDict(tw, level, dict)

return &flateWriteWrapper{fw: fw, tw: tw, p: p}
}

// truncWriter is an io.Writer that writes all but the last four bytes of the
// stream to another io.Writer.
type truncWriter struct {
Expand Down Expand Up @@ -120,14 +143,16 @@ func (w *flateWriteWrapper) Close() error {
}

type flateReadWrapper struct {
fr io.ReadCloser
fr io.ReadCloser // flate.NewReader
}

func (r *flateReadWrapper) Read(p []byte) (int, error) {
if r.fr == nil {
return 0, io.ErrClosedPipe
}

n, err := r.fr.Read(p)

if err == io.EOF {
// Preemptively place the reader back in the pool. This helps with
// scenarios where the application does not call NextReader() soon after
Expand Down
56 changes: 49 additions & 7 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ const (

continuationFrame = 0
noFrame = -1

maxWindowBits = 1 << 15
)

// Close codes defined in RFC 6455, section 11.7.
Expand Down Expand Up @@ -241,7 +243,7 @@ type Conn struct {

enableWriteCompression bool
compressionLevel int
newCompressionWriter func(io.WriteCloser, int) io.WriteCloser
newCompressionWriter func(io.WriteCloser, int, []byte) io.WriteCloser

// Read fields
reader io.ReadCloser // the current reader returned to the application
Expand All @@ -259,8 +261,12 @@ type Conn struct {
readErrCount int
messageReader *messageReader // the current low-level reader

readDecompress bool // whether last read frame had RSV1 set
newDecompressionReader func(io.Reader) io.ReadCloser
readDecompress bool // whether last read frame had RSV1 set
newDecompressionReader func(io.Reader, []byte) io.ReadCloser // arges may flateReadWrapper struct

contextTakeover bool
dict []byte
mutex sync.RWMutex
}

func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
Expand Down Expand Up @@ -499,9 +505,14 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
}
c.writer = mw
if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
w := c.newCompressionWriter(c.writer, c.compressionLevel)
mw.compress = true
c.writer = w
switch {
case c.contextTakeover:
c.writer = c.newCompressionWriter(c.writer, c.compressionLevel, c.dict)
// no-context-takeover
default:
c.writer = c.newCompressionWriter(c.writer, c.compressionLevel, nil)
}
}
return c.writer, nil
}
Expand Down Expand Up @@ -752,6 +763,9 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error {
if _, err = w.Write(data); err != nil {
return err
}
if c.contextTakeover {
c.AddDict(data)
}
return w.Close()
}

Expand Down Expand Up @@ -945,9 +959,14 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
if frameType == TextMessage || frameType == BinaryMessage {
c.messageReader = &messageReader{c}
c.reader = c.messageReader
if c.readDecompress {
c.reader = c.newDecompressionReader(c.reader)

switch {
case c.readDecompress && c.contextTakeover:
c.reader = c.newDecompressionReader(c.reader, c.dict)
case c.readDecompress:
c.reader = c.newDecompressionReader(c.reader, nil)
}

return frameType, c.reader, nil
}
}
Expand All @@ -974,9 +993,11 @@ func (r *messageReader) Read(b []byte) (int, error) {
for c.readErr == nil {

if c.readRemaining > 0 {
// Determine the size of the data to be read.
if int64(len(b)) > c.readRemaining {
b = b[:c.readRemaining]
}

n, err := c.br.Read(b)
c.readErr = hideTempErr(err)
if c.isServer {
Expand All @@ -986,6 +1007,7 @@ func (r *messageReader) Read(b []byte) (int, error) {
if c.readRemaining > 0 && c.readErr == io.EOF {
c.readErr = errUnexpectedEOF
}

return n, c.readErr
}

Expand Down Expand Up @@ -1023,6 +1045,12 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
return messageType, nil, err
}
p, err = ioutil.ReadAll(r)

// if context-takeover add payload to dictionary
if c.contextTakeover {
c.AddDict(p)
}

return messageType, p, err
}

Expand Down Expand Up @@ -1139,6 +1167,20 @@ func (c *Conn) SetCompressionLevel(level int) error {
return nil
}

func (c *Conn) AddDict(b []byte) {
c.mutex.Lock()
defer c.mutex.Unlock()

// Todo I do not know whether to leave the dictionary with 32768 bytes or more
// If it is recognized as a duplicate character string,
// deleting a part of the character may make it impossible to decrypt it.
c.dict = append(b, c.dict...)

if len(c.dict) > maxWindowBits {
c.dict = c.dict[:maxWindowBits]
}
}

// FormatCloseMessage formats closeCode and text as a WebSocket close message.
// An empty message is returned for code CloseNoStatusReceived.
func FormatCloseMessage(closeCode int, text string) []byte {
Expand Down
35 changes: 27 additions & 8 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,21 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
subprotocol := u.selectSubprotocol(r, responseHeader)

// Negotiate PMCE
var compress bool
var (
compress bool
contextTakeover bool
)
if u.EnableCompression {
for _, ext := range parseExtensions(r.Header) {
if ext[""] != "permessage-deflate" {
continue
// map[string]string{"":"permessage-deflate", "client_max_window_bits":""}
// detect context-takeover from client_max_window_bits
if ext[""] == "permessage-deflate" {
compress = true
}

if _, ok := ext["client_max_window_bits"]; ok {
contextTakeover = true
}
compress = true
break
}
}

Expand Down Expand Up @@ -177,8 +184,15 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
c.subprotocol = subprotocol

if compress {
c.newCompressionWriter = compressNoContextTakeover
c.newDecompressionReader = decompressNoContextTakeover
switch {
case contextTakeover:
c.contextTakeover = contextTakeover
c.newCompressionWriter = compressContextTakeover
c.newDecompressionReader = decompressContextTakeover
default:
c.newCompressionWriter = compressNoContextTakeover
c.newDecompressionReader = decompressNoContextTakeover
}
}

p := c.writeBuf[:0]
Expand All @@ -191,7 +205,12 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
p = append(p, "\r\n"...)
}
if compress {
p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
switch {
case contextTakeover:
p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_max_window_bits=15; client_max_window_bits=15\r\n"...)
default:
p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
}
}
for k, vs := range responseHeader {
if k == "Sec-Websocket-Protocol" {
Expand Down