Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions paddle/fluid/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,17 @@ using OpKernelConfigsMap =
std::unordered_map<OpKernelType, std::vector<KernelConfig>,
OpKernelType::Hash>;

class OpDuppy : public OperatorBase {
public:
OpDuppy() : OperatorBase("duppy", {}, {}, {}) {}

void RunImpl(const Scope& scope,
const platform::Place& place) const override {}
};
OpDuppy op_duppy;
Scope scope_duppy;
RuntimeContext runtime_context_duppy({}, {});

class ExecutionContext {
public:
ExecutionContext(const OperatorBase& op, const Scope& scope,
Expand All @@ -244,6 +255,13 @@ class ExecutionContext {
ctx_(ctx),
kernel_configs_(configs) {}

ExecutionContext(const platform::DeviceContext& device_context)
: op_(op_duppy),
scope_(scope_duppy),
device_context_(device_context),
ctx_(runtime_context_duppy),
kernel_configs_(nullptr) {}

const OperatorBase& op() const { return op_; }

const Scope& scope() const { return scope_; }
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/incubate/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include_directories(lite)
2 changes: 1 addition & 1 deletion paddle/fluid/lite/api/cxx_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace paddle {
namespace lite {

#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
void CXXPredictor::SaveModel(const std::string &dir) {
void ExecutorLite::SaveModel(const std::string &dir) {
MkDirRecursively(dir.c_str());
program_->PersistModel(dir, program_desc_);
}
Expand Down
68 changes: 65 additions & 3 deletions paddle/fluid/lite/api/cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,24 @@ namespace lite {

struct Config {};

class CXXPredictor {
class ExecutorLite {
public:
CXXPredictor() { scope_ = std::make_shared<Scope>(); }
ExecutorLite() { scope_ = std::make_shared<Scope>(); }
explicit ExecutorLite(const std::shared_ptr<lite::Scope>& root_scope) {
scope_ = root_scope;
}

void Build(const std::string& model_path, const Place& prefer_place,
const std::vector<Place>& valid_places) {
LoadModel(model_path, scope_.get(), &program_desc_);
Program program(program_desc_, scope_, valid_places);
Build(program_desc_, prefer_place, valid_places);
}

void Build(const framework::proto::ProgramDesc& desc,
const Place& prefer_place,
const std::vector<Place>& valid_places) {
program_desc_ = desc;
Program program(desc, scope_, valid_places);

optimizer_.KernelPickPreferPlace(prefer_place);
core::KernelPickFactor factor;
Expand Down Expand Up @@ -81,5 +91,57 @@ class CXXPredictor {
std::unique_ptr<RuntimeProgram> program_;
};

/*
* An executor for training.
*
* Usage:
*
* CXXTrainer trainer(...);
* trainer.RunStartupProgram(...);
* auto exe = BuildMainProgramExecutor(...);
*
* for (auto& epoch : epoches) {
* auto* tensor0 = exe.GetInput(...);
* // fill data for tensor0
* exe.Run();
* }
*/
class CXXTrainer {
public:
CXXTrainer(const std::shared_ptr<lite::Scope>& root_scope,
const Place& preferred_place,
const std::vector<Place>& valid_places)
: scope_(root_scope),
preferred_place_(preferred_place),
valid_places_(valid_places),
main_program_executor_(ExecutorLite(scope_)) {}

// Build the RuntimeProgram cache for the main program. The cache will run
// multiple times for the epoches.
// NOTE Just support to execute the 0-th block currently.
ExecutorLite& BuildMainProgramExecutor(
const framework::proto::ProgramDesc& desc, int block_id = 0) {
main_program_executor_.Build(desc, preferred_place_, valid_places_);
return main_program_executor_;
}

// Run the startup program. It just executes once, no cache needed.
void RunStartupProgram(const framework::proto::ProgramDesc& desc,
int block_id = 0) {
ExecutorLite exe(scope_);
exe.Build(desc, preferred_place_, valid_places_);
exe.Run();
}

private:
std::shared_ptr<lite::Scope> scope_;

Place preferred_place_;
std::vector<Place> valid_places_;

// The training program.
ExecutorLite main_program_executor_;
};

} // namespace lite
} // namespace paddle
2 changes: 1 addition & 1 deletion paddle/fluid/lite/api/cxx_api_bin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace paddle {
namespace lite {

void Run(const char* model_dir) {
lite::CXXPredictor predictor;
lite::Executor predictor;
#ifndef LITE_WITH_CUDA
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)}});
#else
Expand Down
44 changes: 41 additions & 3 deletions paddle/fluid/lite/api/cxx_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@
DEFINE_string(model_dir, "", "");
DEFINE_string(optimized_model, "", "");

// For training.
DEFINE_string(startup_program_path, "", "");
DEFINE_string(main_program_path, "", "");

namespace paddle {
namespace lite {

TEST(CXXApi, test) {
lite::CXXPredictor predictor;
lite::ExecutorLite predictor;
#ifndef LITE_WITH_CUDA
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)}});
#else
Expand Down Expand Up @@ -64,14 +68,48 @@ TEST(CXXApi, test) {

#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
TEST(CXXApi, save_model) {
lite::CXXPredictor predictor;
lite::ExecutorLite predictor;
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)}});
predictor.Build(FLAGS_model_dir, Place{TARGET(kCUDA), PRECISION(kFloat)},
valid_places);

predictor.SaveModel(FLAGS_optimized_model);
}
#endif
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK

#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
TEST(CXXTrainer, train) {
Place prefer_place({TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)});
std::vector<Place> valid_places({prefer_place});
auto scope = std::make_shared<lite::Scope>();

CXXTrainer trainer(scope, prefer_place, valid_places);

std::string main_program_pb, startup_program_pb;
ReadBinaryFile(FLAGS_main_program_path, &main_program_pb);
ReadBinaryFile(FLAGS_startup_program_path, &startup_program_pb);
framework::proto::ProgramDesc main_program_desc, startup_program_desc;
main_program_desc.ParseFromString(main_program_pb);
startup_program_desc.ParseFromString(startup_program_pb);

LOG(INFO) << main_program_desc.DebugString();

for (const auto& op : main_program_desc.blocks(0).ops()) {
LOG(INFO) << "get op " << op.type();
}

return;

trainer.RunStartupProgram(startup_program_desc);
auto& exe = trainer.BuildMainProgramExecutor(main_program_desc);
auto* tensor0 = exe.GetInput(0);
tensor0->Resize(std::vector<int64_t>({100, 100}));
auto* data0 = tensor0->mutable_data<float>();
data0[0] = 0;

exe.Run();
}
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK

} // namespace lite
} // namespace paddle
Expand Down
10 changes: 9 additions & 1 deletion paddle/fluid/lite/core/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@

