Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions src/quantization/scalar_quantization/scalar_quantization_trainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,159 @@ ScalarQuantizationTrainer::Train(const float* data,
this->classic_train(sample_datas.data(), sample_count, upper_bound, lower_bound);
} else if (mode == TRUNC_BOUND) {
this->trunc_bound_train(sample_datas.data(), sample_count, upper_bound, lower_bound);
} else if (mode == PSO) {
this->pso_train(sample_datas.data(), sample_count, upper_bound, lower_bound);
}
}

void
ScalarQuantizationTrainer::pso_train(const float* data,
uint64_t count,
float* upper_bound,
float* lower_bound) const {
constexpr size_t max_iter = 128;
constexpr size_t grid_side_length = 8;
constexpr float grid_scale_factor = 0.1f;
constexpr float init_inertia = 0.9f;
constexpr float final_inertia = 0.4f;
constexpr float c1 = 1.8f;
constexpr float c2 = 1.8f;

return pso_train_impl(data,
count,
upper_bound,
lower_bound,
max_iter,
grid_side_length,
grid_scale_factor,
init_inertia,
final_inertia,
c1,
c2);
}

void
ScalarQuantizationTrainer::pso_train_impl(const float* data,
uint64_t count,
float* upper_bound,
float* lower_bound,
size_t max_iter,
size_t grid_side_length,
float grid_scale_factor,
float init_inertia,
float final_inertia,
float c1,
float c2) const {
this->classic_train(data, count, upper_bound, lower_bound);
float div = (1 << this->bits_) - 1;

#pragma omp parallel for
for (uint64_t i = 0; i < dim_; ++i) {
float init_upper_bound = upper_bound[i];
float init_lower_bound = lower_bound[i];
const float init_range_width = init_upper_bound - init_lower_bound;
const float init_range_center = (init_lower_bound + init_upper_bound) * 0.5f;

std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> v_dis(-init_range_width * 0.01f,
init_range_width * 0.01f);
std::uniform_real_distribution<float> p_dis(0.0f, 1.0f);

auto loss = [=, this](float lower, float step_size) {
step_size = std::max(step_size, 1e-6f);
float total_loss = 0.0f;
for (uint64_t j = 0; j < count; ++j) {
float value = data[j * dim_ + i];
float quantized_code = std::round((value - lower) / step_size);
quantized_code = std::min(quantized_code, div);
quantized_code = std::max(quantized_code, 0.0f);
float error = (value - (lower + quantized_code * step_size)) *
(value - (lower + quantized_code * step_size));
total_loss += error;
}
return total_loss;
};

struct Particle {
float lower;
float step_size;
float v_lower;
float v_step_size;
float best_lower;
float best_step_size;
float min_loss;

Particle(const float l_val, const float s_val, const float vl_val, const float vs_val)
: lower(l_val),
step_size(s_val),
v_lower(vl_val),
v_step_size(vs_val),
best_lower(l_val),
best_step_size(s_val),
min_loss(std::numeric_limits<float>::max()) {
}
};

std::vector<Particle> swarm;
swarm.reserve(grid_side_length * grid_side_length);
for (size_t m = 0; m < grid_side_length; ++m) {
for (size_t n = 0; n < grid_side_length; ++n) {
float particle_lower =
init_lower_bound + (static_cast<float>(m) - grid_side_length * 0.5f) *
grid_scale_factor * init_range_width / grid_side_length;
float particle_step_size =
init_range_width / div * (0.5f + static_cast<float>(n) / grid_side_length);
particle_step_size = std::max(particle_step_size, 1e-6f);
swarm.emplace_back(particle_lower, particle_step_size, v_dis(gen), v_dis(gen));
}
}

float global_best_lower = init_lower_bound;
float global_best_step_size = std::max(init_range_width / div, 1e-6f);
float global_min_loss = loss(init_lower_bound, global_best_step_size);
for (auto& particle : swarm) {
float curr_loss = loss(particle.lower, particle.step_size);
particle.min_loss = curr_loss;
if (curr_loss < global_min_loss) {
global_min_loss = curr_loss;
global_best_lower = particle.lower;
global_best_step_size = particle.step_size;
}
}

for (size_t iter = 0; iter < max_iter; ++iter) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is suggested to split this long function into smaller functions and add unit tests (UT) to cover each one of them

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be difficult to split it up, so I implemented a pso_train_impl function that can pass training parameters to facilitate possible testing later.

float inertia =
init_inertia - (init_inertia - final_inertia) * static_cast<float>(iter) / max_iter;
for (auto& particle : swarm) {
float r1 = p_dis(gen);
float r2 = p_dis(gen);
particle.v_lower = inertia * particle.v_lower +
c1 * r1 * (particle.best_lower - particle.lower) +
c2 * r2 * (global_best_lower - particle.lower);
particle.v_step_size = inertia * particle.v_step_size +
c1 * r1 * (particle.best_step_size - particle.step_size) +
c2 * r2 * (global_best_step_size - particle.step_size);
particle.lower += particle.v_lower;
particle.step_size += particle.v_step_size;
if (particle.step_size <= 1e-6f) {
particle.step_size = 1e-6f;
}
float curr_loss = loss(particle.lower, particle.step_size);
if (curr_loss < particle.min_loss) {
particle.min_loss = curr_loss;
particle.best_lower = particle.lower;
particle.best_step_size = particle.step_size;
}
if (curr_loss < global_min_loss) {
global_min_loss = curr_loss;
global_best_lower = particle.lower;
global_best_step_size = particle.step_size;
}
}
}
lower_bound[i] = global_best_lower;
upper_bound[i] = global_best_lower + global_best_step_size * div;
}
}

