Skip to content

Commit

Permalink
[CINN] Enable test_resnet50_with_cinn (PaddlePaddle#44017)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhhsplendid authored Jul 4, 2022
1 parent a42f48b commit cf8e86d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
19 changes: 19 additions & 0 deletions python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1687,3 +1687,22 @@ if($ENV{USE_STANDALONE_EXECUTOR})
set_tests_properties(test_imperative_mnist_sorted_gradient
PROPERTIES ENVIRONMENT FLAGS_USE_STANDALONE_EXECUTOR=0)
endif()

if(WITH_CINN AND WITH_TESTING)
set_tests_properties(
test_resnet50_with_cinn
PROPERTIES
LABELS
"RUN_TYPE=CINN"
ENVIRONMENT
FLAGS_allow_cinn_ops="conv2d;conv2d_grad;elementwise_add;elementwise_add_grad;relu;relu_grad;sum"
)
set_tests_properties(
test_parallel_executor_run_cinn
PROPERTIES
LABELS
"RUN_TYPE=CINN"
ENVIRONMENT
FLAGS_allow_cinn_ops="conv2d;conv2d_grad;elementwise_add;elementwise_add_grad;relu;relu_grad;sum"
)
endif()
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def test_check_resnet50_accuracy(self):

loss_c = self.train(place, loop_num, feed, use_cinn=True)
loss_p = self.train(place, loop_num, feed, use_cinn=False)
print("Losses of CINN:")
print(loss_c)
print("Losses of Paddle")
print(loss_p)
self.assertTrue(np.allclose(loss_c, loss_p, atol=1e-5))


Expand Down

0 comments on commit cf8e86d

Please sign in to comment.