Skip to content

Commit

Permalink
Merge branch 'main' into release/0.0.11
Browse files Browse the repository at this point in the history
  • Loading branch information
robomics committed Mar 23, 2024
2 parents e63272e + c7230c2 commit 5497e14
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 112 deletions.
19 changes: 10 additions & 9 deletions src/hictk/balance/balance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,10 @@ static int balance_cooler(cooler::File& f, const BalanceConfig& c) {
const auto weights = balancer.get_weights(c.rescale_marginals);

if (c.stdout_) {
const auto weights_ = weights(balancing::Weights::Type::DIVISIVE);
std::for_each(weights_.begin(), weights_.end(),
[&](const auto w) { fmt::print(FMT_COMPILE("{}\n"), w); });
for (const auto& w :
balancer.get_weights(c.rescale_marginals)(balancing::Weights::Type::DIVISIVE)) {
fmt::print(FMT_COMPILE("{}\n"), w);
}
return 0;
}

Expand Down Expand Up @@ -205,13 +206,13 @@ static int balance_hic(const BalanceConfig& c) {
const Balancer balancer(f, mode, params);

if (c.stdout_) {
const auto weights_ =
balancer.get_weights(c.rescale_marginals)(balancing::Weights::Type::DIVISIVE);
std::for_each(weights_.begin(), weights_.end(),
[&](const auto w) { fmt::print(FMT_COMPILE("{}\n"), w); });
} else {
weights.emplace(res, balancer.get_weights(c.rescale_marginals));
for (const auto& w :
balancer.get_weights(c.rescale_marginals)(balancing::Weights::Type::DIVISIVE)) {
fmt::print(FMT_COMPILE("{}\n"), w);
}
return 0;
}
weights.emplace(res, balancer.get_weights(c.rescale_marginals));
}

// NOLINTNEXTLINE(misc-const-correctness)
Expand Down
207 changes: 104 additions & 103 deletions src/libhictk/balancing/include/hictk/balancing/impl/scale_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,115 +70,116 @@ inline SCALE::SCALE(PixelIt first, PixelIt last, const hictk::BinTable& bins, co
const auto matrix = mask_bins_and_init_buffers(first, last, offset, params.max_percentile,
params.tmpfile, params.chunk_size);

std::visit(
[&](const auto& m) {
VectorOfAtomicDecimals column(size(), 9);
VectorOfAtomicDecimals row(size(), 9);

m.multiply(row, _one, _tpool.get());
row.multiply(_biases);

auto dr = _biases;
auto dc = _biases;
auto current = _biases;

std::vector<double> b_conv(size(), 0);
std::vector<double> b0(size(), 0);
std::vector<bool> bad_conv(size(), false);
_ber_conv = 10.0;

for (_iter = 0, _tot_iter = 0; _convergence_stats.error > params.tol &&
_iter < params.max_iters && _tot_iter < _max_tot_iters;
++_iter, ++_tot_iter) {
update_weights(column, _bad, row, _z_target_vector, dr, m, _tpool.get());
column.multiply(dc);

update_weights(row, _bad, column, _z_target_vector, dc, m, _tpool.get());
row.multiply(dr);

geometric_mean(dr, dc, _biases1);
const auto res = compute_convergence_error(_biases1, current, _bad, params.tol);
_convergence_stats.error = res.first;
const auto num_bad = res.second;

b0 = current;
current = _biases1;

_error_queue_iter.push(_convergence_stats.error);
if (_error_queue_iter.size() == 7) {
_error_queue_iter.pop();
}
std::visit([&](const auto& m) { balance(m, bins, params); }, matrix);
}

const auto frac_bad = static_cast<double>(num_bad) / static_cast<double>(_nnz_rows);
template <typename Matrix>
inline void SCALE::balance(const Matrix& m, const BinTable& bins, const Params& params) {
VectorOfAtomicDecimals column(size(), 9);
VectorOfAtomicDecimals row(size(), 9);

m.multiply(row, _one, _tpool.get());
row.multiply(_biases);

auto dr = _biases;
auto dc = _biases;
auto current = _biases;

std::vector<double> b_conv(size(), 0);
std::vector<double> b0(size(), 0);
std::vector<bool> bad_conv(size(), false);
_ber_conv = 10.0;

for (_iter = 0, _tot_iter = 0; _convergence_stats.error > params.tol &&
_iter < params.max_iters && _tot_iter < _max_tot_iters;
++_iter, ++_tot_iter) {
update_weights(column, _bad, row, _z_target_vector, dr, m, _tpool.get());
column.multiply(dc);

update_weights(row, _bad, column, _z_target_vector, dc, m, _tpool.get());
row.multiply(dr);

geometric_mean(dr, dc, _biases1);
const auto res = compute_convergence_error(_biases1, current, _bad, params.tol);
_convergence_stats.error = res.first;
const auto num_bad = res.second;

b0 = current;
current = _biases1;

_error_queue_iter.push(_convergence_stats.error);
if (_error_queue_iter.size() == 7) {
_error_queue_iter.pop();
}

SPDLOG_INFO(FMT_STRING("Iteration {}: {}"), _tot_iter, _convergence_stats.error);
const auto frac_bad = static_cast<double>(num_bad) / static_cast<double>(_nnz_rows);

if (_convergence_stats.error < params.tol) {
SPDLOG_DEBUG(FMT_STRING("handle_convergence"));
const auto status = handle_convergenece(m, dr, dc, row);
if (status == ControlFlow::break_loop) {
break;
}
assert(status == ControlFlow::continue_loop);
reset_iter();
continue;
}
SPDLOG_INFO(FMT_STRING("Iteration {}: {}"), _tot_iter, _convergence_stats.error);

if (_iter <= 4) {
continue;
}
if (_convergence_stats.error < params.tol) {
SPDLOG_DEBUG(FMT_STRING("handle_convergence"));
const auto status = handle_convergenece(m, dr, dc, row);
if (status == ControlFlow::break_loop) {
break;
}
assert(status == ControlFlow::continue_loop);
reset_iter();
continue;
}

// check whether convergence rate is satisfactory
const auto err1 = _error_queue_iter.front();
const auto err2 = _error_queue_iter.back();
if (err2 * (1.0 + params.delta) < err1 && (_iter < params.max_iters)) {
continue;
}
if (_iter <= 4) {
continue;
}

// handle divergence
SPDLOG_DEBUG(FMT_STRING("handle_divergence"));
_convergence_stats.diverged = true;
_convergence_stats.low_divergence = static_cast<std::uint32_t>(_low_cutoff);
const auto status =
handle_diverged(m, b0, dr, dc, row, frac_bad, params.frac_bad_cutoff, params.tol);
if (status == ControlFlow::break_loop) {
break;
}
if (status == ControlFlow::continue_loop) {
continue;
}
}

m.multiply(column, _biases1, _tpool.get());
const auto row_sum_error = compute_final_error(column, _biases1, _z_target_vector, _bad);

if (_convergence_stats.error > params.tol) {
SPDLOG_DEBUG(FMT_STRING("error > tol: {} > {}"), _convergence_stats.error, params.tol);
}
if (row_sum_error > params.max_row_sum_error) {
SPDLOG_DEBUG(FMT_STRING("row_sum_error > params.max_row_sum_error: {} > {}"),
row_sum_error, params.max_row_sum_error);
}
if (_low_cutoff > _upper_bound) {
SPDLOG_DEBUG(FMT_STRING("low_cutoff > upper_bound: {} > {}"), _low_cutoff, _upper_bound);
}
// convergence not achieved, return vector of nans
if (_convergence_stats.error > params.tol || row_sum_error > params.max_row_sum_error ||
_low_cutoff > _upper_bound) {
std::fill(_biases.begin(), _biases.end(), std::numeric_limits<double>::quiet_NaN());
_scale.push_back(std::numeric_limits<double>::quiet_NaN());
_chrom_offsets = bins.num_bin_prefix_sum();
return;
}

// convergence achieved
for (std::size_t i = 0; i < size(); ++i) {
_biases[i] = _bad[i] ? std::numeric_limits<double>::quiet_NaN() : 1.0 / _biases1[i];
}
_scale.push_back(m.compute_scaling_factor_for_scale(_biases));
_chrom_offsets = bins.num_bin_prefix_sum();
},
matrix);
// check whether convergence rate is satisfactory
const auto err1 = _error_queue_iter.front();
const auto err2 = _error_queue_iter.back();
if (err2 * (1.0 + params.delta) < err1 && (_iter < params.max_iters)) {
continue;
}

// handle divergence
SPDLOG_DEBUG(FMT_STRING("handle_divergence"));
_convergence_stats.diverged = true;
_convergence_stats.low_divergence = static_cast<std::uint32_t>(_low_cutoff);
const auto status =
handle_diverged(m, b0, dr, dc, row, frac_bad, params.frac_bad_cutoff, params.tol);
if (status == ControlFlow::break_loop) {
break;
}
if (status == ControlFlow::continue_loop) {
continue;
}
}

m.multiply(column, _biases1, _tpool.get());
const auto row_sum_error = compute_final_error(column, _biases1, _z_target_vector, _bad);

if (_convergence_stats.error > params.tol) {
SPDLOG_DEBUG(FMT_STRING("error > tol: {} > {}"), _convergence_stats.error, params.tol);
}
if (row_sum_error > params.max_row_sum_error) {
SPDLOG_DEBUG(FMT_STRING("row_sum_error > params.max_row_sum_error: {} > {}"), row_sum_error,
params.max_row_sum_error);
}
if (_low_cutoff > _upper_bound) {
SPDLOG_DEBUG(FMT_STRING("low_cutoff > upper_bound: {} > {}"), _low_cutoff, _upper_bound);
}
// convergence not achieved, return vector of nans
if (_convergence_stats.error > params.tol || row_sum_error > params.max_row_sum_error ||
_low_cutoff > _upper_bound) {
std::fill(_biases.begin(), _biases.end(), std::numeric_limits<double>::quiet_NaN());
_scale.push_back(std::numeric_limits<double>::quiet_NaN());
_chrom_offsets = bins.num_bin_prefix_sum();
return;
}

// convergence achieved
for (std::size_t i = 0; i < size(); ++i) {
_biases[i] = _bad[i] ? std::numeric_limits<double>::quiet_NaN() : 1.0 / _biases1[i];
}
_scale.push_back(m.compute_scaling_factor_for_scale(_biases));
_chrom_offsets = bins.num_bin_prefix_sum();
}

inline std::size_t SCALE::size() const noexcept { return _biases.size(); }
Expand Down
3 changes: 3 additions & 0 deletions src/libhictk/balancing/include/hictk/balancing/scale.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ class SCALE {
template <typename File>
[[nodiscard]] static auto compute_gw(const File& f, const Params& params) -> Result;

template <typename Matrix>
void balance(const Matrix& m, const BinTable& bins, const Params& params);

[[nodiscard]] static VC::Type map_type_to_vc(Type type) noexcept;

template <typename Matrix>
Expand Down

0 comments on commit 5497e14

Please sign in to comment.