@@ -43,38 +43,33 @@ class SD_LIB_EXPORT ConstantShapeHelper {
4343
4444 ~ConstantShapeHelper ();
4545 ConstantShapeBuffer* bufferForShapeInfo (DataType dataType, char order, const std::vector<LongType>& shape);
46- ConstantShapeBuffer* bufferForShapeInfo (ShapeDescriptor *descriptor);
47- ConstantShapeBuffer* bufferForShapeInfo (const LongType* shapeInfo);
48- ConstantShapeBuffer* bufferForShapeInfo (DataType dataType, char order, int rank, const LongType* shape);
49- ConstantShapeBuffer* createShapeInfoWithUnitiesForBroadcast (const LongType* maxShapeInfo,
50- const LongType* minShapeInfo,
51- memory::Workspace* workspace = nullptr ,
52- const std::vector<LongType>& dimensions = {});
53- ConstantShapeBuffer* createShapeInfoWithNoUnitiesForReduce (const LongType* maxShapeInfo,
54- const std::vector<LongType>* dimsWithUnities,
55- memory::Workspace* workspace = nullptr );
56- ConstantShapeBuffer* createSubArrShapeInfo (const LongType* inShapeInfo, const LongType* dims,
57- const LongType dimsSize,
58- memory::Workspace* workspace = nullptr );
46+ ConstantShapeBuffer* bufferForShapeInfo (LongType* shapeInfo);
47+ ConstantShapeBuffer* bufferForShapeInfo (DataType dataType, char order, int rank, LongType* shape);
48+ ConstantShapeBuffer* createShapeInfoWithUnitiesForBroadcast ( LongType* maxShapeInfo,
49+ LongType* minShapeInfo,
50+ memory::Workspace* workspace = nullptr ,
51+ const std::vector<LongType>& dimensions = {});
52+ ConstantShapeBuffer* createShapeInfoWithNoUnitiesForReduce ( LongType* maxShapeInfo,
53+ const std::vector<LongType>* dimsWithUnities,
54+ memory::Workspace* workspace = nullptr );
55+
56+ LongType* emptyShapeInfo (DataType dataType);
57+ LongType * scalarShapeInfo (DataType dataType);
58+ LongType* vectorShapeInfo (LongType length, DataType dataType);
59+ LongType* createShapeInfo (ShapeDescriptor *descriptor);
60+ LongType* createShapeInfo (DataType dataType, char order, const std::vector<LongType>& shape);
61+ LongType* createShapeInfo (DataType dataType, const char order, const int rank, LongType* shape, LongType extraProperties);
62+ LongType* createShapeInfo (DataType dataType, LongType* shapeInfo);
63+ LongType* createFromExisting (LongType* shapeInfo, sd::memory::Workspace* workspace);
64+ LongType* createFromExisting (LongType* shapeInfo, bool destroyOriginal = true );
5965
60- const LongType* emptyShapeInfo (DataType dataType);
61- const LongType* scalarShapeInfo (DataType dataType);
62- const LongType* vectorShapeInfo (LongType length, DataType dataType);
63- const LongType* createShapeInfo (ShapeDescriptor *descriptor);
64- const LongType* createShapeInfo (DataType dataType, char order, const std::vector<LongType>& shape);
65- const LongType* createShapeInfo (DataType dataType, const char order, const int rank, const LongType* shape, LongType extraProperties);
66- const LongType* createShapeInfo (DataType dataType, const LongType* shapeInfo);
67- const LongType* createFromExisting (const LongType* shapeInfo, sd::memory::Workspace* workspace);
68- const LongType* createFromExisting (const LongType* shapeInfo, bool destroyOriginal = true );
69- const LongType* createFromExisting (sd::LongType* shapeInfo, sd::memory::Workspace* workspace);
70- const LongType* createFromExisting (sd::LongType* shapeInfo, bool destroyOriginal = true );
7166
7267 bool checkBufferExistenceForShapeInfo (ShapeDescriptor *descriptor);
7368
74- ConstantShapeBuffer* storeAndWrapBuffer (const LongType* shapeInfo);
75- const LongType* castToDataType (const LongType* shapeInfo, const DataType newType);
76- const LongType* emptyShapeInfoWithShape (const DataType dataType, std::vector<LongType>& shape);
77- ConstantShapeBuffer* createConstBuffFromExisting (const sd::LongType* shapeInfo, sd::memory::Workspace* workspace);
69+ ConstantShapeBuffer* storeAndWrapBuffer ( LongType* shapeInfo);
70+ LongType* castToDataType ( LongType* shapeInfo, DataType newType);
71+ LongType* emptyShapeInfoWithShape (const DataType dataType, std::vector<LongType>& shape);
72+ ConstantShapeBuffer* createConstBuffFromExisting ( sd::LongType* shapeInfo, sd::memory::Workspace* workspace);
7873};
7974} // namespace sd
8075
0 commit comments