@@ -61,73 +61,80 @@ void build_mr_linkage(
6161 size_t n = X.extent (1 );
6262 auto stream = raft::resource::get_cuda_stream (handle);
6363
64- auto mr_indptr = raft::make_device_vector<value_idx, value_idx>(handle, m + 1 );
65- raft::sparse::COO<value_t , value_idx, nnz_t > mr_coo (stream, min_samples * m * 2 );
66-
67- auto inds = raft::make_device_matrix<value_idx, value_idx>(handle, m, min_samples);
68- auto dists = raft::make_device_matrix<value_t , value_idx>(handle, m, min_samples);
69-
70- if (all_neighbors_p.metric != metric) {
71- RAFT_LOG_WARN (" Setting all neighbors metric to given metrix for build_mr_linkage" );
72- all_neighbors_p.metric = metric;
73- }
74- cuvs::neighbors::all_neighbors::build (
75- handle, all_neighbors_p, X, inds.view (), dists.view (), core_dists, alpha);
76-
77- // self-loops get max distance
78- auto coo_rows = raft::make_device_vector<value_idx, value_idx>(handle, min_samples * m);
79- raft::linalg::map_offset (handle, coo_rows.view (), raft::div_const_op<value_idx>(min_samples));
80-
81- raft::sparse::linalg::symmetrize (handle,
82- coo_rows.data_handle (),
83- inds.data_handle (),
84- dists.data_handle (),
85- static_cast <value_idx>(m),
86- static_cast <value_idx>(m),
87- static_cast <nnz_t >(min_samples * m),
88- mr_coo);
89-
90- raft::sparse::convert::sorted_coo_to_csr (
91- mr_coo.rows (), mr_coo.nnz , mr_indptr.data_handle (), m + 1 , stream);
92-
93- auto rows_view = raft::make_device_vector_view<const value_idx, nnz_t >(mr_coo.rows (), mr_coo.nnz );
94- auto cols_view = raft::make_device_vector_view<const value_idx, nnz_t >(mr_coo.cols (), mr_coo.nnz );
95- auto vals_in_view =
96- raft::make_device_vector_view<const value_t , nnz_t >(mr_coo.vals (), mr_coo.nnz );
97- auto vals_out_view = raft::make_device_vector_view<value_t , nnz_t >(mr_coo.vals (), mr_coo.nnz );
98-
99- raft::linalg::map (
100- handle,
101- vals_out_view,
102- [=] __device__ (const value_idx row, const value_idx col, const value_t val) {
103- return row == col ? std::numeric_limits<value_t >::max () : val;
104- },
105- rows_view,
106- cols_view,
107- vals_in_view);
108-
109- rmm::device_uvector<value_idx> color (m, raft::resource::get_cuda_stream (handle));
110- cuvs::sparse::neighbors::MutualReachabilityFixConnectivitiesRedOp<value_idx, value_t >
111- reduction_op (core_dists.data_handle (), m);
112-
113- size_t nnz = m * min_samples;
114-
115- detail::build_sorted_mst<value_idx, value_t >(handle,
116- X.data_handle (),
117- mr_indptr.data_handle (),
118- mr_coo.cols (),
119- mr_coo.vals (),
120- m,
121- n,
122- out_mst.structure_view ().get_rows ().data (),
123- out_mst.structure_view ().get_cols ().data (),
124- out_mst.get_elements ().data (),
125- color.data (),
126- mr_coo.nnz ,
127- reduction_op,
128- metric,
129- 10 );
130-
64+ { // scope to drop mr_coo and mr_indptr early
65+ std::optional<raft::sparse::COO<value_t , value_idx, nnz_t >> mr_coo;
66+
67+ { // scope to drop inds and dists matrices early
68+ auto inds = raft::make_device_matrix<value_idx, value_idx>(handle, m, min_samples);
69+ auto dists = raft::make_device_matrix<value_t , value_idx>(handle, m, min_samples);
70+
71+ if (all_neighbors_p.metric != metric) {
72+ RAFT_LOG_WARN (" Setting all neighbors metric to given metrix for build_mr_linkage" );
73+ all_neighbors_p.metric = metric;
74+ }
75+ cuvs::neighbors::all_neighbors::build (
76+ handle, all_neighbors_p, X, inds.view (), dists.view (), core_dists, alpha);
77+
78+ // allocate memory after all neighbors build
79+ mr_coo.emplace (stream, min_samples * m * 2 );
80+ // self-loops get max distance
81+ auto coo_rows = raft::make_device_vector<value_idx, value_idx>(handle, min_samples * m);
82+ raft::linalg::map_offset (handle, coo_rows.view (), raft::div_const_op<value_idx>(min_samples));
83+
84+ raft::sparse::linalg::symmetrize (handle,
85+ coo_rows.data_handle (),
86+ inds.data_handle (),
87+ dists.data_handle (),
88+ static_cast <value_idx>(m),
89+ static_cast <value_idx>(m),
90+ static_cast <nnz_t >(min_samples * m),
91+ mr_coo.value ());
92+ } // scope to drop inds and dists matrices early
93+ auto mr_indptr = raft::make_device_vector<value_idx, value_idx>(handle, m + 1 );
94+ raft::sparse::convert::sorted_coo_to_csr (
95+ mr_coo.value ().rows (), mr_coo.value ().nnz , mr_indptr.data_handle (), m + 1 , stream);
96+
97+ auto rows_view = raft::make_device_vector_view<const value_idx, nnz_t >(mr_coo.value ().rows (),
98+ mr_coo.value ().nnz );
99+ auto cols_view = raft::make_device_vector_view<const value_idx, nnz_t >(mr_coo.value ().cols (),
100+ mr_coo.value ().nnz );
101+ auto vals_in_view = raft::make_device_vector_view<const value_t , nnz_t >(mr_coo.value ().vals (),
102+ mr_coo.value ().nnz );
103+ auto vals_out_view =
104+ raft::make_device_vector_view<value_t , nnz_t >(mr_coo.value ().vals (), mr_coo.value ().nnz );
105+
106+ raft::linalg::map (
107+ handle,
108+ vals_out_view,
109+ [=] __device__ (const value_idx row, const value_idx col, const value_t val) {
110+ return row == col ? std::numeric_limits<value_t >::max () : val;
111+ },
112+ rows_view,
113+ cols_view,
114+ vals_in_view);
115+
116+ rmm::device_uvector<value_idx> color (m, raft::resource::get_cuda_stream (handle));
117+ cuvs::sparse::neighbors::MutualReachabilityFixConnectivitiesRedOp<value_idx, value_t >
118+ reduction_op (core_dists.data_handle (), m);
119+
120+ size_t nnz = m * min_samples;
121+
122+ detail::build_sorted_mst<value_idx, value_t >(handle,
123+ X.data_handle (),
124+ mr_indptr.data_handle (),
125+ mr_coo.value ().cols (),
126+ mr_coo.value ().vals (),
127+ m,
128+ n,
129+ out_mst.structure_view ().get_rows ().data (),
130+ out_mst.structure_view ().get_cols ().data (),
131+ out_mst.get_elements ().data (),
132+ color.data (),
133+ mr_coo.value ().nnz ,
134+ reduction_op,
135+ metric,
136+ 10 );
137+ } // scope to drop mr_coo and mr_indptr early
131138 /* *
132139 * Perform hierarchical labeling
133140 */
0 commit comments