Skip to content

Commit 66c312b

Browse files
authored
Rewrites multi threading/thread access in the trie infrastructure (deeplearning4j#10174)
* Rewrites multi threading/thread access in the trie infrastructure * update author * add tad calculator * fix author
1 parent 239af78 commit 66c312b

17 files changed

+1365
-757
lines changed

libnd4j/include/array/ConstantOffsetsBuffer.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ class SD_LIB_EXPORT ConstantOffsetsBuffer {
4343
ConstantOffsetsBuffer() = default;
4444
~ConstantOffsetsBuffer() = default;
4545

46-
const LongType *primary() const;
47-
const LongType *special() const;
48-
const LongType *platform() const;
46+
LongType *primary();
47+
LongType *special();
48+
LongType *platform();
4949
};
5050

5151
} // namespace sd

libnd4j/include/array/ConstantShapeBuffer.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,12 @@ class SD_LIB_EXPORT ConstantShapeBuffer {
5151
backward::StackTrace st;
5252
#endif
5353
#endif
54-
const LongType *primary() const;
55-
const LongType *special() const;
56-
const LongType *platform() const;
54+
LongType *primary() ;
55+
LongType *special() ;
56+
LongType *platform() ;
5757
};
5858

59+
5960
} // namespace sd
6061

6162
#endif // SD_ARRAY_CONSTANTSHAPEBUFFER_H_
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
2+
/* ******************************************************************************
3+
*
4+
* Copyright (c) 2024 Konduit K.K.
5+
* This program and the accompanying materials are made available under the
6+
* terms of the Apache License, Version 2.0 which is available at
7+
* https://www.apache.org/licenses/LICENSE-2.0.
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
* License for the specific language governing permissions and limitations
13+
* under the License.
14+
*
15+
* SPDX-License-Identifier: Apache-2.0
16+
******************************************************************************/
17+
18+
//
19+
// @author Adam Gibson
20+
//
21+
22+
#ifndef DEV_TESTS_TADCALCULATOR_H
23+
#define DEV_TESTS_TADCALCULATOR_H
24+
25+
#include <array/TadPack.h>
26+
#include <system/common.h>
27+
#include <helpers/ConstantHelper.h>
28+
#include <helpers/ConstantShapeHelper.h>
29+
#include <array/ConstantShapeBuffer.h>
30+
#include <array/ConstantOffsetsBuffer.h>
31+
#include <vector>
32+
#include <memory>
33+
34+
namespace sd {
35+
36+
/**
37+
* TadCalculator handles the computation of Tensor Along Dimension (TAD) information
38+
* including shapes and offsets for sub-arrays.
39+
*/
40+
class SD_LIB_EXPORT TadCalculator {
41+
private:
42+
LongType* _originalShape; // Original shape info pointer
43+
ConstantShapeBuffer _tadShape; // Calculated TAD shape buffer
44+
ConstantOffsetsBuffer _tadOffsets; // Calculated TAD offsets buffer
45+
LongType _numTads; // Number of TADs
46+
47+
public:
48+
/**
49+
* Constructor for TadCalculator
50+
* @param originalShape Pointer to the original shape information
51+
*/
52+
explicit TadCalculator(LongType* originalShape);
53+
~TadCalculator() = default;
54+
55+
/**
56+
* Creates a TAD pack for the given dimensions
57+
* @param dimensions Vector of dimensions to calculate TADs for
58+
*/
59+
void createTadPack(const std::vector<LongType>& dimensions);
60+
61+
/**
62+
* Returns the calculated TAD shape buffer
63+
* @return ConstantShapeBuffer containing TAD shape information
64+
*/
65+
ConstantShapeBuffer tadShape() const { return _tadShape; }
66+
67+
/**
68+
* Returns the calculated TAD offsets buffer
69+
* @return ConstantOffsetsBuffer containing TAD offset information
70+
*/
71+
ConstantOffsetsBuffer tadOffsets() const { return _tadOffsets; }
72+
73+
/**
74+
* Returns the number of TADs calculated
75+
* @return Number of TADs
76+
*/
77+
LongType numberOfTads() const { return _numTads; }
78+
};
79+
80+
} // namespace sd
81+
82+
#endif // DEV_TESTS_TADCALCULATOR_H

