Skip to content

Commit 76fcb02

Browse files
committed
feat(transformer): Add MRL-E preprocess Transformer
Signed-off-by: LHT129 <tianlan.lht@antgroup.com>
1 parent 434f1a1 commit 76fcb02

File tree

8 files changed

+99
-8
lines changed

8 files changed

+99
-8
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
2+
// Copyright 2024-present the vsag project
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
#pragma once
17+
18+
#include "metric_type.h"
19+
#include "simd/normalize.h"
20+
#include "vector_transformer.h"
21+
#include "vsag_exception.h"
22+
23+
namespace vsag {
24+
struct MRLETMeta : public TransformerMeta {};
25+
26+
template <MetricType metric = MetricType::METRIC_TYPE_L2SQR>
27+
class MRLETransformer : public VectorTransformer {
28+
public:
29+
explicit MRLETransformer(Allocator* allocator, int64_t input_dim, int64_t output_dim)
30+
: VectorTransformer(allocator, input_dim, output_dim) {
31+
this->type_ = VectorTransformerType::MRLE;
32+
}
33+
34+
virtual ~MRLETransformer() override = default;
35+
36+
TransformerMetaPtr
37+
Transform(const float* original_vec, float* transformed_vec) const override {
38+
auto meta = std::make_shared<MRLETMeta>();
39+
memcpy(transformed_vec, original_vec, this->output_dim_ * sizeof(float));
40+
if constexpr (metric == MetricType::METRIC_TYPE_COSINE) {
41+
Normalize(transformed_vec, transformed_vec, this->output_dim_);
42+
}
43+
return meta;
44+
}
45+
46+
void
47+
InverseTransform(const float* transformed_vec, float* original_vec) const override {
48+
throw VsagException(ErrorType::INTERNAL_ERROR, "InverseTransform not implement");
49+
}
50+
51+
void
52+
Serialize(StreamWriter& writer) const override{};
53+
54+
void
55+
Deserialize(StreamReader& reader) override{};
56+
57+
void
58+
Train(const float* data, uint64_t count) override{};
59+
};
60+
61+
} // namespace vsag

src/impl/transform/transformer_headers.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#pragma once
1717

1818
#include "fht_kac_rotate_transformer.h"
19+
#include "mrle_transformer.h"
1920
#include "pca_transformer.h"
2021
#include "random_orthogonal_transformer.h"
2122
#include "vector_transformer.h"

src/impl/transform/vector_transformer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class Allocator;
2626
DEFINE_POINTER(VectorTransformer);
2727
DEFINE_POINTER(TransformerMeta);
2828

29-
enum class VectorTransformerType { NONE, PCA, RANDOM_ORTHOGONAL, FHT, RESIDUAL, NORMALIZE };
29+
enum class VectorTransformerType { NONE, PCA, RANDOM_ORTHOGONAL, FHT, RESIDUAL, NORMALIZE, MRLE };
3030

3131
struct TransformerMeta {
3232
virtual void

src/impl/transform/vector_transformer_parameter.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,17 @@ VectorTransformerParameter::FromJson(const JsonType& json) {
2828
if (json.Contains(PCA_DIM_KEY)) {
2929
pca_dim_ = json[PCA_DIM_KEY].GetInt();
3030
}
31+
32+
if (json.Contains(MRLE_DIM_KEY)) {
33+
mrle_dim_ = json[MRLE_DIM_KEY].GetInt();
34+
}
3135
}
3236

3337
JsonType
3438
VectorTransformerParameter::ToJson() const {
3539
JsonType json;
3640
json[PCA_DIM_KEY].SetInt(pca_dim_);
41+
json[MRLE_DIM_KEY].SetInt(mrle_dim_);
3742
json[INPUT_DIM_KEY].SetInt(input_dim_);
3843
return json;
3944
}
@@ -53,6 +58,9 @@ VectorTransformerParameter::CheckCompatibility(const ParamPtr& other) const {
5358
if (input_dim_ != param->input_dim_) {
5459
return false;
5560
}
61+
if (mrle_dim_ != param->mrle_dim_) {
62+
return false;
63+
}
5664
return true;
5765
}
5866

src/impl/transform/vector_transformer_parameter.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@ class VectorTransformerParameter : public Parameter {
3535
CheckCompatibility(const vsag::ParamPtr& other) const override;
3636

3737
public:
38-
uint32_t input_dim_;
39-
uint32_t pca_dim_;
38+
uint32_t input_dim_{0};
39+
uint32_t pca_dim_{0};
40+
uint32_t mrle_dim_{0};
4041
};
4142

4243
} // namespace vsag

