Skip to content

Commit

Permalink
[PIR][AutoParallel] Support 1F1B/FThenB with PIR (PaddlePaddle#58459)
Browse files Browse the repository at this point in the history
* [PIR][AutoParallel] Support 1F1B/FThenB with PIR

* fix splitfeed

* fix include

* rm fetch

* fix unittest

* fix scope error

* program interpreter use local scope to avoid var conflict

* fix ut

---------

Co-authored-by: zhaoyingli <zhaoyingli@baidu.com>
  • Loading branch information
zhiqiu and zhaoyingli authored Oct 30, 2023
1 parent 38b5762 commit 6ccca1d
Show file tree
Hide file tree
Showing 14 changed files with 305 additions and 146 deletions.
46 changes: 42 additions & 4 deletions paddle/fluid/framework/new_executor/feed_fetch_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/new_executor/feed_fetch_utils.h"

#include <map>
#include <vector>

#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/new_executor/feed_fetch_utils.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"

namespace paddle {
namespace framework {

void SetColAttrForFeedFetchOps(std::shared_ptr<ProgramDesc> program_desc,
const int64_t micro_batch_num,
const int64_t micro_batch_id) {
if (micro_batch_num < 2) return;

const std::set<std::string>& valid_feed_fetch_op_types = {
"fetch", "fetch_v2", "feed"};
for (const auto& op_desc : program_desc->MutableBlock(0)->AllOps()) {
Expand All @@ -48,5 +48,43 @@ void SetColAttrForFeedFetchOps(std::shared_ptr<ProgramDesc> program_desc,
}
}

void SplitFeedTensor(const std::vector<std::string>& feed_names,
const int64_t micro_batch_num,
Scope* scope,
std::vector<std::vector<phi::DenseTensor>>* out) {
if (micro_batch_num < 2) return;

out->resize(micro_batch_num);
for (size_t i = 0; i < feed_names.size(); ++i) {
auto feed_name = feed_names[i];
auto feed_var = scope->GetVar(feed_name);

if (feed_var->IsType<phi::DenseTensor>()) {
phi::DenseTensor feed_tensor = feed_var->Get<phi::DenseTensor>();
int64_t numel_size = feed_tensor.dims()[0];
PADDLE_ENFORCE_EQ(numel_size % micro_batch_num,
0,
platform::errors::InvalidArgument(
"Split expects feed data (%s)'s dim[0] (%d) is "
"diviable by micro_batch_num (%d).",
feed_name,
numel_size,
micro_batch_num));
int64_t split_size = (numel_size + micro_batch_num - 1) / micro_batch_num;
VLOG(4) << "Split feed data:" << feed_name << ", dims:("
<< feed_tensor.dims() << "), micro_batch_num:" << micro_batch_num;
for (int64_t j = 0; j < micro_batch_num; ++j) {
(*out)[j].resize(i + 1);
(*out)[j][i].ShareDataWith(
feed_tensor.Slice(j * split_size, j * split_size + split_size));
}
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Type (%s) not support in SplitFeedTensor.",
ToTypeName(feed_var->Type())));
}
}
}

} // namespace framework
} // namespace paddle
6 changes: 6 additions & 0 deletions paddle/fluid/framework/new_executor/feed_fetch_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "paddle/fluid/framework/new_executor/interpreter/plan.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"

namespace paddle {
namespace framework {
Expand All @@ -26,5 +27,10 @@ void SetColAttrForFeedFetchOps(std::shared_ptr<ProgramDesc> program_desc,
const int64_t micro_batch_num,
const int64_t micro_batch_id);

void SplitFeedTensor(const std::vector<std::string>& feed_names,
const int64_t micro_batch_num,
Scope* scope,
std::vector<std::vector<phi::DenseTensor>>* out);

} // namespace framework
} // namespace paddle
32 changes: 23 additions & 9 deletions paddle/fluid/framework/new_executor/standalone_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
VLOG(6) << ss.str();

const auto& jobs = plan_.JobList();
for (const auto& job : jobs) {
for (size_t job_idx = 0; job_idx < jobs.size(); ++job_idx) {
const auto& job = jobs[job_idx];
const std::string& job_type = job->Type();
std::shared_ptr<ProgramDesc> program = nullptr;
std::shared_ptr<::pir::Program> ir_program = nullptr;
Expand All @@ -69,7 +70,7 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
micro_batch_id,
micro_batch_num));

