@@ -43,6 +43,58 @@ std::string json_vector<std::string>(
4343}
4444
4545#ifdef PADDLE_WITH_CUPTI
46+
47+ #ifdef PADDLE_WITH_HIP
48+
49+ #include " hip/hip_runtime.h"
50+ float CalculateEstOccupancy (uint32_t DeviceId,
51+ int32_t DynamicSharedMemory,
52+ int32_t BlockX,
53+ int32_t BlockY,
54+ int32_t BlockZ,
55+ void * kernelFunc,
56+ uint8_t launchType) {
57+ float occupancy = 0.0 ;
58+ std::vector<int > device_ids = GetSelectedDevices ();
59+ if (DeviceId < device_ids.size ()) {
60+ const gpuDeviceProp& device_property = GetDeviceProperties (DeviceId);
61+ int blockSize = BlockX * BlockY * BlockZ;
62+ int numBlock = 0 ;
63+ hipError_t status;
64+ if (launchType == 0 ) {
65+ status = hipOccupancyMaxActiveBlocksPerMultiprocessor (
66+ &numBlock, kernelFunc, blockSize, DynamicSharedMemory);
67+ if (status == hipSuccess) {
68+ occupancy = static_cast <double >(numBlock) * blockSize /
69+ device_property.maxThreadsPerMultiProcessor ;
70+ } else {
71+ LOG (WARNING) << " Failed to calculate estimated occupancy, status = "
72+ << status << std::endl;
73+ }
74+ } else if (launchType == 100 ) {
75+ status = hipModuleOccupancyMaxActiveBlocksPerMultiprocessor (
76+ &numBlock,
77+ reinterpret_cast <hipFunction_t>(kernelFunc),
78+ blockSize,
79+ DynamicSharedMemory);
80+ if (status == hipSuccess) {
81+ occupancy = static_cast <double >(numBlock) * blockSize /
82+ device_property.maxThreadsPerMultiProcessor ;
83+ } else {
84+ LOG (WARNING) << " Failed to calculate estimated occupancy, status = "
85+ << status << std::endl;
86+ }
87+ } else {
88+ LOG (WARNING) << " Failed to calculate estimated occupancy, can not "
89+ " recognize launchType : "
90+ << launchType << std::endl;
91+ }
92+ }
93+ return occupancy;
94+ }
95+
96+ #else
97+
4698float CalculateEstOccupancy (uint32_t DeviceId,
4799 uint16_t RegistersPerThread,
48100 int32_t StaticSharedMemory,
@@ -88,7 +140,9 @@ float CalculateEstOccupancy(uint32_t DeviceId,
88140 }
89141 return occupancy;
90142}
91- #endif
143+ #endif // PADDLE_WITH_HIP
144+
145+ #endif // PADDLE_WITH_CUPTI
92146
93147const char * StringTracerMemEventType (TracerMemEventType type) {
94148 static const char * categary_name_[] = {
0 commit comments