Skip to content

Commit

Permalink
First implementation of distributed evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
eldakms committed Apr 6, 2017
1 parent 6a50416 commit 3b4efc3
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 21 deletions.
15 changes: 12 additions & 3 deletions Source/CNTKv2LibraryDll/API/CNTKLibrary.h
Original file line number Diff line number Diff line change
Expand Up @@ -4475,6 +4475,14 @@ namespace CNTK
///
CNTK_API void SummarizeTestProgress();

///
/// Progress writers.
///
CNTK_API const std::unordered_set<ProgressWriterPtr>& ProgressWriters() const
{
return m_progressWriters;
}

CNTK_API virtual ~Evaluator() {}

private:
Expand All @@ -4483,9 +4491,10 @@ namespace CNTK

friend class TrainingSession;

// Returns aggregated evaluation criterion value and sample count.
std::pair<ValuePtr, size_t> TestMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, const DeviceDescriptor& computeDevice, bool distributed);
std::pair<ValuePtr, size_t> TestMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, const DeviceDescriptor& computeDevice, bool distributed);
// Returns true if testing should be continued in a distributed mode.
// Aggregated error and sample count can be retrieved using 'result' parameter.
bool TestMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::pair<ValuePtr, size_t>& result, const DeviceDescriptor& computeDevice, bool distributed = false);
bool TestMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, std::pair<ValuePtr, size_t>& result, const DeviceDescriptor& computeDevice, bool distributed = false);

std::pair<ValuePtr, size_t> TestLocalMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, const DeviceDescriptor& computeDevice);

Expand Down
44 changes: 37 additions & 7 deletions Source/CNTKv2LibraryDll/Evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,36 +106,66 @@ namespace CNTK

double Evaluator::TestMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, const DeviceDescriptor& computeDevice)
{
auto evalMinibatchValue = TestMinibatch(arguments, outputsToFetch, computeDevice, false);
std::pair<ValuePtr, size_t> evalMinibatchValue;
TestMinibatch(arguments, outputsToFetch, evalMinibatchValue, computeDevice, false);
return evalMinibatchValue.first->AsScalar<double>() / evalMinibatchValue.second;
}

std::pair<ValuePtr, size_t> Evaluator::TestMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, const DeviceDescriptor& computeDevice, bool distributed)
bool Evaluator::TestMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::pair<ValuePtr, size_t>& result, const DeviceDescriptor& computeDevice, bool distributed)
{
std::unordered_map<Variable, ValuePtr> outputsToFetch = {};
return TestMinibatch(arguments, outputsToFetch, computeDevice, distributed);
return TestMinibatch(arguments, outputsToFetch, result, computeDevice, distributed);
}

std::pair<ValuePtr, size_t> Evaluator::TestMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, const DeviceDescriptor& computeDevice, bool distributed)
bool Evaluator::TestMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, std::pair<ValuePtr, size_t>& result, const DeviceDescriptor& computeDevice, bool distributed)
{
result = TestLocalMinibatch(arguments, outputsToFetch, computeDevice);
if (distributed)
RuntimeError("Currently distributed testing is not supported.");
return TestLocalMinibatch(arguments, outputsToFetch, computeDevice);
{
if (!outputsToFetch.empty())
RuntimeError("Custom outputs are not yet supported in distributed evaluation.");

double localSampleCount = static_cast<double>(result.second);

auto values = std::vector<NDArrayViewPtr>{ result.first->Data(), MakeSharedObject<NDArrayView>(NDShape{ 1 }, &localSampleCount, 1, DeviceDescriptor::CPUDevice()) };
DistributedCommunicatorPtr communicator = MPICommunicator();
communicator->AggregateInPlace(values, communicator->Workers());
result.second = static_cast<size_t>(localSampleCount);
}

bool hasData = (result.second != 0);
if (hasData)
UpdateTestProgress(result.second, result.first, computeDevice);

return hasData;
}

