forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexecutor.h
180 lines (173 loc) · 7.75 KB
/
executor.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2015 by Contributors
* \file executor.h
* \brief Symbolic executor interface of mxnet.
* \author Min Lin, Bing Xu
*/
#ifndef MXNET_EXECUTOR_H_
#define MXNET_EXECUTOR_H_
#include <dmlc/base.h>
#include <vector>
#include <memory>
#include <map>
#include <string>
#include <utility>
#include "./base.h"
#include "./c_api.h"
#include "./ndarray.h"
#include "./operator.h"
// check c++11
#if DMLC_USE_CXX11 == 0
#error "CXX11 was required for symbolic module"
#endif
namespace mxnet {
/*! \brief use symbolic graph from NNVM */
using nnvm::Symbol;
/*!
* \brief Executor of a computation graph.
* Executor can be created by Binding a symbol.
*/
class Executor {
public:
/*! \brief destructor */
virtual ~Executor() {}
/*!
* \brief Perform a Forward operation of Operator
* After this operation, user can get the result by using function head.
*/
virtual void Forward(bool is_train) = 0;
/*!
* \brief Perform a Partial Forward operation of Operator.
* Only issue operation specified by step.
* The caller must keep calling PartialForward with increasing steps, until step_left=0.
* \param is_train Whether this is training phase.
* \param step current step, user can always start from 0
* \param step_left Number of steps left to finish the forward.
*/
virtual void PartialForward(bool is_train, int step, int *step_left) = 0;
/*!
* \brief Perform a Backward operation of the Operator.
* This must be called after Forward.
* After this operation, NDArrays specified by grad_in_args_store will be updated accordingly.
* User is allowed to pass in an empty Array if the head node is
* loss function and head gradeitn is not needed.
*
* \param head_grads the gradient of head nodes to be backproped.
*/
virtual void Backward(const std::vector<NDArray> &head_grads, bool is_train = true) = 0;
/*!
* \brief print the execution plan info to output stream.
* \param os the output stream we like to print to.
*/
virtual void Print(std::ostream &os) const {} // NOLINT(*)
/*!
* \brief get array of outputs in the executor.
* \return array of outputs in the executor.
*/
virtual const std::vector<NDArray> &outputs() const = 0;
/*!
* \brief get input argument map, key is arg name, value is arg's NDArray.
* \return input argument map in the executor.
*/
virtual const std::unordered_map<std::string, NDArray>& in_arg_map() const = 0;
/*!
* \brief get input argument graident map, key is arg name, value is gradient's NDArray.
* \return input argument gradient map in the executor.
*/
virtual const std::unordered_map<std::string, NDArray>& arg_grad_map() const = 0;
/*!
* \brief get aux state map, key is arg name, value is aux state's NDArray.
* \return aux state map in the executor.
*/
virtual const std::unordered_map<std::string, NDArray>& aux_state_map() const = 0;
/*!
* \brief Return a new executor with the same symbol and shared memory,
* but different input/output shapes.
*
* \param partial_shaping Whether to allow changing the shape of unspecified arguments.
* \param allow_up_sizing Whether to allow allocating new ndarrays that's larger than the original.
* \param default_ctx the default context of binding.
* \param ctx_map Context mapping group to context.
* \param provided_arg_shapes New shape for arguments.
* \param in_args the NDArray that stores the input arguments.
* \param arg_grads NDArray that is used to store the gradient output of the input arguments.
* \param aux_states NDArray that is used as internal states.
* \return a new executor.
*/
virtual Executor* Reshape(const bool partial_shaping,
const bool allow_up_sizing,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::unordered_map<std::string, TShape>&
provided_arg_shapes,
std::vector<NDArray>* in_args,
std::vector<NDArray>* arg_grads,
std::vector<NDArray>* aux_states) = 0;
/*!
* \brief Create an operator by bind symbol with context and arguments.
* If user do not want to compute the gradients of i-th argument, grad_req_type[i] can be kNullOp.
*
* \param default_ctx the default context of binding.
* \param group2ctx Context mapping group to context.
* \param symbol the symbol that specifies the output of Forward pass.
* \param in_args the NDArray that stores the input arguments to the symbol.
* \param arg_grad_store NDArray that is used to store the gradient output of the input arguments.
* \param grad_req_type requirment type of gradient saving. Can only be in {kNullOp, kAddTo, kWriteTo}.
* \param aux_states NDArray that is used as internal state in op
* \param shared_exec input executor to share memory with.
* \return a new executor.
*/
static Executor *Bind(nnvm::Symbol symbol,
const Context& default_ctx,
const std::map<std::string, Context>& group2ctx,
const std::vector<NDArray> &in_args,
const std::vector<NDArray> &arg_grad_store,
const std::vector<OpReqType> &grad_req_type,
const std::vector<NDArray> &aux_states,
Executor* shared_exec = NULL);
static Executor* SimpleBind(nnvm::Symbol symbol,
const Context& default_ctx,
const std::map<std::string, Context>& group2ctx,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
const std::unordered_map<std::string, TShape>& arg_shape_map,
const std::unordered_map<std::string, int>& arg_dtype_map,
const std::unordered_map<std::string, int>& arg_stype_map,
const std::vector<OpReqType>& grad_req_types,
const std::unordered_set<std::string>& param_names,
std::vector<NDArray>* in_args,
std::vector<NDArray>* arg_grads,
std::vector<NDArray>* aux_states,
std::unordered_map<std::string, NDArray>*
shared_data_arrays = nullptr,
Executor* shared_exec = nullptr);
/*!
* \brief the prototype of user-defined monitor callback
*/
typedef std::function<void(const char*, void*)> MonitorCallback;
/*!
* \brief Install a callback to notify the completion of operation.
*/
virtual void SetMonitorCallback(const MonitorCallback& callback) {}
}; // class executor
} // namespace mxnet
#endif // MXNET_EXECUTOR_H_