Skip to content

Commit

Permalink
Added weighted_sum fusion (infiniflow#1301)
Browse files Browse the repository at this point in the history
Added weighted_sum fusion

- [x] New Feature (non-breaking change which adds functionality)
- [x] Test cases
  • Loading branch information
yuzhichang authored Jun 7, 2024
1 parent beb166a commit 4e59c7f
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 83 deletions.
1 change: 1 addition & 0 deletions src/common/stl.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ export namespace std {
using std::reverse;
using std::sort;
using std::sqrt;
using std::stable_sort;
using std::tie;
using std::transform;
using std::unique;
Expand Down
253 changes: 174 additions & 79 deletions src/executor/operator/physical_fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

module;
#include <cassert>
#include <cstdlib>
#include <string>

Expand Down Expand Up @@ -51,13 +52,19 @@ import block_index;
import block_column_entry;
import mlas_matrix_multiply;
import physical_match_tensor_scan;
import physical_knn_scan;
import physical_merge_knn;

namespace infinity {

struct RRFRankDoc {
RowID row_id;
float score;
Vector<SizeT> ranks;
struct DocScore {
RowID row_id_;
u64 from_input_data_block_id_;
u32 from_block_idx_;
u32 from_row_idx_;
float fusion_score_;
Vector<float> child_scores_;
Vector<bool> mask_;
};

PhysicalFusion::PhysicalFusion(const u64 id,
Expand All @@ -75,7 +82,15 @@ PhysicalFusion::~PhysicalFusion() {}
void PhysicalFusion::Init() {
{
String &method = fusion_expr_->method_;
std::transform(method.begin(), method.end(), std::back_inserter(to_lower_method_), [](unsigned char c) { return std::tolower(c); });
String to_lower_method;
std::transform(method.begin(), method.end(), std::back_inserter(to_lower_method), [](unsigned char c) { return std::tolower(c); });
if (to_lower_method == "weighted_sum") {
fusion_method_ = FusionMethod::kWeightedSum;
} else if (to_lower_method == "match_tensor") {
fusion_method_ = FusionMethod::kMatchTensor;
} else {
fusion_method_ = FusionMethod::kRRF;
}
}
{
const auto prev_output_names_ptr = left_->GetOutputNames();
Expand All @@ -90,40 +105,69 @@ void PhysicalFusion::Init() {
(*output_types_)[output_types_->size() - 2] = MakeShared<DataType>(LogicalType::kFloat);
}
if (output_names_->size() != output_types_->size()) {
String error_message = fmt::format("output_names_ size {} is not equal to output_types_ size {}.", output_names_->size(), output_types_->size());
String error_message =
fmt::format("output_names_ size {} is not equal to output_types_ size {}.", output_names_->size(), output_types_->size());
LOG_CRITICAL(error_message);
UnrecoverableError(error_message);
}
}

// Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html
void PhysicalFusion::ExecuteRRF(const Map<u64, Vector<UniquePtr<DataBlock>>> &input_data_blocks,
Vector<UniquePtr<DataBlock>> &output_data_block_array) const {
void PhysicalFusion::ExecuteRRFWeighted(const Map<u64, Vector<UniquePtr<DataBlock>>> &input_data_blocks,
Vector<UniquePtr<DataBlock>> &output_data_block_array) const {
SizeT num_children = 2 + other_children_.size();
SizeT rank_constant = 60;
SizeT window_size = 100; // is equivalent to topn
SizeT topn = 100;
Vector<float> weights;
if (fusion_expr_->options_.get() != nullptr) {
if (auto it = fusion_expr_->options_->options_.find("rank_constant"); it != fusion_expr_->options_->options_.end()) {
if (auto it = fusion_expr_->options_->options_.find("window_size"); it != fusion_expr_->options_->options_.end()) {
long l = std::strtol(it->second.c_str(), NULL, 10);
if (l >= 1) {
rank_constant = (SizeT)l;
topn = (SizeT)l;
}
}
if (auto it = fusion_expr_->options_->options_.find("window_size"); it != fusion_expr_->options_->options_.end()) {
if (auto it = fusion_expr_->options_->options_.find("topn"); it != fusion_expr_->options_->options_.end()) {
long l = std::strtol(it->second.c_str(), NULL, 10);
if (l >= 1) {
window_size = (SizeT)l;
topn = (SizeT)l;
}
}
if (fusion_method_ == FusionMethod::kRRF) {
if (auto it = fusion_expr_->options_->options_.find("rank_constant"); it != fusion_expr_->options_->options_.end()) {
long l = std::strtol(it->second.c_str(), NULL, 10);
if (l >= 1) {
rank_constant = (SizeT)l;
}
}
} else {
weights.reserve(num_children);
if (auto it = fusion_expr_->options_->options_.find("weights"); it != fusion_expr_->options_->options_.end()) {
const String &weight_str = it->second;
std::stringstream ss(weight_str);
std::string item;
while (std::getline(ss, item, ',')) {
double value = std::stod(item);
weights.push_back(value);
}
}
}
}
if (fusion_method_ == FusionMethod::kWeightedSum) {
SizeT num_weights = weights.size();
if (num_weights < num_children) {
for (SizeT i = num_weights; i < num_children; ++i) {
weights.push_back(1.0F);
}
}
}

Vector<RRFRankDoc> rrf_vec;
Map<RowID, SizeT> rrf_map; // row_id to index of rrf_vec_
Vector<u64> fragment_ids; // index of children, 0 - left, 1 - right, 2.. - other_children
Vector<DocScore> rescore_vec;
Map<RowID, SizeT> rescore_map; // row_id to index of rescore_vec_
// 1 Prepare rescore_vec
SizeT fragment_idx = 0;
// 1 calculate every doc's ranks
for (const auto &[fragment_id, input_blocks] : input_data_blocks) {
fragment_ids.push_back(fragment_id);
SizeT base_rank = 1;
u32 from_block_idx = 0;
for (const UniquePtr<DataBlock> &input_data_block : input_blocks) {
if (input_data_block->column_count() != GetOutputTypes()->size()) {
String error_message = fmt::format("input_data_block column count {} is incorrect, expect {}.",
Expand All @@ -135,50 +179,127 @@ void PhysicalFusion::ExecuteRRF(const Map<u64, Vector<UniquePtr<DataBlock>>> &in
auto &row_id_column = *input_data_block->column_vectors[input_data_block->column_count() - 1];
auto row_ids = reinterpret_cast<RowID *>(row_id_column.data());
SizeT row_n = input_data_block->row_count();
auto &row_score_column = *input_data_block->column_vectors[input_data_block->column_count() - 2];
auto row_scores = reinterpret_cast<float *>(row_score_column.data());
for (SizeT i = 0; i < row_n; i++) {
RowID docId = row_ids[i];
if (rrf_map.find(docId) == rrf_map.end()) {
RRFRankDoc doc;
doc.row_id = docId;
rrf_vec.push_back(doc);
rrf_map[docId] = rrf_vec.size() - 1;
if (rescore_map.find(docId) == rescore_map.end()) {
DocScore doc;
doc.row_id_ = docId;
doc.from_input_data_block_id_ = fragment_id;
doc.from_block_idx_ = from_block_idx;
doc.from_row_idx_ = i;
doc.child_scores_.resize(num_children, 0.0F);
doc.mask_.resize(num_children, false);
rescore_vec.push_back(doc);
rescore_map[docId] = rescore_vec.size() - 1;
}
SizeT doc_idx = rescore_map[docId];
DocScore &doc = rescore_vec[doc_idx];
assert(fragment_idx < num_children);
doc.mask_[fragment_idx] = true;
if (fusion_method_ == FusionMethod::kRRF) {
doc.child_scores_[fragment_idx] = base_rank + i;
} else {
assert(fusion_method_ == FusionMethod::kWeightedSum);
doc.child_scores_[fragment_idx] = row_scores[i];
}
RRFRankDoc &doc = rrf_vec[rrf_map[docId]];
doc.ranks.resize(fragment_idx + 1, 0);
doc.ranks[fragment_idx] = base_rank + i;
}
base_rank += row_n;
from_block_idx++;
}
fragment_idx++;
}

// 2 calculate every doc's score
for (auto &doc : rrf_vec) {
doc.score = 0.0F;
for (auto &rank : doc.ranks) {
if (rank == 0)
continue;
doc.score += 1.0F / (rank_constant + rank);
// 2 calculate every doc's fusion_score
if (fusion_method_ == FusionMethod::kRRF) {
for (auto &doc : rescore_vec) {
doc.fusion_score_ = 0.0F;
for (auto &rank : doc.child_scores_) {
if (rank < 1.0F)
continue;
doc.fusion_score_ += 1.0F / (rank_constant + rank);
}
}
}
// 3 sort docs in reverse per their score
if (rrf_vec.size() <= window_size) {
std::sort(std::begin(rrf_vec), std::end(rrf_vec), [](const RRFRankDoc &lhs, const RRFRankDoc &rhs) noexcept {
return lhs.score > rhs.score;
});
} else {
std::partial_sort(std::begin(rrf_vec),
std::begin(rrf_vec) + window_size,
std::end(rrf_vec),
[](const RRFRankDoc &lhs, const RRFRankDoc &rhs) noexcept { return lhs.score > rhs.score; });
rrf_vec.resize(window_size);
Vector<bool> min_heaps(num_children, false);
for (SizeT i = 0; i < num_children; i++) {
PhysicalOperator *child_op = nullptr;
if (i == 0)
child_op = left();
else if (i == 1)
child_op = right();
else
child_op = other_children_[i - 2].get();
switch (child_op->operator_type()) {
case PhysicalOperatorType::kKnnScan: {
PhysicalKnnScan *phy_knn_scan = static_cast<PhysicalKnnScan *>(child_op);
min_heaps[i] = phy_knn_scan->IsKnnMinHeap();
break;
}
case PhysicalOperatorType::kMergeKnn: {
PhysicalMergeKnn *phy_merge_knn = static_cast<PhysicalMergeKnn *>(child_op);
min_heaps[i] = phy_merge_knn->IsKnnMinHeap();
break;
}
case PhysicalOperatorType::kMatch: {
min_heaps[i] = true;
break;
}
default: {
String error_message = fmt::format("Cannot determine heap type of operator {}", int(child_op->operator_type()));
LOG_CRITICAL(error_message);
UnrecoverableError(error_message);
}
}
}
Vector<float> min_scores(num_children, std::numeric_limits<float>::max());
Vector<float> max_scores(num_children, std::numeric_limits<float>::min());
Vector<float> gap_scores(num_children, 0.0F);
for (auto &doc : rescore_vec) {
for (SizeT i = 0; i < num_children; ++i) {
if (!doc.mask_[i])
continue;
if (doc.child_scores_[i] < min_scores[i])
min_scores[i] = doc.child_scores_[i];
if (doc.child_scores_[i] > max_scores[i])
max_scores[i] = doc.child_scores_[i];
}
}
for (SizeT i = 0; i < num_children; ++i) {
gap_scores[i] = max_scores[i] - min_scores[i];
}
for (auto &doc : rescore_vec) {
doc.fusion_score_ = 0.0F;
for (SizeT i = 0; i < num_children; ++i) {
if (!doc.mask_[i])
continue;
if (gap_scores[i] <= 1e-6) {
doc.fusion_score_ += weights[i];
continue;
}
if (min_heaps[i]) {
doc.fusion_score_ += weights[i] * (doc.child_scores_[i] - min_scores[i]) / gap_scores[i];
} else {
doc.fusion_score_ += weights[i] * (max_scores[i] - doc.child_scores_[i]) / gap_scores[i];
}
}
}
}

// 3 sort docs in reverse per their fusion_score
std::stable_sort(std::begin(rescore_vec), std::end(rescore_vec), [](const DocScore &lhs, const DocScore &rhs) noexcept {
return lhs.fusion_score_ > rhs.fusion_score_;
});
if (rescore_vec.size() > topn) {
rescore_vec.resize(topn);
}

// 4 generate output data blocks
UniquePtr<DataBlock> output_data_block = DataBlock::MakeUniquePtr();
output_data_block->Init(*GetOutputTypes());
SizeT row_count = 0;
for (RRFRankDoc &doc : rrf_vec) {
for (DocScore &doc : rescore_vec) {
// 4.1 get every doc's columns from input data blocks
if (row_count == output_data_block->capacity()) {
output_data_block->Finalize();
Expand All @@ -187,41 +308,15 @@ void PhysicalFusion::ExecuteRRF(const Map<u64, Vector<UniquePtr<DataBlock>>> &in
output_data_block->Init(*GetOutputTypes());
row_count = 0;
}
SizeT fragment_idx = 0;
while (fragment_idx < doc.ranks.size() && doc.ranks[fragment_idx] == 0)
fragment_idx++;
if (fragment_idx >= doc.ranks.size()) {
String error_message = "Cannot find fragment_idx";
LOG_CRITICAL(error_message);
UnrecoverableError(error_message);
}
u64 fragment_id = fragment_ids[fragment_idx];
const auto &input_blocks = input_data_blocks.at(fragment_id);
if (input_blocks.size() == 0) {
String error_message = fmt::format("input_data_blocks_[{}] is empty.", fragment_id);
LOG_CRITICAL(error_message);
UnrecoverableError(error_message);
}
SizeT block_idx = 0;
SizeT row_idx = doc.ranks[fragment_idx] - 1;
while (row_idx >= input_blocks[block_idx]->row_count()) {
row_idx -= input_blocks[block_idx]->row_count();
block_idx++;
}
if (block_idx >= input_blocks.size()) {
String error_message = "Cannot find block_idx";
LOG_CRITICAL(error_message);
UnrecoverableError(error_message);
}

const auto &input_blocks = input_data_blocks.at(doc.from_input_data_block_id_);
SizeT column_n = GetOutputTypes()->size() - 2;
for (SizeT i = 0; i < column_n; ++i) {
output_data_block->column_vectors[i]->AppendWith(*input_blocks[block_idx]->column_vectors[i], row_idx, 1);
output_data_block->column_vectors[i]->AppendWith(*input_blocks[doc.from_block_idx_]->column_vectors[i], doc.from_row_idx_, 1);
}
// 4.2 add hidden columns: score, row_id
Value v = Value::MakeFloat(doc.score);
Value v = Value::MakeFloat(doc.fusion_score_);
output_data_block->column_vectors[column_n]->AppendValue(v);
output_data_block->column_vectors[column_n + 1]->AppendWith(doc.row_id, 1);
output_data_block->column_vectors[column_n + 1]->AppendWith(doc.row_id_, 1);
row_count++;
}
output_data_block->Finalize();
Expand Down Expand Up @@ -366,13 +461,13 @@ bool PhysicalFusion::ExecuteFirstOp(QueryContext *query_context, FusionOperatorS
if (!fusion_operator_state->input_complete_) {
return false;
}
if (to_lower_method_.compare("rrf") == 0) {
ExecuteRRF(fusion_operator_state->input_data_blocks_, fusion_operator_state->data_block_array_);
if (fusion_method_ == FusionMethod::kRRF || fusion_method_ == FusionMethod::kWeightedSum) {
ExecuteRRFWeighted(fusion_operator_state->input_data_blocks_, fusion_operator_state->data_block_array_);
fusion_operator_state->input_data_blocks_.clear();
fusion_operator_state->SetComplete();
return true;
}
if (to_lower_method_.compare("match_tensor") == 0) {
if (fusion_method_ == FusionMethod::kMatchTensor) {
ExecuteMatchTensor(query_context, fusion_operator_state->input_data_blocks_, fusion_operator_state->data_block_array_);
fusion_operator_state->input_data_blocks_.clear();
fusion_operator_state->SetComplete();
Expand All @@ -392,7 +487,7 @@ bool PhysicalFusion::ExecuteNotFirstOp(QueryContext *query_context, OperatorStat
UnrecoverableError(error_message);
return false;
}
if (to_lower_method_.compare("match_tensor") == 0) {
if (fusion_method_ == FusionMethod::kMatchTensor) {
Map<u64, Vector<UniquePtr<DataBlock>>> input_data_blocks;
input_data_blocks.emplace(0, std::move(operator_state->prev_op_state_->data_block_array_));
operator_state->prev_op_state_->data_block_array_.clear();
Expand Down
11 changes: 7 additions & 4 deletions src/executor/operator/physical_fusion.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ import logger;
namespace infinity {
struct DataBlock;

export enum class FusionMethod { kRRF, kWeightedSum, kMatchTensor };

export class PhysicalFusion final : public PhysicalOperator {
public:
explicit PhysicalFusion(u64 id,
Expand Down Expand Up @@ -72,14 +74,15 @@ public:
private:
bool ExecuteFirstOp(QueryContext *query_context, FusionOperatorState *fusion_operator_state) const;
bool ExecuteNotFirstOp(QueryContext *query_context, OperatorState *operator_state) const;
// RRF has multiple input source, must be first op
void ExecuteRRF(const Map<u64, Vector<UniquePtr<DataBlock>>> &input_data_blocks, Vector<UniquePtr<DataBlock>> &output_data_block_array) const;
// MatchTensor may have multiple or single input source, can be first or not first op
// RRF and WeightedSum have multiple input sources, must be first fusion op
void ExecuteRRFWeighted(const Map<u64, Vector<UniquePtr<DataBlock>>> &input_data_blocks,
Vector<UniquePtr<DataBlock>> &output_data_block_array) const;
// MatchTensor may have multiple or single input source, can be first or not first fusion op
void ExecuteMatchTensor(QueryContext *query_context,
const Map<u64, Vector<UniquePtr<DataBlock>>> &input_data_blocks,
Vector<UniquePtr<DataBlock>> &output_data_block_array) const;

String to_lower_method_;
FusionMethod fusion_method_;
SharedPtr<Vector<String>> output_names_;
SharedPtr<Vector<SharedPtr<DataType>>> output_types_;
};
Expand Down
Loading

0 comments on commit 4e59c7f

Please sign in to comment.