Skip to content

Commit

Permalink
[WebSocket] Reduce memcpy at WebSocketFrameParser::Decode()
Browse files Browse the repository at this point in the history
This reduces memcpy by only storing incomplete header bytes that are
only needed to store for next round decoding rather than whole data.

This make the perf faster a bit:
(Win10, local release build not fully optimized on Z840)
ToT:  147 MB/s
This: 154 MB/s (+4.5%)

Bug: 865001
Change-Id: I52de6b178034d69d7ff538d31c2ed660504e04a3
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1715921
Reviewed-by: Adam Rice <ricea@chromium.org>
Reviewed-by: Yutaka Hirano <yhirano@chromium.org>
Commit-Queue: Yoichi Osato <yoichio@chromium.org>
Cr-Commit-Position: refs/heads/master@{#681743}
  • Loading branch information
Yoichi Osato authored and Commit Bot committed Jul 29, 2019
1 parent fc2a8d3 commit 9a74219
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 94 deletions.
172 changes: 93 additions & 79 deletions net/websockets/websocket_frame_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,17 @@ const uint8_t kPayloadLengthMask = 0x7F;
const uint64_t kMaxPayloadLengthWithoutExtendedLengthField = 125;
const uint64_t kPayloadLengthWithTwoByteExtendedLengthField = 126;
const uint64_t kPayloadLengthWithEightByteExtendedLengthField = 127;
const size_t kMaximumFrameHeaderSize =
net::WebSocketFrameHeader::kBaseHeaderSize +
net::WebSocketFrameHeader::kMaximumExtendedLengthSize +
net::WebSocketFrameHeader::kMaskingKeyLength;

} // namespace.

