forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoperator_util.h
505 lines (481 loc) · 19.6 KB
/
operator_util.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2015 by Contributors
* \file operator_util.h
* \brief Utility functions and registries to help quickly build new operators.
* [Deprecated]
* Use the register functions in this file when possible to simplify operator creations.
* Operators registered in this file will be exposed to both NDArray API and symbolic API.
*
* \author Tianqi Chen
*/
#ifndef MXNET_OPERATOR_UTIL_H_
#define MXNET_OPERATOR_UTIL_H_
#ifdef _MSC_VER
#pragma warning(disable:4503) // disable warning: decorated name length exceeded.
#endif
#include <dmlc/registry.h>
#include <dmlc/parameter.h>
#include <map>
#include <vector>
#include <string>
#include <utility>
#include "./base.h"
#include "./operator.h"
#if DMLC_USE_CXX11
#include <functional>
#endif
namespace mxnet {
/*! \brief namespace of arguments */
namespace op {
/*! \brief super class of all gradient function argument */
struct GradFunctionArgument {
/*! \brief The real data */
TBlob data;
};
/*! \brief First input to the function */
struct Input0 : GradFunctionArgument {};
/*! \brief Second input to the function */
struct Input1 : GradFunctionArgument {};
/*! \brief Ouput value of the function to the function */
struct OutputValue : GradFunctionArgument {};
/*! \brief Gradient of output value */
struct OutputGrad : GradFunctionArgument {};
/*!
* \brief Environment arguments that is used by the function.
* These can be things like scalar arguments when add a value with scalar.
*/
struct EnvArguments {
/*! \brief scalar argument, if enabled */
real_t scalar;
/*! \brief keyword arguments */
std::vector<std::pair<std::string, std::string> > kwargs;
/*! \brief pointer to the resources requested */
std::vector<Resource> resource;
};
/*!
* \brief source function that generate output based on env
* The result container is pre-allocated with the correct shape.
* \param env The Environment arguments.
* \param ret The containter to store return value.
* \param req The requirement to stroe the ret.
* \param ctx Runtime context to execute the function.
*/
typedef void (*SourceFunction)(const EnvArguments& env,
TBlob* ret,
OpReqType req,
RunContext ctx);
/*!
* \brief Shape inference function to get the correct shape.
* \param env The Environment arguments.
* \return The inferred result shape.
*/
typedef TShape (*SourceShapeFunction)(const EnvArguments& env);
/*!
* \brief Unary function that takes a src and save result to ret.
* The result container is pre-allocated with the correct shape.
* \param src The source data.
* \param env The Environment arguments.
* \param ret The containter to store return value.
* \param req The requirement to stroe the ret.
* \param ctx Runtime context to execute the function.
*/
typedef void (*UnaryFunction)(const TBlob& src,
const EnvArguments& env,
TBlob* ret,
OpReqType req,
RunContext ctx);
/*!
* \brief Shape inference function to get the correct shape given source.
* \param src The source shape
* \param env The Environment arguments.
* \return The inferred result shape.
*/
typedef TShape (*UnaryShapeFunction)(const TShape& src,
const EnvArguments& env);
/*!
* \brief Gradient function that takes output value of function and computes gradient wrt to input.
* \param out_grad the gradient wrt to output of the function.
* \param env The Environment arguments.
* \param in_grad The container to store result input gradient.
* \param req The requirement to store the ret value.
* \param ctx Runtime context to execute the function.
*/
typedef void (*UnaryGradFunctionT0)(const OutputGrad& out_grad,
const EnvArguments& env,
TBlob* in_grad,
OpReqType req,
RunContext ctx);
/*!
* \brief Gradient function that takes output value of function and computes gradient wrt to input.
* \param out_grad the gradient wrt to output of the function.
* \param out_value the value of the function.
* \param env The Environment arguments.
* \param in_grad The container to store result input gradient.
* \param req The requirement to store the ret value.
* \param ctx Runtime context to execute the function.
*/
typedef void (*UnaryGradFunctionT1)(const OutputGrad& out_grad,
const OutputValue& out_value,
const EnvArguments& env,
TBlob* in_grad,
OpReqType req,
RunContext ctx);
/*!
* \brief Gradient function that takes input value of function and computes gradient wrt to input.
* \param out_grad the gradient wrt to output of the function.
* \param in_data0 the input value of the function.
* \param env The Environment arguments.
* \param in_grad The container to store result input gradient.
* \param req The requirement to store the ret value.
* \param ctx Runtime context to execute the function.
*/
typedef void (*UnaryGradFunctionT2)(const OutputGrad& out_grad,
const Input0& in_data0,
const EnvArguments& env,
TBlob* in_grad,
OpReqType req,
RunContext ctx);
/*!
* \brief Binary function that takes lhs, rhs and save result to ret.
* The result container is pre-allocated with the correct shape.
* \param lhs The left operand
* \param rhs The right operand
* \param env The Environment arguments.
* \param ret The containter to store return value.
* \param req The requirement to stroe the ret.
* \param ctx Runtime context to execute the function.
*/
typedef void (*BinaryFunction)(const TBlob& lhs,
const TBlob& rhs,
const EnvArguments& env,
TBlob* ret,
OpReqType req,
RunContext ctx);
/*!
* \brief Shape inference function to get the correct shape given source shapes.
* \param lhs The shape of left operand.
* \param rhs The shape of right operand.
* \param env The Environment arguments.
* \return The inferred result shape.
*/
typedef TShape (*BinaryShapeFunction)(const TShape& lhs,
const TShape& rhs,
const EnvArguments& env);
/*!
* \brief Gradient function that takes only output gradient and computes gradient wrt to input.
* We support total gradient as a whole to make it easy to combine a few ops.
* \param out_grad the gradient wrt to output of the function.
* \param env The Environment arguments.
* \param lhs_grad The container to store result of lhs gradient.
* \param rhs_grad The container to store result of lhs gradient.
* \param req_lhs_grad The requirement to store the lhs_grad
* \param req_rhs_grad The requirement to store the rhs_grad
* \param ctx Runtime context to execute the function.
*/
typedef void (*BinaryGradFunctionT0)(const OutputGrad& out_grad,
const EnvArguments& env,
TBlob* lhs_grad,
TBlob* rhs_grad,
OpReqType req_lhs_grad,
OpReqType req_rhs_grad,
RunContext ctx);
/*!
* \brief Gradient function that takes inputs of function anod computes gradient wrt to input.
* \param out_grad the gradient wrt to output of the function.
* \param lhs The left operand to the function.
* \param rhs The right operand to the function.
* \param env The Environment arguments.
* \param lhs_grad The container to store result of lhs gradient.
* \param rhs_grad The container to store result of lhs gradient.
* \param req_lhs_grad The requirement to store the lhs_grad
* \param req_rhs_grad The requirement to store the rhs_grad
* \param ctx Runtime context to execute the function.
*/
typedef void (*BinaryGradFunctionT1)(const OutputGrad& out_grad,
const Input0& lhs,
const Input1& rhs,
const EnvArguments& env,
TBlob* lhs_grad,
TBlob* rhs_grad,
OpReqType req_lhs_grad,
OpReqType req_rhs_grad,
RunContext ctx);
/*! \brief options in the registry to set inplace of operator */
enum SimpleOpInplaceOption {
/*! \brief do not allow inplace in arguments */
kNoInplace,
/*! \brief in unary forward, allow inplace in with out */
kInplaceInOut,
/*! \brief in unary backward, allow inplace out_grad with in_grad */
kInplaceOutIn,
/*! \brief in binary forward, allow inplace left operand with out */
kInplaceLhsOut,
/*! \brief in binary backward, allow inplace out_grad with lhs_grad */
kInplaceOutLhs
};
/*! \brief options in the registry to set symbolic registration */
enum SimpleOpScalarOption {
kScalarBeforeArray,
kArrayBeforeScalar
};
/*! \brief options in the registry to set symbolic registration */
enum SimpleOpRegOption {
kNotRegisterSymbolic,
kRegisterSymbolic
};
/*! \brief registry entry to register simple operators via functions. */
class SimpleOpRegEntry {
public:
/*! \brief declare self type */
typedef SimpleOpRegEntry TSelf;
/*! \brief name of the operator */
std::string name;
/*!
* \brief set a seperate name for symbol
* This must be called before set_function.
* Default: this is set to be same as the name of operator.
* \param symbol_name the name of symbolic operator.
*/
virtual TSelf& set_symbol_op_name(char const* symbol_name) = 0;
/*!
* \brief set number of scalar arguments needed to be passed in env
* A function cannot have both kwargs and scalar arguments.
* Default: this is set to false
* \param enable_scalar whether to enable scalar argument
* \param type_mask the position of the scalar argument.
*/
virtual TSelf& set_enable_scalar(
bool enable_scalar,
SimpleOpScalarOption type_mask = kArrayBeforeScalar) = 0;
/*!
* \brief set whether to enable kwargs
* A function cannot have both kwargs and scalar arguments.
* Default: this is set to false
* \param enable_kwargs whether to enable kwargs
*/
virtual TSelf& set_enable_kwargs(bool enable_kwargs) = 0;
/*!
* \brief set resource request
* By default there is no resource request.
* The resource will be presented in both forward and backward.
* \param reqs the request.
*/
virtual TSelf& set_resource_request(
const std::vector<ResourceRequest>& reqs) = 0;
/*!
* \brief set resource request
* By default there is no resource request.
* The resource will be presented in both forward and backward.
* \param req the request.
*/
virtual TSelf& set_resource_request(ResourceRequest req) = 0;
/*!
* \brief set source inference function.
* \param fshapeinfer The source function that peforms the operation.
*/
virtual TSelf& set_shape_function(SourceShapeFunction fshapeinfer) = 0;
/*!
* \brief set shape inference function.
* Default: out_shape = in_shape
* \param fshapeinfer The unary function that peforms the operation.
*/
virtual TSelf& set_shape_function(UnaryShapeFunction fshapeinfer) = 0;
/*!
* \brief set shape inference function to be the binary inference function
* Default: out_shape = lhs_shape, and lhs_shape must equal rhs_shape.
* \param fshapeinfer The binary function that peforms the operation.
*/
virtual TSelf& set_shape_function(BinaryShapeFunction fshapeinfer) = 0;
/*!
* \brief set function of the function to be fsource
* \param dev_mask The device mask of the function can act on.
* \param fsource The unary function that peforms the operation.
* \param register_symbolic Whether register a symbolic operator as well.
*/
virtual TSelf& set_function(
int dev_mask,
SourceFunction fsource,
SimpleOpRegOption register_symbolic = kRegisterSymbolic) = 0;
/*!
* \brief set function of the function to be funary
* \param dev_mask The device mask of the function can act on.
* \param funary The unary function that peforms the operation.
* \param inplace_in_out Whether do inplace optimization on in and out.
* \param register_symbolic Whether register a symbolic operator as well.
*/
virtual TSelf& set_function(
int dev_mask,
UnaryFunction funary,
SimpleOpInplaceOption inplace_in_out,
SimpleOpRegOption register_symbolic = kRegisterSymbolic) = 0;
/*!
* \brief set function of the function to be funary
* \param dev_mask The device mask of the function can act on.
* \param fbinary The binary function that peforms the operation.
* \param inplace_lhs_out Whether do inplace optimization on lhs and out.
* \param register_symbolic Whether register a symbolic operator as well.
*/
virtual TSelf& set_function(
int dev_mask,
BinaryFunction fbinary,
SimpleOpInplaceOption inplace_lhs_out,
SimpleOpRegOption register_symbolic = kRegisterSymbolic) = 0;
/*!
* \brief set gradient of the function of this function.
* \param dev_mask The device mask of the function can act on.
* \param fgrad The gradient function to be set.
* \param inplace_out_in_grad whether out_grad and in_grad can share memory.
*/
virtual TSelf& set_gradient(int dev_mask,
UnaryGradFunctionT0 fgrad,
SimpleOpInplaceOption inplace_out_in_grad) = 0;
/*!
* \brief set gradient of the function of this function.
* \param dev_mask The device mask of the function can act on.
* \param fgrad The gradient function to be set.
* \param inplace_out_in_grad whether out_grad and in_grad can share memory.
*/
virtual TSelf& set_gradient(int dev_mask,
UnaryGradFunctionT1 fgrad,
SimpleOpInplaceOption inplace_out_in_grad) = 0;
/*!
* \brief set gradient of the function of this function.
* \param dev_mask The device mask of the function can act on.
* \param fgrad The gradient function to be set.
* \param inplace_out_in_grad whether out_grad and in_grad can share memory.
*/
virtual TSelf& set_gradient(int dev_mask,
UnaryGradFunctionT2 fgrad,
SimpleOpInplaceOption inplace_out_in_grad) = 0;
/*!
* \brief set gradient of the function of this function.
* \param dev_mask The device mask of the function can act on.
* \param fgrad The gradient function to be set.
* \param inplace_out_lhs_grad whether out_grad and lhs_grad can share memory.
*/
virtual TSelf& set_gradient(int dev_mask,
BinaryGradFunctionT0 fgrad,
SimpleOpInplaceOption inplace_out_lhs_grad) = 0;
/*!
* \brief set gradient of the function of this function.
* \param dev_mask The device mask of the function can act on.
* \param fgrad The gradient function to be set.
* \param inplace_out_lhs_grad whether out_grad and lhs_grad can share memory.
*/
virtual TSelf& set_gradient(int dev_mask,
BinaryGradFunctionT1 fgrad,
SimpleOpInplaceOption inplace_out_lhs_grad) = 0;
/*!
* \brief Describe the function.
* \param description The description of the function.
* \return reference to self.
*/
virtual TSelf& describe(const std::string &description) = 0;
/*!
* \brief Describe the function.
* \param args argument information.
* Add additional arguments to the function.
* \return reference to self.
*/
virtual TSelf& add_arguments(const std::vector<dmlc::ParamFieldInfo> &args) = 0;
/*! \brief virtual destructor */
virtual ~SimpleOpRegEntry() {}
};
/*! \brief registry for TBlob functions */
class SimpleOpRegistry {
public:
/*!
* \brief Internal function to register a name function under name.
* \param name name of the function
* \return ref to the registered entry, used to set properties
*/
SimpleOpRegEntry &__REGISTER_OR_FIND__(char const* name);
/*!
* \brief Find the entry with corresponding name.
* \param name name of the function
* \return the corresponding function, can be NULL
*/
inline static const SimpleOpRegEntry *Find(const std::string &name) {
return Get()->fmap_.at(name);
}
/*! \return global singleton of the registry */
static SimpleOpRegistry* Get();
private:
// destructor
~SimpleOpRegistry();
/*! \brief internal registry map */
std::map<std::string, SimpleOpRegEntry*> fmap_;
};
/*!
* \brief assign the expression to out according to request
* \param out the data to be assigned
* \param req the assignment request
* \param exp the expression
* \tparam OType output type
* \tparam Exp expression type
*/
#define ASSIGN_DISPATCH(out, req, exp) \
{ \
switch (req) { \
case kNullOp: \
break; \
case kWriteTo: \
case kWriteInplace: \
(out) = (exp); \
break; \
case kAddTo: \
(out) += (exp); \
break; \
default: \
LOG(FATAL) << "not reached"; \
} \
}
/*!
* \brief Maximum ndim supported for special operators like broadcasting with non contiguous lhs/rhs
*/
#define MXNET_SPECIAL_MAX_NDIM 5
//--------------------------------------------------------------
// The following part are API Registration of Simple Operators
//--------------------------------------------------------------
/*!
* \brief Macro to register simple operator to both imperative and symbolic API.
*
* see src/operator/elementwise_unary_op-inl.h for example
*
* \code
* // example of registering a sigmoid operator on GPU
* // MySigmoid is of type UnaryFunction,
* // MySigmoidGrad is of type UnaryGradFunctionT2
*
* MXNET_REGISTER_SIMPLE_OP(sigmoid, cpu)
* .set_function(MySigmoid<gpu>, true)
* .set_gradient(MySigmoidGrad<gpu>, true)
* .describe("Sigmoid function");
*
* \endcode
*/
#define MXNET_REGISTER_SIMPLE_OP(Name, DEV) \
static ::mxnet::op::SimpleOpRegEntry & \
__make_ ## SimpleOpRegEntry ## _ ## Name ## __ ## DEV ##__ = \
::mxnet::op::SimpleOpRegistry::Get()->__REGISTER_OR_FIND__(#Name)
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_UTIL_H_