Skip to content

Commit a9947ce

Browse files
committed
update librett.cpp for the changed librettPlan API
1 parent 32dec92 commit a9947ce

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

tests/librett.cpp

+21-16
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,11 @@ BOOST_AUTO_TEST_CASE(librett_gpu_mem) {
6969
TiledArray::permutation_to_col_major(perm);
7070

7171
librettHandle plan;
72-
//librettResult_t status;
72+
librett_gpuStream_t stream;
7373
librettResult status;
7474

75-
status = librettPlan(&plan, 2, extent.data(), perm.data(), sizeof(int), 0);
75+
status =
76+
librettPlan(&plan, 2, extent.data(), perm.data(), sizeof(int), stream);
7677

7778
BOOST_CHECK(status == LIBRETT_SUCCESS);
7879

@@ -117,7 +118,7 @@ BOOST_AUTO_TEST_CASE(librett_gpu_mem_nonsym) {
117118
cudaMemcpy(a_device, a_host, A * B * sizeof(int), cudaMemcpyHostToDevice);
118119

119120
librettHandle plan;
120-
//librettResult_t status;
121+
librett_gpuStream_t stream;
121122
librettResult status;
122123

123124
std::vector<int> extent({B, A});
@@ -126,7 +127,8 @@ BOOST_AUTO_TEST_CASE(librett_gpu_mem_nonsym) {
126127
std::vector<int> perm({1, 0});
127128
TiledArray::permutation_to_col_major(perm);
128129

129-
status = librettPlan(&plan, 2, extent.data(), perm.data(), sizeof(int), 0);
130+
status =
131+
librettPlan(&plan, 2, extent.data(), perm.data(), sizeof(int), stream);
130132

131133
BOOST_CHECK(status == LIBRETT_SUCCESS);
132134

@@ -175,16 +177,16 @@ BOOST_AUTO_TEST_CASE(librett_gpu_mem_nonsym_rank_three_column_major) {
175177
// b(j,i,k) = a(i,j,k)
176178

177179
librettHandle plan;
178-
//librettResult_t status;
180+
librett_gpuStream_t stream;
179181
librettResult status;
180182

181183
std::vector<int> extent3{int(A), int(B), int(C)};
182184

183185
std::vector<int> perm3{1, 0, 2};
184186
// std::vector<int> perm3{0, 2, 1};
185187

186-
status = librettPlanMeasure(&plan, 3, extent3.data(), perm3.data(), sizeof(int),
187-
0, a_device, b_device);
188+
status = librettPlanMeasure(&plan, 3, extent3.data(), perm3.data(),
189+
sizeof(int), stream, a_device, b_device);
188190

189191
BOOST_CHECK(status == LIBRETT_SUCCESS);
190192

@@ -238,7 +240,7 @@ BOOST_AUTO_TEST_CASE(librett_gpu_mem_nonsym_rank_three_row_major) {
238240
// b(j,i,k) = a(i,j,k)
239241

240242
librettHandle plan;
241-
//librettResult_t status;
243+
librett_gpuStream_t stream;
242244
librettResult status;
243245

244246
std::vector<int> extent({A, B, C});
@@ -247,8 +249,8 @@ BOOST_AUTO_TEST_CASE(librett_gpu_mem_nonsym_rank_three_row_major) {
247249
std::vector<int> perm({1, 0, 2});
248250
TiledArray::permutation_to_col_major(perm);
249251

250-
status = librettPlanMeasure(&plan, 3, extent.data(), perm.data(), sizeof(int), 0,
251-
a_device, b_device);
252+
status = librettPlanMeasure(&plan, 3, extent.data(), perm.data(), sizeof(int),
253+
stream, a_device, b_device);
252254

253255
BOOST_CHECK(status == LIBRETT_SUCCESS);
254256

@@ -295,7 +297,7 @@ BOOST_AUTO_TEST_CASE(librett_unified_mem) {
295297
}
296298

297299
librettHandle plan;
298-
//librettResult_t status;
300+
librett_gpuStream_t stream;
299301
librettResult status;
300302

301303
std::vector<int> extent({A, A});
@@ -304,7 +306,8 @@ BOOST_AUTO_TEST_CASE(librett_unified_mem) {
304306
std::vector<int> perm({1, 0});
305307
TiledArray::permutation_to_col_major(perm);
306308

307-
status = librettPlan(&plan, 2, extent.data(), perm.data(), sizeof(int), 0);
309+
status =
310+
librettPlan(&plan, 2, extent.data(), perm.data(), sizeof(int), stream);
308311

309312
BOOST_CHECK(status == LIBRETT_SUCCESS);
310313

@@ -344,7 +347,7 @@ BOOST_AUTO_TEST_CASE(librett_unified_mem_nonsym) {
344347
}
345348

346349
librettHandle plan;
347-
//librettResult_t status;
350+
librett_gpuStream_t stream;
348351
librettResult status;
349352

350353
std::vector<int> extent({B, A});
@@ -353,7 +356,8 @@ BOOST_AUTO_TEST_CASE(librett_unified_mem_nonsym) {
353356
std::vector<int> perm({1, 0});
354357
TiledArray::permutation_to_col_major(perm);
355358

356-
status = librettPlan(&plan, 2, extent.data(), perm.data(), sizeof(int), 0);
359+
status =
360+
librettPlan(&plan, 2, extent.data(), perm.data(), sizeof(int), stream);
357361

358362
BOOST_CHECK(status == LIBRETT_SUCCESS);
359363

@@ -393,7 +397,7 @@ BOOST_AUTO_TEST_CASE(librett_unified_mem_rank_three) {
393397
}
394398

395399
librettHandle plan;
396-
//librettResult_t status;
400+
librett_gpuStream_t stream;
397401
librettResult status;
398402

399403
// b(k,i,j) = a(i,j,k)
@@ -404,7 +408,8 @@ BOOST_AUTO_TEST_CASE(librett_unified_mem_rank_three) {
404408
std::vector<int> perm({2, 0, 1});
405409
TiledArray::permutation_to_col_major(perm);
406410

407-
status = librettPlan(&plan, 3, extent.data(), perm.data(), sizeof(int), 0);
411+
status =
412+
librettPlan(&plan, 3, extent.data(), perm.data(), sizeof(int), stream);
408413

409414
BOOST_CHECK(status == LIBRETT_SUCCESS);
410415

0 commit comments

Comments
 (0)