|
6 | 6 | "fmt" |
7 | 7 | "io" |
8 | 8 | "math" |
| 9 | + "math/bits" |
9 | 10 | ) |
10 | 11 |
|
11 | 12 | //go:generate stringer -type=opcode,MessageType,StatusCode -output=frame_stringer.go |
@@ -69,7 +70,7 @@ type header struct { |
69 | 70 | payloadLength int64 |
70 | 71 |
|
71 | 72 | masked bool |
72 | | - maskKey [4]byte |
| 73 | + maskKey uint32 |
73 | 74 | } |
74 | 75 |
|
75 | 76 | func makeWriteHeaderBuf() []byte { |
@@ -119,7 +120,7 @@ func writeHeader(b []byte, h header) []byte { |
119 | 120 | if h.masked { |
120 | 121 | b[1] |= 1 << 7 |
121 | 122 | b = b[:len(b)+4] |
122 | | - copy(b[len(b)-4:], h.maskKey[:]) |
| 123 | + binary.LittleEndian.PutUint32(b[len(b)-4:], h.maskKey) |
123 | 124 | } |
124 | 125 |
|
125 | 126 | return b |
@@ -192,7 +193,7 @@ func readHeader(b []byte, r io.Reader) (header, error) { |
192 | 193 | } |
193 | 194 |
|
194 | 195 | if h.masked { |
195 | | - copy(h.maskKey[:], b) |
| 196 | + h.maskKey = binary.LittleEndian.Uint32(b) |
196 | 197 | } |
197 | 198 |
|
198 | 199 | return h, nil |
@@ -321,122 +322,124 @@ func (ce CloseError) bytes() ([]byte, error) { |
321 | 322 | return buf, nil |
322 | 323 | } |
323 | 324 |
|
324 | | -// xor applies the WebSocket masking algorithm to p |
325 | | -// with the given key where the first 3 bits of pos |
326 | | -// are the starting position in the key. |
| 325 | +// fastMask applies the WebSocket masking algorithm to p |
| 326 | +// with the given key. |
327 | 327 | // See https://tools.ietf.org/html/rfc6455#section-5.3 |
328 | 328 | // |
329 | | -// The returned value is the position of the next byte |
330 | | -// to be used for masking in the key. This is so that |
331 | | -// unmasking can be performed without the entire frame. |
332 | | -func fastXOR(key [4]byte, keyPos int, b []byte) int { |
333 | | - // If the payload is greater than or equal to 16 bytes, then it's worth |
334 | | - // masking 8 bytes at a time. |
335 | | - // Optimization from https://github.com/golang/go/issues/31586#issuecomment-485530859 |
336 | | - if len(b) >= 16 { |
337 | | - // We first create a key that is 8 bytes long |
338 | | - // and is aligned on the position correctly. |
339 | | - var alignedKey [8]byte |
340 | | - for i := range alignedKey { |
341 | | - alignedKey[i] = key[(i+keyPos)&3] |
342 | | - } |
343 | | - k := binary.LittleEndian.Uint64(alignedKey[:]) |
| 329 | +// The returned value is the correctly rotated key to |
| 330 | +// to continue to mask/unmask the message. |
| 331 | +// |
| 332 | +// It is optimized for LittleEndian and expects the key |
| 333 | +// to be in little endian. |
| 334 | +// |
| 335 | +// See https://github.com/golang/go/issues/31586 |
| 336 | +func mask(key uint32, b []byte) uint32 { |
| 337 | + if len(b) >= 8 { |
| 338 | + key64 := uint64(key)<<32 | uint64(key) |
344 | 339 |
|
345 | 340 | // At some point in the future we can clean these unrolled loops up. |
346 | 341 | // See https://github.com/golang/go/issues/31586#issuecomment-487436401 |
347 | 342 |
|
348 | 343 | // Then we xor until b is less than 128 bytes. |
349 | 344 | for len(b) >= 128 { |
350 | 345 | v := binary.LittleEndian.Uint64(b) |
351 | | - binary.LittleEndian.PutUint64(b, v^k) |
352 | | - v = binary.LittleEndian.Uint64(b[8:]) |
353 | | - binary.LittleEndian.PutUint64(b[8:], v^k) |
354 | | - v = binary.LittleEndian.Uint64(b[16:]) |
355 | | - binary.LittleEndian.PutUint64(b[16:], v^k) |
356 | | - v = binary.LittleEndian.Uint64(b[24:]) |
357 | | - binary.LittleEndian.PutUint64(b[24:], v^k) |
358 | | - v = binary.LittleEndian.Uint64(b[32:]) |
359 | | - binary.LittleEndian.PutUint64(b[32:], v^k) |
360 | | - v = binary.LittleEndian.Uint64(b[40:]) |
361 | | - binary.LittleEndian.PutUint64(b[40:], v^k) |
362 | | - v = binary.LittleEndian.Uint64(b[48:]) |
363 | | - binary.LittleEndian.PutUint64(b[48:], v^k) |
364 | | - v = binary.LittleEndian.Uint64(b[56:]) |
365 | | - binary.LittleEndian.PutUint64(b[56:], v^k) |
366 | | - v = binary.LittleEndian.Uint64(b[64:]) |
367 | | - binary.LittleEndian.PutUint64(b[64:], v^k) |
368 | | - v = binary.LittleEndian.Uint64(b[72:]) |
369 | | - binary.LittleEndian.PutUint64(b[72:], v^k) |
370 | | - v = binary.LittleEndian.Uint64(b[80:]) |
371 | | - binary.LittleEndian.PutUint64(b[80:], v^k) |
372 | | - v = binary.LittleEndian.Uint64(b[88:]) |
373 | | - binary.LittleEndian.PutUint64(b[88:], v^k) |
374 | | - v = binary.LittleEndian.Uint64(b[96:]) |
375 | | - binary.LittleEndian.PutUint64(b[96:], v^k) |
376 | | - v = binary.LittleEndian.Uint64(b[104:]) |
377 | | - binary.LittleEndian.PutUint64(b[104:], v^k) |
378 | | - v = binary.LittleEndian.Uint64(b[112:]) |
379 | | - binary.LittleEndian.PutUint64(b[112:], v^k) |
380 | | - v = binary.LittleEndian.Uint64(b[120:]) |
381 | | - binary.LittleEndian.PutUint64(b[120:], v^k) |
| 346 | + binary.LittleEndian.PutUint64(b, v^key64) |
| 347 | + v = binary.LittleEndian.Uint64(b[8:16]) |
| 348 | + binary.LittleEndian.PutUint64(b[8:16], v^key64) |
| 349 | + v = binary.LittleEndian.Uint64(b[16:24]) |
| 350 | + binary.LittleEndian.PutUint64(b[16:24], v^key64) |
| 351 | + v = binary.LittleEndian.Uint64(b[24:32]) |
| 352 | + binary.LittleEndian.PutUint64(b[24:32], v^key64) |
| 353 | + v = binary.LittleEndian.Uint64(b[32:40]) |
| 354 | + binary.LittleEndian.PutUint64(b[32:40], v^key64) |
| 355 | + v = binary.LittleEndian.Uint64(b[40:48]) |
| 356 | + binary.LittleEndian.PutUint64(b[40:48], v^key64) |
| 357 | + v = binary.LittleEndian.Uint64(b[48:56]) |
| 358 | + binary.LittleEndian.PutUint64(b[48:56], v^key64) |
| 359 | + v = binary.LittleEndian.Uint64(b[56:64]) |
| 360 | + binary.LittleEndian.PutUint64(b[56:64], v^key64) |
| 361 | + v = binary.LittleEndian.Uint64(b[64:72]) |
| 362 | + binary.LittleEndian.PutUint64(b[64:72], v^key64) |
| 363 | + v = binary.LittleEndian.Uint64(b[72:80]) |
| 364 | + binary.LittleEndian.PutUint64(b[72:80], v^key64) |
| 365 | + v = binary.LittleEndian.Uint64(b[80:88]) |
| 366 | + binary.LittleEndian.PutUint64(b[80:88], v^key64) |
| 367 | + v = binary.LittleEndian.Uint64(b[88:96]) |
| 368 | + binary.LittleEndian.PutUint64(b[88:96], v^key64) |
| 369 | + v = binary.LittleEndian.Uint64(b[96:104]) |
| 370 | + binary.LittleEndian.PutUint64(b[96:104], v^key64) |
| 371 | + v = binary.LittleEndian.Uint64(b[104:112]) |
| 372 | + binary.LittleEndian.PutUint64(b[104:112], v^key64) |
| 373 | + v = binary.LittleEndian.Uint64(b[112:120]) |
| 374 | + binary.LittleEndian.PutUint64(b[112:120], v^key64) |
| 375 | + v = binary.LittleEndian.Uint64(b[120:128]) |
| 376 | + binary.LittleEndian.PutUint64(b[120:128], v^key64) |
382 | 377 | b = b[128:] |
383 | 378 | } |
384 | 379 |
|
385 | 380 | // Then we xor until b is less than 64 bytes. |
386 | 381 | for len(b) >= 64 { |
387 | 382 | v := binary.LittleEndian.Uint64(b) |
388 | | - binary.LittleEndian.PutUint64(b, v^k) |
389 | | - v = binary.LittleEndian.Uint64(b[8:]) |
390 | | - binary.LittleEndian.PutUint64(b[8:], v^k) |
391 | | - v = binary.LittleEndian.Uint64(b[16:]) |
392 | | - binary.LittleEndian.PutUint64(b[16:], v^k) |
393 | | - v = binary.LittleEndian.Uint64(b[24:]) |
394 | | - binary.LittleEndian.PutUint64(b[24:], v^k) |
395 | | - v = binary.LittleEndian.Uint64(b[32:]) |
396 | | - binary.LittleEndian.PutUint64(b[32:], v^k) |
397 | | - v = binary.LittleEndian.Uint64(b[40:]) |
398 | | - binary.LittleEndian.PutUint64(b[40:], v^k) |
399 | | - v = binary.LittleEndian.Uint64(b[48:]) |
400 | | - binary.LittleEndian.PutUint64(b[48:], v^k) |
401 | | - v = binary.LittleEndian.Uint64(b[56:]) |
402 | | - binary.LittleEndian.PutUint64(b[56:], v^k) |
| 383 | + binary.LittleEndian.PutUint64(b, v^key64) |
| 384 | + v = binary.LittleEndian.Uint64(b[8:16]) |
| 385 | + binary.LittleEndian.PutUint64(b[8:16], v^key64) |
| 386 | + v = binary.LittleEndian.Uint64(b[16:24]) |
| 387 | + binary.LittleEndian.PutUint64(b[16:24], v^key64) |
| 388 | + v = binary.LittleEndian.Uint64(b[24:32]) |
| 389 | + binary.LittleEndian.PutUint64(b[24:32], v^key64) |
| 390 | + v = binary.LittleEndian.Uint64(b[32:40]) |
| 391 | + binary.LittleEndian.PutUint64(b[32:40], v^key64) |
| 392 | + v = binary.LittleEndian.Uint64(b[40:48]) |
| 393 | + binary.LittleEndian.PutUint64(b[40:48], v^key64) |
| 394 | + v = binary.LittleEndian.Uint64(b[48:56]) |
| 395 | + binary.LittleEndian.PutUint64(b[48:56], v^key64) |
| 396 | + v = binary.LittleEndian.Uint64(b[56:64]) |
| 397 | + binary.LittleEndian.PutUint64(b[56:64], v^key64) |
403 | 398 | b = b[64:] |
404 | 399 | } |
405 | 400 |
|
406 | 401 | // Then we xor until b is less than 32 bytes. |
407 | 402 | for len(b) >= 32 { |
408 | 403 | v := binary.LittleEndian.Uint64(b) |
409 | | - binary.LittleEndian.PutUint64(b, v^k) |
410 | | - v = binary.LittleEndian.Uint64(b[8:]) |
411 | | - binary.LittleEndian.PutUint64(b[8:], v^k) |
412 | | - v = binary.LittleEndian.Uint64(b[16:]) |
413 | | - binary.LittleEndian.PutUint64(b[16:], v^k) |
414 | | - v = binary.LittleEndian.Uint64(b[24:]) |
415 | | - binary.LittleEndian.PutUint64(b[24:], v^k) |
| 404 | + binary.LittleEndian.PutUint64(b, v^key64) |
| 405 | + v = binary.LittleEndian.Uint64(b[8:16]) |
| 406 | + binary.LittleEndian.PutUint64(b[8:16], v^key64) |
| 407 | + v = binary.LittleEndian.Uint64(b[16:24]) |
| 408 | + binary.LittleEndian.PutUint64(b[16:24], v^key64) |
| 409 | + v = binary.LittleEndian.Uint64(b[24:32]) |
| 410 | + binary.LittleEndian.PutUint64(b[24:32], v^key64) |
416 | 411 | b = b[32:] |
417 | 412 | } |
418 | 413 |
|
419 | 414 | // Then we xor until b is less than 16 bytes. |
420 | 415 | for len(b) >= 16 { |
421 | 416 | v := binary.LittleEndian.Uint64(b) |
422 | | - binary.LittleEndian.PutUint64(b, v^k) |
423 | | - v = binary.LittleEndian.Uint64(b[8:]) |
424 | | - binary.LittleEndian.PutUint64(b[8:], v^k) |
| 417 | + binary.LittleEndian.PutUint64(b, v^key64) |
| 418 | + v = binary.LittleEndian.Uint64(b[8:16]) |
| 419 | + binary.LittleEndian.PutUint64(b[8:16], v^key64) |
425 | 420 | b = b[16:] |
426 | 421 | } |
427 | 422 |
|
428 | 423 | // Then we xor until b is less than 8 bytes. |
429 | 424 | for len(b) >= 8 { |
430 | 425 | v := binary.LittleEndian.Uint64(b) |
431 | | - binary.LittleEndian.PutUint64(b, v^k) |
| 426 | + binary.LittleEndian.PutUint64(b, v^key64) |
432 | 427 | b = b[8:] |
433 | 428 | } |
434 | 429 | } |
435 | 430 |
|
| 431 | + // Then we xor until b is less than 4 bytes. |
| 432 | + for len(b) >= 4 { |
| 433 | + v := binary.LittleEndian.Uint32(b) |
| 434 | + binary.LittleEndian.PutUint32(b, v^key) |
| 435 | + b = b[4:] |
| 436 | + } |
| 437 | + |
436 | 438 | // xor remaining bytes. |
437 | 439 | for i := range b { |
438 | | - b[i] ^= key[keyPos&3] |
439 | | - keyPos++ |
| 440 | + b[i] ^= byte(key) |
| 441 | + key = bits.RotateLeft32(key, -8) |
440 | 442 | } |
441 | | - return keyPos & 3 |
| 443 | + |
| 444 | + return key |
442 | 445 | } |
0 commit comments