if (micro_batch_num > 1 && !FLAGS_enable_pir_api) {
if (!FLAGS_enable_pir_api && !FLAGS_enable_new_ir_in_executor) {
SetColAttrForFeedFetchOps(program, micro_batch_num, micro_batch_id);
}

Expand All @@ -79,6 +80,8 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,

// TODO(phlrain) we only support cpu for now
if (FLAGS_enable_new_ir_in_executor) {
auto inner_scope =
micro_batch_num == 1 ? scope : micro_batch_scopes_[micro_batch_id];
std::shared_ptr<::pir::Program> base_program = ir_program;
auto block = base_program->block();
for (auto it = block->begin(); it != block->end(); ++it) {
Expand All @@ -104,7 +107,7 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
auto kernel_program =
paddle::dialect::PdOpLowerToKernelPass(base_program.get(), place);
std::shared_ptr<pir::Program> shared_program = std::move(kernel_program);
plan_.UpdateIrProgram("base", shared_program);
plan_.UpdateIrProgram("job_" + std::to_string(job_idx), shared_program);

if (FLAGS_new_ir_apply_inplace_pass) {
pir::PassManager pm(pir::IrContext::Instance(), 3);
Expand All @@ -116,7 +119,7 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
std::make_shared<InterpreterCore>(place_,
fetch_var_names_,
shared_program->block(),
scope_,
inner_scope,
execution_config));
} else {
interpretercores_.emplace_back(
Expand Down Expand Up @@ -175,6 +178,11 @@ paddle::framework::FetchList StandaloneExecutor::Run(
is_interpretercore_build_result_shared_ = true;
}

std::vector<std::vector<phi::DenseTensor>> splited_feeds;
if (FLAGS_enable_new_ir_in_executor) {
SplitFeedTensor(feed_names, plan_.MicroBatchNum(), scope_, &splited_feeds);
}

for (size_t job_idx = 0; job_idx < jobs.size(); ++job_idx) {
const auto& job = jobs[job_idx];
const std::string& job_type = job->Type();
Expand All @@ -192,12 +200,18 @@ paddle::framework::FetchList StandaloneExecutor::Run(
interpretercores_[job_idx]->ShareBuildResultsFrom(
interpretercores_[type_to_first_id[job_type]]);
}
// TODO(zhaoyinglia): use a more general method
if (jobs.size() > 1 && job_type != "forward") {
const std::vector<std::string> tmp_feed_names = {};
interpretercores_[job_idx]->Run(tmp_feed_names, /*need_fetch = */ false);

if (FLAGS_enable_new_ir_in_executor && splited_feeds.size() > 0) {
interpretercores_[job_idx]->Run(feed_names,
splited_feeds[job->MicroBatchId()]);
} else {
interpretercores_[job_idx]->Run(feed_names, /*need_fetch = */ false);
if (jobs.size() > 1 && job_type != "forward") {
const std::vector<std::string> tmp_feed_names = {};
interpretercores_[job_idx]->Run(tmp_feed_names,
/*need_fetch = */ false);
} else {
interpretercores_[job_idx]->Run(feed_names, /*need_fetch = */ false);
}
}
}

Expand Down
3 changes: 1 addition & 2 deletions paddle/fluid/framework/new_executor/standalone_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,10 @@ class StandaloneExecutor {
bool is_interpretercore_build_result_shared_{false};
const platform::Place place_;
interpreter::Plan plan_;

std::vector<framework::Scope*> micro_batch_scopes_;
std::vector<std::shared_ptr<InterpreterCore>> interpretercores_;

Scope* scope_;
std::vector<Scope*> micro_batch_scopes_;

std::vector<std::string> fetch_var_names_;

Expand Down
6 changes: 5 additions & 1 deletion python/paddle/base/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,11 @@ def run(self, feed_names, return_numpy=True):
tensors = self._new_exe.run(feed_names)._move_to_list()
if return_numpy:
tensors = as_numpy(tensors, copy=True)
return _merge_tensors(tensors, self._plan.micro_batch_num())
if not get_flags("FLAGS_enable_new_ir_in_executor")[
'FLAGS_enable_new_ir_in_executor'
]:
return _merge_tensors(tensors, self._plan.micro_batch_num())
return tensors
else:
if self._plan.micro_batch_num() > 1:
raise RuntimeError(
Expand Down
Loading

0 comments on commit 6ccca1d

Please sign in to comment.