forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathshape_info.h
165 lines (141 loc) · 4.67 KB
/
shape_info.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#pragma once
#include "caffe2/core/operator.h"
namespace caffe2 {
struct TORCH_API QShapeInfo {
QShapeInfo(float o = 0, float s = 1, uint32_t a = 1) {
offset.clear();
scale.clear();
offset.push_back(o);
scale.push_back(s);
axis = a;
}
uint32_t axis;
vector<float> offset;
vector<float> scale;
};
struct TORCH_API ShapeInfo {
ShapeInfo(bool q = false) : is_quantized(q) {}
ShapeInfo(
std::vector<TensorBoundShape_DimType>&& t,
TensorShape&& s,
bool q = false)
: shape(std::move(s)),
is_quantized(q),
dim_type(std::move(t)),
dim_type_is_set(true) {}
ShapeInfo(
const std::vector<TensorBoundShape_DimType>& t,
TensorShape&& s,
bool q = false)
: shape(std::move(s)),
is_quantized(q),
dim_type(t),
dim_type_is_set(true) {}
ShapeInfo(
const std::vector<TensorBoundShape_DimType>& t,
const TensorShape& s,
bool q = false)
: shape(s), is_quantized(q), dim_type(t), dim_type_is_set(true) {}
ShapeInfo(bool q, const QShapeInfo& info) : is_quantized(q), q_info(info) {}
ShapeInfo(
const std::vector<TensorBoundShape_DimType>& t,
TensorShape&& s,
bool q,
const QShapeInfo& info)
: shape(std::move(s)),
is_quantized(q),
q_info(info),
dim_type(t),
dim_type_is_set(true) {}
ShapeInfo(
const std::vector<TensorBoundShape_DimType>& t,
const TensorShape& s,
bool q,
const QShapeInfo& info)
: shape(s),
is_quantized(q),
q_info(info),
dim_type(t),
dim_type_is_set(true) {}
void setDimType(const std::vector<TensorBoundShape_DimType>& dim_types) {
if (shape.dims_size()) {
CAFFE_ENFORCE_EQ(shape.dims_size(), dim_types.size());
}
dim_type = dim_types;
dim_type_is_set = true;
}
void setDimType(int idx, TensorBoundShape_DimType type) {
CAFFE_ENFORCE(
dim_type.size() > static_cast<unsigned>(idx), dim_type.size(), "vs", dim_type.size());
dim_type[idx] = type;
dim_type_is_set = true;
}
bool dimTypeIsSet() {
return dim_type_is_set;
}
const std::vector<TensorBoundShape_DimType>& getDimType() const {
return dim_type;
}
TensorBoundShape_DimType getDimType(int idx) const {
if (dim_type.size() > static_cast<unsigned>(idx)) {
return dim_type[idx];
} else {
return TensorBoundShape_DimType_UNKNOWN;
}
}
bool getShapeIsFinal() {
return shape_is_final;
}
void setShapeIsFinal(bool flag) {
shape_is_final = flag;
}
TensorShape shape;
// quantization related information
bool is_quantized;
QShapeInfo q_info;
private:
// type of the shape for every dimension
// dim_type.size == shape.dims.size
std::vector<TensorBoundShape_DimType> dim_type;
bool dim_type_is_set = false;
// a flag to indicate whether the shape is final and cannot be changed
// eg: input/output of in-place ops
bool shape_is_final = false;
};
using ShapeInfoMap = std::unordered_map<std::string, ShapeInfo>;
// Generates ShapeInfo from Blob.
ShapeInfo getShapeInfoFromBlob(const Blob* blob);
bool operator==(const ShapeInfo& lhs, const ShapeInfo& rhs);
// Construct a ShapeInfo instance from TensorShape and constructed dimType.
// Default first dimension of dimType is BATCH, reason:
// We treat first dimension of hinted shapes as BATCH.
// If there are shape hints on blobs in the workspace,
// since they are already inserted as CONSTANT, it will take effect here.
// For SEQ typed tensors, there are only a few of them and they will be
// handled by BoundShapeInferencer.
TORCH_API ShapeInfo constructShapeInfoWithDefaultDimType(
TensorShape shape,
TensorBoundShape_DimType defaultFirstDimType =
TensorBoundShape_DimType_BATCH);
TORCH_API void parseShapeInfoMapFromString(const std::string&, ShapeInfoMap&);
// Extract shape info from tensorBoundShapes to a ShapeInfoMap.
// Change shape according to new max_batch_size and max_feature_len
// at the same time if necessary.
TORCH_API ShapeInfoMap extractShapeInfoFromTensorBoundShapes(
TensorBoundShapes tensor_bound_shapes,
int64_t new_max_batch_size = -1,
int64_t new_max_feature_len = -1);
// In-place modify TensorBoundShape to change shape size based on type
TORCH_API void changeTensorBoundShapes(
TensorBoundShape& tensor_shape_and_type,
const int64_t old_batch_size,
const int64_t old_seq_size,
const int64_t new_batch_size,
const int64_t new_seq_size);
// In-place modify TensorShape's shape at a specific dimension
TORCH_API void modifyTensorShapeDimSize(
TensorShape* tensor_shape,
int dim_index,
const int64_t old_size,
const int64_t new_size);
} // namespace caffe2