Skip to content

Commit

Permalink
[BackPort][Feature] support prepare statement (StarRocks#27840)
Browse files Browse the repository at this point in the history
Signed-off-by: jukejian <jukejian@bytedance.com>
Co-authored-by: root <root@n37-042-050.byted.org>
  • Loading branch information
2 people authored and wuxueyang96 committed Mar 21, 2024
1 parent a95882d commit 2cb2f80
Show file tree
Hide file tree
Showing 38 changed files with 1,151 additions and 17 deletions.
1 change: 1 addition & 0 deletions be/src/exec/pipeline/fragment_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,7 @@ Status FragmentExecutor::_decompose_data_sink_to_operator(RuntimeState* runtime_
result_sink->get_file_opts(), dop, fragment_ctx);
} else {
op = std::make_shared<ResultSinkOperatorFactory>(context->next_operator_id(), result_sink->get_sink_type(),
result_sink->isBinaryFormat(),
result_sink->get_output_exprs(), fragment_ctx);
}
// Add result sink operator to last pipeline
Expand Down
3 changes: 2 additions & 1 deletion be/src/exec/pipeline/result_sink_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ Status ResultSinkOperator::prepare(RuntimeState* state) {
// Create writer based on sink type
switch (_sink_type) {
case TResultSinkType::MYSQL_PROTOCAL:
_writer = std::make_shared<MysqlResultWriter>(_sender.get(), _output_expr_ctxs, _profile.get());
_writer = std::make_shared<MysqlResultWriter>(_sender.get(), _output_expr_ctxs, _is_binary_format,
_profile.get());
break;
case TResultSinkType::STATISTIC:
_writer = std::make_shared<StatisticResultWriter>(_sender.get(), _output_expr_ctxs, _profile.get());
Expand Down
19 changes: 12 additions & 7 deletions be/src/exec/pipeline/result_sink_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@ namespace pipeline {
class ResultSinkOperator final : public Operator {
public:
ResultSinkOperator(OperatorFactory* factory, int32_t id, int32_t plan_node_id, int32_t driver_sequence,
TResultSinkType::type sink_type, std::vector<ExprContext*> output_expr_ctxs,
const std::shared_ptr<BufferControlBlock>& sender, std::atomic<int32_t>& num_sinks,
std::atomic<int64_t>& num_written_rows, FragmentContext* const fragment_ctx)
TResultSinkType::type sink_type, bool is_binary_format,
std::vector<ExprContext*> output_expr_ctxs, const std::shared_ptr<BufferControlBlock>& sender,
std::atomic<int32_t>& num_sinks, std::atomic<int64_t>& num_written_rows,
FragmentContext* const fragment_ctx)
: Operator(factory, id, "result_sink", plan_node_id, false, driver_sequence),
_sink_type(sink_type),
_is_binary_format(is_binary_format),
_output_expr_ctxs(std::move(output_expr_ctxs)),
_sender(sender),
_num_sinkers(num_sinks),
Expand Down Expand Up @@ -68,6 +70,7 @@ class ResultSinkOperator final : public Operator {

private:
TResultSinkType::type _sink_type;
bool _is_binary_format;
std::vector<ExprContext*> _output_expr_ctxs;

/// The following three fields are shared by all the ResultSinkOperators
Expand All @@ -89,10 +92,11 @@ class ResultSinkOperator final : public Operator {

class ResultSinkOperatorFactory final : public OperatorFactory {
public:
ResultSinkOperatorFactory(int32_t id, TResultSinkType::type sink_type, std::vector<TExpr> t_output_expr,
FragmentContext* const fragment_ctx)
ResultSinkOperatorFactory(int32_t id, TResultSinkType::type sink_type, bool is_binary_format,
std::vector<TExpr> t_output_expr, FragmentContext* const fragment_ctx)
: OperatorFactory(id, "result_sink", Operator::s_pseudo_plan_node_id_for_final_sink),
_sink_type(sink_type),
_is_binary_format(is_binary_format),
_t_output_expr(std::move(t_output_expr)),
_fragment_ctx(fragment_ctx) {}

Expand All @@ -105,8 +109,8 @@ class ResultSinkOperatorFactory final : public OperatorFactory {
// so it doesn't need memory barrier here.
_increment_num_sinkers_no_barrier();
return std::make_shared<ResultSinkOperator>(this, _id, _plan_node_id, driver_sequence, _sink_type,
_output_expr_ctxs, _sender, _num_sinkers, _num_written_rows,
_fragment_ctx);
_is_binary_format, _output_expr_ctxs, _sender, _num_sinkers,
_num_written_rows, _fragment_ctx);
}

Status prepare(RuntimeState* state) override;
Expand All @@ -117,6 +121,7 @@ class ResultSinkOperatorFactory final : public OperatorFactory {
void _increment_num_sinkers_no_barrier() { _num_sinkers.fetch_add(1, std::memory_order_relaxed); }

TResultSinkType::type _sink_type;
bool _is_binary_format;
std::vector<TExpr> _t_output_expr;
std::vector<ExprContext*> _output_expr_ctxs;

Expand Down
16 changes: 13 additions & 3 deletions be/src/runtime/mysql_result_writer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,12 @@
namespace starrocks {

MysqlResultWriter::MysqlResultWriter(BufferControlBlock* sinker, const std::vector<ExprContext*>& output_expr_ctxs,
RuntimeProfile* parent_profile)
: _sinker(sinker), _output_expr_ctxs(output_expr_ctxs), _row_buffer(nullptr), _parent_profile(parent_profile) {}
bool is_binary_format, RuntimeProfile* parent_profile)
: _sinker(sinker),
_output_expr_ctxs(output_expr_ctxs),
_row_buffer(nullptr),
_is_binary_format(is_binary_format),
_parent_profile(parent_profile) {}

MysqlResultWriter::~MysqlResultWriter() {
delete _row_buffer;
Expand All @@ -60,7 +64,7 @@ Status MysqlResultWriter::init(RuntimeState* state) {
return Status::InternalError("sinker is NULL pointer.");
}

_row_buffer = new (std::nothrow) MysqlRowBuffer();
_row_buffer = new (std::nothrow) MysqlRowBuffer(_is_binary_format);

if (nullptr == _row_buffer) {
return Status::InternalError("no memory to alloc.");
Expand Down Expand Up @@ -134,6 +138,9 @@ StatusOr<TFetchDataResultPtr> MysqlResultWriter::_process_chunk(Chunk* chunk) {
SCOPED_TIMER(_convert_tuple_timer);
for (int i = 0; i < num_rows; ++i) {
DCHECK_EQ(0, _row_buffer->length());
if (_is_binary_format) {
_row_buffer->start_binary_row(num_columns);
};
for (auto& result_column : result_columns) {
result_column->put_mysql_row_buffer(_row_buffer, i);
}
Expand Down Expand Up @@ -176,6 +183,9 @@ StatusOr<TFetchDataResultPtrs> MysqlResultWriter::process_chunk(Chunk* chunk) {

for (int i = 0; i < num_rows; ++i) {
DCHECK_EQ(0, _row_buffer->length());
if (_is_binary_format) {
_row_buffer->start_binary_row(num_columns);
}
for (auto& result_column : result_columns) {
result_column->put_mysql_row_buffer(_row_buffer, i);
}
Expand Down
3 changes: 2 additions & 1 deletion be/src/runtime/mysql_result_writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ using TFetchDataResultPtrs = std::vector<TFetchDataResultPtr>;
class MysqlResultWriter final : public ResultWriter {
public:
MysqlResultWriter(BufferControlBlock* sinker, const std::vector<ExprContext*>& output_expr_ctxs,
RuntimeProfile* parent_profile);
bool is_binary_format, RuntimeProfile* parent_profile);

~MysqlResultWriter() override;

Expand All @@ -72,6 +72,7 @@ class MysqlResultWriter final : public ResultWriter {
BufferControlBlock* _sinker;
const std::vector<ExprContext*>& _output_expr_ctxs;
MysqlRowBuffer* _row_buffer;
bool _is_binary_format;

RuntimeProfile* _parent_profile; // parent profile from result sink. not owned
// total time cost on append chunk operation
Expand Down
5 changes: 4 additions & 1 deletion be/src/runtime/result_sink.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ ResultSink::ResultSink(const RowDescriptor& row_desc, const std::vector<TExpr>&
CHECK(sink.__isset.file_options);
_file_opts = std::make_shared<ResultFileOptions>(sink.file_options);
}

_is_binary_format = sink.is_binary_row;
}

Status ResultSink::prepare_exprs(RuntimeState* state) {
Expand Down Expand Up @@ -94,7 +96,8 @@ Status ResultSink::prepare(RuntimeState* state) {
// create writer based on sink type
switch (_sink_type) {
case TResultSinkType::MYSQL_PROTOCAL:
_writer.reset(new (std::nothrow) MysqlResultWriter(_sender.get(), _output_expr_ctxs, _profile));
_writer.reset(new (std::nothrow)
MysqlResultWriter(_sender.get(), _output_expr_ctxs, _is_binary_format, _profile));
break;
case TResultSinkType::FILE:
CHECK(_file_opts.get() != nullptr);
Expand Down
3 changes: 3 additions & 0 deletions be/src/runtime/result_sink.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,12 @@ class ResultSink final : public DataSink {

std::shared_ptr<ResultFileOptions> get_file_opts() const { return _file_opts; }

bool isBinaryFormat() const { return _is_binary_format; }

private:
Status prepare_exprs(RuntimeState* state);
TResultSinkType::type _sink_type;
bool _is_binary_format;
// set file options when sink type is FILE
std::shared_ptr<ResultFileOptions> _file_opts;

Expand Down
4 changes: 3 additions & 1 deletion be/src/util/mysql_global.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@ typedef unsigned char uchar;
*(T + 1) = (uchar)(((uint32_t)(A) >> 8)); \
*(T + 2) = (uchar)(((A) >> 16)); \
} while (0)
#define int4store(T, A) *((uint32_t*)(T)) = (uint32_t)(A)
#define int8store(T, A) *((int64_t*)(T)) = (uint64_t)(A)

#define float4store(T, A) *((float*)(T)) = (float)(A)
#define float8store(T, A) *((double*)(T)) = (double)(A)
#define MAX_TINYINT_WIDTH 3 /* Max width for a TINY w.o. sign */
#define MAX_SMALLINT_WIDTH 5 /* Max width for a SHORT w.o. sign */
#define MAX_INT_WIDTH 10 /* Max width for a LONG w.o. sign */
Expand Down
67 changes: 67 additions & 0 deletions be/src/util/mysql_row_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

#include "common/logging.h"
#include "gutil/strings/fastmem.h"
#include "runtime/large_int_value.h"
#include "util/mysql_global.h"

namespace starrocks {
Expand Down Expand Up @@ -78,6 +79,16 @@ static uint8_t* pack_vlen(uint8_t* packet, uint64_t length) {
}

void MysqlRowBuffer::push_null() {
if (_is_binary_format) {
uint offset = (_field_pos + 2) / 8 + 1;
uint bit = (1 << ((_field_pos + 2) & 7));
/* Room for this as it's allocated start_binary_row*/
char* to = _data.data() + offset;
*to = (char)((uchar)*to | (uchar)bit);
_field_pos++;
return;
}

if (_array_level == 0) {
_data.push_back(0xfb);
} else {
Expand All @@ -86,9 +97,49 @@ void MysqlRowBuffer::push_null() {
}
}

template <typename T>
void MysqlRowBuffer::push_number_binary_format(T data) {
_field_pos++;
if constexpr (std::is_same_v<T, float>) {
char buff[4];
float4store(buff, data);
_data.append(buff, 4);
} else if constexpr (std::is_same_v<T, double>) {
char buff[8];
float8store(buff, data);
_data.append(buff, 8);
} else if constexpr (std::is_same_v<std::make_signed_t<T>, int8_t>) {
char buff[1];
int1store(buff, data);
_data.append(buff, 1);
} else if constexpr (std::is_same_v<std::make_signed_t<T>, int16_t>) {
char buff[2];
int2store(buff, data);
_data.append(buff, 2);
} else if constexpr (std::is_same_v<std::make_signed_t<T>, int32_t>) {
char buff[4];
int4store(buff, data);
_data.append(buff, 4);
} else if constexpr (std::is_same_v<std::make_signed_t<T>, int64_t>) {
char buff[8];
int8store(buff, data);
_data.append(buff, 8);
} else if constexpr (std::is_same_v<std::make_signed_t<T>, __int128>) {
std::string value = LargeIntValue::to_string(data);
_push_string_normal(value.data(), value.size());
} else {
CHECK(false) << "unhandled data type";
}
}

template <typename T>
void MysqlRowBuffer::push_number(T data) {
static_assert(std::is_arithmetic_v<T> || std::is_same_v<T, __int128>);

if (_is_binary_format) {
return push_number_binary_format(data);
}

int length = 0;
char* end = nullptr;
char* pos = nullptr;
Expand Down Expand Up @@ -133,6 +184,10 @@ void MysqlRowBuffer::push_number(T data) {
}

void MysqlRowBuffer::push_string(const char* str, size_t length, char escape_char) {
if (_is_binary_format) {
++_field_pos;
}

if (_array_level == 0) {
_push_string_normal(str, length);
} else {
Expand All @@ -155,6 +210,10 @@ void MysqlRowBuffer::push_string(const char* str, size_t length, char escape_cha
}

void MysqlRowBuffer::push_decimal(const Slice& s) {
if (_is_binary_format) {
++_field_pos;
}

if (_array_level == 0) {
_push_string_normal(s.data, s.size);
} else {
Expand Down Expand Up @@ -252,6 +311,14 @@ template void MysqlRowBuffer::push_number<__int128>(__int128);
template void MysqlRowBuffer::push_number<float>(float);
template void MysqlRowBuffer::push_number<double>(double);

void MysqlRowBuffer::start_binary_row(uint32_t num_cols) {
DCHECK(_is_binary_format) << "start_binary_row() only for is_binary_format=true";
int bit_fields = (num_cols + 9) / 8;
char* pos = _resize_extra(bit_fields + 1);
memset(pos, 0, 1 + bit_fields);
_field_pos = 0;
}

} // namespace starrocks

/* vim: set ts=4 sw=4 sts=4 tw=100 */
11 changes: 11 additions & 0 deletions be/src/util/mysql_row_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,13 @@ namespace starrocks {
class MysqlRowBuffer final {
public:
MysqlRowBuffer() = default;
MysqlRowBuffer(bool is_binary_format) : _is_binary_format(is_binary_format){};
~MysqlRowBuffer() = default;

void reset() { _data.clear(); }

void start_binary_row(uint32_t num_cols);

void push_null();
void push_tinyint(int8_t data) { push_number(data); }
void push_smallint(int16_t data) { push_number(data); }
Expand All @@ -63,6 +66,10 @@ class MysqlRowBuffer final {
template <typename T>
void push_number(T data);
void push_number(uint24_t data) { push_number((uint32_t)data); }

template <typename T>
void push_number_binary_format(T data);

void push_decimal(const Slice& s);

void begin_push_array() { _enter_scope('['); }
Expand Down Expand Up @@ -101,6 +108,10 @@ class MysqlRowBuffer final {
raw::RawString _data;
uint32_t _array_level = 0;
uint32_t _array_offset = 0;

bool _is_binary_format = false;
// used for calculate null position if is_binary_format = true
uint32_t _field_pos = 0;
};

} // namespace starrocks
45 changes: 45 additions & 0 deletions fe/fe-core/src/main/java/com/starrocks/analysis/DateLiteral.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

import com.google.common.base.Preconditions;
import com.starrocks.catalog.PrimitiveType;
import com.starrocks.catalog.ScalarType;
import com.starrocks.catalog.Type;
import com.starrocks.common.AnalysisException;
import com.starrocks.common.util.DateUtils;
Expand Down Expand Up @@ -452,4 +453,48 @@ public boolean isNullable() {
}
return true;
}

@Override
public void parseMysqlParam(ByteBuffer data) {
int len = getParamLen(data);
if (type.getPrimitiveType() == PrimitiveType.DATE) {
if (len >= 4) {
year = (int) data.getChar();
month = (int) data.get();
day = (int) data.get();
hour = 0;
minute = 0;
second = 0;
microsecond = 0;
} else {
copy(MIN_DATE);
}
return;
}
if (type.getPrimitiveType() == PrimitiveType.DATETIME) {
if (len >= 4) {
year = (int) data.getChar();
month = (int) data.get();
day = (int) data.get();
microsecond = 0;
if (len > 4) {
hour = (int) data.get();
minute = (int) data.get();
second = (int) data.get();
} else {
hour = 0;
minute = 0;
second = 0;
microsecond = 0;
}
if (len > 7) {
microsecond = data.getInt();
// choose the highest scale to keep microsecond value
type = ScalarType.createDecimalV2Type(6);
}
} else {
copy(MIN_DATETIME);
}
}
}
}
Loading

0 comments on commit 2cb2f80

Please sign in to comment.