Skip to content

Commit 094ced5

Browse files
【custom】add Custom pass list in LoadCustomRuntimeLib and analysis_predictor will use it in customplace (#71362)
* add customdevice default_pass * add declare * rm stdmove * add custom_load pass
1 parent 0f78075 commit 094ced5

File tree

7 files changed

+42
-2
lines changed

7 files changed

+42
-2
lines changed

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
#include "paddle/phi/api/include/context_pool.h"
6060
#include "paddle/phi/api/include/tensor.h"
6161
#include "paddle/phi/backends/context_pool.h"
62+
#include "paddle/phi/backends/device_manager.h"
6263
#include "paddle/phi/common/backend.h"
6364
#include "paddle/phi/common/data_type.h"
6465
#include "paddle/phi/common/place.h"
@@ -992,6 +993,8 @@ void AnalysisPredictor::OptimizeInferencePirProgram() {
992993
} else if (config_.use_custom_device()) {
993994
// custom device
994995
if (!config_.custom_pass_only_) {
996+
auto kPirCustomDevicePasses =
997+
phi::CustomDevicePassManager::Instance()->GetCustomDevicePass();
995998
for (const auto &custom_device_pass : kPirCustomDevicePasses) {
996999
if (std::find(config_.deleted_passes_.begin(),
9971000
config_.deleted_passes_.end(),

paddle/fluid/inference/api/paddle_pass_builder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ class PD_INFER_DECL IpuPassStrategy final : public PassStrategy {
323323
deleted_passes_ = other.deleted_passes_;
324324
}
325325
};
326+
326327
#ifdef PADDLE_WITH_OPENVINO
327328
/// \brief List of OpenVINO subgraph passes.
328329
PD_INFER_DECL extern const std::vector<std::string> kOVSubgraphPasses;

paddle/phi/backends/custom/custom_device.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,6 +1060,11 @@ void LoadCustomRuntimeLib(const CustomRuntimeParams& runtime_params,
10601060
<< "]. Register failed!!! there may be a "
10611061
"Custom Runtime with the same name.";
10621062
}
1063+
if (runtime_params.pir_default_passes != nullptr) {
1064+
CustomDevicePassManager::Instance()->SetCustomDevicePass(
1065+
*(reinterpret_cast<std::vector<std::string>*>(
1066+
runtime_params.pir_default_passes)));
1067+
}
10631068
} else {
10641069
LOG(WARNING) << "Skipped lib [" << dso_lib_path
10651070
<< "]. Wrong Runtime parameters!!! please check the version "

paddle/phi/backends/custom/fake_cpu_device.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,13 +225,16 @@ C_Status BlasAXPBY(const C_Device device,
225225
#define DEVICE_TYPE "FakeCPU"
226226
#define SUB_DEVICE_TYPE "V100"
227227

228+
std::vector<std::string> fake_cpu_device_pass_list = {"fake_cpu_device_pass"};
229+
228230
void InitFakeCPUDevice(CustomRuntimeParams *params) {
229231
params->device_type = const_cast<char *>(DEVICE_TYPE);
230232
params->sub_device_type = const_cast<char *>(SUB_DEVICE_TYPE);
231233
params->version.major = PADDLE_CUSTOM_RUNTIME_MAJOR_VERSION;
232234
params->version.minor = PADDLE_CUSTOM_RUNTIME_MINOR_VERSION;
233235
params->version.patch = PADDLE_CUSTOM_RUNTIME_PATCH_VERSION;
234-
236+
params->pir_default_passes =
237+
reinterpret_cast<void *>(&fake_cpu_device_pass_list);
235238
memset(reinterpret_cast<void *>(params->interface),
236239
0,
237240
sizeof(C_DeviceInterface));

paddle/phi/backends/device_ext.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,7 @@ struct CustomRuntimeParams {
715715
// Plugin fill it
716716
char* sub_device_type;
717717

718+
void* pir_default_passes;
718719
char reserved[32];
719720
};
720721

paddle/phi/backends/device_manager.h

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
// limitations under the License.
1414

1515
#pragma once
16-
16+
#include <cstring>
1717
#include <unordered_map>
18+
#include <vector>
1819

1920
#include "paddle/phi/common/data_type.h"
2021
#include "paddle/phi/common/place.h"
@@ -295,6 +296,28 @@ class DeviceManager {
295296

296297
std::vector<std::string> ListAllLibraries(const std::string& library_dir);
297298
#ifdef PADDLE_WITH_CUSTOM_DEVICE
299+
300+
class CustomDevicePassManager {
301+
public:
302+
explicit CustomDevicePassManager(const std::vector<std::string>& passes)
303+
: all_passes_(passes) {}
304+
~CustomDevicePassManager() = default;
305+
static CustomDevicePassManager* Instance() {
306+
std::vector<std::string> passes;
307+
static CustomDevicePassManager manager(passes);
308+
return &manager;
309+
}
310+
void SetCustomDevicePass(const std::vector<std::string>& passes) {
311+
all_passes_ = passes;
312+
}
313+
const std::vector<std::string> GetCustomDevicePass() const {
314+
return all_passes_;
315+
}
316+
317+
private:
318+
std::vector<std::string> all_passes_;
319+
};
320+
298321
void LoadCustomRuntimeLib(const CustomRuntimeParams& runtime_params,
299322
std::unique_ptr<C_DeviceInterface> device_interface,
300323
const std::string& dso_lib_path,

test/cpp/fluid/platform/device/custom/custom_device_test.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ void RegisterDevice() {
3636
InitFakeCPUDevice(&runtime_params);
3737
phi::LoadCustomRuntimeLib(
3838
runtime_params, std::move(device_interface), "", nullptr);
39+
40+
std::vector<std::string> passes =
41+
phi::CustomDevicePassManager::Instance()->GetCustomDevicePass();
42+
EXPECT_EQ(passes[0], "fake_cpu_device_pass");
3943
}
4044

4145
void InitDevice() {

0 commit comments

Comments
 (0)