libnd4j/include/array/TadPack.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,21 +45,21 @@ class SD_LIB_EXPORT TadPack {
4545
TadPack() = default;
4646
~TadPack() {};
4747

48-
const LongType* primaryShapeInfo() const;
49-
const LongType* primaryOffsets() const;
48+
LongType* primaryShapeInfo();
49+
LongType* primaryOffsets();
5050

51-
const LongType* specialShapeInfo() const;
52-
const LongType* specialOffsets() const;
51+
LongType* specialShapeInfo();
52+
LongType* specialOffsets();
5353

5454
LongType numberOfTads() const;
55-
LongType shapeInfoLength() const;
55+
LongType shapeInfoLength();
5656
/**
5757
* Extracts an NDArray view for the given TAD index.
5858
* @param input The input NDArray.
5959
* @param tadIndex The index of the TAD to extract.
6060
* @return A new NDArray view representing the TAD.
6161
*/
62-
NDArray *extractTadView(NDArray* input, sd::LongType tadIndex) const {
62+
NDArray *extractTadView(NDArray* input, sd::LongType tadIndex) {
6363
auto shapeInfo = primaryShapeInfo();
6464
auto offsets = primaryOffsets();
6565

@@ -73,10 +73,10 @@ class SD_LIB_EXPORT TadPack {
7373
* These methods return either primary or special pointers depending on platform binaries were compiled for
7474
* @return
7575
*/
76-
const LongType* platformShapeInfo() const;
77-
const LongType* platformOffsets() const;
76+
LongType* platformShapeInfo();
77+
LongType* platformOffsets();
7878

79-
void print(const char* msg) const;
79+
void print(const char* msg);
8080
};
8181
} // namespace sd
8282

libnd4j/include/helpers/ConstantShapeHelper.h

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -43,38 +43,33 @@ class SD_LIB_EXPORT ConstantShapeHelper {
4343

4444
~ConstantShapeHelper();
4545
ConstantShapeBuffer* bufferForShapeInfo(DataType dataType, char order, const std::vector<LongType>& shape);
46-
ConstantShapeBuffer* bufferForShapeInfo(ShapeDescriptor *descriptor);
47-
ConstantShapeBuffer* bufferForShapeInfo(const LongType* shapeInfo);
48-
ConstantShapeBuffer* bufferForShapeInfo(DataType dataType, char order, int rank, const LongType* shape);
49-
ConstantShapeBuffer* createShapeInfoWithUnitiesForBroadcast(const LongType* maxShapeInfo,
50-
const LongType* minShapeInfo,
51-
memory::Workspace* workspace = nullptr,
52-
const std::vector<LongType>& dimensions = {});
53-
ConstantShapeBuffer* createShapeInfoWithNoUnitiesForReduce(const LongType* maxShapeInfo,
54-
const std::vector<LongType>* dimsWithUnities,
55-
memory::Workspace* workspace = nullptr);
56-
ConstantShapeBuffer* createSubArrShapeInfo(const LongType* inShapeInfo, const LongType* dims,
57-
const LongType dimsSize,
58-
memory::Workspace* workspace = nullptr);
46+
ConstantShapeBuffer* bufferForShapeInfo(LongType* shapeInfo);
47+
ConstantShapeBuffer* bufferForShapeInfo(DataType dataType, char order, int rank, LongType* shape);
48+
ConstantShapeBuffer* createShapeInfoWithUnitiesForBroadcast( LongType* maxShapeInfo,
49+
LongType* minShapeInfo,
50+
memory::Workspace* workspace = nullptr,
51+
const std::vector<LongType>& dimensions = {});
52+
ConstantShapeBuffer* createShapeInfoWithNoUnitiesForReduce( LongType* maxShapeInfo,
53+
const std::vector<LongType>* dimsWithUnities,
54+
memory::Workspace* workspace = nullptr);
55+
56+
LongType* emptyShapeInfo(DataType dataType);
57+
LongType * scalarShapeInfo(DataType dataType);
58+
LongType* vectorShapeInfo(LongType length, DataType dataType);
59+
LongType* createShapeInfo(ShapeDescriptor *descriptor);
60+
LongType* createShapeInfo(DataType dataType, char order, const std::vector<LongType>& shape);
61+
LongType* createShapeInfo(DataType dataType, const char order, const int rank, LongType* shape, LongType extraProperties);
62+
LongType* createShapeInfo(DataType dataType, LongType* shapeInfo);
63+
LongType* createFromExisting(LongType* shapeInfo, sd::memory::Workspace* workspace);
64+
LongType* createFromExisting(LongType* shapeInfo, bool destroyOriginal = true);
5965

60-
const LongType* emptyShapeInfo(DataType dataType);
61-
const LongType* scalarShapeInfo(DataType dataType);
62-
const LongType* vectorShapeInfo(LongType length, DataType dataType);
63-
const LongType* createShapeInfo(ShapeDescriptor *descriptor);
64-
const LongType* createShapeInfo(DataType dataType, char order, const std::vector<LongType>& shape);
65-
const LongType* createShapeInfo(DataType dataType, const char order, const int rank, const LongType* shape, LongType extraProperties);
66-
const LongType* createShapeInfo(DataType dataType, const LongType* shapeInfo);
67-
const LongType* createFromExisting(const LongType* shapeInfo, sd::memory::Workspace* workspace);
68-
const LongType* createFromExisting(const LongType* shapeInfo, bool destroyOriginal = true);
69-
const LongType* createFromExisting(sd::LongType* shapeInfo, sd::memory::Workspace* workspace);
70-
const LongType* createFromExisting(sd::LongType* shapeInfo, bool destroyOriginal = true);
7166

7267
bool checkBufferExistenceForShapeInfo(ShapeDescriptor *descriptor);
7368

74-
ConstantShapeBuffer* storeAndWrapBuffer(const LongType* shapeInfo);
75-
const LongType* castToDataType(const LongType* shapeInfo, const DataType newType);
76-
const LongType* emptyShapeInfoWithShape(const DataType dataType, std::vector<LongType>& shape);
77-
ConstantShapeBuffer* createConstBuffFromExisting(const sd::LongType* shapeInfo, sd::memory::Workspace* workspace);
69+
ConstantShapeBuffer* storeAndWrapBuffer( LongType* shapeInfo);
70+
LongType* castToDataType( LongType* shapeInfo, DataType newType);
71+
LongType* emptyShapeInfoWithShape(const DataType dataType, std::vector<LongType>& shape);
72+
ConstantShapeBuffer* createConstBuffFromExisting( sd::LongType* shapeInfo, sd::memory::Workspace* workspace);
7873
};
7974
} // namespace sd
8075

libnd4j/include/helpers/ConstantTadHelper.h

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,9 @@
3333
namespace sd {
3434
class SD_LIB_EXPORT ConstantTadHelper {
3535
private:
36-
std::mutex _mutex;
3736
DirectTadTrie _trie; // Single trie for device 0
3837

39-
ConstantTadHelper();
38+
ConstantTadHelper() = default;
4039

4140
public:
4241
~ConstantTadHelper() = default;
@@ -51,34 +50,15 @@ class SD_LIB_EXPORT ConstantTadHelper {
5150
* @param keepUnitiesInShape
5251
* @return
5352
*/
54-
TadPack *tadForDimensions(const LongType *originalShape, const std::vector<LongType> *dimensions,
55-
const bool keepUnitiesInShape = false);
56-
TadPack *tadForDimensions(const LongType *originalShape, LongType *dimensions, LongType dimLength,
57-
const bool keepUnitiesInShape = false);
58-
TadPack *tadForDimensions(const LongType *originalShape, LongType dimension, const bool keepUnitiesInShape = false);
53+
54+
TadPack *tadForDimensions(LongType *originalShape, LongType *dimensions, LongType dimLength,
55+
bool keepUnitiesInShape = false);
5956
TadPack *tadForDimensions(ShapeDescriptor &descriptor, std::vector<LongType> &dimensions,
6057
const bool keepUnitiesInShape = false);
6158
TadPack *tadForDimensions(TadDescriptor *descriptor);
6259

63-
/**
64-
* This method returns number of cached TAD shapes/offsets on specific device
65-
* @return
66-
*/
67-
SD_INLINE int cachedEntriesForDevice(int deviceId) {
68-
std::lock_guard<std::mutex> lock(_mutex);
69-
if (deviceId > 0) THROW_EXCEPTION("deviceId > number of actual devices");
70-
71-
return _trie.totalCachedEntries();
72-
}
73-
74-
/**
75-
* This method returns total number of cached TAD shapes/offsets on all devices
76-
* @return
77-
*/
78-
SD_INLINE int totalCachedEntries() {
79-
std::lock_guard<std::mutex> lock(_mutex);
80-
return _trie.totalCachedEntries();
81-
}
60+
TadPack *tadForDimensions(LongType *originalShape, LongType dimension, bool keepUnitiesInShape);
61+
TadPack *tadForDimensions(LongType *originalShape, std::vector<LongType> *dimensions, bool keepUnitiesInShape);
8262
};
8363
} // namespace sd
8464

0 commit comments

Comments
 (0)