34
34
MRT_CFG .EVALUATE .ITER_NUM = 10
35
35
36
36
def forward (net , data , ctx , baxis , olen ):
37
- #TODO(ryt.dev) documentation
38
- """ Multiple xpu run support.
37
+ """
38
+ Multiple xpu run support.
39
+
40
+ Parameters
41
+ ----------
42
+ net : mxnet.gluon.block.SymbolBlock
43
+ Graph for inference.
44
+ data : mxnet.ndarray.ndarray.NDArray
45
+ Input data to pass into the graph.
46
+ ctx : mx.context.Context
47
+ Context for inference.
48
+ baxis : int
49
+ Axis id of batch dimension.
50
+ olen : int
51
+ Length of the output.
52
+
53
+ Returns
54
+ -------
55
+ outs : mxnet.ndarray.ndarray.NDArray or list
56
+ inference result of the graph with respect to the given input data,
57
+ for multiple outputs, outs will be a list the entry type of which is
58
+ mxnet.ndarray.ndarray.NDArray.
39
59
"""
40
60
data = gluon .utils .split_and_load (
41
61
data , ctx_list = ctx , batch_axis = baxis , even_split = False )
@@ -48,6 +68,19 @@ def forward(net, data, ctx, baxis, olen):
48
68
return outs
49
69
50
70
def get_evaluation_info (cm_cfg , pass_cfg , logger = None ):
71
+ """
72
+ YAML configuration API to get evaluation function,
73
+ quantization function and dataset iteration function
74
+
75
+ Parameters
76
+ ----------
77
+ cm_cfg : yacs.config.CfgNode
78
+ CfgNode of common stage.
79
+ pass_cfg : yacs.config.CfgNode
80
+ CfgNode of calibration stage.
81
+ logger : logging.RootLogger
82
+ Console logger.
83
+ """
51
84
model_dir = cm_cfg .MODEL_DIR
52
85
model_name = cm_cfg .MODEL_NAME
53
86
verbosity = cm_cfg .VERBOSITY
@@ -134,6 +167,18 @@ def quantize(data, label):
134
167
return evalfunc , data_iter_func , quantize
135
168
136
169
def evaluate (cm_cfg , pass_cfg , logger = None ):
170
+ """
171
+ YAML configuration API of MRT evaluation stage.
172
+
173
+ Parameters
174
+ ----------
175
+ cm_cfg : yacs.config.CfgNode
176
+ CfgNode of common stage.
177
+ pass_cfg : yacs.config.CfgNode
178
+ CfgNode of calibration stage.
179
+ logger : logging.RootLogger
180
+ Console logger.
181
+ """
137
182
evalfunc , data_iter_func , quantize = get_evaluation_info (
138
183
cm_cfg , pass_cfg , logger = logger )
139
184
@@ -152,7 +197,19 @@ def evaluate(cm_cfg, pass_cfg, logger=None):
152
197
logger .info ("evaluatation stage skipped" )
153
198
154
199
def get_ctx_eval (ctx ):
155
- #TODO(ryt.dev) documentation
200
+ """
201
+ Get the context instance for evaluation stage
202
+
203
+ Parameters
204
+ ----------
205
+ ctx : mx.context.Context
206
+ The input context.
207
+
208
+ Returns
209
+ -------
210
+ ctx : mx.context.Context
211
+ The modified context.
212
+ """
156
213
if isinstance (ctx , mx .Context ):
157
214
ctx = [ctx ]
158
215
elif isinstance (ctx , list ):
@@ -166,7 +223,31 @@ def inference_original_model(
166
223
symbol_file , params_file , data , batch_axis = 0 ,
167
224
device_type = MRT_CFG .EVALUATE .DEVICE_TYPE ,
168
225
device_ids = MRT_CFG .EVALUATE .DEVICE_IDS ):
169
- #TODO(ryt.dev) documentation
226
+ """
227
+ MRT Inference API for original model.
228
+
229
+ Parameters
230
+ ----------
231
+ symbol_file : str
232
+ Path to the quantized mxnet symbol JSON file.
233
+ params_file : str
234
+ Path to the quantized mxnet parameters file.
235
+ data: mxnet.ndarray.ndarray.NDArray
236
+ Input data to pass into the graph.
237
+ batch_axis : int
238
+ Axis id of batch dimension.
239
+ device_type : str
240
+ Context type string chosen from `cpu` or `gpu`.
241
+ device_ids : list
242
+ List of context ids.
243
+
244
+ Returns
245
+ -------
246
+ outs : mxnet.ndarray.ndarray.NDArray or list
247
+ inference result of the graph with respect to the given input data,
248
+ for multiple outputs, outs will be a list the entry type of which is
249
+ mxnet.ndarray.ndarray.NDArray.
250
+ """
170
251
171
252
ctx = get_ctx_eval (get_ctx (device_type , device_ids ))
172
253
omodel = Model .load (symbol_file , params_file )
@@ -180,7 +261,35 @@ def inference_quantized_model(
180
261
qsymbol_file , qparams_file , qext_file , data , batch_axis = 0 , split = False ,
181
262
device_type = MRT_CFG .EVALUATE .DEVICE_TYPE ,
182
263
device_ids = MRT_CFG .EVALUATE .DEVICE_IDS ):
183
- #TODO(ryt.dev) documentation
264
+ """
265
+ MRT Inference API for quantized model.
266
+
267
+ Parameters
268
+ ----------
269
+ qsymbol_file : str
270
+ Path to the quantized mxnet symbol JSON file.
271
+ qparams_file : str
272
+ Path to the quantized mxnet parameters file.
273
+ qext_file : str
274
+ Path to the quantized extension file which store intermediate results.
275
+ data: mxnet.ndarray.ndarray.NDArray
276
+ Input data to pass into the graph.
277
+ batch_axis : int
278
+ Axis id of batch dimension.
279
+ split: bool
280
+ Flag indicating whether the model is split before quantization.
281
+ device_type : str
282
+ Context type string chosen from `cpu` or `gpu`.
283
+ device_ids : list
284
+ List of context ids.
285
+
286
+ Returns
287
+ -------
288
+ outs : mxnet.ndarray.ndarray.NDArray or list
289
+ inference result of the graph with respect to the given input data,
290
+ for multiple outputs, outs will be a list the entry type of which is
291
+ mxnet.ndarray.ndarray.NDArray.
292
+ """
184
293
185
294
ctx = get_ctx_eval (get_ctx (device_type , device_ids ))
186
295
0 commit comments