Skip to content

Commit

Permalink
Add local worker's state checkpointing
Browse files Browse the repository at this point in the history
  Add a mechanism to preserve the local state of a distributed worker
  (together with the external state aggregated by the main node and saved
  inside the checkpoint).
  • Loading branch information
Alexey Reznichenko committed Apr 4, 2017
1 parent e333412 commit 8592965
Show file tree
Hide file tree
Showing 24 changed files with 616 additions and 233 deletions.
13 changes: 11 additions & 2 deletions Source/CNTKv2LibraryDll/API/CNTKLibrary.h
Original file line number Diff line number Diff line change
Expand Up @@ -1515,7 +1515,15 @@ namespace CNTK
DictionaryIterator end() const { return m_dictionaryData->end(); }
ConstDictionaryIterator cend() const { return m_dictionaryData->cend(); }

size_t Size() { return m_dictionaryData->size(); }
size_t Size() const { return m_dictionaryData->size(); }

std::unordered_set<std::wstring> Keys()
{
std::unordered_set<std::wstring> keys;
for (const auto& kv : *m_dictionaryData)
keys.insert(kv.first);
return keys;
}

friend CNTK_API std::istream& operator>>(std::istream& stream, Dictionary& us);
friend CNTK_API std::ostream& operator<<(std::ostream& stream, const Dictionary& us);
Expand Down Expand Up @@ -4574,7 +4582,8 @@ namespace CNTK
bool TrainLocalMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, bool sweepEnd, const DeviceDescriptor& computeDevice);
bool TrainDistributedMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, bool sweepEnd, const DeviceDescriptor& computeDevice);

void Save(const std::wstring& modelFilePath, const std::vector<DictionaryValue>& learnerState, const Dictionary& externalState);
void Save(const std::wstring& modelFilePath, const std::vector<DictionaryValue>& learnerState,
const Dictionary& externalState, const Dictionary& distributedState = {});

void UpdateTrainingProgress(size_t numSamples, const ValuePtr& loss, const ValuePtr& evalCriterion, const DeviceDescriptor& computeDevice);
void AddProgressWriters(const std::vector<ProgressWriterPtr>& progressWriters);
Expand Down
2 changes: 2 additions & 0 deletions Source/CNTKv2LibraryDll/BlockFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "stdafx.h"
#include "CNTKLibrary.h"
#include "PrimitiveFunction.h"
#include "Utils.h"
#include "Variable.h"

namespace CNTK
{
Expand Down
2 changes: 1 addition & 1 deletion Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj
Original file line number Diff line number Diff line change
Expand Up @@ -193,4 +193,4 @@
<Warning Condition="!$(HasProtobuf)" Text="CNTKv2LibraryDll requires Protocol Buffers to build. Please see https://github.com/Microsoft/CNTK/wiki/Setup-CNTK-on-Windows#protobuf for installation instructions." />
<Error Condition="!$(HasBoost)" Text="CNTKv2LibraryDll requires the Boost library to build. Please see https://github.com/Microsoft/CNTK/wiki/Setup-CNTK-on-Windows#boost for installation instructions." />
</Target>
</Project>
</Project>
22 changes: 16 additions & 6 deletions Source/CNTKv2LibraryDll/Common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,26 @@ namespace CNTK

static std::atomic_ullong s_currentRandomSeed = ATOMIC_VAR_INIT(0);

// This is used to generate a default seed for stateful nodes (dropout, and both
// flavors of random sample). As a result, in distributed environment, each worker
// ends up having a different seed.

size_t GenerateRandomSeed()
{
DistributedCommunicatorPtr communicator = MPICommunicator();
auto numWorkers = communicator->Workers().size();
auto rank = communicator->CurrentWorker().m_globalRank;
static size_t numWorkers = 1, rank = 0;
static bool initialized = false;
if (MPIWrapper::GetTotalNumberOfMPINodes() != 0 && !initialized)
{
DistributedCommunicatorPtr communicator = MPICommunicator();
numWorkers = communicator->Workers().size();
rank = communicator->CurrentWorker().m_globalRank;

if (numWorkers < 1)
numWorkers = 1;
if (numWorkers < 1)
numWorkers = 1;
}

return (numWorkers * ++s_currentRandomSeed) + rank;
initialized = true;
return (numWorkers * s_currentRandomSeed++) + rank;
}

std::atomic<bool> s_reverseTensorShapesInErrorMessages(false);
Expand Down
236 changes: 146 additions & 90 deletions Source/CNTKv2LibraryDll/CompositeFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,58 @@ namespace CNTK
return dict;
}

// Copy the internal state from the network into the function graph,
// specifically from RngUser nodes into the attributes dictionaries of
// the corresponding stateful primitive functions.
void CompositeFunction::UpdateInternalState() const
{
if (!m_computationNetwork)
return;

for (auto& function : m_allPrimitiveFunctions)
{
auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
if (!primitiveFunction->IsStateful())
continue;

// TODO: same for BatchNorm

auto& outputs = primitiveFunction->RawOutputs();
if (outputs.size() != 1)
LogicError("Function '%S' UpdateInternalState: a stateful primitive function must have a single output.", AsString().c_str());

const auto& rng = m_variableToNodeMap.at(outputs[0])->As<RngUser>();

Dictionary state;
state[PrimitiveFunction::AttributeNameRngSeed] = static_cast<size_t>(rng->GetRngSeed());
state[PrimitiveFunction::AttributeNameRngOffset] = static_cast<size_t>(rng->GetRngOffset());
primitiveFunction->SetState(state);
}
}

