Skip to content

Commit 1ea101e

Browse files
authored
Tensor helpers to get the size of a data type and number of elements of a shape.
Differential Revision: D71903681 Pull Request resolved: #9673
1 parent 93c3b2f commit 1ea101e

File tree

3 files changed

+78
-0
lines changed

3 files changed

+78
-0
lines changed

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,28 @@ typedef NS_ENUM(uint8_t, ExecuTorchShapeDynamism) {
6161
ExecuTorchShapeDynamismDynamicUnbound,
6262
} NS_SWIFT_NAME(ShapeDynamism);
6363

64+
/**
65+
* Returns the size in bytes of the specified data type.
66+
*
67+
* @param dataType An ExecuTorchDataType value representing the tensor's element type.
68+
* @return An NSInteger indicating the size in bytes.
69+
*/
70+
FOUNDATION_EXPORT
71+
__attribute__((deprecated("This API is experimental.")))
72+
NSInteger ExecuTorchSizeOfDataType(ExecuTorchDataType dataType)
73+
NS_SWIFT_NAME(size(ofDataType:));
74+
75+
/**
76+
* Computes the total number of elements in a tensor based on its shape.
77+
*
78+
* @param shape An NSArray of NSNumber objects, where each element represents a dimension size.
79+
* @return An NSInteger equal to the product of the sizes of all dimensions.
80+
*/
81+
FOUNDATION_EXPORT
82+
__attribute__((deprecated("This API is experimental.")))
83+
NSInteger ExecuTorchElementCountOfShape(NSArray<NSNumber *> *shape)
84+
NS_SWIFT_NAME(elementCount(ofShape:));
85+
6486
/**
6587
* A tensor class for ExecuTorch operations.
6688
*

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,18 @@
1616
using namespace executorch::aten;
1717
using namespace executorch::extension;
1818

19+
NSInteger ExecuTorchSizeOfDataType(ExecuTorchDataType dataType) {
20+
return elementSize(static_cast<ScalarType>(dataType));
21+
}
22+
23+
NSInteger ExecuTorchElementCountOfShape(NSArray<NSNumber *> *shape) {
24+
NSInteger count = 1;
25+
for (NSNumber *dimension in shape) {
26+
count *= dimension.integerValue;
27+
}
28+
return count;
29+
}
30+
1931
@implementation ExecuTorchTensor {
2032
TensorPtr _tensor;
2133
NSArray<NSNumber *> *_shape;

extension/apple/ExecuTorch/__tests__/TensorTest.swift

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,50 @@
1111
import XCTest
1212

1313
class TensorTest: XCTestCase {
14+
func testElementCountOfShape() {
15+
XCTAssertEqual(elementCount(ofShape: [2, 3, 4]), 24)
16+
XCTAssertEqual(elementCount(ofShape: [5]), 5)
17+
XCTAssertEqual(elementCount(ofShape: []), 1)
18+
}
19+
20+
func testSizeOfDataType() {
21+
let expectedSizes: [DataType: Int] = [
22+
.byte: 1,
23+
.char: 1,
24+
.short: 2,
25+
.int: 4,
26+
.long: 8,
27+
.half: 2,
28+
.float: 4,
29+
.double: 8,
30+
.complexHalf: 4,
31+
.complexFloat: 8,
32+
.complexDouble: 16,
33+
.bool: 1,
34+
.qInt8: 1,
35+
.quInt8: 1,
36+
.qInt32: 4,
37+
.bFloat16: 2,
38+
.quInt4x2: 1,
39+
.quInt2x4: 1,
40+
.bits1x8: 1,
41+
.bits2x4: 1,
42+
.bits4x2: 1,
43+
.bits8: 1,
44+
.bits16: 2,
45+
.float8_e5m2: 1,
46+
.float8_e4m3fn: 1,
47+
.float8_e5m2fnuz: 1,
48+
.float8_e4m3fnuz: 1,
49+
.uInt16: 2,
50+
.uInt32: 4,
51+
.uInt64: 8,
52+
]
53+
for (dataType, expectedSize) in expectedSizes {
54+
XCTAssertEqual(size(ofDataType: dataType), expectedSize, "Size for \(dataType) should be \(expectedSize)")
55+
}
56+
}
57+
1458
func testInitBytesNoCopy() {
1559
var data: [Float] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
1660
let tensor = data.withUnsafeMutableBytes {

0 commit comments

Comments
 (0)