Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Simple engine #30

Merged
merged 27 commits into from
Aug 31, 2015
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7a3fb78
[simple-engine] cherry pick Minjie's refactoring on interface
hotpxl Aug 23, 2015
6bfcc7d
[simple-engine] fix typo and unmark virtual
hotpxl Aug 23, 2015
e149942
[simple-engine] WIP, need to refactor interface and remove inheritance
hotpxl Aug 23, 2015
cfe80c6
[simple-engine] A not so simple engine that should somehow work
hotpxl Aug 23, 2015
005db3a
[simple-engine] lint
hotpxl Aug 23, 2015
5dfc759
[simple-engine] remove unnecessary files
hotpxl Aug 24, 2015
6b948fd
[simple-engine] that goes worker
hotpxl Aug 24, 2015
c7146a7
[simple-engine] passed some tests
hotpxl Aug 24, 2015
fef28d2
[simple-engine] some document
hotpxl Aug 24, 2015
761e0e1
[simple-engine] implement missing functions
hotpxl Aug 24, 2015
11432db
[simple-engine] fix order
hotpxl Aug 24, 2015
50bf69b
[simple-engine] fix concurrency
hotpxl Aug 24, 2015
1ad9132
[simple-engine] fix a concurrency bug
hotpxl Aug 24, 2015
4662b4e
[simple-engine] lint
hotpxl Aug 24, 2015
d59a125
[simple-engine] Merge branch 'master' into simple-engine
hotpxl Aug 25, 2015
e71055c
[simple-engine] @antinucleon Sorry I had to turn this on to pass Doxygen
hotpxl Aug 25, 2015
be8312d
[simple-engine] heavier test on engine
hotpxl Aug 25, 2015
3a15c13
[simple-engine] add tests to Makefile
hotpxl Aug 25, 2015
e9ab2a2
[simple-engine] use macro for debugging
hotpxl Aug 25, 2015
a3dbb93
[simple-engine] fix engine bug
hotpxl Aug 25, 2015
6b1b9a1
[simple-engine] even more tests
hotpxl Aug 25, 2015
1c57396
[simple-engine] Merge branch 'master' into simple-engine
hotpxl Aug 26, 2015
602a436
[simple-engine] Merge branch 'master' into simple-engine
hotpxl Aug 28, 2015
e85fb0d
[simple-engine] fix concurrency bug
hotpxl Aug 29, 2015
5d4189e
[simple-engine] obj pool raw implementation
hotpxl Aug 29, 2015
bb16dc4
[simple-engine] Merge branch 'master' into simple-engine
hotpxl Aug 29, 2015
8628138
[simple-engine] fix counter
hotpxl Aug 29, 2015
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[simple-engine] implement missing functions
  • Loading branch information
hotpxl committed Aug 24, 2015
commit 761e0e14ad4af1b5e065956b0ede8fd671c039a8
64 changes: 58 additions & 6 deletions src/dag_engine/simple_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
#include <cassert>
#include <algorithm>
#include <utility>
#include <condition_variable>
#include <mutex>
#include "../common/cuda_utils.h"