// Generate a dictionary representing the internal (local) state of the function graph.
Dictionary CompositeFunction::GetInternalState() const
{
UpdateInternalState();

Dictionary stateDictionary;
for (auto& function : m_allPrimitiveFunctions)
{
auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(function.get());
if (!primitiveFunction->IsStateful())
continue;

// TODO: same for BatchNorm

stateDictionary[primitiveFunction->Uid()] = primitiveFunction->GetState();
}
return stateDictionary;
}

/*virtual*/ Dictionary CompositeFunction::Serialize() const
{
UpdateInternalState();

Dictionary dict = SerializeBlockComposite();

// Find cycles in the graph and "break" them by inserting placeholders.
Expand Down Expand Up @@ -129,29 +179,6 @@ namespace CNTK
}

dict[functionsKey] = std::move(functionDictionaries);

// Now, collect and store the internal state for all non-pure (stateful) functions in the graph
// (with the corresponding nodes that subclass from RngUser: Dropout, RandomSample, etc).
Dictionary stateDictionary;
for (const auto& kv : m_variableToNodeMap)
{
if (kv.second->Is<RngUser>() && kv.first.IsOutput())
{
// The RNG state should be associated with the actual function that the computation node
// corresponds to, and not the block primitives that wrap the actual function
auto ownerFunction = kv.first.Owner().get();
if (!ownerFunction->IsBlock())
{
auto rng = kv.second->As<RngUser>();
Dictionary state;
state[rngSeedKey] = static_cast<size_t>(rng->GetRngSeed());
state[rngOffsetKey] = static_cast<size_t>(rng->GetRngOffset());
stateDictionary[ownerFunction->Uid()] = state;
}
}
}

dict[stateKey] = std::move(stateDictionary);

return dict;
}
Expand Down Expand Up @@ -217,10 +244,6 @@ namespace CNTK
uidToInputMap[inputVar.Uid()] = inputVar;
}

Dictionary stateDictionary;
if (dict.Contains(stateKey))
stateDictionary = dict[stateKey].Value<Dictionary>();

const auto& functions = dict[functionsKey].Value<vector<DictionaryValue>>();

std::unordered_map<Variable, Variable> allPlaceholderReplacements;
Expand All @@ -238,25 +261,6 @@ namespace CNTK
if (opType == PrimitiveOpType::Combine)
continue;

if (primitiveFunction->IsStateful())
{
if (stateDictionary.Contains(primitiveFunction->Uid()))
{
auto state = stateDictionary[primitiveFunction->Uid()].Value<Dictionary>();
auto seed = state[rngSeedKey].Value<size_t>();
auto offset = state[rngOffsetKey].Value<size_t>();
primitiveFunction->m_attributes[PrimitiveFunction::AttributeNameRngSeed] = seed;
primitiveFunction->m_attributes[PrimitiveFunction::AttributeNameRngOffset] = offset;
}
else if (Internal::GetComputationNetworkTraceLevel() > 0)
{
// TODO: all logging functionality should be refactored to live in a logging utility class.
fprintf(stderr, "WARNING: no state information found for the stateful function (%ls) "
"when deserializing from a dictionary (version=%zu). "
"Reproducibility not guaranteed.", primitiveFunction->OpName().c_str(), version);
}
}

for (const auto& output : root->RawOutputs())
{
const auto& it = uidToInputMap.find(output.Uid());
Expand All @@ -276,63 +280,122 @@ namespace CNTK
}
}


// starting with the serialization version = 3, the state is preserved inside the attribute dictionaries of the
// corresponding primitive functions. Earlier versions have a dedicated key-value pair in the composite function dict.
if (version < 3)
RestoreStatefulFunctions(version, dict, allPrimitiveFunctions);

return DeserializeBlockComposite(dict, allPrimitiveFunctions, allPlaceholderReplacements, device);
}

