Skip to content

Commit f225a4f

Browse files
umar456pradeep
authored andcommitted
Allow creation of empty sparse arrays. Allow sparse deep copies
1 parent 8eaceb8 commit f225a4f

File tree

3 files changed

+88
-45
lines changed

3 files changed

+88
-45
lines changed

src/api/c/array.cpp

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
#include <handle.hpp>
1313
#include <backend.hpp>
1414
#include <copy.hpp>
15+
#include <sparse.hpp>
1516
#include <sparse_handle.hpp>
17+
#include <af/sparse.h>
1618

1719
using namespace detail;
20+
using common::SparseArrayBase;
1821

1922
const ArrayInfo&
2023
getInfo(const af_array arr, bool sparse_check, bool device_check)
@@ -134,36 +137,43 @@ af_err af_copy_array(af_array *out, const af_array in)
134137
const ArrayInfo& info = getInfo(in, false);
135138
const af_dtype type = info.getType();
136139

137-
if(info.ndims() == 0) {
138-
return af_create_handle(out, 0, nullptr, type);
139-
}
140-
141-
af_array res;
142-
140+
af_array res = 0;
143141
if(info.isSparse()) {
144-
switch(type) {
145-
case f32: res = copySparseArray<float >(in); break;
146-
case f64: res = copySparseArray<double >(in); break;
147-
case c32: res = copySparseArray<cfloat >(in); break;
148-
case c64: res = copySparseArray<cdouble>(in); break;
149-
default : TYPE_ERROR(0, type);
142+
SparseArrayBase sbase = getSparseArrayBase(in);
143+
if(info.ndims() == 0) {
144+
return af_create_sparse_array_from_ptr(out,
145+
info.dims()[0], info.dims()[1],
146+
0, nullptr, nullptr, nullptr,
147+
type, sbase.getStorage(), afDevice);
148+
} else {
149+
switch(type) {
150+
case f32: res = copySparseArray<float >(in); break;
151+
case f64: res = copySparseArray<double >(in); break;
152+
case c32: res = copySparseArray<cfloat >(in); break;
153+
case c64: res = copySparseArray<cdouble>(in); break;
154+
default : TYPE_ERROR(0, type);
155+
}
150156
}
151157
} else {
152-
switch(type) {
153-
case f32: res = copyArray<float >(in); break;
154-
case c32: res = copyArray<cfloat >(in); break;
155-
case f64: res = copyArray<double >(in); break;
156-
case c64: res = copyArray<cdouble >(in); break;
157-
case b8: res = copyArray<char >(in); break;
158-
case s32: res = copyArray<int >(in); break;
159-
case u32: res = copyArray<uint >(in); break;
160-
case u8: res = copyArray<uchar >(in); break;
161-
case s64: res = copyArray<intl >(in); break;
162-
case u64: res = copyArray<uintl >(in); break;
163-
case s16: res = copyArray<short >(in); break;
164-
case u16: res = copyArray<ushort >(in); break;
165-
default: TYPE_ERROR(1, type);
166-
}
158+
if(info.ndims() == 0) {
159+
return af_create_handle(out, 0, nullptr, type);
160+
} else {
161+
switch(type) {
162+
case f32: res = copyArray<float >(in); break;
163+
case c32: res = copyArray<cfloat >(in); break;
164+
case f64: res = copyArray<double >(in); break;
165+
case c64: res = copyArray<cdouble >(in); break;
166+
case b8: res = copyArray<char >(in); break;
167+
case s32: res = copyArray<int >(in); break;
168+
case u32: res = copyArray<uint >(in); break;
169+
case u8: res = copyArray<uchar >(in); break;
170+
case s64: res = copyArray<intl >(in); break;
171+
case u64: res = copyArray<uintl >(in); break;
172+
case s16: res = copyArray<short >(in); break;
173+
case u16: res = copyArray<ushort >(in); break;
174+
default: TYPE_ERROR(1, type);
175+
}
176+
}
167177
}
168178
std::swap(*out, res);
169179
}

