Skip to content

Commit

Permalink
Merge pull request #336 from ksahlin/nams
Browse files Browse the repository at this point in the history
Speed up NAM merging a bit
  • Loading branch information
marcelm authored Sep 6, 2023
2 parents f6b6208 + 15c8466 commit 0da9e81
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 23 deletions.
58 changes: 36 additions & 22 deletions src/nam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@ struct Hit {
int query_end;
int ref_start;
int ref_end;
bool is_rc = false;
};

void add_to_hits_per_ref(
inline void add_to_hits_per_ref(
robin_hood::unordered_map<unsigned int, std::vector<Hit>>& hits_per_ref,
int query_start,
int query_end,
bool is_rc,
const StrobemerIndex& index,
size_t position
) {
Expand All @@ -24,18 +22,19 @@ void add_to_hits_per_ref(
int ref_end = ref_start + index.strobe2_offset(position) + index.k();
int diff = std::abs((query_end - query_start) - (ref_end - ref_start));
if (diff <= min_diff) {
hits_per_ref[index.reference_index(position)].push_back(Hit{query_start, query_end, ref_start, ref_end, is_rc});
hits_per_ref[index.reference_index(position)].push_back(Hit{query_start, query_end, ref_start, ref_end});
min_diff = diff;
}
}
}

std::vector<Nam> merge_hits_into_nams(
void merge_hits_into_nams(
robin_hood::unordered_map<unsigned int, std::vector<Hit>>& hits_per_ref,
int k,
bool sort
bool sort,
bool is_revcomp,
std::vector<Nam>& nams // inout
) {
std::vector<Nam> nams;
int nam_id_cnt = 0;
for (auto &[ref_id, hits] : hits_per_ref) {
if (sort) {
Expand All @@ -53,7 +52,7 @@ std::vector<Nam> merge_hits_into_nams(
for (auto & o : open_nams) {

// Extend NAM
if (( o.is_rc == h.is_rc) && (o.query_prev_hit_startpos < h.query_start) && (h.query_start <= o.query_end ) && (o.ref_prev_hit_startpos < h.ref_start) && (h.ref_start <= o.ref_end) ){
if ((o.query_prev_hit_startpos < h.query_start) && (h.query_start <= o.query_end ) && (o.ref_prev_hit_startpos < h.ref_start) && (h.ref_start <= o.ref_end) ){
if ( (h.query_end > o.query_end) && (h.ref_end > o.ref_end) ) {
o.query_end = h.query_end;
o.ref_end = h.ref_end;
Expand Down Expand Up @@ -94,7 +93,7 @@ std::vector<Nam> merge_hits_into_nams(
n.query_prev_hit_startpos = h.query_start;
n.ref_prev_hit_startpos = h.ref_start;
n.n_hits = 1;
n.is_rc = h.is_rc;
n.is_rc = is_revcomp;
// n.score += (float)1 / (float)h.count;
open_nams.push_back(n);
}
Expand Down Expand Up @@ -134,6 +133,18 @@ std::vector<Nam> merge_hits_into_nams(
nams.push_back(n);
}
}
}

std::vector<Nam> merge_hits_into_nams_forward_and_reverse(
std::array<robin_hood::unordered_map<unsigned int, std::vector<Hit>>, 2>& hits_per_ref,
int k,
bool sort
) {
std::vector<Nam> nams;
for (size_t is_revcomp = 0; is_revcomp < 2; ++is_revcomp) {
auto& hits_oriented = hits_per_ref[is_revcomp];
merge_hits_into_nams(hits_oriented, k, sort, is_revcomp, nams);
}
return nams;
}

Expand All @@ -149,8 +160,9 @@ std::pair<float, std::vector<Nam>> find_nams(
const QueryRandstrobeVector &query_randstrobes,
const StrobemerIndex& index
) {
robin_hood::unordered_map<unsigned int, std::vector<Hit>> hits_per_ref;
hits_per_ref.reserve(100);
std::array<robin_hood::unordered_map<unsigned int, std::vector<Hit>>, 2> hits_per_ref;
hits_per_ref[0].reserve(100);
hits_per_ref[1].reserve(100);
int nr_good_hits = 0, total_hits = 0;
for (const auto &q : query_randstrobes) {
size_t position = index.find(q.hash);
Expand All @@ -160,11 +172,11 @@ std::pair<float, std::vector<Nam>> find_nams(
continue;
}
nr_good_hits++;
add_to_hits_per_ref(hits_per_ref, q.start, q.end, q.is_reverse, index, position);
add_to_hits_per_ref(hits_per_ref[q.is_reverse], q.start, q.end, index, position);
}
}
float nonrepetitive_fraction = total_hits > 0 ? ((float) nr_good_hits) / ((float) total_hits) : 1.0;
auto nams = merge_hits_into_nams(hits_per_ref, index.k(), false);
auto nams = merge_hits_into_nams_forward_and_reverse(hits_per_ref, index.k(), false);
return make_pair(nonrepetitive_fraction, nams);
}

Expand All @@ -179,30 +191,30 @@ std::vector<Nam> find_nams_rescue(
unsigned int rescue_cutoff
) {
struct RescueHit {
unsigned int count;
size_t position;
unsigned int count;
unsigned int query_start;
unsigned int query_end;
bool is_rc;

bool operator< (const RescueHit& rhs) const {
return std::tie(count, query_start, query_end, is_rc)
< std::tie(rhs.count, rhs.query_start, rhs.query_end, rhs.is_rc);
return std::tie(count, query_start, query_end)
< std::tie(rhs.count, rhs.query_start, rhs.query_end);
}
};

robin_hood::unordered_map<unsigned int, std::vector<Hit>> hits_per_ref;
std::array<robin_hood::unordered_map<unsigned int, std::vector<Hit>>, 2> hits_per_ref;
std::vector<RescueHit> hits_fw;
std::vector<RescueHit> hits_rc;
hits_per_ref.reserve(100);
hits_per_ref[0].reserve(100);
hits_per_ref[1].reserve(100);
hits_fw.reserve(5000);
hits_rc.reserve(5000);

for (auto &qr : query_randstrobes) {
size_t position = index.find(qr.hash);
if (position != index.end()) {
unsigned int count = index.get_count(position);
RescueHit rh{count, position, qr.start, qr.end, qr.is_reverse};
RescueHit rh{position, count, qr.start, qr.end};
if (qr.is_reverse){
hits_rc.push_back(rh);
} else {
Expand All @@ -213,18 +225,20 @@ std::vector<Nam> find_nams_rescue(

std::sort(hits_fw.begin(), hits_fw.end());
std::sort(hits_rc.begin(), hits_rc.end());
size_t is_revcomp = 0;
for (auto& rescue_hits : {hits_fw, hits_rc}) {
int cnt = 0;
for (auto &rh : rescue_hits) {
if ((rh.count > rescue_cutoff && cnt >= 5) || rh.count > 1000) {
break;
}
add_to_hits_per_ref(hits_per_ref, rh.query_start, rh.query_end, rh.is_rc, index, rh.position);
add_to_hits_per_ref(hits_per_ref[is_revcomp], rh.query_start, rh.query_end, index, rh.position);
cnt++;
}
is_revcomp++;
}

return merge_hits_into_nams(hits_per_ref, index.k(), true);
return merge_hits_into_nams_forward_and_reverse(hits_per_ref, index.k(), true);
}

std::ostream& operator<<(std::ostream& os, const Nam& n) {
Expand Down
1 change: 1 addition & 0 deletions src/nam.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define STROBEALIGN_NAM_HPP

#include <vector>
#include <array>
#include "index.hpp"
#include "randstrobes.hpp"

Expand Down
2 changes: 1 addition & 1 deletion tests/baseline-commit.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
baseline_commit=19d5d20f5fb1051169639a1efb37bd7371a078f2
baseline_commit=91f22234baf192a3167354a7f2bbfefb49cb2162

0 comments on commit 0da9e81

Please sign in to comment.