Skip to content

Commit

Permalink
support string store and equals judgement
Browse files Browse the repository at this point in the history
  • Loading branch information
MrPresent-Han committed Nov 14, 2024
1 parent b45e9f7 commit 8bdfd11
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 12 deletions.
2 changes: 1 addition & 1 deletion internal/core/src/exec/HashTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ void HashTable<nullableKeys>::groupProbe(milvus::exec::HashLookup &lookup) {

template<bool nullableKeys>
void HashTable<nullableKeys>::setHashMode(HashMode mode, int32_t numNew) {
// set hash mode kArray/kHash/kNormalizedKey
// TODO set hash mode kArray/kHash/kNormalizedKey
}

template <bool nullable>
Expand Down
4 changes: 2 additions & 2 deletions internal/core/src/exec/operator/AggregationNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ void PhyAggregationNode::prepareOutput(vector_size_t size){
}

RowVectorPtr PhyAggregationNode::GetOutput() {
if (finished_||!no_more_input_||(!no_more_input_ && !grouping_set_->hasOutput())) {
if (finished_||(!no_more_input_ && !grouping_set_->hasOutput())) {
input_ = nullptr;
return nullptr;
}

const auto& queryConfig = operator_context_->get_driver_context()->GetQueryConfig();
auto batch_size = queryConfig->get_expr_batch_size();
const auto outputRowCount = isGlobal_?1:batch_size;
const auto outputRowCount = isGlobal_? 1: batch_size;
prepareOutput(outputRowCount);
const bool hasData = grouping_set_->getOutput(outputRowCount, outputRowCount, resultIterator_, output_);
if (!hasData) {
Expand Down
33 changes: 24 additions & 9 deletions internal/core/src/exec/operator/query-agg/RowContainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ class RowContainer {
return (row[nullByte] & nullMask) != 0;
}

static inline const std::string*& strAt(const char* group, int32_t offset) {
return *reinterpret_cast<const std::string**>(const_cast<char*>(group + offset));
}

template <typename T>
static inline T valueAt(const char* group, int32_t offset) {
return *reinterpret_cast<const T*>(group + offset);
Expand All @@ -183,12 +187,15 @@ class RowContainer {
vector_size_t index) {
if constexpr (Type == DataType::NONE || Type == DataType::ROW || Type == DataType::JSON || Type == DataType::ARRAY) {
PanicInfo(DataTypeInvalid, "Cannot support complex data type:[ROW/JSON/ARRAY] in rows container for now");
} else if constexpr (Type == DataType::VARCHAR || Type == DataType::STRING) {
PanicInfo(DataTypeInvalid, "Cannot support varchar/string types in rows container for now");
} else {
using T = typename TypeTraits<Type>::NativeType;
T* raw_value = static_cast<T*>(column->RawValueAt(index, sizeof(T)));
return milvus::comparePrimitiveAsc(*raw_value, valueAt<T>(row, offset));
if constexpr (std::is_same_v<T, std::string>) {
const std::string& raw = *static_cast<std::string*>(raw_value);
return raw == *(strAt(row, offset));
} else {
return milvus::comparePrimitiveAsc(*raw_value, valueAt<T>(row, offset));
}
}
}

Expand Down Expand Up @@ -233,15 +240,19 @@ class RowContainer {
int32_t offset,
int32_t nullByte,
uint8_t nullMask) {
static std::string null_string_val = "";
static std::string* null_string_val_ptr = &null_string_val;
if constexpr (Type == DataType::NONE || Type == DataType::ROW || Type == DataType::JSON || Type == DataType::ARRAY) {
PanicInfo(DataTypeInvalid, "Cannot support complex data type:[ROW/JSON/ARRAY] in rows container for now");
} else if constexpr (Type == DataType::VARCHAR || Type == DataType::STRING) {
PanicInfo(DataTypeInvalid, "Cannot support varchar/string types in rows container for now");
} else {
using T = typename milvus::TypeTraits<Type>::NativeType;
if (!column->ValidAt(index)) {
row[nullByte]|=nullMask;
*reinterpret_cast<T*>(row+offset) = T();
if constexpr (std::is_same_v<T, std::string>) {
*reinterpret_cast<std::string**>(row + offset) = null_string_val_ptr;
} else {
*reinterpret_cast<T*>(row+offset) = T();
}
return;
}
storeNoNulls<Type>(column, index, row, offset);
Expand All @@ -256,11 +267,15 @@ class RowContainer {
using T = typename milvus::TypeTraits<Type>::NativeType;
if constexpr (Type == DataType::NONE || Type == DataType::ROW || Type == DataType::JSON || Type == DataType::ARRAY) {
PanicInfo(DataTypeInvalid, "Cannot support complex data type:[ROW/JSON/ARRAY] in rows container for now");
} else if constexpr (Type == DataType::VARCHAR || Type == DataType::STRING) {
PanicInfo(DataTypeInvalid, "Cannot support varchar/string types in rows container for now");
} else {
auto raw_val_ptr = column->RawValueAt(index, sizeof(T));
*reinterpret_cast<T*>(group + offset) = *(static_cast<T*>(raw_val_ptr));
if constexpr (std::is_same_v<T, std::string>) {
// the string object and also the underlying char array are both allocated on the heap
// must call clear method to deallocate these memory allocated for varchar type to avoid memory leak
*reinterpret_cast<std::string**>(group + offset) = new std::string(*static_cast<std::string*>(raw_val_ptr));
} else {
*reinterpret_cast<T*>(group + offset) = *(static_cast<T*>(raw_val_ptr));
}
}
}

Expand Down

0 comments on commit 8bdfd11

Please sign in to comment.