Skip to content

Commit

Permalink
Add safety checks around missed allocations for AsyncWebSocketMessage…
Browse files Browse the repository at this point in the history
…Buffer
  • Loading branch information
mathieucarbou committed Jan 26, 2024
1 parent ed538f9 commit ade1030
Showing 1 changed file with 45 additions and 11 deletions.
56 changes: 45 additions & 11 deletions src/AsyncWebSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,21 @@ AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer()
AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer(uint8_t* data, size_t size)
: _buffer(std::make_shared<std::vector<uint8_t>>(size))
{
std::memcpy(_buffer->data(), data, size);
if (_buffer->capacity() < size) {
_buffer.reset();
_buffer = std::make_shared<std::vector<uint8_t>>(0);
} else {
std::memcpy(_buffer->data(), data, size);
}
}

AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer(size_t size)
: _buffer(std::make_shared<std::vector<uint8_t>>(size))
{
if (_buffer->capacity() < size) {
_buffer.reset();
_buffer = std::make_shared<std::vector<uint8_t>>(0);
}
}

AsyncWebSocketMessageBuffer::~AsyncWebSocketMessageBuffer()
Expand Down Expand Up @@ -443,6 +452,9 @@ void AsyncWebSocketClient::_queueMessage(std::shared_ptr<std::vector<uint8_t>> b
if (!_client)
return;

if (buffer->size() == 0)
return;

{
AsyncWebLockGuard l(_lock);
if (_messageQueue.size() >= WS_MAX_QUEUED_MESSAGES)
Expand Down Expand Up @@ -687,8 +699,10 @@ std::shared_ptr<std::vector<uint8_t>> makeSharedBuffer(const uint8_t *message, s

void AsyncWebSocketClient::text(AsyncWebSocketMessageBuffer * buffer)
{
text(std::move(buffer->_buffer));
delete buffer;
if (buffer) {
text(std::move(buffer->_buffer));
delete buffer;
}
}

void AsyncWebSocketClient::text(std::shared_ptr<std::vector<uint8_t>> buffer)
Expand Down Expand Up @@ -739,8 +753,10 @@ void AsyncWebSocketClient::text(const __FlashStringHelper *data)

void AsyncWebSocketClient::binary(AsyncWebSocketMessageBuffer * buffer)
{
binary(std::move(buffer->_buffer));
delete buffer;
if (buffer) {
binary(std::move(buffer->_buffer));
delete buffer;
}
}

void AsyncWebSocketClient::binary(std::shared_ptr<std::vector<uint8_t>> buffer)
Expand Down Expand Up @@ -936,8 +952,10 @@ void AsyncWebSocket::text(uint32_t id, const __FlashStringHelper *data)

void AsyncWebSocket::textAll(AsyncWebSocketMessageBuffer * buffer)
{
textAll(std::move(buffer->_buffer));
delete buffer;
if (buffer) {
textAll(std::move(buffer->_buffer));
delete buffer;
}
}

void AsyncWebSocket::textAll(std::shared_ptr<std::vector<uint8_t>> buffer)
Expand Down Expand Up @@ -1014,8 +1032,10 @@ void AsyncWebSocket::binary(uint32_t id, const __FlashStringHelper *data, size_t

void AsyncWebSocket::binaryAll(AsyncWebSocketMessageBuffer * buffer)
{
binaryAll(std::move(buffer->_buffer));
delete buffer;
if (buffer) {
binaryAll(std::move(buffer->_buffer));
delete buffer;
}
}

void AsyncWebSocket::binaryAll(std::shared_ptr<std::vector<uint8_t>> buffer)
Expand Down Expand Up @@ -1200,12 +1220,26 @@ void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request)

AsyncWebSocketMessageBuffer * AsyncWebSocket::makeBuffer(size_t size)
{
return new AsyncWebSocketMessageBuffer(size);
AsyncWebSocketMessageBuffer * buffer = new AsyncWebSocketMessageBuffer(size);
if (buffer->length() != size)
{
delete buffer;
return nullptr;
} else {
return buffer;
}
}

AsyncWebSocketMessageBuffer * AsyncWebSocket::makeBuffer(uint8_t * data, size_t size)
{
return new AsyncWebSocketMessageBuffer(data, size);
AsyncWebSocketMessageBuffer * buffer = new AsyncWebSocketMessageBuffer(data, size);
if (buffer->length() != size)
{
delete buffer;
return nullptr;
} else {
return buffer;
}
}

/*
Expand Down

0 comments on commit ade1030

Please sign in to comment.