forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoperator.h
511 lines (499 loc) · 20.9 KB
/
operator.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2015 by Contributors
* \file operator.h
* \brief Operator interface of mxnet.
* \author Naiyan Wang
*/
#ifndef MXNET_OPERATOR_H_
#define MXNET_OPERATOR_H_
#include <dmlc/base.h>
#include <dmlc/json.h>
#include <dmlc/logging.h>
#include <dmlc/registry.h>
#include <nnvm/node.h>
#include <vector>
#include <map>
#include <string>
#include <utility>
#include "./base.h"
#include "./resource.h"
#include "./op_attr_types.h"
namespace mxnet {
/*!
* \brief Operator interface.
* Operator defines basic operation unit of optimized computation graph in mxnet.
* This interface relies on pre-allocated memory in TBlob, the caller need to set
* the memory region in TBlob correctly before calling Forward and Backward.
*
* Operator is generated by OperatorProperty.
* To add new operator(aka. layers of neural nets) to mxnet, developer need to create
* a new OperatorProperty and its corresponding Operator.
*
* \sa TBlob, TShape, OperatorProperty
*/
class Operator {
public:
/*! \brief destructor */
virtual ~Operator() {}
/*!
* \brief perform a forward operation of Operator, save the output to TBlob.
* \param ctx runtime context available to this call
* \param in_data array of input data, it is const
* \param req the request types of saving operation, can only be kWriteTo or kWriteInplace.
* \param out_data array of output data, pointer is used to indicate that this is holder
* the space of TBlob in out_data must be pre-allocated with InferShape
* \param aux_states Auxiliary states of operator. Normally operator doesn't
* need, epecial case like Batch Norm requires.
* \sa OpReqType, OpContext
*/
virtual void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_states) = 0;
/*!
* \brief Perform a Backward Operation, write gradient to the in_grad.
*
* \note
* Convention:
* out_grad.size() == OperatorProperty.NumVisibleOutputs()
* out_data.size() == OperatorProperty.NumOutputs()
* out_data can contain additional invisible returns that remembers the
* state carried from the Forward pass. For example mask in the dropout.
* The gradients are passed from visible returns in this function.
*
* \par
* Not all the TBlobs in the arguments will be available
* if you override the DeclareBackwardDependency of corresponding OperatorProperty class.
* Only the dependencies you declared will be available at corresponding position,
* the rest of the parameters are simply dummy where you will get a nullptr.
* You will be safe if you use the default DeclareBackwardDependency.
* But only declare what you need will give engine more chance for optimization.
*
* \param ctx runtime context available to this call
* \param out_grad the gradient value we get from of the Operator.
* \param in_data the array of input data.
* \param out_data the array of output data.
* \param req request types of the saving operation, can be all types.
* \param in_grad the array of gradient we need to write to.
* \param aux_states Auxiliary states of operator. Normally operator doesn't need
* \sa OperatorProperty, OpReqType, OpContext
*/
virtual void Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad,
const std::vector<TBlob> &aux_states) {
LOG(FATAL) << "Backward is not implemented";
}
/*! \return [Deprecated] execution type of the operator */
virtual ExecType exec_type() const final { // NOLINT(*) exec_type has been moved to OperatorProperty
return ExecType::kSync;
}
};
#if DMLC_USE_CXX11
// OperatorProperty allows C++11, while Operator do not rely on it.
/*!
* \brief OperatorProperty is a object that stores all information about Operator.
* It also contains method to generate context(device) specific operators.
*
* It also contains various functions that can be optimally overriden to
* provide optimization chance for computation engine.
*/
class OperatorProperty {
public:
/*!
* \brief virtual destructor
*/
virtual ~OperatorProperty() {}
/*!
* \brief Initialize the Operator by setting the parameters
* This function need to be called before all other functions.
* \param kwargs the keyword arguments parameters
*/
virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) = 0;
/*!
* \brief Get a map representation of internal parameters.
* This can be used by Init to recover the state of OperatorProperty.
*/
virtual std::map<std::string, std::string> GetParams() const = 0;
/*!
* \brief Get input arguments of the Operator.
* \return vector of arguments.
*/
virtual std::vector<std::string> ListArguments() const {
return {"data"};
}
/*!
* \brief Get name of output values of Operator
* \return name of output values.
*/
virtual std::vector<std::string> ListOutputs() const {
return {"output"};
}
/*!
* \brief Get name of auxiliary states of Operator
* \return name of return values.
*/
virtual std::vector<std::string> ListAuxiliaryStates() const {
return {};
}
/*! \return number of real return values of the Operator */
virtual int NumOutputs() const {
return this->ListOutputs().size();
}
/*!
* \brief get number of visible return values during Symbol creation.
* If NumVisibleOutputs() = k, and NumOutputs() = n.
* The first k returns will be presented in the resulting symbol.
*
* The rest of the returns can be used for auxiliary states for Backward.
* For example, Dropout will return [data, mask], with NumVisibleOutputs() == 1.
* So when user call sym = Dropout(input), only data is presented in sym.
* But all the returns will be presented in out_data parameter of Backward if requested.
*
* \return number of default return values
*/
virtual int NumVisibleOutputs() const {
return NumOutputs();
}
/*!
* \brief infer the shapes of outputs and unknown input arguments
* \param in_shape the shape of input arguments of the operator
* this should be of same length as the vector returned by DescribeArgs
* in_shape allows unknown elements, which are checked by shape.ndim() == 0.
* For unknown shapes, InferShape will try to fill in the correct Shape in in_shape
* For known shapes, InferShape will check shape consistency
*
* common practice: set the shape of data input, and usually weight's shape can be inferred
*
* \param out_shape the shape of outputs of the operator
* InferShape will modify the vector to fill output TShape
* \param aux_shape the shape of auxiliary states of the operator
* InferShape will modify the vector to fill output TShape
* \return true if the shape inference is successful, false if there is not enough information.
* \throws dmlc::Error if the known arg_shapes are inconsistent.
*/
virtual bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape,
std::vector<TShape> *aux_shape) const = 0;
/*!
* \brief infer the data types of outputs and unknown input arguments
* \param in_type the type of input arguments of the operator
* this should be of same length as the vector returned by DescribeArgs
* in_type allows unknown elements, which are checked by type.ndim() == 0.
* For unknown types, Infertype will try to fill in the correct type in in_type
* For known types, Infertype will check type consistency
*
* common practice: set the type of data input, and usually weight's type can be inferred
*
* \param out_type the type of outputs of the operator
* Infertype will modify the vector to fill output Ttype
* \param aux_type the type of auxiliary states of the operator
* Infertype will modify the vector to fill output Ttype
* \return true if the type inference is successful, false if there is not enough information.
* \throws dmlc::Error if the known arg_types are inconsistent.
*/
virtual bool InferType(std::vector<int> *in_type,
std::vector<int> *out_type,
std::vector<int> *aux_type) const {
CHECK_LE(in_type->size(), this->ListArguments().size());
int n_in = this->ListArguments().size();
for (unsigned i = 0; i < in_type->size(); ++i) {
CHECK(in_type->at(i) == mshadow::default_type_flag ||
in_type->at(i) == -1) << "Unsupported data type " << in_type->at(i);
}
in_type->clear();
for (int i = 0; i < n_in; ++i ) in_type->push_back(mshadow::default_type_flag);
int n_out = this->ListOutputs().size();
out_type->clear();
for (int i = 0; i < n_out; ++i ) out_type->push_back(mshadow::default_type_flag);
int n_aux = this->ListAuxiliaryStates().size();
aux_type->clear();
for (int i = 0; i < n_aux; ++i ) aux_type->push_back(mshadow::default_type_flag);
return true;
}
/*!
* \brief Copy this OperatorProperty.
* \return a pointer to the copied OperatorProperty
*/
virtual OperatorProperty* Copy() const = 0;
/*!
* \brief Create a Operator on specific context
*/
virtual Operator* CreateOperator(Context ctx) const = 0;
/*!
* \brief Create a Operator on specific context and input shape/type
* \param ctx context of this operator
* \param in_shape shape of the input ndarrays
* \param in_type dtype of the input ndarrays
* \return the created operator
*/
virtual Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
std::vector<int> *in_type) const {
std::vector<int> out_type, aux_type;
std::vector<TShape> out_shape, aux_shape;
out_type.resize(this->ListOutputs().size());
out_shape.resize(this->ListOutputs().size());
aux_type.resize(this->ListAuxiliaryStates().size());
aux_shape.resize(this->ListAuxiliaryStates().size());
CHECK(InferType(in_type, &out_type, &aux_type));
CHECK(InferShape(in_shape, &out_shape, &aux_shape));
return CreateOperator(ctx);
}
/*!
* \brief return the type string of the Operator
* subclasses override this function.
* \return The type string.
*/
virtual std::string TypeString() const = 0;
//--------------------------------------------------------
// All the below functions are optional to override.
//--------------------------------------------------------
/*!
* \brief Declare additional resource required in forward pass.
* These additional resources will be presented in OpContext.requested
* in the same order of the returned Resource.
* \param in_shape The input shape to the operator, corresponds to shapes of in_data.
* \return Additional resource request
*/
virtual std::vector<ResourceRequest> ForwardResource(
const std::vector<TShape> &in_shape) const {
return std::vector<ResourceRequest>();
}
/*!
* \brief Declare additional resource required in backward pass.
* These additional resources will be presented in OpContext.requested
* in the same order of the returned Resource.
* \param in_shape The input shape to the operator, corresponds to shapes of in_data.
* \return Additional resource request
*/
virtual std::vector<ResourceRequest> BackwardResource(
const std::vector<TShape> &in_shape) const {
return std::vector<ResourceRequest>();
}
/*!
* \brief Declare the input requirement of Backward pass.
*
* Only the returned list of variables will be used in Backward.
* This function is used for memory optimization.
* It is advised to override and only return what is actually needed.
* If this function is not overriden, all the variables will be valid in Backward.
*
* \code
* // The following code declares Backward need out_grad[0], in_data[0],in_data[1]
* vector<int> BackwardInputs(const vector<int> &out_grad,
* const vector<int> &in_data,
* const vector<int> &out_data) const {
* return {out_grad[0], in_data[0], in_data[1]};
* }
* \endcode
* \param out_grad gradient of outputs in backward pass.
* \param in_data the input data in forward pass.
* \param out_data the output data in forward pass.
* \return an integer vector indicating the input requirments
* \sa BackwardInputs
*/
virtual std::vector<int> DeclareBackwardDependency(
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
const std::vector<int> &out_data) const {
// By default requires to see all the things.
// remember to override this function to get a better performance.
std::vector<int> ret = out_grad;
ret.insert(ret.end(), in_data.begin(), in_data.end());
ret.insert(ret.end(), out_data.begin(), out_data.end());
return ret;
}
/*!
* \brief Get possible forward inplace options.
* This function enables optimization to reuse memory of inputs in output.
* Only override when necessary, by default in-place is disabled.
*
* The reason for void* type in the out_data is to distinguish the order
* of mappings between the two, compiler will report error when
* in_data and out_data's order in the pair get reversed.
*
* \code
* // The following code says out_data[0] can share data with in_data[0]
* vector<pair<int, void*> > ForwardInplaceOption(const vector<int> &in_data,
* const vector<void*> &out_data) const {
* return {{in_data[0], out_data[0]}};
* }
* \endcode
* \param in_data The input data in forward pass.
* \param out_data The output data in forward pass.
* \return list of pair of that maps input->output,
* indicating possible in place operations.
*/
virtual std::vector<std::pair<int, void*> > ForwardInplaceOption(
const std::vector<int> &in_data,
const std::vector<void*> &out_data) const {
return std::vector<std::pair<int, void*> >();
}
/*!
* \brief Get possible backward inplace options.
* This function enables optimization to reuse memory of inputs in output.
* Only override when necessary, by default in-place is disabled.
*
* The reason for void* type in the in_grad is to distinguish the order
* of mappings between the two, compiler will report error when
* in_data and out_data's order in the pair get reversed.
*
* \code
* // The following code says in_grad[0] can share data with in_data[0]
* vector<pair<int,int> > BackwardInplaceOption(
* const std::vector<int> &out_grad,
* const std::vector<int> &in_data,
* const std::vector<int> &out_data,
* const std::vector<int> &in_grad) const {
* return {in_data[0], in_grad[0]}};
* }
* \endcode
* \param in_data The input data in forward pass.
* \param out_data The output data in forward pass.
* \param in_grad Gradient of inputs in backward pass.
* \param out_grad Gradient of outputs in backward pass.
* \return list of pair of that maps input->output,
* indicating possible in place operations.
*/
virtual std::vector<std::pair<int, void*> > BackwardInplaceOption(
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
const std::vector<int> &out_data,
const std::vector<void*> &in_grad) const {
return std::vector<std::pair<int, void*> >();
}
/*!
* \brief Get Backward Input Dependency for generic types of data.
* Normally T can be pointer of Symbol::DataEntry, or NDArray.
* This function will select the result list of T according to DeclareBackwardDependency.
*
* \param in_data the input data in forward pass.
* \param out_data the output data in forward pass.
* \param out_grad gradient of outputs in backward pass.
* \tparam T the generic type parameter.
* \return vector of inputs the Backward Operation depends on.
* \sa DeclareBackwardDependency
*/
template<typename T>
inline std::vector<T> BackwardInputs(const std::vector<T> &out_grad,
const std::vector<T> &in_data,
const std::vector<T> &out_data) const {
int counter = 0;
std::vector<int> out_grad_index(out_grad.size());
std::vector<int> in_data_index(in_data.size());
std::vector<int> out_data_index(out_data.size());
for (size_t i = 0; i < out_grad_index.size(); ++i) {
out_grad_index[i] = counter++;
}
for (size_t i = 0; i < in_data_index.size(); ++i) {
in_data_index[i] = counter++;
}
for (size_t i = 0; i < out_data_index.size(); ++i) {
out_data_index[i] = counter++;
}
std::vector<T> all_data;
all_data.insert(all_data.end(), out_grad.begin(), out_grad.end());
all_data.insert(all_data.end(), in_data.begin(), in_data.end());
all_data.insert(all_data.end(), out_data.begin(), out_data.end());
std::vector<int> ret_index = this->DeclareBackwardDependency(
out_grad_index, in_data_index, out_data_index);
std::vector<T> ret(ret_index.size());
for (size_t i = 0; i < ret_index.size(); ++i) {
ret[i] = all_data[ret_index[i]];
}
return ret;
}
/*!
* \brief create OperatorProperty
* \param type_name the type string of the OperatorProperty
* \return a new constructed OperatorProperty
*/
static OperatorProperty *Create(const char* type_name);
/*! \return execution type of the operator */
virtual ExecType exec_type() const {
return ExecType::kSync;
}
};
/*! \brief typedef the factory function of operator property */
typedef std::function<OperatorProperty *()> OperatorPropertyFactory;
/*!
* \brief Registry entry for OperatorProperty factory functions.
*/
struct OperatorPropertyReg
: public dmlc::FunctionRegEntryBase<OperatorPropertyReg,
OperatorPropertyFactory> {
/*!
* \brief Set key_var_num_args
* When this is set, the API caller is required to pass in a
* argument with key=key_num_args.c_str(), and value=num_args.
* num_args is number of positional argument when calling the function.
*
* This is used to pass in length of positional arguments
* for operators that can take variable length of input.
* Most operators do not need to set this property.
*
* \param key the key name to be set
*/
inline OperatorPropertyReg& set_key_var_num_args(const std::string &key) { // NOLINT(*)
this->key_var_num_args = key;
return *this;
}
/*!
* \brief Check if TypeString of the type matches the registered name
*/
inline OperatorPropertyReg& check_name() {
OperatorProperty *p = this->body();
std::string type = p->TypeString();
delete p;
CHECK_EQ(this->name, type)
<< "Register Name and TypeString mismatch, name=\"" << this->name << "\","
<< " but TypeString=\"" << type <<"\"";
return *this;
}
/*! \brief The key num_args name. */
std::string key_var_num_args;
};
//---------------------------------------------------------------------------------
// The following part are API Registration of Operators
// See also MXNET_REGISTER_SIMPLE_OP in operator_util.h for registering simple ops.
//---------------------------------------------------------------------------------
/*!
* \brief Macro to register OperatorProperty
*
* \code
* // example of registering a fully connected operator
* REGISTER_OP_PROPERTY(FullyConnected, FullyConnectedOpProp)
* .describe("Fully connected layer");
*
* \endcode
*/
#define MXNET_REGISTER_OP_PROPERTY(name, OperatorPropertyType) \
DMLC_REGISTRY_REGISTER(::mxnet::OperatorPropertyReg, OperatorPropertyReg, name) \
.set_body([]() { return new OperatorPropertyType(); }) \
.set_return_type("NDArray-or-Symbol") \
.check_name()
#endif // DMLC_USE_CXX11
} // namespace mxnet
#endif // MXNET_OPERATOR_H_