Skip to content

Commit e2bd294

Browse files
committed
Fix lu, rank and qr handling of empty arrays and check for nullptr
1 parent fc99193 commit e2bd294

File tree

6 files changed

+83
-16
lines changed

6 files changed

+83
-16
lines changed

src/api/c/lu.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ af_err af_lu(af_array *lower, af_array *upper, af_array *pivot,
4949

5050
af_dtype type = i_info.getType();
5151

52+
ARG_ASSERT(0, lower != nullptr);
53+
ARG_ASSERT(1, upper != nullptr);
54+
ARG_ASSERT(2, pivot != nullptr);
5255
ARG_ASSERT(3, i_info.isFloating()); // Only floating and complex types
5356

5457
if (i_info.ndims() == 0) {
@@ -81,21 +84,21 @@ af_err af_lu_inplace(af_array *pivot, af_array in, const bool is_lapack_piv) {
8184
}
8285

8386
ARG_ASSERT(1, i_info.isFloating()); // Only floating and complex types
87+
ARG_ASSERT(0, pivot != nullptr);
8488

8589
if (i_info.ndims() == 0) {
8690
return af_create_handle(pivot, 0, nullptr, type);
8791
}
8892

8993
af_array out;
90-
9194
switch (type) {
9295
case f32: out = lu_inplace<float>(in, is_lapack_piv); break;
9396
case f64: out = lu_inplace<double>(in, is_lapack_piv); break;
9497
case c32: out = lu_inplace<cfloat>(in, is_lapack_piv); break;
9598
case c64: out = lu_inplace<cdouble>(in, is_lapack_piv); break;
9699
default: TYPE_ERROR(1, type);
97100
}
98-
if (pivot != NULL) std::swap(*pivot, out);
101+
std::swap(*pivot, out);
99102
}
100103
CATCHALL;
101104

src/api/c/qr.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ af_err af_qr(af_array *q, af_array *r, af_array *tau, const af_array in) {
5555
return AF_SUCCESS;
5656
}
5757

58+
ARG_ASSERT(0, q != nullptr);
59+
ARG_ASSERT(1, r != nullptr);
60+
ARG_ASSERT(2, tau != nullptr);
5861
ARG_ASSERT(3, i_info.isFloating()); // Only floating and complex types
5962

6063
switch (type) {
@@ -81,21 +84,21 @@ af_err af_qr_inplace(af_array *tau, af_array in) {
8184
af_dtype type = i_info.getType();
8285

8386
ARG_ASSERT(1, i_info.isFloating()); // Only floating and complex types
87+
ARG_ASSERT(0, tau != nullptr);
8488

8589
if (i_info.ndims() == 0) {
8690
return af_create_handle(tau, 0, nullptr, type);
8791
}
8892

8993
af_array out;
90-
9194
switch (type) {
9295
case f32: out = qr_inplace<float>(in); break;
9396
case f64: out = qr_inplace<double>(in); break;
9497
case c32: out = qr_inplace<cfloat>(in); break;
9598
case c64: out = qr_inplace<cdouble>(in); break;
9699
default: TYPE_ERROR(1, type);
97100
}
98-
if (tau != NULL) std::swap(*tau, out);
101+
std::swap(*tau, out);
99102
}
100103
CATCHALL;
101104

src/api/c/rank.cpp

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,17 @@ af_err af_rank(uint* out, const af_array in, const double tol) {
5656
af_dtype type = i_info.getType();
5757

5858
ARG_ASSERT(1, i_info.isFloating()); // Only floating and complex types
59+
ARG_ASSERT(0, out != nullptr);
5960

60-
uint output;
61-
if (i_info.ndims() == 0) {
62-
output = 0;
63-
return AF_SUCCESS;
64-
}
65-
66-
switch (type) {
67-
case f32: output = rank<float>(in, tol); break;
68-
case f64: output = rank<double>(in, tol); break;
69-
case c32: output = rank<cfloat>(in, tol); break;
70-
case c64: output = rank<cdouble>(in, tol); break;
71-
default: TYPE_ERROR(1, type);
61+
uint output = 0;
62+
if (i_info.ndims() != 0) {
63+
switch (type) {
64+
case f32: output = rank<float>(in, tol); break;
65+
case f64: output = rank<double>(in, tol); break;
66+
case c32: output = rank<cfloat>(in, tol); break;
67+
case c64: output = rank<cdouble>(in, tol); break;
68+
default: TYPE_ERROR(1, type);
69+
}
7270
}
7371
std::swap(*out, output);
7472
}

test/lu_dense.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,46 @@ TYPED_TEST(LU, RectangularLarge1) {
235235
TYPED_TEST(LU, RectangularMultipleOfTwoLarge1) {
236236
luTester<TypeParam>(512, 1024, eps<TypeParam>());
237237
}
238+
239+
TEST(LU, NullLowerOutput) {
240+
if (noLAPACKTests()) return;
241+
dim4 dims(3, 3);
242+
af_array in = 0;
243+
ASSERT_SUCCESS(af_randu(&in, dims.ndims(), dims.get(), f32));
244+
245+
af_array upper, pivot;
246+
ASSERT_EQ(AF_ERR_ARG, af_lu(NULL, &upper, &pivot, in));
247+
ASSERT_SUCCESS(af_release_array(in));
248+
}
249+
250+
TEST(LU, NullUpperOutput) {
251+
if (noLAPACKTests()) return;
252+
dim4 dims(3, 3);
253+
af_array in = 0;
254+
ASSERT_SUCCESS(af_randu(&in, dims.ndims(), dims.get(), f32));
255+
256+
af_array lower, pivot;
257+
ASSERT_EQ(AF_ERR_ARG, af_lu(&lower, NULL, &pivot, in));
258+
ASSERT_SUCCESS(af_release_array(in));
259+
}
260+
261+
TEST(LU, NullPivotOutput) {
262+
if (noLAPACKTests()) return;
263+
dim4 dims(3, 3);
264+
af_array in = 0;
265+
ASSERT_SUCCESS(af_randu(&in, dims.ndims(), dims.get(), f32));
266+
267+
af_array lower, upper;
268+
ASSERT_EQ(AF_ERR_ARG, af_lu(&lower, &upper, NULL, in));
269+
ASSERT_SUCCESS(af_release_array(in));
270+
}
271+
272+
TEST(LU, InPlaceNullOutput) {
273+
if (noLAPACKTests()) return;
274+
dim4 dims(3, 3);
275+
af_array in = 0;
276+
ASSERT_SUCCESS(af_randu(&in, dims.ndims(), dims.get(), f32));
277+
278+
ASSERT_EQ(AF_ERR_ARG, af_lu_inplace(NULL, in, true));
279+
ASSERT_SUCCESS(af_release_array(in));
280+
}

test/qr_dense.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,13 @@ TYPED_TEST(QR, RectangularLarge1) {
179179
TYPED_TEST(QR, RectangularMultipleOfTwoLarge1) {
180180
qrTester<TypeParam>(512, 1024, eps<TypeParam>());
181181
}
182+
183+
TEST(QR, InPlaceNullOutput) {
184+
if (noLAPACKTests()) return;
185+
dim4 dims(3, 3);
186+
af_array in = 0;
187+
ASSERT_SUCCESS(af_randu(&in, dims.ndims(), dims.get(), f32));
188+
189+
ASSERT_EQ(AF_ERR_ARG, af_qr_inplace(NULL, in));
190+
ASSERT_SUCCESS(af_release_array(in));
191+
}

test/rank_dense.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,13 @@ void detTest() {
112112
}
113113

114114
TYPED_TEST(Det, Small) { detTest<TypeParam>(); }
115+
116+
TEST(Rank, NullOutput) {
117+
if (noLAPACKTests()) return;
118+
dim4 dims(3, 3);
119+
af_array in = 0;
120+
af_randu(&in, dims.ndims(), dims.get(), f32);
121+
122+
ASSERT_EQ(AF_ERR_ARG, af_rank(NULL, in, 1e-6));
123+
ASSERT_SUCCESS(af_release_array(in));
124+
}

0 commit comments

Comments
 (0)