Skip to content

Commit 9413bf8

Browse files
authored
Merge pull request #5 from edawson/chainer-utils-backtrace
[cudamapper] Remove OverlapperMinimap test file, refactor to use back…
2 parents c2dc3e1 + 1dffb46 commit 9413bf8

File tree

4 files changed

+113
-33
lines changed

4 files changed

+113
-33
lines changed

cudamapper/src/chainer_utils.cu

Lines changed: 90 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,81 @@ __device__ bool operator==(const QueryReadID& a, const QueryReadID& b)
5050
return a.query_read_id_ == b.query_read_id_;
5151
}
5252

53+
__device__ Overlap create_simple_overlap(const Anchor& start, const Anchor& end, const int32_t num_anchors)
54+
{
55+
Overlap overlap;
56+
overlap.num_residues_ = num_anchors;
57+
58+
overlap.query_read_id_ = start.query_read_id_;
59+
overlap.target_read_id_ = start.target_read_id_;
60+
assert(start.query_read_id_ == end.query_read_id_ && start.target_read_id_ == end.target_read_id_);
61+
62+
overlap.query_start_position_in_read_ = min(start.query_position_in_read_, end.query_position_in_read_);
63+
overlap.query_end_position_in_read_ = max(start.query_position_in_read_, end.query_position_in_read_);
64+
bool is_negative_strand = end.target_position_in_read_ < start.target_position_in_read_;
65+
if (is_negative_strand)
66+
{
67+
overlap.relative_strand = RelativeStrand::Reverse;
68+
overlap.target_start_position_in_read_ = end.target_position_in_read_;
69+
overlap.target_end_position_in_read_ = start.target_position_in_read_;
70+
}
71+
else
72+
{
73+
overlap.relative_strand = RelativeStrand::Forward;
74+
overlap.target_start_position_in_read_ = start.target_position_in_read_;
75+
overlap.target_end_position_in_read_ = end.target_position_in_read_;
76+
}
77+
return overlap;
78+
}
79+
80+
__global__ void backtrace_anchors_to_overlaps(const Anchor* anchors,
81+
Overlap* overlaps,
82+
double* scores,
83+
bool* max_select_mask,
84+
int32_t* predecessors,
85+
const int32_t n_anchors,
86+
const int32_t min_score)
87+
{
88+
const std::size_t d_tid = blockIdx.x * blockDim.x + threadIdx.x;
89+
if (d_tid < n_anchors)
90+
{
91+
92+
int32_t global_overlap_index = d_tid;
93+
if (scores[d_tid] >= min_score)
94+
{
95+
96+
int32_t index = global_overlap_index;
97+
int32_t first_index = index;
98+
int32_t num_anchors_in_chain = 0;
99+
Anchor final_anchor = anchors[global_overlap_index];
100+
101+
while (index != -1)
102+
{
103+
first_index = index;
104+
int32_t pred = predecessors[index];
105+
if (pred != -1)
106+
{
107+
max_select_mask[pred] = false;
108+
}
109+
num_anchors_in_chain++;
110+
index = predecessors[index];
111+
}
112+
Anchor first_anchor = anchors[first_index];
113+
overlaps[global_overlap_index] = create_simple_overlap(first_anchor, final_anchor, num_anchors_in_chain);
114+
// Overlap final_overlap = overlaps[global_overlap_index];
115+
// printf("%d %d %d %d %d %d %d %f\n",
116+
// final_overlap.query_read_id_, final_overlap.query_start_position_in_read_, final_overlap.query_end_position_in_read_,
117+
// final_overlap.target_read_id_, final_overlap.target_start_position_in_read_, final_overlap.target_end_position_in_read_,
118+
// final_overlap.num_residues_,
119+
// final_score);
120+
}
121+
else
122+
{
123+
max_select_mask[global_overlap_index] = false;
124+
}
125+
}
126+
}
127+
53128
__global__ void convert_offsets_to_ends(std::int32_t* starts, std::int32_t* lengths, std::int32_t* ends, std::int32_t n_starts)
54129
{
55130
std::int32_t d_tid = blockIdx.x * blockDim.x + threadIdx.x;
@@ -89,7 +164,7 @@ __global__ void calculate_tile_starts(const std::int32_t* query_starts,
89164
const std::int32_t* tiles_per_query_up_to_point)
90165
{
91166
int32_t d_thread_id = blockIdx.x * blockDim.x + threadIdx.x;
92-
int32_t stride = blockDim.x * gridDim.x;
167+
int32_t stride = blockDim.x * gridDim.x;
93168
if (d_thread_id < num_queries)
94169
{
95170
// for each tile, we look up the query it corresponds to and offset it by the which tile in the query
@@ -167,7 +242,7 @@ void encode_anchor_query_locations(const Anchor* anchors,
167242
cub::DeviceScan::ExclusiveSum(d_temp_storage,
168243
temp_storage_bytes,
169244
query_lengths.data(), // this is the vector of encoded lengths
170-
query_starts.data(), // at this point, this vector is empty
245+
query_starts.data(), // at this point, this vector is empty
171246
n_queries,
172247
_cuda_stream);
173248

@@ -178,15 +253,15 @@ void encode_anchor_query_locations(const Anchor* anchors,
178253
cub::DeviceScan::ExclusiveSum(d_temp_storage,
179254
temp_storage_bytes,
180255
query_lengths.data(), // this is the vector of encoded lengths
181-
query_starts.data(),
256+
query_starts.data(),
182257
n_queries,
183258
_cuda_stream);
184259

185260
// paper uses the ends and finds the beginnings with x - w + 1, are we converting to that here?
186261
// TODO VI: I'm not entirely sure what this is for? I think we want to change the read query
187262
// (defined by [query_start, query_start + query_length] to [query_end - query_length + 1, query_end])
188263
// The above () is NOT true
189-
convert_offsets_to_ends<<<(n_queries / block_size) + 1, block_size, 0, _cuda_stream>>>(query_starts.data(), // this gives how many starts at each index
264+
convert_offsets_to_ends<<<(n_queries / block_size) + 1, block_size, 0, _cuda_stream>>>(query_starts.data(), // this gives how many starts at each index
190265
query_lengths.data(), // this is the vector of encoded lengths
191266
query_ends.data(),
192267
n_queries);
@@ -215,21 +290,21 @@ void encode_anchor_query_locations(const Anchor* anchors,
215290
d_temp_storage = nullptr;
216291
temp_storage_bytes = 0;
217292
cub::DeviceScan::ExclusiveSum(d_temp_storage,
218-
temp_storage_bytes,
219-
tiles_per_query.data(), // this is the vector of encoded lengths
220-
d_tiles_per_query_up_to_point.data(),
221-
n_queries,
222-
_cuda_stream);
293+
temp_storage_bytes,
294+
tiles_per_query.data(), // this is the vector of encoded lengths
295+
d_tiles_per_query_up_to_point.data(),
296+
n_queries,
297+
_cuda_stream);
223298

224299
d_temp_buf.clear_and_resize(temp_storage_bytes);
225300
d_temp_storage = d_temp_buf.data();
226-
301+
227302
cub::DeviceScan::ExclusiveSum(d_temp_storage,
228-
temp_storage_bytes,
229-
tiles_per_query.data(), // this is the vector of encoded lengths
230-
d_tiles_per_query_up_to_point.data(),
231-
n_queries,
232-
_cuda_stream);
303+
temp_storage_bytes,
304+
tiles_per_query.data(), // this is the vector of encoded lengths
305+
d_tiles_per_query_up_to_point.data(),
306+
n_queries,
307+
_cuda_stream);
233308

234309
calculate_tile_starts<<<(n_queries / block_size) + 1, block_size, 0, _cuda_stream>>>(query_starts.data(), tiles_per_query.data(), tile_starts.data(), tile_size, n_queries, d_tiles_per_query_up_to_point.data());
235310
}
@@ -276,7 +351,6 @@ void encode_anchor_query_target_pairs(const Anchor* anchors,
276351
d_num_query_target_pairs.data(),
277352
n_anchors);
278353

279-
280354
n_query_target_pairs = cudautils::get_value_from_device(d_num_query_target_pairs.data(), _cuda_stream);
281355

282356
d_temp_storage = nullptr;

cudamapper/src/chainer_utils.cuh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,14 @@ struct TileResults
126126
__device__ bool
127127
operator==(const QueryTargetPair& a, const QueryTargetPair& b);
128128

129+
__global__ void backtrace_anchors_to_overlaps(const Anchor* anchors,
130+
Overlap* overlaps,
131+
double* scores,
132+
bool* max_select_mask,
133+
int32_t* predecessors,
134+
const int32_t n_anchors,
135+
const int32_t min_score);
136+
129137
__global__ void convert_offsets_to_ends(std::int32_t* starts, std::int32_t* lengths, std::int32_t* ends, std::int32_t n_starts);
130138

131139
__global__ void calculate_tile_starts(const std::int32_t* query_starts,

cudamapper/src/overlapper_minimap.cu

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,8 @@ __global__ void mask_overlaps(Overlap* overlaps, std::size_t n_overlaps, bool* s
161161
const bool mask_self_self = false;
162162
auto query_bases_per_residue = static_cast<double>(overlap_query_length) / static_cast<double>(overlaps[d_tid].num_residues_);
163163
auto target_bases_per_residue = static_cast<double>(overlap_target_length) / static_cast<double>(overlaps[d_tid].num_residues_);
164-
select_mask[d_tid] = select_mask[d_tid] && (overlap_query_length >= min_overlap_length) && (overlap_target_length >= min_overlap_length);
165-
select_mask[d_tid] = select_mask[d_tid] && (overlaps[d_tid].num_residues_ >= min_residues);
164+
select_mask[d_tid] = select_mask[d_tid] && (overlap_query_length >= min_overlap_length) && (overlap_target_length >= min_overlap_length);
165+
select_mask[d_tid] = select_mask[d_tid] && (overlaps[d_tid].num_residues_ >= min_residues);
166166
//mask[d_tid] &= !mask_self_self;
167167
select_mask[d_tid] = select_mask[d_tid] && (query_bases_per_residue < max_bases_per_residue) && (target_bases_per_residue < max_bases_per_residue);
168168
// Look at the overlaps and all the overlaps adjacent to me, up to some maximum. Between neighbor i and myself, if
@@ -314,7 +314,6 @@ __device__ __forceinline__ int32_t fast_approx_log2(const int32_t val)
314314
return 8;
315315
}
316316

317-
318317
// TODO VI: This may need to be fixed at some point. Likely the last line
319318
__device__ __forceinline__ int32_t log_linear_anchor_weight(const Anchor& a,
320319
const Anchor& b,
@@ -398,7 +397,7 @@ __global__ void chain_anchors_in_block(const Anchor* anchors,
398397
const int32_t* tile_starts,
399398
const int32_t num_anchors,
400399
const int32_t num_query_tiles,
401-
const int32_t batch_id, // which batch number we are on
400+
const int32_t batch_id, // which batch number we are on
402401
const int32_t batch_size, // fixed to TILE_SIZE...?
403402
const int32_t word_size,
404403
const int32_t max_distance,
@@ -429,14 +428,14 @@ __global__ void chain_anchors_in_block(const Anchor* anchors,
429428
__shared__ int32_t block_predecessor_cache[PREDECESSOR_SEARCH_ITERATIONS];
430429

431430
// Initialize the local caches
432-
block_anchor_cache[thread_id_in_block] = anchors[global_read_index];
431+
block_anchor_cache[thread_id_in_block] = anchors[global_read_index];
433432
// I _believe_ some or most of these will be 0
434433
// not sure why we downcast to integer here
435-
block_score_cache[thread_id_in_block] = static_cast<int32_t>(scores[global_read_index]);
434+
block_score_cache[thread_id_in_block] = static_cast<int32_t>(scores[global_read_index]);
436435
// I _believe some or most of these will be -1 at first
437436
block_predecessor_cache[thread_id_in_block] = predecessors[global_read_index];
438437
// Still not sure what this is for
439-
block_max_select_mask[thread_id_in_block] = false;
438+
block_max_select_mask[thread_id_in_block] = false;
440439

441440
// iterate through the tile
442441
for (int32_t i = PREDECESSOR_SEARCH_ITERATIONS, counter = 0; counter < batch_size; ++counter, ++i)
@@ -486,7 +485,7 @@ __global__ void chain_anchors_in_block(const Anchor* anchors,
486485
// possible_successor_anchor.target_position_in_read_);
487486
__syncthreads();
488487

489-
// if
488+
// if
490489
if (current_score + marginal_score >= block_score_cache[thread_id_in_block] && (global_read_index + i) < num_anchors)
491490
{
492491
//current_score = current_score + marginal_score;
@@ -739,7 +738,7 @@ void OverlapperMinimap::get_overlaps(std::vector<Overlap>& fused_overlaps,
739738
// generates the scheduler blocks
740739
chainerutils::encode_anchor_query_locations(d_anchors.data(),
741740
n_anchors,
742-
TILE_SIZE, // This is 1024
741+
TILE_SIZE, // This is 1024
743742
query_id_starts,
744743
query_id_lengths,
745744
query_id_ends,
@@ -799,13 +798,13 @@ void OverlapperMinimap::get_overlaps(std::vector<Overlap>& fused_overlaps,
799798
#endif
800799

801800
// the deschedule block. Get outputs from here
802-
produce_anchor_chains<<<(n_anchors / block_size) + 1, block_size, 0, _cuda_stream>>>(d_anchors.data(),
803-
d_overlaps_source.data(),
804-
d_anchor_scores.data(),
805-
d_overlaps_select_mask.data(),
806-
d_anchor_predecessors.data(),
807-
n_anchors,
808-
20);
801+
chainerutils::backtrace_anchors_to_overlaps<<<BLOCK_COUNT, block_size, 0, _cuda_stream>>>(d_anchors.data(),
802+
d_overlaps_source.data(),
803+
d_anchor_scores.data(),
804+
d_overlaps_select_mask.data(),
805+
d_anchor_predecessors.data(),
806+
n_anchors,
807+
40);
809808

810809
// TODO VI: I think we can get better device occupancy here with some kernel refactoring
811810
mask_overlaps<<<(n_anchors / block_size) + 1, block_size, 0, _cuda_stream>>>(d_overlaps_source.data(),

cudamapper/tests/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ set(SOURCES
2828
Test_CudamapperMinimizer.cpp
2929
Test_CudamapperOverlapper.cpp
3030
Test_CudamapperOverlapperTriggered.cu
31-
Test_CudamapperOverlapperMinimap.cu
3231
Test_CudamapperUtilsKmerFunctions.cpp
3332
)
3433

0 commit comments

Comments
 (0)