Skip to content
This repository has been archived by the owner on Jan 20, 2020. It is now read-only.

fix group symbols cached bug #11

Merged
merged 1 commit into from
Nov 20, 2016
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions src/session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,15 @@ class TorchSession : public Session {
private:
// entry to store cached executor
struct ExecEntry {
nnvm::Symbol cached_symbol;
std::shared_ptr<TorchExecutor> exec;
size_t use_count{0};
};
int default_dev_mask_{kCPU};
// local cached variable states.
VarStateMap states_;
// cached executor
std::unordered_map<Symbol*, ExecEntry> cached_execs_;
std::unordered_map<uint64_t, ExecEntry> cached_execs_;
};


Expand Down Expand Up @@ -145,17 +146,23 @@ Session* Session::Create(const std::string& option) {
}

const std::vector<TBlob>& TorchSession::Run(
nnvm::Symbol* sym,
nnvm::Symbol* new_sym,
const std::unordered_map<std::string, TBlob>& inputs) {
if (cached_execs_.count(sym) != 0) {
auto& entry = cached_execs_.at(sym);
const nnvm::Symbol& s = entry.exec->symbol();
bool stale_exec = (s.outputs.size() != sym->outputs.size());
// compute the hash value
uint64_t hash_value = new_sym->outputs.size();
for (NodeEntry& e : new_sym->outputs) {
uint64_t value = reinterpret_cast<uint64_t>(e.node.get());
hash_value ^= value + 0x9e3779b9 + (hash_value << 6) + (hash_value >> 2);
}
if (cached_execs_.count(hash_value) != 0) {
auto& entry = cached_execs_.at(hash_value);
const nnvm::Symbol& old_sym = entry.cached_symbol;
bool stale_exec = (old_sym.outputs.size() != new_sym->outputs.size());
if (!stale_exec) {
for (size_t i = 0; i < s.outputs.size(); ++i) {
if (s.outputs[i].node.get() != sym->outputs[i].node.get() ||
s.outputs[i].index != sym->outputs[i].index ||
s.outputs[i].version != sym->outputs[i].version) {
for (size_t i = 0; i < old_sym.outputs.size(); ++i) {
if (old_sym.outputs[i].node.get() != new_sym->outputs[i].node.get() ||
old_sym.outputs[i].index != new_sym->outputs[i].index ||
old_sym.outputs[i].version != new_sym->outputs[i].version) {
stale_exec = true; break;
}
}
Expand All @@ -164,16 +171,18 @@ const std::vector<TBlob>& TorchSession::Run(
++entry.use_count;
return entry.exec->Run(inputs);
} else {
cached_execs_.erase(sym);
cached_execs_.erase(hash_value);
}
}
// dump technique, remove all previous executors
// better strategy, LRU?
LOG(INFO) << "New Executor";
cached_execs_.clear();
ExecEntry e;
e.cached_symbol = *new_sym;
e.exec = std::make_shared<TorchExecutor>();
e.exec->Init(*sym, &states_, default_dev_mask_);
cached_execs_[sym] = e;
e.exec->Init(*new_sym, &states_, default_dev_mask_);
cached_execs_[hash_value] = e;
return e.exec->Run(inputs);
}

Expand Down