Skip to content
12 changes: 10 additions & 2 deletions python/paddle/fluid/tests/unittests/test_profiler_statistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,12 @@ def test_statistic_case2(self):
'Communication', profiler.TracerEventType.Communication, 105, 110,
1000, 1001)

reduce_all_op1 = HostPythonNode('cudalaunchkernel',
reduce_all_op1 = HostPythonNode('reduce_all_op1',
profiler.TracerEventType.Operator, 105,
108, 1000, 1001)
reduce_all_op1_infershape = HostPythonNode(
'reduce_all_op1::infershape',
profiler.TracerEventType.OperatorInner, 105, 106, 1000, 1001)

reduce_all_launchkernel1 = HostPythonNode(
'cudalaunchkernel', profiler.TracerEventType.CudaRuntime, 106, 107,
Expand Down Expand Up @@ -306,6 +309,9 @@ def test_statistic_case2(self):
profiler.TracerEventType.Operator,
230, 250, 1000, 1001)

reduce_all_node2_infershape = HostPythonNode(
'reduce_all_node2::infershape',
profiler.TracerEventType.OperatorInner, 231, 232, 1000, 1001)
reduce_all_launchkernel2 = HostPythonNode(
'cudalaunchkernel', profiler.TracerEventType.CudaRuntime, 235, 240,
1000, 1001)
Expand All @@ -326,6 +332,7 @@ def test_statistic_case2(self):
userdefined_node.runtime_node.append(reduce_all_launchkernel0)
reduce_all_launchkernel0.device_node.append(nccl_reduce_all_kernel0)
communication_node.children_node.append(reduce_all_op1)
reduce_all_op1.children_node.append(reduce_all_op1_infershape)
reduce_all_op1.runtime_node.append(reduce_all_launchkernel1)
reduce_all_launchkernel1.device_node.append(nccl_reduce_all_kernel1)
conv2d_node.children_node.extend(
Expand All @@ -344,6 +351,7 @@ def test_statistic_case2(self):
sync_batch_norm_launchkernel.device_node.append(sync_batch_norm_kernel)
sync_batch_norm_cudaMemCpy.device_node.append(sync_batch_norm_memcpy)
optimization_node.children_node.append(reduce_all_node2)
reduce_all_node2.children_node.append(reduce_all_node2_infershape)
reduce_all_node2.runtime_node.append(reduce_all_launchkernel2)
reduce_all_launchkernel2.device_node.append(nccl_reduce_all_kernel2)
thread_tree = {'thread1001': root_node}
Expand Down Expand Up @@ -374,7 +382,7 @@ def test_statistic_case2(self):
profiler.TracerEventType.Operator), 78)
self.assertEqual(
time_range_summary.get_cpu_range_sum(
profiler.TracerEventType.OperatorInner), 45)
profiler.TracerEventType.OperatorInner), 47)
self.assertEqual(
time_range_summary.get_cpu_range_sum(
profiler.TracerEventType.CudaRuntime), 38)
Expand Down
Loading