-
Notifications
You must be signed in to change notification settings - Fork 368
Conversation
ZihengJiang
commented
Nov 18, 2016
•
edited
Loading
edited
- change the interface of Session.Run(...)
- use a hash function to generate the hash value of the vector of Symbols*
uint64_t hash_value = hash_symbols(syms); | ||
|
||
if (cached_execs_.count(hash_value) != 0) { | ||
auto& entry = cached_execs_.at(hash_value); | ||
const nnvm::Symbol& s = entry.exec->symbol(); | ||
bool stale_exec = (s.outputs.size() != sym->outputs.size()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check is not enough, need to check the ptr of all symbols
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tqchen looks like this requires us to save the pointer of symbols before. Do you mean there is a condition that two vector of symbols have the same hash value and also can pass the stale executor check below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. thatiswhat I mean
for (uint32_t i = 0; i < syms.size(); ++i) { | ||
sym_arr.push_back(*syms[i]); | ||
} | ||
*new_sym = nnvm::Symbol::CreateGroup(sym_arr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
simply write nnvm::Symbol new_sym instead of pointer
|
||
// compute the hash value of the input symbols | ||
uint64_t hash_value = syms.size(); | ||
for (nnvm::Symbol* i : syms) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let us add a new function called FindCachedExecs
|
||
// compute the hash value of the input symbols | ||
uint64_t hash_value = syms.size(); | ||
for (nnvm::Symbol* i : syms) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use pointer of nodes as hash keys, this removes the need of symbol pointers
@@ -67,6 +67,8 @@ class TorchSession : public Session { | |||
private: | |||
// entry to store cached executor | |||
struct ExecEntry { | |||
// cached symbols | |||
std::vector<nnvm::Symbol*> cached_symbols_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simply use nnvm::Symbol, use the node pointers as key
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mean group symbol? then use the pointer of output node?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the output node will change, so it should be symbol.outputs[0].inputs
nodes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the output node won't change, as they are refering back to the same NodePtr
|