Skip to content

Commit

Permalink
[Enhancement] Add memory limit for ByteBuffer allocation in data inge…
Browse files Browse the repository at this point in the history
…stion (StarRocks#49308)

Signed-off-by: srlch <linzichao@starrocks.com>
  • Loading branch information
srlch authored Aug 19, 2024
1 parent 92748b8 commit 4a34533
Show file tree
Hide file tree
Showing 12 changed files with 240 additions and 32 deletions.
11 changes: 11 additions & 0 deletions be/src/common/status.h
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,17 @@ struct StatusInstance {
#define RETURN_IF_ERROR(stmt) RETURN_IF_ERROR_INTERNAL(stmt)
#endif

#define SET_STATUE_AND_RETURN_IF_ERROR_INTERNAL(err_status, stmt) \
do { \
auto&& status__ = (stmt); \
if (UNLIKELY(!status__.ok())) { \
err_status = to_status(status__).clone_and_append_context(__FILE__, __LINE__, AS_STRING(stmt)); \
return; \
} \
} while (false)

#define SET_STATUE_AND_RETURN_IF_ERROR(err_status, stmt) SET_STATUE_AND_RETURN_IF_ERROR_INTERNAL(err_status, stmt)

#define EXIT_IF_ERROR(stmt) \
do { \
auto&& status__ = (stmt); \
Expand Down
11 changes: 11 additions & 0 deletions be/src/common/statusor.h
Original file line number Diff line number Diff line change
Expand Up @@ -733,4 +733,15 @@ inline std::ostream& operator<<(std::ostream& os, const StatusOr<T>& st) {
// an lvalue StatusOr which you *don't* want to move out of cast appropriately.
#define ASSIGN_OR_RETURN(lhs, rhs) ASSIGN_OR_RETURN_IMPL(VARNAME_LINENUM(value_or_err), lhs, rhs)

#define ASSIGN_OR_SET_STATUS_AND_RETURN_IF_ERROR_IMPL(err_status, lhs, rhs) \
auto&& varname = (rhs); \
SET_STATUE_AND_RETURN_IF_ERROR(err_status, varname); \
lhs = std::move(varname).value();

// ASSIGN_OR_SET_STATUS_AND_RETURN_IF_ERROR is basiclly the same as ASSIGN_OR_RETURN, except:
// 1. return void if the status of rhs is NOT ok
// 2. set the status of rhs into err_status before return
#define ASSIGN_OR_SET_STATUS_AND_RETURN_IF_ERROR(err_status, lhs, rhs) \
ASSIGN_OR_SET_STATUS_AND_RETURN_IF_ERROR_IMPL(err_status, lhs, rhs)

} // namespace starrocks
3 changes: 2 additions & 1 deletion be/src/exec/json_scanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,8 @@ Status JsonReader::_read_file_stream() {
if (_file_stream_buffer->capacity < _file_stream_buffer->remaining() + simdjson::SIMDJSON_PADDING) {
// For efficiency reasons, simdjson requires a string with a few bytes (simdjson::SIMDJSON_PADDING) at the end.
// Hence, a re-allocation is needed if the space is not enough.
auto buf = ByteBuffer::allocate_with_tracker(_file_stream_buffer->remaining() + simdjson::SIMDJSON_PADDING);
ASSIGN_OR_RETURN(auto buf, ByteBuffer::allocate_with_tracker(_file_stream_buffer->remaining() +
simdjson::SIMDJSON_PADDING));
buf->put_bytes(_file_stream_buffer->ptr, _file_stream_buffer->remaining());
buf->flip();
std::swap(buf, _file_stream_buffer);
Expand Down
19 changes: 13 additions & 6 deletions be/src/http/action/stream_load.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,8 @@ Status StreamLoadAction::_on_header(HttpRequest* http_req, StreamLoadContext* ct
if (ctx->format == TFileFormatType::FORMAT_JSON) {
// Allocate buffer in advance, since the json payload cannot be parsed in stream mode.
// For efficiency reasons, simdjson requires a string with a few bytes (simdjson::SIMDJSON_PADDING) at the end.
ctx->buffer = ByteBuffer::allocate_with_tracker(ctx->body_bytes + simdjson::SIMDJSON_PADDING);
ASSIGN_OR_RETURN(ctx->buffer,
ByteBuffer::allocate_with_tracker(ctx->body_bytes + simdjson::SIMDJSON_PADDING));
}
} else {
#ifndef BE_TEST
Expand Down Expand Up @@ -356,15 +357,19 @@ void StreamLoadAction::on_chunk_data(HttpRequest* req) {
while ((len = evbuffer_get_length(evbuf)) > 0) {
if (ctx->buffer == nullptr) {
// Initialize buffer.
ctx->buffer = ByteBuffer::allocate_with_tracker(
ctx->format == TFileFormatType::FORMAT_JSON ? std::max(len, ctx->kDefaultBufferSize) : len);
ASSIGN_OR_SET_STATUS_AND_RETURN_IF_ERROR(
ctx->status, ctx->buffer,
ByteBuffer::allocate_with_tracker(ctx->format == TFileFormatType::FORMAT_JSON
? std::max(len, ctx->kDefaultBufferSize)
: len));

} else if (ctx->buffer->remaining() < len) {
if (ctx->format == TFileFormatType::FORMAT_JSON) {
// For json format, we need build a complete json before we push the buffer to the pipe.
// buffer capacity is not enough, so we try to expand the buffer.
ByteBufferPtr buf =
ByteBuffer::allocate_with_tracker(BitUtil::RoundUpToPowerOfTwo(ctx->buffer->pos + len));
ASSIGN_OR_SET_STATUS_AND_RETURN_IF_ERROR(
ctx->status, ByteBufferPtr buf,
ByteBuffer::allocate_with_tracker(BitUtil::RoundUpToPowerOfTwo(ctx->buffer->pos + len)));
buf->put_bytes(ctx->buffer->ptr, ctx->buffer->pos);
std::swap(buf, ctx->buffer);

Expand All @@ -379,7 +384,9 @@ void StreamLoadAction::on_chunk_data(HttpRequest* req) {
return;
}

ctx->buffer = ByteBuffer::allocate_with_tracker(std::max(len, ctx->kDefaultBufferSize));
ASSIGN_OR_SET_STATUS_AND_RETURN_IF_ERROR(
ctx->status, ctx->buffer,
ByteBuffer::allocate_with_tracker(std::max(len, ctx->kDefaultBufferSize)));
}
}

Expand Down
15 changes: 11 additions & 4 deletions be/src/http/action/transaction_stream_load.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -529,8 +529,11 @@ void TransactionStreamLoadAction::on_chunk_data(HttpRequest* req) {
while ((len = evbuffer_get_length(evbuf)) > 0) {
if (ctx->buffer == nullptr) {
// Initialize buffer.
ctx->buffer = ByteBuffer::allocate_with_tracker(
ctx->format == TFileFormatType::FORMAT_JSON ? std::max(len, ctx->kDefaultBufferSize) : len);
ASSIGN_OR_SET_STATUS_AND_RETURN_IF_ERROR(
ctx->status, ctx->buffer,
ByteBuffer::allocate_with_tracker(ctx->format == TFileFormatType::FORMAT_JSON
? std::max(len, ctx->kDefaultBufferSize)
: len));

} else if (ctx->buffer->remaining() < len) {
if (ctx->format == TFileFormatType::FORMAT_JSON) {
Expand All @@ -544,7 +547,9 @@ void TransactionStreamLoadAction::on_chunk_data(HttpRequest* req) {
ctx->status = Status::MemoryLimitExceeded(err_msg);
return;
}
ByteBufferPtr buf = ByteBuffer::allocate_with_tracker(BitUtil::RoundUpToPowerOfTwo(data_sz));
ASSIGN_OR_SET_STATUS_AND_RETURN_IF_ERROR(
ctx->status, ByteBufferPtr buf,
ByteBuffer::allocate_with_tracker(BitUtil::RoundUpToPowerOfTwo(data_sz)));
buf->put_bytes(ctx->buffer->ptr, ctx->buffer->pos);
std::swap(buf, ctx->buffer);

Expand All @@ -559,7 +564,9 @@ void TransactionStreamLoadAction::on_chunk_data(HttpRequest* req) {
return;
}

ctx->buffer = ByteBuffer::allocate_with_tracker(std::max(len, ctx->kDefaultBufferSize));
ASSIGN_OR_SET_STATUS_AND_RETURN_IF_ERROR(
ctx->status, ctx->buffer,
ByteBuffer::allocate_with_tracker(std::max(len, ctx->kDefaultBufferSize)));
}
}

Expand Down
2 changes: 1 addition & 1 deletion be/src/runtime/routine_load/kafka_consumer_pipe.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class KafkaConsumerPipe : public StreamLoadPipe {

Status append_json(const char* data, size_t size, char row_delimiter) {
// For efficiency reasons, simdjson requires a string with a few bytes (simdjson::SIMDJSON_PADDING) at the end.
auto buf = ByteBuffer::allocate_with_tracker(size + simdjson::SIMDJSON_PADDING);
ASSIGN_OR_RETURN(auto buf, ByteBuffer::allocate_with_tracker(size + simdjson::SIMDJSON_PADDING));
buf->put_bytes(data, size);
buf->flip();
return append(std::move(buf));
Expand Down
10 changes: 5 additions & 5 deletions be/src/runtime/stream_load/stream_load_pipe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ Status StreamLoadPipe::append(const char* data, size_t size) {
// need to allocate a new chunk, min chunk is 64k
size_t chunk_size = std::max(_min_chunk_size, size - pos);
chunk_size = BitUtil::RoundUpToPowerOfTwo(chunk_size);
_write_buf = ByteBuffer::allocate_with_tracker(chunk_size);
ASSIGN_OR_RETURN(_write_buf, ByteBuffer::allocate_with_tracker(chunk_size));
_write_buf->put_bytes(data + pos, size - pos);
return Status::OK();
}
Expand Down Expand Up @@ -195,7 +195,7 @@ Status StreamLoadPipe::no_block_read(uint8_t* data, size_t* data_size, bool* eof
// put back the read data to the buf_queue, read the data in the next time
size_t chunk_size = bytes_read;
chunk_size = BitUtil::RoundUpToPowerOfTwo(chunk_size);
ByteBufferPtr write_buf = ByteBuffer::allocate_with_tracker(chunk_size);
ASSIGN_OR_RETURN(ByteBufferPtr write_buf, ByteBuffer::allocate_with_tracker(chunk_size));
write_buf->put_bytes((char*)data, bytes_read);
write_buf->flip();
// error happens iff pipe is cancelled
Expand Down Expand Up @@ -293,7 +293,7 @@ StatusOr<ByteBufferPtr> CompressedStreamLoadPipeReader::read() {
}

if (_decompressed_buffer == nullptr) {
_decompressed_buffer = ByteBuffer::allocate_with_tracker(buffer_size);
ASSIGN_OR_RETURN(_decompressed_buffer, ByteBuffer::allocate_with_tracker(buffer_size));
}

ASSIGN_OR_RETURN(auto buf, StreamLoadPipeReader::read());
Expand All @@ -316,7 +316,7 @@ StatusOr<ByteBufferPtr> CompressedStreamLoadPipeReader::read() {
while (!stream_end) {
// buffer size grows exponentially
buffer_size = buffer_size < MAX_DECOMPRESS_BUFFER_SIZE ? buffer_size * 2 : MAX_DECOMPRESS_BUFFER_SIZE;
auto piece = ByteBuffer::allocate_with_tracker(buffer_size);
ASSIGN_OR_RETURN(auto piece, ByteBuffer::allocate_with_tracker(buffer_size));
RETURN_IF_ERROR(_decompressor->decompress(
reinterpret_cast<uint8_t*>(buf->ptr) + total_bytes_read, buf->remaining() - total_bytes_read,
&bytes_read, reinterpret_cast<uint8_t*>(piece->ptr), piece->capacity, &bytes_written, &stream_end));
Expand All @@ -332,7 +332,7 @@ StatusOr<ByteBufferPtr> CompressedStreamLoadPipeReader::read() {
if (_decompressed_buffer->remaining() < pieces_size) {
// align to 1024 bytes.
auto sz = ALIGN_UP(_decompressed_buffer->pos + pieces_size, 1024);
_decompressed_buffer = ByteBuffer::reallocate(_decompressed_buffer, sz);
ASSIGN_OR_RETURN(_decompressed_buffer, ByteBuffer::reallocate_with_tracker(_decompressed_buffer, sz));
}
for (const auto& piece : pieces) {
_decompressed_buffer->put_bytes(piece->ptr, piece->pos);
Expand Down
32 changes: 20 additions & 12 deletions be/src/util/byte_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@
#include "common/logging.h"
#include "gutil/strings/fastmem.h"
#include "runtime/current_thread.h"
#include "runtime/exec_env.h"
#include "runtime/mem_tracker.h"
#include "storage/utils.h"
#include "testutil/sync_point.h"

namespace starrocks {

Expand All @@ -61,25 +64,30 @@ struct MemTrackerDeleter {
};

struct ByteBuffer {
static ByteBufferPtr allocate(size_t size) {
ByteBufferPtr ptr(new ByteBuffer(size));
return ptr;
}

static ByteBufferPtr allocate(size_t size, MemTracker* tracker) {
static StatusOr<ByteBufferPtr> allocate_with_tracker(size_t size) {
auto tracker = CurrentThread::mem_tracker();
if (tracker == nullptr) {
return allocate(size);
return Status::InternalError("current thread memory tracker Not Found when allocate ByteBuffer");
}
#ifndef BE_TEST
// check limit before allocation
TRY_CATCH_BAD_ALLOC(ByteBufferPtr ptr(new ByteBuffer(size), MemTrackerDeleter(tracker)); return ptr;);
#else
ByteBufferPtr ptr(new ByteBuffer(size), MemTrackerDeleter(tracker));
return ptr;
Status ret = Status::OK();
TEST_SYNC_POINT_CALLBACK("ByteBuffer::allocate_with_tracker", &ret);
if (ret.ok()) {
return ptr;
} else {
return ret;
}
#endif
}

static ByteBufferPtr allocate_with_tracker(size_t size) { return allocate(size, CurrentThread::mem_tracker()); }

static ByteBufferPtr reallocate(const ByteBufferPtr& old_ptr, size_t new_size) {
static StatusOr<ByteBufferPtr> reallocate_with_tracker(const ByteBufferPtr& old_ptr, size_t new_size) {
if (new_size <= old_ptr->capacity) return old_ptr;

ByteBufferPtr ptr(new ByteBuffer(new_size));
ASSIGN_OR_RETURN(ByteBufferPtr ptr, allocate_with_tracker(new_size));
ptr->put_bytes(old_ptr->ptr, old_ptr->pos);
return ptr;
}
Expand Down
83 changes: 83 additions & 0 deletions be/test/http/stream_load_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

#include "http/action/stream_load.h"

#include <event2/buffer.h>
#include <event2/http.h>
#include <event2/http_struct.h>
#include <gtest/gtest.h>
Expand All @@ -45,6 +46,7 @@
#include "http/http_request.h"
#include "runtime/exec_env.h"
#include "runtime/stream_load/load_stream_mgr.h"
#include "runtime/stream_load/stream_load_context.h"
#include "runtime/stream_load/stream_load_executor.h"
#include "testutil/sync_point.h"
#include "util/brpc_stub_cache.h"
Expand Down Expand Up @@ -299,4 +301,85 @@ TEST_F(StreamLoadActionTest, plan_fail) {
SyncPoint::GetInstance()->DisableProcessing();
}

TEST_F(StreamLoadActionTest, huge_malloc) {
StreamLoadAction action(&_env, _limiter.get());
auto ctx = new StreamLoadContext(&_env);
ctx->ref();
ctx->body_sink = std::make_shared<StreamLoadPipe>();
HttpRequest request(_evhttp_req);
std::string content = "abc";

struct evhttp_request ev_req;
ev_req.remote_host = nullptr;
auto evb = evbuffer_new();
ev_req.input_buffer = evb;
request._ev_req = &ev_req;

request._headers.emplace(HttpHeaders::AUTHORIZATION, "Basic cm9vdDo=");
request._headers.emplace(HttpHeaders::CONTENT_LENGTH, "16");
request._headers.emplace(HTTP_DB_KEY, "db");
request._headers.emplace(HTTP_LABEL_KEY, "123");
request._headers.emplace(HTTP_COLUMN_SEPARATOR, "|");
request.set_handler(&action);
request.set_handler_ctx(ctx);

evbuffer_add(evb, content.data(), content.size());
SyncPoint::GetInstance()->EnableProcessing();
SyncPoint::GetInstance()->SetCallBack("ByteBuffer::allocate_with_tracker",
[](void* arg) { *((Status*)arg) = Status::MemoryLimitExceeded("TestFail"); });
ctx->status = Status::OK();
action.on_chunk_data(&request);
ASSERT_TRUE(ctx->status.is_mem_limit_exceeded());
SyncPoint::GetInstance()->ClearCallBack("ByteBuffer::allocate_with_tracker");
SyncPoint::GetInstance()->DisableProcessing();
ctx->status = Status::OK();
action.on_chunk_data(&request);
ASSERT_TRUE(ctx->status.ok());

evbuffer_add(evb, content.data(), content.size());
SyncPoint::GetInstance()->EnableProcessing();
SyncPoint::GetInstance()->SetCallBack("ByteBuffer::allocate_with_tracker",
[](void* arg) { *((Status*)arg) = Status::MemoryLimitExceeded("TestFail"); });
ctx->buffer = std::move(ByteBufferPtr(new ByteBuffer(1)));
ctx->status = Status::OK();
action.on_chunk_data(&request);
ASSERT_TRUE(ctx->status.is_mem_limit_exceeded());
ctx->buffer = nullptr;
SyncPoint::GetInstance()->ClearCallBack("ByteBuffer::allocate_with_tracker");
SyncPoint::GetInstance()->DisableProcessing();
ctx->buffer = std::move(ByteBufferPtr(new ByteBuffer(1)));
ctx->status = Status::OK();
action.on_chunk_data(&request);
ASSERT_TRUE(ctx->status.ok());
ctx->buffer = nullptr;

evbuffer_add(evb, content.data(), content.size());
auto old_format = ctx->format;
SyncPoint::GetInstance()->EnableProcessing();
SyncPoint::GetInstance()->SetCallBack("ByteBuffer::allocate_with_tracker",
[](void* arg) { *((Status*)arg) = Status::MemoryLimitExceeded("TestFail"); });
ctx->format = TFileFormatType::FORMAT_JSON;
ctx->buffer = std::move(ByteBufferPtr(new ByteBuffer(1)));
ctx->status = Status::OK();
action.on_chunk_data(&request);
ASSERT_TRUE(ctx->status.is_mem_limit_exceeded());
ctx->buffer = nullptr;
SyncPoint::GetInstance()->ClearCallBack("ByteBuffer::allocate_with_tracker");
SyncPoint::GetInstance()->DisableProcessing();
ctx->format = TFileFormatType::FORMAT_JSON;
ctx->buffer = std::move(ByteBufferPtr(new ByteBuffer(1)));
ctx->status = Status::OK();
action.on_chunk_data(&request);
ASSERT_TRUE(ctx->status.ok());
ctx->buffer = nullptr;
ctx->format = old_format;

request.set_handler_ctx(nullptr);
request.set_handler(nullptr);
if (ctx->unref()) {
delete ctx;
}
evbuffer_free(evb);
}

} // namespace starrocks
Loading

0 comments on commit 4a34533

Please sign in to comment.