namespace mxnet {

Expand All @@ -15,7 +18,8 @@ SimpleVar* SimpleVar::CastFromBase(Var* v) { return v->Cast<SimpleVar>(); }

SimpleOpr* SimpleOpr::CastFromBase(Opr* o) { return o->Cast<SimpleOpr>(); }

SimpleEngine::SimpleEngine() : thread_pool_{[this]() { ThreadWorker(); }} {}
SimpleEngine::SimpleEngine()
: pending_{0}, thread_pool_{[this]() { ThreadWorker(); }} {}

SimpleEngine::~SimpleEngine() noexcept(false) { task_queue_.SignalForKill(); }

Expand All @@ -42,9 +46,9 @@ SimpleEngine::Operator SimpleEngine::NewOperator(
void SimpleEngine::DeleteOperator(Operator op) { delete op; }

void SimpleEngine::Push(Operator op, Context exec_ctx) {
static_cast<void>(exec_ctx);
auto opr = SimpleOpr::CastFromBase(op);
auto opr_block = new OprBlock{};
++pending_;
opr_block->wait.store(opr->use_vars.size() + opr->mutate_vars.size() + 1);
// Add reading dependencies.
auto add_dependency = [&opr_block](SimpleVar* i) {
Expand Down Expand Up @@ -97,15 +101,60 @@ void SimpleEngine::Push(Operator op, Context exec_ctx) {
previous->lock.unlock();
}
auto callback = [this, first]() { OnComplete(first); };
// TODO(hotpxl) do something useful
RunContext ctx{};
ctx.stream = nullptr;
opr_block->fn = [opr, ctx, callback]() { opr->fn(ctx, callback); };
RunContext rctx{};
rctx.stream = nullptr;
opr_block->fn = [exec_ctx, opr, rctx, callback]() {
if (exec_ctx.dev_mask == gpu::kDevMask) {
#if MXNET_USE_CUDA
CUDA_CALL(cudaSetDevice(exec_ctx.dev_id));
#else // MXNET_USE_CUDA
LOG(FATAL) << "Please compile with CUDA enabled";
#endif // MXNET_USE_CUDA
}
opr->fn(rctx, callback);
};
if (--opr_block->wait == 0) {
task_queue_.Push(opr_block);
}
}

void SimpleEngine::PushAsync(AsyncFn fn, Context exec_ctx,
std::vector<Variable> const& use_vars,
std::vector<Variable> const& mutate_vars) {
auto&& opr = NewOperator(fn, use_vars, mutate_vars);
Push(opr, exec_ctx);
DeleteOperator(opr);
}

void SimpleEngine::PushDelete(Fn delete_fn, Context exec_ctx, Variable var) {
auto&& callback = [delete_fn, var](RunContext ctx) {
delete var;
// If you used `var` after `PushDelete`, then the following will be
// undefined
delete SimpleVar::CastFromBase(var)->var;
delete_fn(ctx);
};
Push(callback, exec_ctx, {}, {var});
}

void SimpleEngine::WaitForVar(Variable var) {
std::condition_variable cv;
std::mutex m;
std::unique_lock<std::mutex> lock{m};
std::atomic<bool> done{false};
auto&& callback = [&cv, &done](RunContext) {
done.store(true);
cv.notify_all();
};
cv.wait(lock, [&done]() { return done.load(); });
Push(callback, Context{}, {var}, {});
}

void SimpleEngine::WaitForAll() {
std::unique_lock<std::mutex> lock{finished_m_};
finished_cv_.wait(lock, [this]() { return pending_.load() == 0; });
}

void SimpleEngine::OnComplete(VersionedVarBlock* trigger) {
auto head = trigger;
while (head != nullptr) {
Expand Down Expand Up @@ -145,6 +194,9 @@ void SimpleEngine::ThreadWorker() {
assert(opr_block->wait.load() == 0);
opr_block->fn();
delete opr_block;
if (--pending_ == 0) {
finished_cv_.notify_all();
}
}
}

Expand Down
26 changes: 18 additions & 8 deletions src/dag_engine/simple_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include <vector>
#include <functional>
#include <atomic>
#include <condition_variable>
#include <mutex>
#include "mxnet/dag_engine.h"
#include "dag_engine_impl.h"
#include "thread_pool.h"
Expand Down Expand Up @@ -76,15 +78,17 @@ class SimpleEngine final : public DAGEngine {
* \brief Overriding methods.
*/
Variable NewVar() override;
Operator NewOperator(AsyncFn, std::vector<Variable> const&,
std::vector<Variable> const&) override;
Operator NewOperator(AsyncFn fn, std::vector<Variable> const& use_vars,
std::vector<Variable> const& mutate_vars) override;
void DeleteOperator(Operator op) override;
void Push(Operator op, Context) override;
void PushAsync(AsyncFn, Context, std::vector<Variable> const&,
std::vector<Variable> const&) override{};
void PushDelete(Fn, Context, Variable) override{};
void WaitForVar(Variable) override{};
void WaitForAll() override{};
void Push(Operator op, Context exec_ctx) override;
using DAGEngine::Push;
void PushAsync(AsyncFn exec_fun, Context exec_ctx,
std::vector<Variable> const& use_vars,
std::vector<Variable> const& mutate_vars) override;
void PushDelete(Fn delete_fn, Context exec_ctx, Variable var) override;
void WaitForVar(Variable var) override;
void WaitForAll() override;
/*!
* \brief Callback on operation completion.
*
Expand All @@ -103,6 +107,12 @@ class SimpleEngine final : public DAGEngine {
* \brief Concurrency for thread pool.
*/
static constexpr std::size_t kNumWorkingThreads = 16;
/*!
* \brief Number of pending operations.
*/
std::atomic<std::size_t> pending_;
std::condition_variable finished_cv_;
std::mutex finished_m_;
/*!
* \brief Task queue.
*/
Expand Down
18 changes: 4 additions & 14 deletions test/test_simple_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ int main() {
{var}, {}));
engine->Push(oprs.at(i), mxnet::Context{});
}
std::this_thread::sleep_for(std::chrono::seconds{1});
engine->WaitForAll();
// std::this_thread::sleep_for(std::chrono::seconds{1});

printf("============= Test #2 ==============\n");
var = engine->NewVar();
Expand All @@ -44,19 +45,8 @@ int main() {
{}, {var}));
engine->Push(oprs.at(i), mxnet::Context{});
}
std::this_thread::sleep_for(std::chrono::seconds{1});
// std::this_thread::sleep_for(std::chrono::seconds{1});
engine->WaitForAll();

// usleep(1000000);

// // Test #2
// cout << "============= Test #2 ==============" << endl;
// for (int i = 0; i < 10; ++i) {
// engine->Push([i] (RunContext rctx) { Foo(rctx, i); },
// exec_ctx, {}, vars);
// }

// usleep(1000000);

// // Test #3
return 0;
}