@@ -30,13 +30,6 @@ enum BufferType {
3030 TENSOR_SPARSE = 4
3131};
3232
33- enum SparseDataType {
34- SPARSE_NO_VALUE = 0 , // do not need value pointer, all values are 1
35- SPARSE_FLOAT_VALUE = 1
36- };
37-
38- enum SparseDataFormat { SPARSE_CSR_FORMAT = 0 , SPARSE_CSC_FORMAT = 1 };
39-
4033class BufferArg ;
4134class SequenceArg ;
4235class SparseMatrixArg ;
@@ -79,19 +72,21 @@ class BufferArg {
7972 BufferArg (ValueType valueType,
8073 const TensorShape& shape,
8174 ArgType argType = UNSPECIFIED)
82- : buf_(nullptr ),
83- valueType_ (valueType),
84- shape_(shape),
85- argType_(argType) {}
75+ : buf_(nullptr ), valueType_(valueType), shape_(shape), argType_(argType) {
76+ bufferType_ = TENSOR_NORMAL;
77+ }
8678
8779 BufferArg (void * buf,
8880 ValueType valueType,
8981 const TensorShape& shape,
9082 ArgType argType = UNSPECIFIED)
91- : buf_(buf), valueType_(valueType), shape_(shape), argType_(argType) {}
83+ : buf_(buf), valueType_(valueType), shape_(shape), argType_(argType) {
84+ bufferType_ = TENSOR_NORMAL;
85+ }
9286
93- BufferArg (void * buf, ValueType valueType)
94- : buf_(buf), valueType_(valueType) {}
87+ BufferArg (void * buf, ValueType valueType) : buf_(buf), valueType_(valueType) {
88+ bufferType_ = TENSOR_NORMAL;
89+ }
9590
9691 BufferArg (const Matrix& matrix, ArgType argType = UNSPECIFIED)
9792 : buf_(
@@ -167,8 +162,9 @@ class BufferArg {
167162 ValueType valueType () const { return valueType_; }
168163 BufferType bufferType () const { return bufferType_; }
169164 const TensorShape& shape () const { return shape_; }
170- bool isSparse () const { return ( TENSOR_SPARSE == bufferType_) ; }
165+ bool isSparseArg () const { return TENSOR_SPARSE == bufferType_; }
171166 bool isSequenceArg () const { return TENSOR_SEQUENCE_DATA == bufferType_; }
167+ virtual size_t numElements () const { return shape_.getElements (); }
172168
173169 const SequenceArg& sequence () const ;
174170 const SparseMatrixArg& sparse () const ;
@@ -179,6 +175,7 @@ class BufferArg {
179175 TensorShape shape_;
180176 BufferType bufferType_{TENSOR_UNKNOWN};
181177 ArgType argType_{UNSPECIFIED};
178+ // TODO(tianbing), add deviceType_
182179 // leading dimensions. The size is dims_.size()
183180 // Dims lds_;
184181};
@@ -191,6 +188,7 @@ class SequenceIdArg : public BufferArg {
191188public:
192189 SequenceIdArg (const TensorShape& shape, ArgType argType = UNSPECIFIED)
193190 : BufferArg(VALUE_TYPE_INT32, shape, argType) {
191+ bufferType_ = TENSOR_SEQUENCE_ID;
194192 CHECK_EQ (shape_.ndims (), (size_t )1 );
195193 CHECK_GT (shape_[0 ], 1 );
196194 numSeqs_ = shape_[0 ] - 1 ;
@@ -228,7 +226,9 @@ class SequenceArg : public BufferArg {
228226 SequenceArg (ValueType valueType,
229227 const TensorShape& shape,
230228 ArgType argType = UNSPECIFIED)
231- : BufferArg(valueType, shape, argType), startPositions_(TensorShape()) {}
229+ : BufferArg(valueType, shape, argType), startPositions_(TensorShape()) {
230+ bufferType_ = TENSOR_SEQUENCE_DATA;
231+ }
232232
233233 SequenceArg (void * buf,
234234 ValueType valueType,
@@ -269,31 +269,75 @@ class SparseMatrixArg : public BufferArg {
269269 const BufferArg& row,
270270 const BufferArg& col,
271271 size_t nnz,
272- SparseDataFormat format,
273- SparseDataType type,
272+ SparseFormat format,
273+ SparseValueType type,
274274 ArgType argType = UNSPECIFIED)
275275 : BufferArg(buf, valueType, shape, argType),
276276 row_ (row),
277277 col_(col),
278278 nnz_(nnz),
279- format_(format),
280- type_(type) {
279+ format_(static_cast <SparseDataFormat>( format) ),
280+ type_(static_cast <SparseDataType>( type) ) {
281281 bufferType_ = TENSOR_SPARSE;
282282 CHECK ((valueType == VALUE_TYPE_FLOAT) || (valueType == VALUE_TYPE_DOUBLE));
283283 CHECK_EQ (shape_.ndims (), (size_t )2 );
284284 CHECK_EQ (row_.shape ().ndims (), (size_t )1 );
285285 CHECK_EQ (col_.shape ().ndims (), (size_t )1 );
286- if (format == SPARSE_CSR_FORMAT ) {
286+ if (format_ == T_SPARSE_CSR ) {
287287 CHECK_EQ (nnz, col.shape ()[0 ]);
288- } else if (format == SPARSE_CSC_FORMAT ) {
288+ } else if (format_ == T_SPARSE_CSC ) {
289289 CHECK_EQ (nnz, row.shape ()[0 ]);
290290 }
291291 }
292292
293+ SparseMatrixArg (ValueType valueType,
294+ const TensorShape& shape,
295+ size_t nnz,
296+ SparseFormat format,
297+ SparseValueType type,
298+ ArgType argType = UNSPECIFIED)
299+ : BufferArg(valueType, shape, argType),
300+ row_(BufferArg(nullptr , VALUE_TYPE_INT32)),
301+ col_(BufferArg(nullptr , VALUE_TYPE_INT32)),
302+ nnz_(nnz),
303+ format_(static_cast <SparseDataFormat>(format)),
304+ type_(static_cast <SparseDataType>(type)) {
305+ bufferType_ = TENSOR_SPARSE;
306+ CHECK ((valueType == VALUE_TYPE_FLOAT) || (valueType == VALUE_TYPE_DOUBLE));
307+ CHECK_EQ (shape_.ndims (), (size_t )2 );
308+
309+ // / len of row_ : height + 1 (CSR) or nnz (CSC), buf_ == nullptr
310+ row_ = (format_ == T_SPARSE_CSR
311+ ? BufferArg (VALUE_TYPE_INT32, TensorShape{shape_[0 ] + 1 })
312+ : BufferArg (VALUE_TYPE_INT32, TensorShape{nnz}));
313+ // / len of col_ : width + 1 (CSC) or nnz (CSR), buf_ == nullptr
314+ col_ = (format_ == T_SPARSE_CSR
315+ ? BufferArg (VALUE_TYPE_INT32, TensorShape{nnz})
316+ : BufferArg (VALUE_TYPE_INT32, TensorShape{shape_[1 ] + 1 }));
317+ }
318+
293319 SparseMatrixArg (const CpuSparseMatrix& sparse, ArgType argType = UNSPECIFIED);
294320
295321 SparseMatrixArg (const GpuSparseMatrix& sparse, ArgType argType = UNSPECIFIED);
296322
323+ template <DeviceType DType>
324+ typename Tensor<real, DType>::SparseMatrix SparseMatrix () const {
325+ CHECK (buf_);
326+ CHECK (valueType_ == DataType<real>::value);
327+ // CHECK(deviceType_ == DType);
328+ CHECK_EQ (2 , shape_.ndims ());
329+ return typename Tensor<real, DType>::SparseMatrix (
330+ reinterpret_cast <real*>(buf_),
331+ reinterpret_cast <int *>(row_.data ()),
332+ reinterpret_cast <int *>(col_.data ()),
333+ shape_[0 ],
334+ shape_[1 ],
335+ nnz_,
336+ static_cast <SparseValueType>(type_),
337+ static_cast <SparseFormat>(format_),
338+ false );
339+ }
340+
297341 ~SparseMatrixArg () {}
298342
299343 void * getRowBuf () const { return row_.data (); }
@@ -302,6 +346,8 @@ class SparseMatrixArg : public BufferArg {
302346
303347 size_t nnz () const { return nnz_; }
304348
349+ size_t numElements () const override { return nnz_; }
350+
305351 SparseDataFormat dataFormat () const { return format_; }
306352
307353 SparseDataType dataType () const { return type_; }
0 commit comments