Skip to content
8 changes: 4 additions & 4 deletions framework/api/nn/apibase.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def run(self, res, data=None, **kwargs):
paddle.seed(self.seed)
self._check_params(res, data, **kwargs)
dygraph_forward_res = self._dygraph_forward()
if isinstance(dygraph_forward_res, (list)):
if isinstance(dygraph_forward_res, (list, tuple)):
compare(dygraph_forward_res, res, self.delta, self.rtol)
else:
compare(dygraph_forward_res.numpy(), res, self.delta, self.rtol)
Expand Down Expand Up @@ -267,7 +267,7 @@ def _baserun(self, res, data=None, **kwargs):
self._check_params(res, data, **kwargs)
dygraph_forward_res = self._dygraph_forward()
logging.info("dygraph forward result is :")
if isinstance(dygraph_forward_res, (list)):
if isinstance(dygraph_forward_res, (list, tuple)):
compare(dygraph_forward_res, res, self.delta, self.rtol)
logging.info(dygraph_forward_res)
else:
Expand Down Expand Up @@ -328,7 +328,7 @@ def _baserun(self, res, data=None, **kwargs):
paddle.seed(self.seed)
self._check_params(res, data, **kwargs)
dygraph_forward_res = self._dygraph_forward()
if isinstance(dygraph_forward_res, (list)):
if isinstance(dygraph_forward_res, (list, tuple)):
compare(dygraph_forward_res, res, self.delta, self.rtol)
else:
compare(dygraph_forward_res.numpy(), res, self.delta, self.rtol)
Expand Down Expand Up @@ -687,7 +687,7 @@ def compare(result, expect, delta=1e-6, rtol=1e-5):
assert res
# tools.assert_equal(result.shape, expect.shape)
assert result.shape == expect.shape
elif isinstance(result, list):
elif isinstance(result, (list, tuple)):
for i, j in enumerate(result):
if isinstance(j, (np.generic, np.ndarray)):
compare(j, expect[i], delta, rtol)
Expand Down
Loading