src/api/c/sparse.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,14 @@ af_array createSparseArrayFromPtr(
113113
{
114114
SparseArray<T> sparse = createEmptySparseArray<T>(dims, nNZ, stype);
115115

116-
if(source == afHost)
117-
sparse = common::createHostDataSparseArray(
118-
dims, nNZ, values, rowIdx, colIdx, stype);
119-
else if (source == afDevice)
120-
sparse = common::createDeviceDataSparseArray(
121-
dims, nNZ, values, rowIdx, colIdx, stype);
116+
if(nNZ) {
117+
if(source == afHost)
118+
sparse = common::createHostDataSparseArray(
119+
dims, nNZ, values, rowIdx, colIdx, stype);
120+
else if (source == afDevice)
121+
sparse = common::createDeviceDataSparseArray(
122+
dims, nNZ, values, rowIdx, colIdx, stype);
123+
}
122124

123125
return getHandle(sparse);
124126
}

test/sparse.cpp

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,39 +12,39 @@
1212
#include <sparse_common.hpp>
1313

1414
#define SPARSE_TESTS(T, eps) \
15-
TEST(SPARSE, T##Square) \
15+
TEST(Sparse, T##Square) \
1616
{ \
1717
sparseTester<T>(1000, 1000, 100, 5, eps); \
1818
} \
19-
TEST(SPARSE, T##RectMultiple) \
19+
TEST(Sparse, T##RectMultiple) \
2020
{ \
2121
sparseTester<T>(2048, 1024, 512, 3, eps); \
2222
} \
23-
TEST(SPARSE, T##RectDense) \
23+
TEST(Sparse, T##RectDense) \
2424
{ \
2525
sparseTester<T>(500, 1000, 250, 1, eps); \
2626
} \
27-
TEST(SPARSE, T##MatVec) \
27+
TEST(Sparse, T##MatVec) \
2828
{ \
2929
sparseTester<T>(625, 1331, 1, 2, eps); \
3030
} \
31-
TEST(SPARSE_TRANSPOSE, T##MatVec) \
31+
TEST(Sparse, Transpose_##T##MatVec) \
3232
{ \
3333
sparseTransposeTester<T>(625, 1331, 1, 2, eps); \
3434
} \
35-
TEST(SPARSE_TRANSPOSE, T##Square) \
35+
TEST(Sparse, Transpose_##T##Square) \
3636
{ \
3737
sparseTransposeTester<T>(1000, 1000, 100, 5, eps); \
3838
} \
39-
TEST(SPARSE_TRANSPOSE, T##RectMultiple) \
39+
TEST(Sparse, Transpose_##T##RectMultiple) \
4040
{ \
4141
sparseTransposeTester<T>(2048, 1024, 512, 3, eps); \
4242
} \
43-
TEST(SPARSE_TRANSPOSE, T##RectDense) \
43+
TEST(Sparse, Transpose_##T##RectDense) \
4444
{ \
4545
sparseTransposeTester<T>(453, 751, 397, 1, eps); \
4646
} \
47-
TEST(SPARSE, T##ConvertCSR) \
47+
TEST(Sparse, T##ConvertCSR) \
4848
{ \
4949
convertCSR<T>(2345, 5678, 0.5); \
5050
} \
@@ -57,7 +57,7 @@ SPARSE_TESTS(cdouble, 1E-5)
5757
#undef SPARSE_TESTS
5858

5959
#define CREATE_TESTS(STYPE) \
60-
TEST(SPARSE_CREATE, STYPE) \
60+
TEST(Sparse, Create_##STYPE) \
6161
{ \
6262
createFunction<STYPE>(); \
6363
}
@@ -67,7 +67,7 @@ CREATE_TESTS(AF_STORAGE_COO)
6767

6868
#undef CREATE_TESTS
6969

70-
TEST(SPARSE_CREATE, AF_STORAGE_CSC)
70+
TEST(Sparse, Create_AF_STORAGE_CSC)
7171
{
7272
af::array d = af::identity(3, 3);
7373

@@ -78,7 +78,7 @@ TEST(SPARSE_CREATE, AF_STORAGE_CSC)
7878
}
7979

8080
#define CAST_TESTS_TYPES(Ti, To, SUFFIX, M, N, F) \
81-
TEST(SPARSE_CAST, Ti##_##To##_##SUFFIX) \
81+
TEST(Sparse, Cast_##Ti##_##To##_##SUFFIX) \
8282
{ \
8383
sparseCastTester<Ti, To>(M, N, F); \
8484
} \
@@ -171,3 +171,34 @@ TYPED_TEST(Sparse, DeepCopy) {
171171
ASSERT_TRUE(allTrue<bool>(d == d2));
172172
}
173173
}
174+
175+
TYPED_TEST(Sparse, Empty) {
176+
if (noDoubleTests<TypeParam>()) return;
177+
using namespace af;
178+
af_array ret = 0;
179+
dim_t rows = 0, cols = 0, nnz = 0;
180+
EXPECT_EQ(AF_SUCCESS,
181+
af_create_sparse_array_from_ptr(
182+
&ret,
183+
rows, cols,
184+
nnz, NULL, NULL, NULL,
185+
(af_dtype)dtype_traits<TypeParam>::af_type,
186+
AF_STORAGE_CSR, afHost));
187+
bool sparse = false;
188+
EXPECT_EQ(AF_SUCCESS, af_is_sparse(&sparse, ret));
189+
EXPECT_EQ(true, sparse);
190+
}
191+
192+
TYPED_TEST(Sparse, EmptyDeepCopy) {
193+
if (noDoubleTests<TypeParam>()) return;
194+
using namespace af;
195+
array a = sparse(0, 0,
196+
array(0, (af_dtype)af::dtype_traits<TypeParam>::af_type),
197+
array(0, s32), array(0, s32));
198+
EXPECT_TRUE(a.issparse());
199+
EXPECT_EQ(0, sparseGetNNZ(a));
200+
201+
array b = a.copy();
202+
EXPECT_TRUE(b.issparse());
203+
EXPECT_EQ(0, sparseGetNNZ(b));
204+
}

0 commit comments

Comments
 (0)