diff --git a/paddle/fluid/inference/lite/CMakeLists.txt b/paddle/fluid/inference/lite/CMakeLists.txt index 061fe2da55971..851900c0236d1 100644 --- a/paddle/fluid/inference/lite/CMakeLists.txt +++ b/paddle/fluid/inference/lite/CMakeLists.txt @@ -2,5 +2,5 @@ cc_binary(test_leaky_relu SRCS test_leaky_relu.cc DEPS lite_full_static dynload_ cc_library(lite_op_teller SRCS op_teller.cc DEPS framework_proto device_context boost xxhash) cc_library(lite_engine SRCS engine.cc DEPS lite_full_static framework_proto) cc_library(lite_tensor_utils SRCS tensor_utils.cc DEPS memcpy lite_full_static framework_proto boost) -cc_test(test_lite_engine SRCS test_engine.cc DEPS lite_engine protobuf) +cc_test(test_lite_engine SRCS test_engine.cc DEPS lite_engine protobuf framework_proto glog gtest analysis) cc_test(test_lite_predictor SRCS test_predictor.cc DEPS lite_engine paddle_fluid) diff --git a/paddle/fluid/inference/lite/engine.cc b/paddle/fluid/inference/lite/engine.cc index a951516945101..2558e725f13ec 100644 --- a/paddle/fluid/inference/lite/engine.cc +++ b/paddle/fluid/inference/lite/engine.cc @@ -15,20 +15,16 @@ #define LITE_WITH_CUDA 1 #include "paddle/fluid/inference/lite/engine.h" -#include "lite/core/context.h" -#include "lite/core/device_info.h" namespace paddle { namespace inference { namespace lite { -bool EngineManager::Empty() const { - return engines_.size() == 0; -} +bool EngineManager::Empty() const { return engines_.size() == 0; } bool EngineManager::Has(const std::string& name) const { if (engines_.count(name) == 0) { - return false; + return false; } return engines_.at(name).get() != nullptr; } @@ -37,12 +33,12 @@ paddle::lite::Predictor* EngineManager::Get(const std::string& name) const { return engines_.at(name).get(); } -paddle::lite::Predictor* EngineManager::Create( - const std::string& name, const EngineConfig& cfg) { - paddle::lite::Env::Init(); +paddle::lite::Predictor* EngineManager::Create(const std::string& name, + const EngineConfig& cfg) { auto* p = new paddle::lite::Predictor(); - p->Build("", cfg.model, cfg.param, cfg.prefer_place, cfg.valid_places, cfg.neglected_passes, - cfg.model_type, cfg.memory_from_memory); + paddle::lite::Env::Init(); + p->Build("", cfg.model, cfg.param, cfg.prefer_place, cfg.valid_places, + cfg.neglected_passes, cfg.model_type, cfg.model_from_memory); engines_[name].reset(p); return p; } diff --git a/paddle/fluid/inference/lite/engine.h b/paddle/fluid/inference/lite/engine.h index 7e49ad4c08691..f29607490ed17 100644 --- a/paddle/fluid/inference/lite/engine.h +++ b/paddle/fluid/inference/lite/engine.h @@ -15,8 +15,10 @@ #pragma once #include +#include #include #include +#include #include "lite/api/cxx_api.h" @@ -31,7 +33,7 @@ struct EngineConfig { std::vector valid_places; std::vector neglected_passes; lite_api::LiteModelType model_type{lite_api::LiteModelType::kProtobuf}; - bool memory_from_memory{true}; + bool model_from_memory{true}; }; class EngineManager { @@ -39,10 +41,13 @@ class EngineManager { bool Empty() const; bool Has(const std::string& name) const; paddle::lite::Predictor* Get(const std::string& name) const; - paddle::lite::Predictor* Create(const std::string& name, const EngineConfig& cfg); + paddle::lite::Predictor* Create(const std::string& name, + const EngineConfig& cfg); void DeleteAll(); + private: - std::unordered_map> engines_; + std::unordered_map> + engines_; }; } // namespace lite diff --git a/paddle/fluid/inference/lite/tensor_utils.cc b/paddle/fluid/inference/lite/tensor_utils.cc index 9dcec4210707e..8021768ce389b 100644 --- a/paddle/fluid/inference/lite/tensor_utils.cc +++ b/paddle/fluid/inference/lite/tensor_utils.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include "paddle/fluid/inference/lite/engine.h" #include "paddle/fluid/inference/lite/tensor_utils.h" +#include #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/inference/lite/engine.h" namespace paddle { namespace inference { @@ -40,7 +40,20 @@ platform::Place GetNativePlace(const TargetType& type) { } } -framework::proto::VarType::Type GetNativePrecisionType(const PrecisionType& type) { +PrecisionType GetLitePrecisionType(framework::proto::VarType::Type type) { + switch (type) { + case framework::proto::VarType_Type_FP32: + return PrecisionType::kFloat; + case framework::proto::VarType_Type_INT8: + return PrecisionType::kInt8; + default: + LOG(FATAL) << "Error precision type."; + return PrecisionType::kUnk; + } +} + +framework::proto::VarType::Type GetNativePrecisionType( + const PrecisionType& type) { switch (type) { case PrecisionType::kFloat: return framework::proto::VarType_Type_FP32; @@ -63,7 +76,8 @@ framework::DataLayout GetNativeLayoutType(const DataLayoutType& type) { } void MemoryCopy(const platform::Place& dst_place, void* dst_data, - const platform::Place& src_place, const void* src_data, const size_t size) { + const platform::Place& src_place, const void* src_data, + const size_t size) { const platform::CPUPlace cpu_place; const platform::CUDAPlace gpu_place; if (platform::is_cpu_place(dst_place) && platform::is_cpu_place(src_place)) { @@ -71,14 +85,18 @@ void MemoryCopy(const platform::Place& dst_place, void* dst_data, } else { #ifdef PADDLE_WITH_CUDA // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &ctx = *pool.Get(platform::CUDAPlace()); - auto stream = reinterpret_cast(ctx).stream(); - if (platform::is_cpu_place(dst_place) && platform::is_gpu_place(src_place)) { + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& ctx = *pool.Get(platform::CUDAPlace()); + auto stream = + reinterpret_cast(ctx).stream(); + if (platform::is_cpu_place(dst_place) && + platform::is_gpu_place(src_place)) { memory::Copy(cpu_place, dst_data, gpu_place, src_data, size, stream); - } else if (platform::is_gpu_place(dst_place) && platform::is_cpu_place(src_place)) { + } else if (platform::is_gpu_place(dst_place) && + platform::is_cpu_place(src_place)) { memory::Copy(gpu_place, dst_data, cpu_place, src_data, size, stream); - } else if (platform::is_gpu_place(dst_place) && platform::is_gpu_place(src_place)) { + } else if (platform::is_gpu_place(dst_place) && + platform::is_gpu_place(src_place)) { memory::Copy(gpu_place, dst_data, gpu_place, src_data, size, stream); } #else @@ -87,9 +105,14 @@ void MemoryCopy(const platform::Place& dst_place, void* dst_data, } } -} // namespace +} // namespace + +void InitLiteTensorType(paddle::lite::Tensor* lite, + const framework::LoDTensor& fluid) { + lite->set_precision(GetLitePrecisionType(fluid.type())); +} -template<> +template <> void TensorCopy(paddle::lite::Tensor* dst, const framework::LoDTensor& src) { const platform::Place& src_place = src.place(); const platform::Place& dst_place = GetNativePlace(dst->target()); @@ -98,10 +121,11 @@ void TensorCopy(paddle::lite::Tensor* dst, const framework::LoDTensor& src) { dst->Resize(framework::vectorize(src.dims())); const void* src_data = src.data(); void* dst_data = dst->mutable_data(size); - MemoryCopy(dst_place, dst_data, src_place, src_data, size); + MemoryCopy(dst_place, dst_data, src_place, src_data, + size * framework::SizeOfType(src.type())); } -template<> +template <> void TensorCopy(framework::LoDTensor* dst, const paddle::lite::Tensor& src) { const platform::Place& src_place = GetNativePlace(src.target()); const platform::Place& dst_place = dst->place(); @@ -110,7 +134,8 @@ void TensorCopy(framework::LoDTensor* dst, const paddle::lite::Tensor& src) { const size_t size = static_cast(src.numel()); const void* src_data = src.raw_data(); void* dst_data = dst->mutable_data(dst_place, dst->type()); - MemoryCopy(dst_place, dst_data, src_place, src_data, size); + MemoryCopy(dst_place, dst_data, src_place, src_data, + size * framework::SizeOfType(dst->type())); } } // namespace lite diff --git a/paddle/fluid/inference/lite/tensor_utils.h b/paddle/fluid/inference/lite/tensor_utils.h index 334eda46ee902..943b1613abb36 100644 --- a/paddle/fluid/inference/lite/tensor_utils.h +++ b/paddle/fluid/inference/lite/tensor_utils.h @@ -14,9 +14,9 @@ #pragma once -#include "paddle/fluid/framework/tensor.h" #include "lite/api/paddle_place.h" #include "lite/core/tensor.h" +#include "paddle/fluid/framework/tensor.h" namespace paddle { namespace inference { @@ -25,6 +25,9 @@ namespace lite { template void TensorCopy(DstTensor* dst, const SrcTensor& src); +void InitLiteTensorType(paddle::lite::Tensor* lite, + const framework::LoDTensor& fluid); + } // namespace lite } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/lite/test_engine.cc b/paddle/fluid/inference/lite/test_engine.cc index 7f80f5e786a0b..990903f056060 100644 --- a/paddle/fluid/inference/lite/test_engine.cc +++ b/paddle/fluid/inference/lite/test_engine.cc @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include #include +#include +#include #include "lite/api/paddle_use_kernels.h" #include "lite/api/paddle_use_ops.h" @@ -23,44 +23,88 @@ #include "paddle/fluid/inference/lite/engine.h" #include "paddle/fluid/inference/utils/singleton.h" +#include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" + namespace paddle { namespace lite { namespace { -std::string read_file(const std::string &file) { - std::ifstream ifs(file.c_str(), std::ios::in | std::ios::binary | std::ios::ate); - std::ifstream::pos_type file_size = ifs.tellg(); - ifs.seekg(0, std::ios::beg); - std::vector bytes(file_size); - ifs.read(bytes.data(), file_size); - return std::string(bytes.data(), file_size); +void AddTensorToBlockDesc(framework::proto::BlockDesc* block, + const std::string& name, + const std::vector& shape) { + using framework::proto::VarType; + auto* var = block->add_vars(); + framework::VarDesc desc(name); + desc.SetType(VarType::LOD_TENSOR); + desc.SetDataType(VarType::FP32); + desc.SetShape(shape); + *var = *desc.Proto(); } -} // namespace +void make_fake_model(std::string* model, std::string* param) { + framework::ProgramDesc program; + auto* block_ = program.Proto()->mutable_blocks(0); + LOG(INFO) << "create block desc"; + framework::BlockDesc block_desc(&program, block_); + LOG(INFO) << "create feed op"; + auto* feed0 = block_desc.AppendOp(); + feed0->SetType("feed"); + feed0->SetInput("X", {"feed"}); + feed0->SetOutput("Out", {"x"}); + feed0->SetAttr("col", 1); + AddTensorToBlockDesc(block_, "x", std::vector({2, 4, 1, 1})); + *block_->add_ops() = *feed0->Proto(); + ASSERT_EQ(block_->ops_size(), 1); + framework::Scope scope; + platform::CPUPlace place; + platform::CPUDeviceContext ctx(place); + *model = program.Proto()->SerializeAsString(); +} +} // namespace -TEST(EngineManager, Create) { - const std::string unique_key("engine_0"); - const std::string model_dir = "/shixiaowei02/models/tmp/__model__"; +TEST(EngineManager, manual) { + ASSERT_EQ( + inference::Singleton::Global().Empty(), + true); inference::lite::EngineConfig config; - config.model = read_file(model_dir); - config.param = ""; - config.prefer_place = {TARGET(kCUDA), PRECISION(kFloat)}; + make_fake_model(&(config.model), &(config.param)); + + const std::string unique_key("engine_0"); + config.model_from_memory = true; + config.prefer_place = {TARGET(kX86), PRECISION(kFloat)}; config.valid_places = { - paddle::lite::Place({TARGET(kHost), PRECISION(kFloat)}), + paddle::lite::Place({TARGET(kX86), PRECISION(kFloat)}), + paddle::lite::Place({TARGET(kHost), PRECISION(kAny)}), #ifdef PADDLE_WITH_CUDA - paddle::lite::Place({TARGET(kCUDA), PRECISION(kFloat)}), + paddle::lite::Place({TARGET(kCUDA), PRECISION(kFloat)}), #endif }; - inference::Singleton::Global() - .Create(unique_key, config); - /* - paddle::lite::Predictor* engine = inference::Singleton::Global() - .Get(Attr(unique_key)); - */ + LOG(INFO) << "Create EngineManager"; + inference::Singleton::Global().Create( + unique_key, config); + LOG(INFO) << "Create EngineManager done"; + ASSERT_EQ( + inference::Singleton::Global().Empty(), + false); + ASSERT_EQ(inference::Singleton::Global().Has( + unique_key), + true); + paddle::lite::Predictor* engine_0 = + inference::Singleton::Global().Get( + unique_key); + + CHECK_NOTNULL(engine_0); + inference::Singleton::Global().DeleteAll(); + CHECK(inference::Singleton::Global().Get( + unique_key) == nullptr) + << "the engine_0 should be nullptr"; } } // namespace lite diff --git a/paddle/fluid/operators/lite/CMakeLists.txt b/paddle/fluid/operators/lite/CMakeLists.txt index ca3b62648378b..5bb7892590848 100644 --- a/paddle/fluid/operators/lite/CMakeLists.txt +++ b/paddle/fluid/operators/lite/CMakeLists.txt @@ -1 +1,2 @@ op_library(lite_engine_op DEPS lite_engine lite_tensor_utils) +cc_test(test_lite_engine_op SRCS lite_engine_op_test.cc DEPS lite_engine_op analysis) diff --git a/paddle/fluid/operators/lite/lite_engine_op.h b/paddle/fluid/operators/lite/lite_engine_op.h index fb2c54ce3daf3..71e8092f9343c 100644 --- a/paddle/fluid/operators/lite/lite_engine_op.h +++ b/paddle/fluid/operators/lite/lite_engine_op.h @@ -38,18 +38,19 @@ class LiteEngineOp : public framework::OperatorBase { private: std::vector in_names_; std::vector out_names_; - paddle::lite::Predictor* engine_; + paddle::lite::Predictor *engine_; public: LiteEngineOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) : framework::OperatorBase(type, inputs, outputs, attrs) { in_names_ = Inputs("Xs"); out_names_ = Outputs("Ys"); - engine_ = inference::Singleton::Global() - .Get(Attr("engine_key")); + engine_ = + inference::Singleton::Global().Get( + Attr("engine_key")); } protected: @@ -61,15 +62,22 @@ class LiteEngineOp : public framework::OperatorBase { void Execute(const framework::Scope &scope, const platform::Place &dev_place) const { for (size_t i = 0; i < in_names_.size(); i++) { - const framework::LoDTensor& src_t = inference::analysis::GetFromScope(scope, in_names_[i]); - paddle::lite::Tensor* dst_t = engine_->GetInput(i); + const framework::LoDTensor &src_t = + inference::analysis::GetFromScope(scope, + in_names_[i]); + paddle::lite::Tensor *dst_t = engine_->GetInput(i); + inference::lite::InitLiteTensorType(dst_t, src_t); inference::lite::TensorCopy(dst_t, src_t); } engine_->Run(); cudaDeviceSynchronize(); for (size_t i = 0; i < out_names_.size(); i++) { - const paddle::lite::Tensor& src_t = *(engine_->GetOutput(i)); - framework::LoDTensor* dst_t = &inference::analysis::GetFromScope(scope, out_names_[i]); + const paddle::lite::Tensor &src_t = *(engine_->GetOutput(i)); + framework::LoDTensor *dst_t = + &inference::analysis::GetFromScope( + scope, out_names_[i]); + inference::lite::InitLiteTensorType( + &const_cast(src_t), *dst_t); inference::lite::TensorCopy(dst_t, src_t); } } diff --git a/paddle/fluid/operators/lite/lite_engine_op_test.cc b/paddle/fluid/operators/lite/lite_engine_op_test.cc new file mode 100644 index 0000000000000..91c4fec461cf8 --- /dev/null +++ b/paddle/fluid/operators/lite/lite_engine_op_test.cc @@ -0,0 +1,164 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include + +#include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/inference/utils/singleton.h" +#include "paddle/fluid/operators/lite/lite_engine_op.h" +#include "paddle/fluid/operators/lite/ut_helper.h" + +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" + +USE_NO_KERNEL_OP(lite_engine) +namespace paddle { +namespace operators { + +namespace { +void CreateTensor(framework::Scope* scope, const std::string& name, + const std::vector& shape) { + auto* var = scope->Var(name); + auto* tensor = var->GetMutable(); + auto dims = framework::make_ddim(shape); + tensor->Resize(dims); +#ifdef PADDLE_WITH_CUDA + platform::CUDAPlace place; +#else + platform::CPUPlace place; +#endif + inference::lite::RandomizeTensor(tensor, place); +} + +void AddTensorToBlockDesc(framework::proto::BlockDesc* block, + const std::string& name, + const std::vector& shape, bool persistable) { + using framework::proto::VarType; + auto* var = block->add_vars(); + framework::VarDesc desc(name); + desc.SetType(VarType::LOD_TENSOR); + desc.SetDataType(VarType::FP32); + desc.SetShape(shape); + desc.SetPersistable(persistable); + *var = *desc.Proto(); +} +} // namespace + +TEST(LiteEngineOp, manual) { + framework::ProgramDesc program; + auto* block_ = program.Proto()->mutable_blocks(0); + + LOG(INFO) << "create block desc"; + framework::BlockDesc block_desc(&program, block_); + LOG(INFO) << "create elementwise_add op"; + auto* elt_add = block_desc.AppendOp(); + elt_add->SetType("elementwise_add"); + elt_add->SetInput("X", std::vector({"x"})); + elt_add->SetInput("Y", std::vector({"y"})); + elt_add->SetOutput("Out", std::vector({"z"})); + elt_add->SetAttr("axis", -1); + LOG(INFO) << "create fetch op"; + auto* fetch = block_desc.AppendOp(); + fetch->SetType("fetch"); + fetch->SetInput("X", std::vector({"z"})); + fetch->SetOutput("Out", std::vector({"out"})); + fetch->SetAttr("col", 0); + // Set inputs' variable shape in BlockDesc + AddTensorToBlockDesc(block_, "x", std::vector({2, 4}), true); + AddTensorToBlockDesc(block_, "y", std::vector({2, 4}), true); + AddTensorToBlockDesc(block_, "z", std::vector({2, 4}), false); + AddTensorToBlockDesc(block_, "out", std::vector({2, 4}), false); + + *block_->add_ops() = *elt_add->Proto(); + *block_->add_ops() = *fetch->Proto(); + + framework::Scope scope; +#ifdef PADDLE_WITH_CUDA + platform::CUDAPlace place; + platform::CUDADeviceContext ctx(place); +#else + platform::CPUPlace place; + platform::CPUDeviceContext ctx(place); +#endif + // Prepare variables. + CreateTensor(&scope, "x", std::vector({2, 4})); + CreateTensor(&scope, "y", std::vector({2, 4})); + CreateTensor(&scope, "z", std::vector({2, 4})); + CreateTensor(&scope, "out", std::vector({2, 4})); + + ASSERT_EQ(block_->ops_size(), 2); + + auto serialize_params = [](std::string* str, framework::Scope* scope, + const std::vector& params) { + std::ostringstream os; +#ifdef PADDLE_WITH_CUDA + platform::CUDAPlace place; + platform::CUDADeviceContext ctx(place); +#else + platform::CPUDeviceContext ctx; +#endif + for (const auto& param : params) { + PADDLE_ENFORCE_NOT_NULL(scope->FindVar(param), + "Block should already have a '%s' variable", + param); + auto* tensor = scope->FindVar(param)->GetMutable(); + framework::SerializeToStream(os, *tensor, ctx); + } + *str = os.str(); + }; + std::vector repetitive_params{"x", "y"}; + inference::lite::EngineConfig config; + config.prefer_place = { +#ifdef PADDLE_WITH_CUDA + TARGET(kCUDA), PRECISION(kFloat), +#else + TARGET(kX86), PRECISION(kFloat) +#endif + }; + config.valid_places = { + paddle::lite::Place({TARGET(kHost), PRECISION(kAny)}), + paddle::lite::Place({TARGET(kX86), PRECISION(kFloat)}), +#ifdef PADDLE_WITH_CUDA + paddle::lite::Place({TARGET(kCUDA), PRECISION(kFloat)}), +#endif + }; + serialize_params(&(config.param), &scope, repetitive_params); + config.model = program.Proto()->SerializeAsString(); + + LOG(INFO) << "create lite_engine desc"; + framework::OpDesc engine_op_desc(nullptr); + engine_op_desc.SetType("lite_engine"); + engine_op_desc.SetInput("Xs", std::vector({"x", "y"})); + engine_op_desc.SetOutput("Ys", std::vector({"out"})); + std::string engine_key = "engine_0"; + engine_op_desc.SetAttr("engine_key", engine_key); + engine_op_desc.SetBlockAttr("sub_block", &block_desc); + + inference::Singleton::Global().Create( + engine_key, config); + + LOG(INFO) << "create engine op"; + auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc); + LOG(INFO) << "engine_op " << engine_op.get(); + + // Execute them. + LOG(INFO) << "engine_op run"; + engine_op->Run(scope, place); +} +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/lite/ut_helper.h b/paddle/fluid/operators/lite/ut_helper.h new file mode 100644 index 0000000000000..cad8c411b8239 --- /dev/null +++ b/paddle/fluid/operators/lite/ut_helper.h @@ -0,0 +1,56 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ +#pragma once + +#include +#include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/inference/analysis/helper.h" + +namespace paddle { +namespace inference { +namespace lite { + +/* + * Get a random float value between [low, high] + */ +float random(float low, float high) { + // static std::random_device rd; + static std::mt19937 mt(100); + std::uniform_real_distribution dist(low, high); + return dist(mt); +} + +void RandomizeTensor(framework::LoDTensor* tensor, + const platform::Place& place) { + auto dims = tensor->dims(); + size_t num_elements = analysis::AccuDims(dims, dims.size()); + PADDLE_ENFORCE_GT(num_elements, 0); + + platform::CPUPlace cpu_place; + framework::LoDTensor temp_tensor; + temp_tensor.Resize(dims); + auto* temp_data = temp_tensor.mutable_data(cpu_place); + + for (size_t i = 0; i < num_elements; i++) { + *(temp_data + i) = random(0., 1.); + } + + TensorCopySync(temp_tensor, place, tensor); +} +} // namespace lite +} // namespace inference +} // namespace paddle