src/inner_string_params.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,14 @@ const char* const QUANTIZATION_TYPE_VALUE_TQ = "tq";
8282
const char* const TRANSFORMER_TYPE_VALUE_PCA = "pca";
8383
const char* const TRANSFORMER_TYPE_VALUE_ROM = "rom";
8484
const char* const TRANSFORMER_TYPE_VALUE_FHT = "fht";
85+
const char* const TRANSFORMER_TYPE_VALUE_MRLE = "mrle";
8586
const char* const TRANSFORMER_TYPE_VALUE_RESIDUAL = "residual";
8687
const char* const TRANSFORMER_TYPE_VALUE_NORMALIZE = "normalize";
8788

8889
// vector transformer param
8990
const char* const INPUT_DIM_KEY = "input_dim";
9091
const char* const PCA_DIM_KEY = "pca_dim";
92+
const char* const MRLE_DIM_KEY = "mrle_dim";
9193
const char* const USE_FHT_KEY = "use_fht";
9294

9395
// quantization param

src/quantization/transform_quantization/transform_quantizer.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,11 @@ TransformQuantizer<QuantTmpl, metric>::TransformQuantizer(const TransformQuantiz
132132
transformer_param.input_dim_ = this->dim_;
133133
for (const auto& transform_str : param->tq_chain_) {
134134
transform_chain_.emplace_back(MakeTransformerInstance(transform_str, transformer_param));
135+
if (transform_chain_.back()->GetType() == VectorTransformerType::MRLE and
136+
transform_chain_.size() > 1) {
137+
throw VsagException(ErrorType::INVALID_ARGUMENT,
138+
fmt::format("MRLE must be first if exists"));
139+
}
135140
transformer_param.input_dim_ = transform_chain_.back()->GetOutputDim();
136141
}
137142

@@ -186,6 +191,17 @@ TransformQuantizer<QuantTmpl, metric>::MakeTransformerInstance(
186191
return std::make_shared<RandomOrthogonalMatrix>(this->allocator_, input_dim, output_dim);
187192
}
188193

194+
if (transform_str == TRANSFORMER_TYPE_VALUE_MRLE) {
195+
if (param.mrle_dim_ != 0) {
196+
output_dim = param.mrle_dim_;
197+
if (output_dim > input_dim) {
198+
throw VsagException(ErrorType::INVALID_ARGUMENT,
199+
fmt::format("mrle dim must be less than input dim"));
200+
}
201+
}
202+
return std::make_shared<MRLETransformer<metric>>(this->allocator_, input_dim, output_dim);
203+
}
204+
189205
throw VsagException(ErrorType::INVALID_ARGUMENT,
190206
fmt::format("invalid transformer name {}", transform_str));
191207
};

src/quantization/transform_quantization/transform_quantizer_test.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,11 @@ TestComputeMetricTQ(std::string tq_chain, uint64_t dim, int count, float error =
3636
constexpr static const char* param_template = R"(
3737
{{
3838
"tq_chain": "{}",
39-
"pca_dim": {}
39+
"pca_dim": {},
40+
"mrle_dim": {}
4041
}}
4142
)";
42-
auto param_str = fmt::format(param_template, tq_chain, dim - 1);
43+
auto param_str = fmt::format(param_template, tq_chain, dim, dim - 1);
4344
auto param_json = vsag::JsonType::Parse(param_str);
4445
param->FromJson(param_json);
4546

@@ -66,10 +67,11 @@ TestSerializeDeserializeTQ(std::string tq_chain, uint64_t dim, int count) {
6667
constexpr static const char* param_template = R"(
6768
{{
6869
"tq_chain": "{}",
69-
"pca_dim": {}
70+
"pca_dim": {},
71+
"mrle_dim": {}
7072
}}
7173
)";
72-
auto param_str = fmt::format(param_template, tq_chain, dim - 2);
74+
auto param_str = fmt::format(param_template, tq_chain, dim, dim - 1);
7375
auto param_json = vsag::JsonType::Parse(param_str);
7476
param->FromJson(param_json);
7577

@@ -92,7 +94,7 @@ TestSerializeDeserializeTQ(std::string tq_chain, uint64_t dim, int count) {
9294

9395
TEST_CASE("TQ Compute", "[ut][TransformQuantizer]") {
9496
constexpr MetricType metrics[1] = {MetricType::METRIC_TYPE_L2SQR};
95-
std::string tq_chain = GENERATE("rom, pca, fp32", "rom, fp32", "fht, fp32");
97+
std::string tq_chain = GENERATE("rom, pca, fp32", "rom, fp32", "fht, fp32", "mrle, fp32");
9698

9799
for (auto dim : dims) {
98100
if (dim < 100) {

0 commit comments

Comments
 (0)