@@ -50,6 +50,81 @@ __device__ bool operator==(const QueryReadID& a, const QueryReadID& b)
5050 return a.query_read_id_ == b.query_read_id_ ;
5151}
5252
53+ __device__ Overlap create_simple_overlap (const Anchor& start, const Anchor& end, const int32_t num_anchors)
54+ {
55+ Overlap overlap;
56+ overlap.num_residues_ = num_anchors;
57+
58+ overlap.query_read_id_ = start.query_read_id_ ;
59+ overlap.target_read_id_ = start.target_read_id_ ;
60+ assert (start.query_read_id_ == end.query_read_id_ && start.target_read_id_ == end.target_read_id_ );
61+
62+ overlap.query_start_position_in_read_ = min (start.query_position_in_read_ , end.query_position_in_read_ );
63+ overlap.query_end_position_in_read_ = max (start.query_position_in_read_ , end.query_position_in_read_ );
64+ bool is_negative_strand = end.target_position_in_read_ < start.target_position_in_read_ ;
65+ if (is_negative_strand)
66+ {
67+ overlap.relative_strand = RelativeStrand::Reverse;
68+ overlap.target_start_position_in_read_ = end.target_position_in_read_ ;
69+ overlap.target_end_position_in_read_ = start.target_position_in_read_ ;
70+ }
71+ else
72+ {
73+ overlap.relative_strand = RelativeStrand::Forward;
74+ overlap.target_start_position_in_read_ = start.target_position_in_read_ ;
75+ overlap.target_end_position_in_read_ = end.target_position_in_read_ ;
76+ }
77+ return overlap;
78+ }
79+
80+ __global__ void backtrace_anchors_to_overlaps (const Anchor* anchors,
81+ Overlap* overlaps,
82+ double * scores,
83+ bool * max_select_mask,
84+ int32_t * predecessors,
85+ const int32_t n_anchors,
86+ const int32_t min_score)
87+ {
88+ const std::size_t d_tid = blockIdx .x * blockDim .x + threadIdx .x ;
89+ if (d_tid < n_anchors)
90+ {
91+
92+ int32_t global_overlap_index = d_tid;
93+ if (scores[d_tid] >= min_score)
94+ {
95+
96+ int32_t index = global_overlap_index;
97+ int32_t first_index = index;
98+ int32_t num_anchors_in_chain = 0 ;
99+ Anchor final_anchor = anchors[global_overlap_index];
100+
101+ while (index != -1 )
102+ {
103+ first_index = index;
104+ int32_t pred = predecessors[index];
105+ if (pred != -1 )
106+ {
107+ max_select_mask[pred] = false ;
108+ }
109+ num_anchors_in_chain++;
110+ index = predecessors[index];
111+ }
112+ Anchor first_anchor = anchors[first_index];
113+ overlaps[global_overlap_index] = create_simple_overlap (first_anchor, final_anchor, num_anchors_in_chain);
114+ // Overlap final_overlap = overlaps[global_overlap_index];
115+ // printf("%d %d %d %d %d %d %d %f\n",
116+ // final_overlap.query_read_id_, final_overlap.query_start_position_in_read_, final_overlap.query_end_position_in_read_,
117+ // final_overlap.target_read_id_, final_overlap.target_start_position_in_read_, final_overlap.target_end_position_in_read_,
118+ // final_overlap.num_residues_,
119+ // final_score);
120+ }
121+ else
122+ {
123+ max_select_mask[global_overlap_index] = false ;
124+ }
125+ }
126+ }
127+
53128__global__ void convert_offsets_to_ends (std::int32_t * starts, std::int32_t * lengths, std::int32_t * ends, std::int32_t n_starts)
54129{
55130 std::int32_t d_tid = blockIdx .x * blockDim .x + threadIdx .x ;
@@ -89,7 +164,7 @@ __global__ void calculate_tile_starts(const std::int32_t* query_starts,
89164 const std::int32_t * tiles_per_query_up_to_point)
90165{
91166 int32_t d_thread_id = blockIdx .x * blockDim .x + threadIdx .x ;
92- int32_t stride = blockDim .x * gridDim .x ;
167+ int32_t stride = blockDim .x * gridDim .x ;
93168 if (d_thread_id < num_queries)
94169 {
95170 // for each tile, we look up the query it corresponds to and offset it by the which tile in the query
@@ -167,7 +242,7 @@ void encode_anchor_query_locations(const Anchor* anchors,
167242 cub::DeviceScan::ExclusiveSum (d_temp_storage,
168243 temp_storage_bytes,
169244 query_lengths.data (), // this is the vector of encoded lengths
170- query_starts.data (), // at this point, this vector is empty
245+ query_starts.data (), // at this point, this vector is empty
171246 n_queries,
172247 _cuda_stream);
173248
@@ -178,15 +253,15 @@ void encode_anchor_query_locations(const Anchor* anchors,
178253 cub::DeviceScan::ExclusiveSum (d_temp_storage,
179254 temp_storage_bytes,
180255 query_lengths.data (), // this is the vector of encoded lengths
181- query_starts.data (),
256+ query_starts.data (),
182257 n_queries,
183258 _cuda_stream);
184259
185260 // paper uses the ends and finds the beginnings with x - w + 1, are we converting to that here?
186261 // TODO VI: I'm not entirely sure what this is for? I think we want to change the read query
187262 // (defined by [query_start, query_start + query_length] to [query_end - query_length + 1, query_end])
188263 // The above () is NOT true
189- convert_offsets_to_ends<<<(n_queries / block_size) + 1 , block_size, 0 , _cuda_stream>>> (query_starts.data (), // this gives how many starts at each index
264+ convert_offsets_to_ends<<<(n_queries / block_size) + 1 , block_size, 0 , _cuda_stream>>> (query_starts.data (), // this gives how many starts at each index
190265 query_lengths.data (), // this is the vector of encoded lengths
191266 query_ends.data (),
192267 n_queries);
@@ -215,21 +290,21 @@ void encode_anchor_query_locations(const Anchor* anchors,
215290 d_temp_storage = nullptr ;
216291 temp_storage_bytes = 0 ;
217292 cub::DeviceScan::ExclusiveSum (d_temp_storage,
218- temp_storage_bytes,
219- tiles_per_query.data (), // this is the vector of encoded lengths
220- d_tiles_per_query_up_to_point.data (),
221- n_queries,
222- _cuda_stream);
293+ temp_storage_bytes,
294+ tiles_per_query.data (), // this is the vector of encoded lengths
295+ d_tiles_per_query_up_to_point.data (),
296+ n_queries,
297+ _cuda_stream);
223298
224299 d_temp_buf.clear_and_resize (temp_storage_bytes);
225300 d_temp_storage = d_temp_buf.data ();
226-
301+
227302 cub::DeviceScan::ExclusiveSum (d_temp_storage,
228- temp_storage_bytes,
229- tiles_per_query.data (), // this is the vector of encoded lengths
230- d_tiles_per_query_up_to_point.data (),
231- n_queries,
232- _cuda_stream);
303+ temp_storage_bytes,
304+ tiles_per_query.data (), // this is the vector of encoded lengths
305+ d_tiles_per_query_up_to_point.data (),
306+ n_queries,
307+ _cuda_stream);
233308
234309 calculate_tile_starts<<<(n_queries / block_size) + 1 , block_size, 0 , _cuda_stream>>> (query_starts.data (), tiles_per_query.data (), tile_starts.data (), tile_size, n_queries, d_tiles_per_query_up_to_point.data ());
235310 }
@@ -276,7 +351,6 @@ void encode_anchor_query_target_pairs(const Anchor* anchors,
276351 d_num_query_target_pairs.data (),
277352 n_anchors);
278353
279-
280354 n_query_target_pairs = cudautils::get_value_from_device (d_num_query_target_pairs.data (), _cuda_stream);
281355
282356 d_temp_storage = nullptr ;
0 commit comments