Skip to content

Commit 3943938

Browse files
authored
Tensor description
Differential Revision: D76915254 Pull Request resolved: #11792
1 parent fcc7f3b commit 3943938

File tree

2 files changed

+151
-0
lines changed

2 files changed

+151
-0
lines changed

extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,3 +1095,10 @@ public extension Tensor {
10951095
))
10961096
}
10971097
}
1098+
1099+
@available(*, deprecated, message: "This API is experimental.")
1100+
extension Tensor: CustomStringConvertible {
1101+
public var description: String {
1102+
self.anyTensor.description
1103+
}
1104+
}

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,86 @@
1717
using namespace executorch::extension;
1818
using namespace executorch::runtime;
1919

20+
static inline NSString *dataTypeDescription(ExecuTorchDataType dataType) {
21+
switch (dataType) {
22+
case ExecuTorchDataTypeByte:
23+
return @"byte";
24+
case ExecuTorchDataTypeChar:
25+
return @"char";
26+
case ExecuTorchDataTypeShort:
27+
return @"short";
28+
case ExecuTorchDataTypeInt:
29+
return @"int";
30+
case ExecuTorchDataTypeLong:
31+
return @"long";
32+
case ExecuTorchDataTypeHalf:
33+
return @"half";
34+
case ExecuTorchDataTypeFloat:
35+
return @"float";
36+
case ExecuTorchDataTypeDouble:
37+
return @"double";
38+
case ExecuTorchDataTypeComplexHalf:
39+
return @"complexHalf";
40+
case ExecuTorchDataTypeComplexFloat:
41+
return @"complexFloat";
42+
case ExecuTorchDataTypeComplexDouble:
43+
return @"complexDouble";
44+
case ExecuTorchDataTypeBool:
45+
return @"bool";
46+
case ExecuTorchDataTypeQInt8:
47+
return @"qint8";
48+
case ExecuTorchDataTypeQUInt8:
49+
return @"quint8";
50+
case ExecuTorchDataTypeQInt32:
51+
return @"qint32";
52+
case ExecuTorchDataTypeBFloat16:
53+
return @"bfloat16";
54+
case ExecuTorchDataTypeQUInt4x2:
55+
return @"quint4x2";
56+
case ExecuTorchDataTypeQUInt2x4:
57+
return @"quint2x4";
58+
case ExecuTorchDataTypeBits1x8:
59+
return @"bits1x8";
60+
case ExecuTorchDataTypeBits2x4:
61+
return @"bits2x4";
62+
case ExecuTorchDataTypeBits4x2:
63+
return @"bits4x2";
64+
case ExecuTorchDataTypeBits8:
65+
return @"bits8";
66+
case ExecuTorchDataTypeBits16:
67+
return @"bits16";
68+
case ExecuTorchDataTypeFloat8_e5m2:
69+
return @"float8_e5m2";
70+
case ExecuTorchDataTypeFloat8_e4m3fn:
71+
return @"float8_e4m3fn";
72+
case ExecuTorchDataTypeFloat8_e5m2fnuz:
73+
return @"float8_e5m2fnuz";
74+
case ExecuTorchDataTypeFloat8_e4m3fnuz:
75+
return @"float8_e4m3fnuz";
76+
case ExecuTorchDataTypeUInt16:
77+
return @"uint16";
78+
case ExecuTorchDataTypeUInt32:
79+
return @"uint32";
80+
case ExecuTorchDataTypeUInt64:
81+
return @"uint64";
82+
default:
83+
return @"undefined";
84+
}
85+
}
86+
87+
static inline NSString *shapeDynamismDescription(ExecuTorchShapeDynamism dynamism) {
88+
switch (dynamism) {
89+
case ExecuTorchShapeDynamismStatic:
90+
return @"static";
91+
case ExecuTorchShapeDynamismDynamicBound:
92+
return @"dynamicBound";
93+
case ExecuTorchShapeDynamismDynamicUnbound:
94+
return @"dynamicUnbound";
95+
default:
96+
return @"undefined";
97+
}
98+
}
99+
20100
NSInteger ExecuTorchSizeOfDataType(ExecuTorchDataType dataType) {
21101
return elementSize(static_cast<ScalarType>(dataType));
22102
}
@@ -150,6 +230,70 @@ - (BOOL)isEqual:(nullable id)other {
150230
return [self isEqualToTensor:(ExecuTorchTensor *)other];
151231
}
152232

233+
- (NSString *)description {
234+
std::ostringstream os;
235+
os << "Tensor {";
236+
os << "\n dataType: " << dataTypeDescription(static_cast<ExecuTorchDataType>(_tensor->scalar_type())).UTF8String << ",";
237+
os << "\n shape: [";
238+
const auto& sizes = _tensor->sizes();
239+
for (size_t index = 0; index < sizes.size(); ++index) {
240+
if (index > 0) {
241+
os << ",";
242+
}
243+
os << sizes[index];
244+
}
245+
os << "],";
246+
os << "\n strides: [";
247+
const auto& strides = _tensor->strides();
248+
for (size_t index = 0; index < strides.size(); ++index) {
249+
if (index > 0) {
250+
os << ",";
251+
}
252+
os << strides[index];
253+
}
254+
os << "],";
255+
os << "\n dimensionOrder: [";
256+
const auto& dim_order = _tensor->dim_order();
257+
for (size_t index = 0; index < dim_order.size(); ++index) {
258+
if (index > 0) {
259+
os << ",";
260+
}
261+
os << static_cast<int>(dim_order[index]);
262+
}
263+
os << "],";
264+
os << "\n shapeDynamism: " << shapeDynamismDescription(static_cast<ExecuTorchShapeDynamism>(_tensor->shape_dynamism())).UTF8String << ",";
265+
auto const count = _tensor->numel();
266+
os << "\n count: " << count << ",";
267+
os << "\n scalars: [";
268+
ET_SWITCH_REALHBBF16_TYPES(
269+
static_cast<ScalarType>(_tensor->scalar_type()),
270+
nullptr,
271+
"description",
272+
CTYPE,
273+
[&] {
274+
auto const *pointer = reinterpret_cast<const CTYPE*>(_tensor->unsafeGetTensorImpl()->data());
275+
auto const countToPrint = std::min(count, (ssize_t)100);
276+
for (size_t index = 0; index < countToPrint; ++index) {
277+
if (index > 0) {
278+
os << ",";
279+
}
280+
if constexpr (std::is_same_v<CTYPE, int8_t> ||
281+
std::is_same_v<CTYPE, uint8_t>) {
282+
os << static_cast<int>(pointer[index]);
283+
} else {
284+
os << pointer[index];
285+
}
286+
}
287+
if (count > countToPrint) {
288+
os << ",...";
289+
}
290+
}
291+
);
292+
os << "]";
293+
os << "\n}";
294+
return @(os.str().c_str());
295+
}
296+
153297
@end
154298

155299
@implementation ExecuTorchTensor (BytesNoCopy)

0 commit comments

Comments
 (0)