Expand All @@ -65,6 +218,14 @@ ScalarQuantizationTrainer::TrainUniform(const float* data,
// case for count == 1 or trunc_rate > 0.5
std::swap(lower_bound, upper_bound);
}
} else if (mode == PSO) {
this->pso_train(sample_datas.data(), sample_count, upper.data(), lower.data());
upper_bound = *std::min_element(upper.begin(), upper.end());
lower_bound = *std::max_element(lower.begin(), lower.end());
if (lower_bound > upper_bound) {
// case for count == 1 or trunc_rate > 0.5
std::swap(lower_bound, upper_bound);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#pragma once

#include <cstdint>
#include <vector>

#include "typing.h"

Expand All @@ -25,6 +26,7 @@ enum SQTrainMode {
CLASSIC = 1,
K_MEANS = 2,
TRUNC_BOUND = 3,
PSO = 4,
};

class ScalarQuantizationTrainer {
Expand Down Expand Up @@ -67,6 +69,22 @@ class ScalarQuantizationTrainer {
float* upper_bound,
float* lower_bound) const;

void
pso_train(const float* data, uint64_t count, float* upper_bound, float* lower_bound) const;

void
pso_train_impl(const float* data,
uint64_t count,
float* upper_bound,
float* lower_bound,
size_t max_iter,
size_t grid_side_length,
float grid_scale_factor,
float init_inertia,
float final_inertia,
float c1,
float c2) const;

uint64_t
sample_train_data(const float* data,
uint64_t count,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@

// Copyright 2024-present the vsag project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "scalar_quantization_trainer.h"

#include <algorithm>
#include <catch2/catch_message.hpp>
#include <catch2/catch_test_macros.hpp>
#include <cstdint>
#include <fstream>
#include <iostream>
#include <numeric>
#include <random>
#include <vector>

using namespace vsag;

float
compute_mse(const std::vector<float>& data, float lower, float upper, int bits) {
float div = (1 << bits) - 1;
float step = (upper - lower) / div;
float mse = 0.0f;
for (float v : data) {
float code = std::round((v - lower) / step);
code = std::min(std::max(code, 0.0f), div);
float recon = lower + code * step;
mse += (v - recon) * (v - recon);
}
return mse / data.size();
}

TEST_CASE("ScalarQuantizationTrainer", "[ft][scalar_quantization_trainer]") {
std::vector<float> data;
std::mt19937 gen(42);
std::uniform_real_distribution<float> dist(0.0f, 1.0f);
for (int i = 0; i < 1000; ++i) {
data.push_back(dist(gen));
}
int bits = 4;
float lower_c[1], upper_c[1], lower_t[1], upper_t[1], lower_p[1], upper_p[1];

ScalarQuantizationTrainer trainer(1, bits);

// CLASSIC
trainer.Train(data.data(), data.size(), upper_c, lower_c, false, vsag::CLASSIC);
float mse_classic = compute_mse(data, lower_c[0], upper_c[0], bits);

// TRUNC_BOUND
trainer.Train(data.data(), data.size(), upper_t, lower_t, false, vsag::TRUNC_BOUND);
float mse_trunc = compute_mse(data, lower_t[0], upper_t[0], bits);

// PSO
trainer.Train(data.data(), data.size(), upper_p, lower_p, false, vsag::PSO);
float mse_pso = compute_mse(data, lower_p[0], upper_p[0], bits);

REQUIRE(lower_c <= upper_c);
REQUIRE(lower_t <= upper_t);
REQUIRE(lower_p <= upper_p);

REQUIRE(mse_pso < mse_classic * 0.95);
REQUIRE(mse_pso < mse_trunc);
}
Loading