Skip to content

Commit

Permalink
TracedLayer Error Message Enhancement (PaddlePaddle#25734)
Browse files Browse the repository at this point in the history
Enhance TracedLayer Error Message

Note: this PR uses assert to check type somewhere and check_type somewhere, the reason is that the check_type skips checking when it is under dygraph mode.
  • Loading branch information
zhhsplendid authored Jul 29, 2020
1 parent c9285a1 commit b3f58d3
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 9 deletions.
44 changes: 38 additions & 6 deletions python/paddle/fluid/dygraph/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

import warnings
from paddle.fluid import core
from paddle.fluid.compiler import CompiledProgram
from paddle.fluid.compiler import BuildStrategy, CompiledProgram, ExecutionStrategy
from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph.base import program_desc_tracing_guard, switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator, FunctionSpec
from paddle.fluid.dygraph.layers import Layer
Expand All @@ -43,10 +44,13 @@ def create_program_from_desc(program_desc):
def _extract_vars(inputs, result_list):
if isinstance(inputs, Variable):
result_list.append(inputs)

if isinstance(inputs, (list, tuple)):
elif isinstance(inputs, (list, tuple)):
for var in inputs:
_extract_vars(var, result_list)
else:
raise TypeError(
"The type of 'each element of inputs' in fluid.dygraph.jit.TracedLayer.trace must be fluid.Variable, but received {}.".
format(type(inputs)))


def extract_vars(inputs):
Expand Down Expand Up @@ -1063,12 +1067,13 @@ def trace(layer, inputs):
Args:
layer (dygraph.Layer): the layer object to be traced.
inputs (list(Variable)): the input variables of the layer object.
inputs (list(Tensor)|tuple(Tensor)|Tensor): the input tensors of
the layer object.
Returns:
tuple: A tuple of 2 items, whose the first item is the output of
:code:`layer(*inputs)` , and the second item is the created
TracedLayer object.
:code:`layer(*inputs)` , and the second item is the created
TracedLayer object.
Examples:
.. code-block:: python:
Expand Down Expand Up @@ -1100,6 +1105,10 @@ def forward(self, input):
# save the static graph model for inference
static_layer.save_inference_model(dirname='./saved_infer_model')
"""
assert isinstance(
layer, Layer
), "The type of 'layer' in fluid.dygraph.jit.TracedLayer.trace must be fluid.dygraph.Layer, but received {}.".format(
type(layer))
outs, prog, feed, fetch, parameters = _trace(layer, inputs)
traced = TracedLayer(prog, parameters, feed, fetch)
return outs, traced
Expand Down Expand Up @@ -1149,6 +1158,14 @@ def forward(self, input):
out_static_graph = static_layer([in_var])
"""
assert self._compiled_program is None, "Cannot set strategy after run"
assert isinstance(
build_strategy, (type(None), BuildStrategy)
), "The type of 'build_strategy' in fluid.dygraph.jit.TracedLayer.set_strategy must be fluid.BuildStrategy, but received {}.".format(
type(build_strategy))
assert isinstance(
exec_strategy, (type(None), ExecutionStrategy)
), "The type of 'exec_strategy' in fluid.dygraph.jit.TracedLayer.set_strategy must be fluid.ExecutionStrategy, but received {}.".format(
type(exec_strategy))
self._build_strategy = build_strategy
self._exec_strategy = exec_strategy

Expand Down Expand Up @@ -1239,6 +1256,21 @@ def forward(self, input):
fetch, = exe.run(program, feed={feed_vars[0]: in_np}, fetch_list=fetch_vars)
print(fetch.shape) # (2, 10)
"""
check_type(dirname, "dirname", str,
"fluid.dygraph.jit.TracedLayer.save_inference_model")
check_type(feed, "feed", (type(None), list),
"fluid.dygraph.jit.TracedLayer.save_inference_model")
if isinstance(feed, list):
for f in feed:
check_type(f, "each element of feed", int,
"fluid.dygraph.jit.TracedLayer.save_inference_model")
check_type(fetch, "fetch", (type(None), list),
"fluid.dygraph.jit.TracedLayer.save_inference_model")
if isinstance(fetch, list):
for f in fetch:
check_type(f, "each element of fetch", int,
"fluid.dygraph.jit.TracedLayer.save_inference_model")

from paddle.fluid.io import save_inference_model

def get_feed_fetch(all_vars, partial_vars):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def ptb_rnn_cpu_float32(self, is_sparse):
program = traced_layer.program

traced_layer.save_inference_model(
'./infe_imperative_ptb_rnn', feed=range(4))
'./infe_imperative_ptb_rnn', feed=list(range(4)))
else:
outs = ptb_model(x, y, init_hidden, init_cell)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1010,8 +1010,8 @@ def transformer_sort_gradient_float32(self, is_sparse):
program = traced_layer.program
traced_layer.save_inference_model(
'./infer_imperative_transformer',
feed=range(len(ins_static)),
fetch=range(len(outs_static)))
feed=list(range(len(ins_static))),
fetch=list(range(len(outs_static))))
else:
outs = transformer(enc_inputs, dec_inputs, label, weights)

Expand Down
156 changes: 156 additions & 0 deletions python/paddle/fluid/tests/unittests/test_traced_layer_err_msg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle.fluid as fluid
import six
import unittest


class SimpleFCLayer(fluid.dygraph.Layer):
def __init__(self, feature_size, batch_size, fc_size):
super(SimpleFCLayer, self).__init__()
self._linear = fluid.dygraph.Linear(feature_size, fc_size)
self._offset = fluid.dygraph.to_variable(
np.random.random((batch_size, fc_size)).astype('float32'))

def forward(self, x):
fc = self._linear(x)
return fc + self._offset


class TestTracedLayerErrMsg(unittest.TestCase):
def setUp(self):
self.batch_size = 4
self.feature_size = 3
self.fc_size = 2
self.layer = self._train_simple_net()
if six.PY2:
self.type_str = 'type'
else:
self.type_str = 'class'

def test_trace_err(self):
with fluid.dygraph.guard():
in_x = fluid.dygraph.to_variable(
np.random.random((self.batch_size, self.feature_size)).astype(
'float32'))

with self.assertRaises(AssertionError) as e:
dygraph_out, traced_layer = fluid.dygraph.TracedLayer.trace(
None, [in_x])
self.assertEqual(
"The type of 'layer' in fluid.dygraph.jit.TracedLayer.trace must be fluid.dygraph.Layer, but received <{} 'NoneType'>.".
format(self.type_str), str(e.exception))
with self.assertRaises(TypeError) as e:
dygraph_out, traced_layer = fluid.dygraph.TracedLayer.trace(
self.layer, 3)
self.assertEqual(
"The type of 'each element of inputs' in fluid.dygraph.jit.TracedLayer.trace must be fluid.Variable, but received <{} 'int'>.".
format(self.type_str, self.type_str), str(e.exception))
with self.assertRaises(TypeError) as e:
dygraph_out, traced_layer = fluid.dygraph.TracedLayer.trace(
self.layer, [True, 1])
self.assertEqual(
"The type of 'each element of inputs' in fluid.dygraph.jit.TracedLayer.trace must be fluid.Variable, but received <{} 'bool'>.".
format(self.type_str), str(e.exception))

dygraph_out, traced_layer = fluid.dygraph.TracedLayer.trace(
self.layer, [in_x])

def test_set_strategy_err(self):
with fluid.dygraph.guard():
in_x = fluid.dygraph.to_variable(
np.random.random((self.batch_size, self.feature_size)).astype(
'float32'))
dygraph_out, traced_layer = fluid.dygraph.TracedLayer.trace(
self.layer, [in_x])

with self.assertRaises(AssertionError) as e:
traced_layer.set_strategy(1, fluid.ExecutionStrategy())
self.assertEqual(
"The type of 'build_strategy' in fluid.dygraph.jit.TracedLayer.set_strategy must be fluid.BuildStrategy, but received <{} 'int'>.".
format(self.type_str), str(e.exception))

with self.assertRaises(AssertionError) as e:
traced_layer.set_strategy(fluid.BuildStrategy(), False)
self.assertEqual(
"The type of 'exec_strategy' in fluid.dygraph.jit.TracedLayer.set_strategy must be fluid.ExecutionStrategy, but received <{} 'bool'>.".
format(self.type_str), str(e.exception))

traced_layer.set_strategy(build_strategy=fluid.BuildStrategy())
traced_layer.set_strategy(exec_strategy=fluid.ExecutionStrategy())
traced_layer.set_strategy(fluid.BuildStrategy(),
fluid.ExecutionStrategy())

def test_save_inference_model_err(self):
with fluid.dygraph.guard():
in_x = fluid.dygraph.to_variable(
np.random.random((self.batch_size, self.feature_size)).astype(
'float32'))
dygraph_out, traced_layer = fluid.dygraph.TracedLayer.trace(
self.layer, [in_x])

dirname = './traced_layer_err_msg'
with self.assertRaises(TypeError) as e:
traced_layer.save_inference_model([0])
self.assertEqual(
"The type of 'dirname' in fluid.dygraph.jit.TracedLayer.save_inference_model must be <{} 'str'>, but received <{} 'list'>. ".
format(self.type_str, self.type_str), str(e.exception))
with self.assertRaises(TypeError) as e:
traced_layer.save_inference_model(dirname, [0], [None])
self.assertEqual(
"The type of 'each element of fetch' in fluid.dygraph.jit.TracedLayer.save_inference_model must be <{} 'int'>, but received <{} 'NoneType'>. ".
format(self.type_str, self.type_str), str(e.exception))
with self.assertRaises(TypeError) as e:
traced_layer.save_inference_model(dirname, [0], False)
self.assertEqual(
"The type of 'fetch' in fluid.dygraph.jit.TracedLayer.save_inference_model must be (<{} 'NoneType'>, <{} 'list'>), but received <{} 'bool'>. ".
format(self.type_str, self.type_str, self.type_str),
str(e.exception))
with self.assertRaises(TypeError) as e:
traced_layer.save_inference_model(dirname, [None], [0])
self.assertEqual(
"The type of 'each element of feed' in fluid.dygraph.jit.TracedLayer.save_inference_model must be <{} 'int'>, but received <{} 'NoneType'>. ".
format(self.type_str, self.type_str), str(e.exception))
with self.assertRaises(TypeError) as e:
traced_layer.save_inference_model(dirname, True, [0])
self.assertEqual(
"The type of 'feed' in fluid.dygraph.jit.TracedLayer.save_inference_model must be (<{} 'NoneType'>, <{} 'list'>), but received <{} 'bool'>. ".
format(self.type_str, self.type_str, self.type_str),
str(e.exception))

traced_layer.save_inference_model(dirname)

def _train_simple_net(self):
layer = None
with fluid.dygraph.guard():
layer = SimpleFCLayer(self.feature_size, self.batch_size,
self.fc_size)
optimizer = fluid.optimizer.SGD(learning_rate=1e-3,
parameter_list=layer.parameters())

for i in range(5):
in_x = fluid.dygraph.to_variable(
np.random.random((self.batch_size, self.feature_size))
.astype('float32'))
dygraph_out = layer(in_x)
loss = fluid.layers.reduce_mean(dygraph_out)
loss.backward()
optimizer.minimize(loss)
return layer


if __name__ == '__main__':
unittest.main()

0 comments on commit b3f58d3

Please sign in to comment.