Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add api to apply external job pass #8370

Merged
merged 8 commits into from
Jun 8, 2022
36 changes: 35 additions & 1 deletion oneflow/api/cpp/framework/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ class Graph::GraphImpl final {
std::vector<Tensor> Forward(const std::vector<Tensor>& inputs);
void set_batch_size(int batch_size) { batch_size_ = batch_size; }

of::Maybe<void> RegisterJobPass(
const std::function<std::string(const std::string& job)>& pass_fn);

private:
of::Maybe<void> CollectInputOutputInfos();
of::Maybe<void> Compile(const std::vector<Tensor>& inputs);
Expand All @@ -135,6 +138,7 @@ class Graph::GraphImpl final {
of::Maybe<void> BuildGraph();
of::Maybe<void> LoadCheckpoint();
of::Maybe<void> RegisterTensors(const std::vector<Tensor>& inputs);
of::Maybe<of::Job> ApplyJobPasses(const of::Job& job);

std::shared_ptr<of::NNGraph> graph_ = nullptr;
std::string model_path_;
Expand All @@ -149,6 +153,7 @@ class Graph::GraphImpl final {
of::HashMap<std::string, std::shared_ptr<of::one::Tensor>> variable_op_name_to_tensor_;
std::shared_ptr<of::one::TensorTuple> output_tensor_tuple_;
std::shared_ptr<of::one::TensorTuple> parameter_tensor_tuple_;
std::vector<std::function<std::string(const std::string&)>> registered_job_passes_;
};

Graph::Graph(const std::string& model_path, const Device& device)
Expand All @@ -168,6 +173,10 @@ InputOutputInfos Graph::GetInputInfos() { return graph_->GetInputInfos(); }

InputOutputInfos Graph::GetOutputInfos() { return graph_->GetOutputInfos(); }

void Graph::RegisterJobPass(const std::function<std::string(const std::string& job)>& pass_fn) {
CHECK_JUST(graph_->RegisterJobPass(pass_fn));
}

IValue Graph::Forward(const IValue& inputs) {
std::vector<Tensor> input_tensors;
if (inputs.IsNone()) {
Expand Down Expand Up @@ -234,6 +243,28 @@ of::Maybe<void> Graph::GraphImpl::CollectInputOutputInfos() {
return of::Maybe<void>::Ok();
}

of::Maybe<void> Graph::GraphImpl::RegisterJobPass(
const std::function<std::string(const std::string& job)>& pass_fn) {
if (is_compiled_) {
return of::Error::RuntimeError() << "job pass should be registered before compile and forward";
}
registered_job_passes_.emplace_back(pass_fn);
return of::Maybe<void>::Ok();
}

of::Maybe<of::Job> Graph::GraphImpl::ApplyJobPasses(const of::Job& job) {
auto current_job = std::make_shared<of::Job>(job);
for (const auto& pass_fn : registered_job_passes_) {
std::string new_serialized_job = pass_fn(current_job->SerializeAsString());
of::Job new_job;
if (!new_job.ParseFromString(new_serialized_job)) {
return of::Error::RuntimeError() << "invalid serialized job after pass applied";
}
current_job->Swap(&new_job);
}
return current_job;
}

std::vector<Tensor> Graph::GraphImpl::Forward(const std::vector<Tensor>& inputs) {
if (!is_compiled_) {
static std::mutex mtx;
Expand Down Expand Up @@ -299,9 +330,12 @@ of::Maybe<void> Graph::GraphImpl::BuildGraph() {
}
JUST(LoadCheckpoint());
JUST(of::CurJobBuildAndInferCtx_Complete());
const std::shared_ptr<of::Job> complete_job = JUST(of::GetCurrentJob());
std::shared_ptr<of::Job> complete_job = JUST(of::GetCurrentJob());
int64_t job_id = JUST(of::JobBuildAndInferCtx_GetCurrentJobId());
CHECK(of::Global<OneFlowEnv>::Get() != nullptr);

// apply custom job passes
complete_job = JUST(ApplyJobPasses(*complete_job));
graph_ = std::make_shared<of::NNGraph>(job_.job_conf().job_name(), *complete_job, job_id,
of::Global<OneFlowEnv>::Get()->GetSessionCtx());
{
Expand Down
3 changes: 3 additions & 0 deletions oneflow/api/cpp/framework/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "tensor.h"
#include <cstddef>
#include <string>
#include <functional>
#include <unordered_map>

namespace oneflow {
Expand Down Expand Up @@ -64,6 +65,8 @@ class Graph {
IValue Forward(const IValue& inputs);
void set_batch_size(int batch_size);

void RegisterJobPass(const std::function<std::string(const std::string& job)>& pass_fn);

static Graph Load(const std::string& model_path, const Device& device = Device("cpu"));

private:
Expand Down