void CompositeFunction::CopyState(const CompositeFunction& source)
void CompositeFunction::RestoreStatefulFunctions(size_t version, const Dictionary& dict, std::unordered_set<FunctionPtr> functions)
{
// Create a map with all non-pure (stateful) functions in the function graph.
auto collectStatefulFunctions = [](const std::unordered_set<FunctionPtr>& allPrimitiveFunctions) -> std::map<std::wstring, FunctionPtr> {
std::map<std::wstring, FunctionPtr> functionMap;
for (auto funcPtr : allPrimitiveFunctions)
Dictionary stateDictionary;
if (dict.Contains(stateKey))
stateDictionary = dict[stateKey].Value<Dictionary>();

for (auto& function : functions)
{
auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
if (!primitiveFunction->IsStateful())
continue;

if (stateDictionary.Contains(primitiveFunction->Uid()))
{
auto state = stateDictionary[primitiveFunction->Uid()].Value<Dictionary>();
// Add key-value pairs expected by the SetState method to the state dictionary.
state[PrimitiveFunction::AttributeNameRngSeed] = state[rngSeedKey].Value<size_t>();
state[PrimitiveFunction::AttributeNameRngOffset] = state[rngOffsetKey].Value<size_t>();
primitiveFunction->SetState(state);
}
else
{
if (Internal::GetComputationNetworkTraceLevel() > 0) {
// TODO: all logging functionality should be refactored to live in a logging utility class.
fprintf(stderr, "WARNING: no state information found for the stateful function (%ls) "
"when deserializing from a dictionary (version=%zu). "
"Reproducibility not guaranteed.", primitiveFunction->OpName().c_str(), version);
}

// Create state from scratch, so that function attributes contain all the required key-value pairs.
Dictionary state;
state[PrimitiveFunction::AttributeNameRngSeed] = Internal::GenerateRandomSeed();
state[PrimitiveFunction::AttributeNameRngOffset] = 0;
primitiveFunction->SetState(state);
}
}
}

void CompositeFunction::CopyState(const CompositeFunction& source)
{
// Collect a vector of stateful funciton uids using a pre-order traversal of a function graphs.
auto collectStatefulFunctionUIDs = [](const Function& function) -> vector<wstring> {
vector<wstring> uids;
PreorderTraverseFunctions(function.RootFunction(), [&uids](const FunctionPtr& funcPtr) {
auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(funcPtr.get());
if (primitiveFunction->IsStateful())
if (primitiveFunction->IsStateful())
{
functionMap[primitiveFunction->Uid()] = funcPtr;
uids.push_back(funcPtr->Uid());
}
}
return functionMap;
}, true);

return uids;
};

std::map<std::wstring, FunctionPtr> statefulFunctionsTo = collectStatefulFunctions(m_allPrimitiveFunctions);
std::map<std::wstring, FunctionPtr> statefulFunctionsFrom = collectStatefulFunctions(source.m_allPrimitiveFunctions);
auto theirUIDs = collectStatefulFunctionUIDs(source);
auto ourUIDs = collectStatefulFunctionUIDs(*this);

assert(statefulFunctionsTo.size() == statefulFunctionsFrom.size());
if (statefulFunctionsFrom.size() == 0)
{
return;
}
if (theirUIDs.size() != ourUIDs.size())
CNTK::LogicError("Cannot copy internal state, the source and the destination contain different number of stateful functions.");

auto state = source.GetInternalState();

// Copy state captured in the attributes dictionaries.
for (const auto& kv : statefulFunctionsFrom)
if (theirUIDs == ourUIDs)
{
statefulFunctionsTo[kv.first]->m_attributes = kv.second->Attributes();
// uids are identialy, no need to remap.
SetInternalState(state);
return;
}

// build a map of souce funtion to the destination (this) function UIDs.
map<wstring, wstring> uidMap;
for (auto i = 0; i < theirUIDs.size(); i++)
uidMap[theirUIDs[i]] = ourUIDs[i];

Dictionary remappedState;
for (auto& kv : state)
remappedState[uidMap[kv.first]] = kv.second;

UpdateInternalNetworkState();
SetInternalState(remappedState);
}

void CompositeFunction::UpdateInternalNetworkState()
void CompositeFunction::SetInternalState(const Dictionary& state)
{
if (!m_computationNetwork)
{
if (state.Size() == 0)
return;
}

for (const auto& function : m_allPrimitiveFunctions)
{
auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(function.get());
if (primitiveFunction->IsStateful())
auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
if (!primitiveFunction->IsStateful())
continue;

auto functionState = state[primitiveFunction->Uid()].Value<Dictionary>();

primitiveFunction->SetState(functionState);

if (!m_computationNetwork)
continue;

auto seed = functionState[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
auto offset = functionState[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();

// copy the state directly into the network
for (const auto& output : function->RawOutputs())
{
for (const auto& output : function->RawOutputs())
{
auto node = m_variableToNodeMap.at(output);
auto attributes = function->Attributes();
auto seed = attributes[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
auto offset = attributes[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
node->As<RngUser>()->SetRngState(seed, offset);
}
auto node = m_variableToNodeMap.at(output);
node->As<RngUser>()->SetRngState(seed, offset);
}
}
}
Expand Down Expand Up @@ -895,16 +958,9 @@ namespace CNTK

if (computationNodePtr->Is<RngUser>())
{
if (functionConfig.Contains(PrimitiveFunction::AttributeNameRngSeed))
{
auto seed = functionConfig[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
uint64_t offset = 0;
if (functionConfig.Contains(PrimitiveFunction::AttributeNameRngOffset))
{
offset = functionConfig[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
}
computationNodePtr->As<RngUser>()->SetRngState(seed, offset);
}
auto seed = functionConfig[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
auto offset = functionConfig[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
computationNodePtr->As<RngUser>()->SetRngState(seed, offset);
}
}
else
Expand Down
Loading

0 comments on commit 8592965

Please sign in to comment.