namespace net {

WebSocketFrameParser::WebSocketFrameParser()
: current_read_pos_(0),
frame_offset_(0),
websocket_error_(kWebSocketNormalClosure) {
: frame_offset_(0), websocket_error_(kWebSocketNormalClosure) {
std::fill(masking_key_.key,
masking_key_.key + WebSocketFrameHeader::kMaskingKeyLength,
'\0');
Expand All @@ -52,106 +54,116 @@ bool WebSocketFrameParser::Decode(
if (!length)
return true;

// TODO(yutak): Remove copy.
buffer_.insert(buffer_.end(), data, data + length);
base::span<const char> data_span = base::make_span(data, length);
// If we have incomplete frame header, try to decode a header combining with
// |data|.
bool first_chunk = false;
if (incomplete_header_buffer_.size() > 0) {
DCHECK(!current_frame_header_.get());
const size_t original_size = incomplete_header_buffer_.size();
DCHECK_LE(original_size, kMaximumFrameHeaderSize);
incomplete_header_buffer_.insert(
incomplete_header_buffer_.end(), data,
data + std::min(length, kMaximumFrameHeaderSize - original_size));
const size_t consumed = DecodeFrameHeader(incomplete_header_buffer_);
if (websocket_error_ != kWebSocketNormalClosure)
return false;
if (!current_frame_header_.get())
return true;

DCHECK_GE(consumed, original_size);
data_span = data_span.subspan(consumed - original_size);
incomplete_header_buffer_.clear();
first_chunk = true;
}

while (current_read_pos_ < buffer_.size()) {
bool first_chunk = false;
DCHECK(incomplete_header_buffer_.empty());
while (data_span.size() > 0 || first_chunk) {
if (!current_frame_header_.get()) {
DecodeFrameHeader();
const size_t consumed = DecodeFrameHeader(data_span);
if (websocket_error_ != kWebSocketNormalClosure)
return false;
// If frame header is incomplete, then carry over the remaining
// data to the next round of Decode().
if (!current_frame_header_.get())
break;
if (!current_frame_header_.get()) {
DCHECK(!consumed);
incomplete_header_buffer_.insert(incomplete_header_buffer_.end(),
data_span.data(),
data_span.data() + data_span.size());
// Sanity check: the size of carried-over data should not exceed
// the maximum possible length of a frame header.
DCHECK_LT(incomplete_header_buffer_.size(), kMaximumFrameHeaderSize);
return true;
}
DCHECK_GE(data_span.size(), consumed);
data_span = data_span.subspan(consumed);
first_chunk = true;
}

DCHECK(incomplete_header_buffer_.empty());
std::unique_ptr<WebSocketFrameChunk> frame_chunk =
DecodeFramePayload(first_chunk);
DecodeFramePayload(first_chunk, &data_span);
first_chunk = false;
DCHECK(frame_chunk.get());
frame_chunks->push_back(std::move(frame_chunk));

if (current_frame_header_.get()) {
DCHECK(current_read_pos_ == buffer_.size());
break;
}
}

// Drain unnecessary data. TODO(yutak): Remove copy. (but how?)
buffer_.erase(buffer_.begin(), buffer_.begin() + current_read_pos_);
current_read_pos_ = 0;

// Sanity check: the size of carried-over data should not exceed
// the maximum possible length of a frame header.
static const size_t kMaximumFrameHeaderSize =
WebSocketFrameHeader::kBaseHeaderSize +
WebSocketFrameHeader::kMaximumExtendedLengthSize +
WebSocketFrameHeader::kMaskingKeyLength;
DCHECK_LT(buffer_.size(), kMaximumFrameHeaderSize);

return true;
}

void WebSocketFrameParser::DecodeFrameHeader() {
size_t WebSocketFrameParser::DecodeFrameHeader(base::span<const char> data) {
DVLOG(3) << "DecodeFrameHeader buffer size:"
<< ", data size:" << data.size();
typedef WebSocketFrameHeader::OpCode OpCode;
static const int kMaskingKeyLength = WebSocketFrameHeader::kMaskingKeyLength;

DCHECK(!current_frame_header_.get());

const char* start = &buffer_.front() + current_read_pos_;
const char* current = start;
const char* end = &buffer_.front() + buffer_.size();

// Header needs 2 bytes at minimum.
if (end - current < 2)
return;

uint8_t first_byte = *current++;
uint8_t second_byte = *current++;

bool final = (first_byte & kFinalBit) != 0;
bool reserved1 = (first_byte & kReserved1Bit) != 0;
bool reserved2 = (first_byte & kReserved2Bit) != 0;
bool reserved3 = (first_byte & kReserved3Bit) != 0;
OpCode opcode = first_byte & kOpCodeMask;
if (data.size() < 2)
return 0;
size_t current = 0;
const uint8_t first_byte = data[current++];
const uint8_t second_byte = data[current++];

const bool final = (first_byte & kFinalBit) != 0;
const bool reserved1 = (first_byte & kReserved1Bit) != 0;
const bool reserved2 = (first_byte & kReserved2Bit) != 0;
const bool reserved3 = (first_byte & kReserved3Bit) != 0;
const OpCode opcode = first_byte & kOpCodeMask;

bool masked = (second_byte & kMaskBit) != 0;
uint64_t payload_length = second_byte & kPayloadLengthMask;
if (payload_length == kPayloadLengthWithTwoByteExtendedLengthField) {
if (end - current < 2)
return;
if (data.size() < current + 2)
return 0;
uint16_t payload_length_16;
base::ReadBigEndian(current, &payload_length_16);
base::ReadBigEndian(&data[current], &payload_length_16);
current += 2;
payload_length = payload_length_16;
if (payload_length <= kMaxPayloadLengthWithoutExtendedLengthField)
if (payload_length <= kMaxPayloadLengthWithoutExtendedLengthField) {
websocket_error_ = kWebSocketErrorProtocolError;
return 0;
}
} else if (payload_length == kPayloadLengthWithEightByteExtendedLengthField) {
if (end - current < 8)
return;
base::ReadBigEndian(current, &payload_length);
if (data.size() < current + 8)
return 0;
base::ReadBigEndian(&data[current], &payload_length);
current += 8;
if (payload_length <= UINT16_MAX ||
payload_length > static_cast<uint64_t>(INT64_MAX)) {
websocket_error_ = kWebSocketErrorProtocolError;
} else if (payload_length > static_cast<uint64_t>(INT32_MAX)) {
return 0;
}
if (payload_length > static_cast<uint64_t>(INT32_MAX)) {
websocket_error_ = kWebSocketErrorMessageTooBig;
return 0;
}
}
if (websocket_error_ != kWebSocketNormalClosure) {
buffer_.clear();
current_read_pos_ = 0;
current_frame_header_.reset();
frame_offset_ = 0;
return;
}
DCHECK_EQ(websocket_error_, kWebSocketNormalClosure);

const bool masked = (second_byte & kMaskBit) != 0;
static const int kMaskingKeyLength = WebSocketFrameHeader::kMaskingKeyLength;
if (masked) {
if (end - current < kMaskingKeyLength)
return;
std::copy(current, current + kMaskingKeyLength, masking_key_.key);
if (data.size() < current + kMaskingKeyLength)
return 0;
std::copy(&data[current], &data[current] + kMaskingKeyLength,
masking_key_.key);
current += kMaskingKeyLength;
} else {
std::fill(masking_key_.key, masking_key_.key + kMaskingKeyLength, '\0');
Expand All @@ -164,37 +176,39 @@ void WebSocketFrameParser::DecodeFrameHeader() {
current_frame_header_->reserved3 = reserved3;
current_frame_header_->masked = masked;
current_frame_header_->payload_length = payload_length;
current_read_pos_ += current - start;
DCHECK_EQ(0u, frame_offset_);
return current;
}

std::unique_ptr<WebSocketFrameChunk> WebSocketFrameParser::DecodeFramePayload(
bool first_chunk) {
bool first_chunk,
base::span<const char>* data) {
// The cast here is safe because |payload_length| is already checked to be
// less than std::numeric_limits<int>::max() when the header is parsed.
int next_size = static_cast<int>(
std::min(static_cast<uint64_t>(buffer_.size() - current_read_pos_),
const int chunk_data_size = static_cast<int>(
std::min(static_cast<uint64_t>(data->size()),
current_frame_header_->payload_length - frame_offset_));

auto frame_chunk = std::make_unique<WebSocketFrameChunk>();
if (first_chunk) {
frame_chunk->header = current_frame_header_->Clone();
}
frame_chunk->final_chunk = false;
if (next_size) {
frame_chunk->data =
base::MakeRefCounted<IOBufferWithSize>(static_cast<int>(next_size));
if (chunk_data_size) {
frame_chunk->data = base::MakeRefCounted<IOBufferWithSize>(
static_cast<int>(chunk_data_size));
char* io_data = frame_chunk->data->data();
memcpy(io_data, &buffer_.front() + current_read_pos_, next_size);
// TODO(yoichio): Remove copy by making |frame_chunk| having refs of |data|.
memcpy(io_data, data->data(), chunk_data_size);
*data = data->subspan(chunk_data_size);
if (current_frame_header_->masked) {
// The masking function is its own inverse, so we use the same function to
// unmask as to mask.
MaskWebSocketFramePayload(
masking_key_, frame_offset_, io_data, next_size);
MaskWebSocketFramePayload(masking_key_, frame_offset_, io_data,
chunk_data_size);
}

current_read_pos_ += next_size;
frame_offset_ += next_size;
frame_offset_ += chunk_data_size;
}

DCHECK_LE(frame_offset_, current_frame_header_->payload_length);
Expand Down
27 changes: 14 additions & 13 deletions net/websockets/websocket_frame_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <memory>
#include <vector>

#include "base/containers/span.h"
#include "base/macros.h"
#include "net/base/net_export.h"
#include "net/websockets/websocket_errors.h"
Expand Down Expand Up @@ -48,27 +49,27 @@ class NET_EXPORT WebSocketFrameParser {
WebSocketError websocket_error() const { return websocket_error_; }

private:
// Tries to decode a frame header from |current_read_pos_|.
// If successful, this function updates |current_read_pos_|,
// |current_frame_header_|, and |masking_key_| (if available).
// This function may set |failed_| to true if it observes a corrupt frame.
// Tries to decode a frame header from |data|.
// If successful, this function updates
// |current_frame_header_|, and |masking_key_| (if available) and returns
// the number of consumed bytes in |data|.
// If there is not enough data in the remaining buffer to parse a frame
// header, this function returns without doing anything.
void DecodeFrameHeader();
// header, this function returns 0 without doing anything.
// This function may update |websocket_error_| if it observes a corrupt frame.
size_t DecodeFrameHeader(base::span<const char> data);

// Decodes frame payload and creates a WebSocketFrameChunk object.
// This function updates |current_read_pos_| and |frame_offset_| after
// This function updates |frame_offset_| after
// parsing. This function returns a frame object even if no payload data is
// available at this moment, so the receiver could make use of frame header
// information. If the end of frame is reached, this function clears
// |current_frame_header_|, |frame_offset_| and |masking_key_|.
std::unique_ptr<WebSocketFrameChunk> DecodeFramePayload(bool first_chunk);
std::unique_ptr<WebSocketFrameChunk> DecodeFramePayload(
bool first_chunk,
base::span<const char>* data);

// Internal buffer to store the data to parse.
std::vector<char> buffer_;

// Position in |buffer_| where the next round of parsing starts.
size_t current_read_pos_;
// Internal buffer to store the data to parse header.
std::vector<char> incomplete_header_buffer_;

// Frame header and masking key of the current frame.
// |masking_key_| is filled with zeros if the current frame is not masked.
Expand Down
4 changes: 2 additions & 2 deletions net/websockets/websocket_frame_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -384,9 +384,9 @@ TEST(WebSocketFrameParserTest, DecodePartialHeader) {
}
if (kFrameHeaderTests[i].error_code == kWebSocketNormalClosure &&
j == last_byte_offset) {
EXPECT_EQ(1u, frames.size());
EXPECT_EQ(1u, frames.size()) << "i=" << i << ", j=" << j;
} else {
EXPECT_EQ(0u, frames.size());
EXPECT_EQ(0u, frames.size()) << "i=" << i << ", j=" << j;
}
}
if (frames.size() != 1u)
Expand Down

0 comments on commit 9a74219

Please sign in to comment.