@@ -4,14 +4,13 @@ package websocket
44
55import (
66 "bufio"
7- "compress/flate"
87 "context"
98 "crypto/rand"
109 "encoding/binary"
1110 "io"
12- "sync"
1311 "time"
1412
13+ "github.com/klauspost/compress/flate"
1514 "golang.org/x/xerrors"
1615
1716 "nhooyr.io/websocket/internal/errd"
@@ -51,16 +50,15 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
5150type msgWriter struct {
5251 c * Conn
5352
54- mu * mu
55- writeMu sync.Mutex
53+ mu * mu
5654
5755 ctx context.Context
5856 opcode opcode
5957 closed bool
6058 flate bool
6159
62- trimWriter * trimLastFourBytesWriter
63- flateWriter * flate. Writer
60+ trimWriter * trimLastFourBytesWriter
61+ dict * slidingWindow
6462}
6563
6664func newMsgWriter (c * Conn ) * msgWriter {
@@ -72,16 +70,16 @@ func newMsgWriter(c *Conn) *msgWriter {
7270}
7371
7472func (mw * msgWriter ) ensureFlate () {
73+ if mw .flateContextTakeover () && mw .dict == nil {
74+ mw .dict = newSlidingWindow (8192 )
75+ }
76+
7577 if mw .trimWriter == nil {
7678 mw .trimWriter = & trimLastFourBytesWriter {
7779 w : writerFunc (mw .write ),
7880 }
7981 }
8082
81- if mw .flateWriter == nil {
82- mw .flateWriter = getFlateWriter (mw .trimWriter )
83- }
84-
8583 mw .flate = true
8684}
8785
@@ -138,20 +136,10 @@ func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {
138136 return nil
139137}
140138
141- func (mw * msgWriter ) returnFlateWriter () {
142- if mw .flateWriter != nil {
143- putFlateWriter (mw .flateWriter )
144- mw .flateWriter = nil
145- }
146- }
147-
148139// Write writes the given bytes to the WebSocket connection.
149140func (mw * msgWriter ) Write (p []byte ) (_ int , err error ) {
150141 defer errd .Wrap (& err , "failed to write" )
151142
152- mw .writeMu .Lock ()
153- defer mw .writeMu .Unlock ()
154-
155143 if mw .closed {
156144 return 0 , xerrors .New ("cannot use closed writer" )
157145 }
@@ -165,7 +153,11 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) {
165153 }
166154
167155 if mw .flate {
168- return mw .flateWriter .Write (p )
156+ err = flate .StatelessDeflate (mw .trimWriter , p , false , mw .dict .getBuf ())
157+ if mw .flateContextTakeover () {
158+ mw .dict .write (p )
159+ }
160+ return len (p ), err
169161 }
170162
171163 return mw .write (p )
@@ -184,17 +176,14 @@ func (mw *msgWriter) write(p []byte) (int, error) {
184176func (mw * msgWriter ) Close () (err error ) {
185177 defer errd .Wrap (& err , "failed to close writer" )
186178
187- mw .writeMu .Lock ()
188- defer mw .writeMu .Unlock ()
189-
190179 if mw .closed {
191180 return xerrors .New ("cannot use closed writer" )
192181 }
193182
194183 if mw .flate {
195- err = mw .flateWriter . Flush ( )
184+ err = flate . StatelessDeflate ( mw .trimWriter , nil , true , mw . dict . getBuf () )
196185 if err != nil {
197- return xerrors .Errorf ("failed to flush flate writer : %w" , err )
186+ return xerrors .Errorf ("failed to flush flate: %w" , err )
198187 }
199188 }
200189
@@ -207,18 +196,10 @@ func (mw *msgWriter) Close() (err error) {
207196 return xerrors .Errorf ("failed to write fin frame: %w" , err )
208197 }
209198
210- if mw .flate && ! mw .flateContextTakeover () {
211- mw .returnFlateWriter ()
212- }
213199 mw .mu .Unlock ()
214200 return nil
215201}
216202
217- func (mw * msgWriter ) close () {
218- mw .writeMu .Lock ()
219- mw .returnFlateWriter ()
220- }
221-
222203func (c * Conn ) writeControl (ctx context.Context , opcode opcode , p []byte ) error {
223204 ctx , cancel := context .WithTimeout (ctx , time .Second * 5 )
224205 defer cancel ()
0 commit comments