std::pair<ValuePtr, size_t> Evaluator::TestLocalMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, const DeviceDescriptor& computeDevice)
{
if (!m_aggregatedEvaluationFunction)
InvalidArgument("Evaluator::TestMinibatch: Cannot test when no evaluation function was specified during construction.");

if (arguments.empty()) // Empty minibatch, return 0.
{
auto zeroValue = MakeSharedObject<Value>(
MakeSharedObject<NDArrayView>(
m_aggregatedEvaluationFunction->Output().GetDataType(),
m_aggregatedEvaluationFunction->Output().IsSparse() ? StorageFormat::SparseCSC : StorageFormat::Dense,
m_aggregatedEvaluationFunction->Output().Shape(), computeDevice));
if(zeroValue->GetDataType() == DataType::Float)
zeroValue->Data()->SetValue(0.0f);
else
zeroValue->Data()->SetValue(0.0);
return std::make_pair(zeroValue, 0);
}

std::unordered_map<Variable, ValuePtr> outputs = { { m_aggregatedEvaluationFunction, nullptr }, { m_testSampleCountVar, nullptr } };
outputs.insert(outputsToFetch.begin(), outputsToFetch.end());

m_combinedEvalFunction->Forward(arguments, outputs, computeDevice);

const ValuePtr& aggregateEvalCriterionValue = outputs[m_aggregatedEvaluationFunction];
auto sampleCount = GetSampleCount(m_testSampleCountVar, outputs[m_testSampleCountVar]);
UpdateTestProgress(sampleCount, aggregateEvalCriterionValue, computeDevice);

// Copy back output values for requested variables only.
for (auto& o : outputsToFetch)
Expand Down
31 changes: 20 additions & 11 deletions Source/CNTKv2LibraryDll/TrainingSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,15 +238,22 @@ namespace CNTK
size_t totalNumberOfSamples = 0;
size_t numberOfMinibatches = 0;

std::pair<ValuePtr, size_t> errorAndCount;
auto checkpoint = m_cv.m_source->GetCheckpointState();
while (GetCrossValidationMinibatch(minibatch, m_cv.m_mbSize[totalNumberOfSamples], computeDevice), !minibatch.empty())
bool shouldCV = true;
while (shouldCV)
{
GetCrossValidationMinibatch(minibatch, m_cv.m_mbSize[totalNumberOfSamples], computeDevice);

// TODO: it may be slow to rely on TestMinibatch to return error each time, since it may require transfer
// of error from the GPU each time.
auto result = m_trainer->TestMinibatch(minibatch, computeDevice, false);
accumulatedError += result.first->AsScalar<double>();
totalNumberOfSamples += result.second;
numberOfMinibatches++;
// of error from the GPU each time, accumulatedError can be allocated on GPU
shouldCV = m_trainer->TestMinibatch(minibatch, errorAndCount, computeDevice, m_numberOfWorkers != 1);
if (shouldCV)
{
accumulatedError += errorAndCount.first->AsScalar<double>();
totalNumberOfSamples += errorAndCount.second;
numberOfMinibatches++;
}
}

m_cv.m_source->RestoreFromCheckpoint(checkpoint);
Expand All @@ -266,10 +273,13 @@ namespace CNTK

std::unordered_map<Variable, ValuePtr> minibatch;
size_t totalNumberOfSamples = 0;
while (GetNextMinibatch(m_test.m_source, minibatch, m_test.m_mbSize[totalNumberOfSamples], 0, 1, computeDevice), !minibatch.empty())
bool shouldTest = true;
std::pair<ValuePtr, size_t> errorAndCount;
while (shouldTest)
{
auto result = m_trainer->TestMinibatch(minibatch, computeDevice, false);
totalNumberOfSamples += result.second;
GetNextMinibatch(m_test.m_source, minibatch, m_test.m_mbSize[totalNumberOfSamples], m_workerRank, m_numberOfWorkers, computeDevice);
shouldTest = m_trainer->TestMinibatch(minibatch, errorAndCount, computeDevice, m_numberOfWorkers != 1);
totalNumberOfSamples += errorAndCount.second;
}

m_trainer->SummarizeTestProgress();
Expand Down Expand Up @@ -298,8 +308,7 @@ namespace CNTK

void TrainingSession::GetCrossValidationMinibatch(std::unordered_map<Variable, ValuePtr>& minibatch, size_t maxMbSize, const DeviceDescriptor& computeDevice)
{
// TODO: Support distributed cross-validation, when TestMinibatch supports it.
GetNextMinibatch(m_cv.m_source, minibatch, maxMbSize, 0, 1, computeDevice);
GetNextMinibatch(m_cv.m_source, minibatch, maxMbSize, m_workerRank, m_numberOfWorkers, computeDevice);
}

void TrainingSession::GetNextMinibatch(const MinibatchSourcePtr& source, std::unordered_map<Variable, ValuePtr>& minibatch, size_t mbSize, size_t workerRank, size_t numberOfWorkers, const DeviceDescriptor& computeDevice)
Expand Down

0 comments on commit 3b4efc3

Please sign in to comment.