#include "paddle/fluid/lite/utils/any.h"
#ifdef LITE_WITH_CUDA
#include <paddle/fluid/lite/cuda/blas.h>
#include "paddle/fluid/lite/cuda/blas.h"
#include "paddle/fluid/lite/cuda/cuda_utils.h"
#endif
#ifdef LITE_WITH_X86
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/device_context.h"
#endif
#include <memory>
#include <set>
#include <vector>
Expand Down Expand Up @@ -54,6 +58,10 @@ struct X86Context {
// overall information

// kernel information

// legacy info.
std::unique_ptr<::paddle::platform::CPUDeviceContext> x86_device_context;
std::unique_ptr<::paddle::framework::ExecutionContext> x86_execution_context;
};
#endif

Expand Down
16 changes: 13 additions & 3 deletions paddle/fluid/lite/core/hvy_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace lite {
class DDimHvy : public DDimBase<DDimHvy> {
public:
DDimHvy() = default;
explicit DDimHvy(const std::vector<value_type>& x) : DDimBase<DDimHvy>() {
DDimHvy(const std::vector<value_type>& x) : DDimBase<DDimHvy>() { // NOLINT
ConstructFrom(x);
}
explicit DDimHvy(const framework::DDim& x) : data_(x) {}
Expand All @@ -47,6 +47,14 @@ class DDimHvy : public DDimBase<DDimHvy> {
size_t size() const { return data_.size(); }
bool empty() const { return data_.size() == 0; }

bool operator==(const DDimHvy& other) {
if (data_.size() != other.data_.size()) return false;
for (int i = 0; i < data_.size(); i++) {
if (data_[i] != other.data_[i]) return false;
}
return true;
}

private:
framework::DDim data_;
};
Expand Down Expand Up @@ -85,8 +93,7 @@ class TensorHvy : public TensorBase<TensorHvy> {

const void* raw_data() const { return data_.raw_data(); }

template <typename DimT>
void Resize(const DimT& dims) {
void Resize(const DDimHvy& dims) {
LOG(INFO) << "dims.size " << dims.size();
data_.Resize(framework::make_ddim(dims.Vectorize()));
}
Expand All @@ -103,6 +110,9 @@ class TensorHvy : public TensorBase<TensorHvy> {
const framework::LoD& lod() const { return data_.lod(); }
framework::LoD* mutable_lod() { return data_.mutable_lod(); }

const framework::LoDTensor& raw_tensor() const { return data_; }
framework::LoDTensor& raw_tensor() { return data_; }

private:
framework::LoDTensor data_;
};
Expand Down
46 changes: 46 additions & 0 deletions paddle/fluid/lite/core/naive_test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import numpy
import sys, os
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.backward import append_backward

a = fluid.layers.data(name="a", shape=[100], dtype='float32')
label = fluid.layers.data(name="label", shape=[100], dtype='float32')

a1 = fluid.layers.fc(input=a, size=500, act=None, bias_attr=False)

cost = fluid.layers.square_error_cost(a1, label)
avg_cost = fluid.layers.mean(cost)

optimizer = fluid.optimizer.SGD(learning_rate=0.001)
optimizer.minimize(cost)

cpu = fluid.core.CPUPlace()
loss = exe = fluid.Executor(cpu)

exe.run(fluid.default_startup_program())
with open('startup_program.pb', 'wb') as f:
f.write(fluid.default_startup_program().desc.serialize_to_string())

data_1 = np.array(numpy.random.random([100, 100]), dtype='float32')

#fluid.default_main_program().desc.



#prog = fluid.compiler.CompiledProgram(fluid.default_main_program())
prog = fluid.default_main_program()

#append_backward(loss)

with open('main_program.pb', 'wb') as f:
f.write(prog.desc.serialize_to_string())


#outs = exe.run(program=prog, feed={'a':data_1, }, fetch_list=[cost])

sys.exit(0)
fluid.io.save_inference_model("./model2", [a.name], [a1], exe)

print(numpy.array(outs))

2 changes: 1 addition & 1 deletion paddle/fluid/lite/core/op_lite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ bool OpLite::Run() {
bool OpLite::Attach(const OpDesc &opdesc, lite::Scope *scope) {
// valid_places_.clear();
CHECK(scope != nullptr);
// CHECK(!op_info_.get());
//CHECK(!op_info_.get());
scope_ = scope;
op_info_.reset(new OpInfo); // Force clean the out-of-date infomation.
op_info_->Build(opdesc.ReadonlyProto());
Expand Down
16 changes: 16 additions & 0 deletions paddle/fluid/lite/core/op_lite.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,22 @@ class OpLite : public Registry {
friend class mir::Node;
friend class mir::SSAGraph;

protected:
// some helper functions.
template <typename T>
const T *GetVar(Scope *scope, const std::string &name) {
auto *var = scope->FindVar(name);
CHECK(var) << "No var found for " << name;
return &var->Get<T>();
}
template <typename T>
T *GetMutableVar(Scope *scope, const std::string &name) {
auto *var = scope->FindVar(name);
CHECK(var) << "No var found for " << name;
return var->GetMutable<T>();
}


protected:
lite::Scope *scope_{};
std::unique_ptr<KernelBase> kernel_;
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/lite/core/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ struct Program {
// Build from a program and scope.
void Build(const framework::proto::ProgramDesc& program) {
CHECK(ops.empty()) << "Executor duplicate Build found";

// Create operators.
for (const auto& proto_op_desc : program.blocks(0).ops()) {
lite::OpDesc op_desc(proto_op_desc);
auto op_type = op_desc.Type();
// if (op_type == "feed" || op_type == "fetch") continue;
VLOG(4) << "create Op [" << op_type << "]";
LOG(INFO) << "create Op [" << op_type << "]";
auto op = LiteOpRegistry::Global().Create(op_type);
CHECK(op) << "no Op found for " << op_type;
ops.emplace_back(std::move(op));
Expand All @@ -86,6 +86,7 @@ struct Program {

tmp_vars.push_back("feed");
tmp_vars.push_back("fetch");
CHECK(!program.blocks().empty());
for (auto proto_var_desc : program.blocks(0).vars()) {
lite::VarDesc var_desc(proto_var_desc);
if (!var_desc.Persistable()) {
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/lite/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ set(lite_kernel_deps type_system kernel_lite op_lite op_registry_lite ${tensor_l
add_subdirectory(host)
add_subdirectory(arm)
add_subdirectory(cuda)
add_subdirectory(x86)
6 changes: 6 additions & 0 deletions paddle/fluid/lite/kernels/x86/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
if(NOT LITE_WITH_X86)
return()
endif()

cc_library(activation_compute SRCS activation_compute.cc DEPS ${lite_kernel_deps} activation_op)
cc_library(elementwise_compute SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} elementwise_op)
Loading