Skip to content

Commit

Permalink
add profiler_range (#59634)
Browse files Browse the repository at this point in the history
* add profiler_range

* add test cases and fix logic

* Update test_job_schedule_profiler_range.py

* Update CMakeLists.txt

* Update CMakeLists.txt

* add test case
  • Loading branch information
AndSonder authored Dec 7, 2023
1 parent 9498ed4 commit b9d9b17
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 0 deletions.
17 changes: 17 additions & 0 deletions python/paddle/profiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,20 @@ def _nvprof_range(iter_id, start, end, exit_after_prof=True):
core.nvprof_stop()
if exit_after_prof:
sys.exit()


@contextmanager
def job_schedule_profiler_range(iter_id, start, end, exit_after_prof=True):
if start >= end:
yield False
return

try:
if iter_id >= start and iter_id < end:
yield True
else:
yield False
finally:
if iter_id == end - 1:
if exit_after_prof:
sys.exit()
3 changes: 3 additions & 0 deletions test/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,6 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
# End of unittests WITH single card WITHOUT timeout

endif()

py_test_modules(test_job_schedule_profiler_range MODULES
test_job_schedule_profiler_range)
91 changes: 91 additions & 0 deletions test/auto_parallel/test_job_schedule_profiler_range.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from paddle.profiler.utils import job_schedule_profiler_range


class TestJobScheDuleProfilerRange(unittest.TestCase):
def test_not_exit_after_prof_1(self):
status_list = [
False,
False,
False,
True,
True,
True,
False,
False,
False,
False,
]
for i in range(10):
with job_schedule_profiler_range(i, 3, 6, False) as status:
self.assertEqual(status, status_list[i])

def test_not_exit_after_prof_2(self):
status_list = [
True,
True,
True,
True,
True,
False,
False,
False,
False,
False,
]
for i in range(10):
with job_schedule_profiler_range(i, 0, 5, False) as status:
self.assertEqual(status, status_list[i])

def test_not_exit_after_prof_3(self):
status_list = [
False,
False,
False,
True,
True,
False,
False,
False,
False,
False,
]
for i in range(10):
with job_schedule_profiler_range(i, 3, 5, False) as status:
self.assertEqual(status, status_list[i])

def test_end_less_than_start(self):
status_list = [
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
]
for i in range(10):
with job_schedule_profiler_range(i, 5, 3, False) as status:
self.assertEqual(status, status_list[i])


if __name__ == "__main__":
unittest.main()

0 comments on commit b9d9b17

Please sign in to comment.