Skip to content

Commit

Permalink
Merge branch 'qiwye/asgd-dev' of https://github.com/Microsoft/CNTK in…
Browse files Browse the repository at this point in the history
…to qiwye/asgd-dev
  • Loading branch information
feiga committed Nov 1, 2016
2 parents 9776a03 + b33e67a commit 6b1e2cd
Show file tree
Hide file tree
Showing 7 changed files with 354 additions and 308 deletions.
22 changes: 17 additions & 5 deletions Source/Common/Include/ASGDCommon.h
Original file line number Diff line number Diff line change
@@ -1,17 +1,29 @@
//
// <copyright file="ASGDCommon.h" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
// </copyright>
//
#pragma once

namespace Microsoft { namespace MSR { namespace CNTK {

enum class AdjustLearningRateatBeginning : int
// -----------------------------------------------------------------------
// class AdjustLearningRateAtBeginning
// Providing option for DataParallelASGD training. so that every nodes
// could adjust learning rate every minibatch at first N epochs.
// -----------------------------------------------------------------------


// TODO: We can removed these options once we can adjust learning rate at minibatchs level

enum class AdjustLearningRateAtBeginning : int
{
None = 0,
Linearly = 1,
Staircase = (1 << 1),
None = 0, // default, don't adjust learning rate
Linearly = 1, // using linear adjustment, learning rate will from 0 to learningRatesPerMB
Staircase = (1 << 1), // using staircased adjustment, learning rate will from 0 to learningRatesPerMB every adjustNbMinibatch
};

}}}
466 changes: 248 additions & 218 deletions Source/Common/Include/MultiversoWrapper.h

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion Source/Common/Include/NoMultiversoWrapper.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
//
// <copyright file="NoMultiversoWrapper.h" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
// </copyright>
//
#pragma once
Expand Down
150 changes: 75 additions & 75 deletions Source/SGDLib/SGD.cpp

Large diffs are not rendered by default.

15 changes: 8 additions & 7 deletions Source/SGDLib/SGD.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#pragma once

#include "Basics.h"
#include "ASGDCommon.h"
#include "ComputationNetwork.h"
#include "SimpleEvaluator.h"
#include "DataReader.h"
Expand All @@ -20,6 +19,7 @@
#include <random>
#include "Profiler.h"
#include "MASGD.h"
#include "ASGDCommon.h"
using namespace std; // ugh! TODO: get rid of this from .h files!!!

#define CNTK_CHECKPOINT_VERSION_1 1 // 1 -> no version number
Expand Down Expand Up @@ -288,16 +288,16 @@ struct SGDParams : public ScriptableObjects::Object
double m_L1RegWeight;

// Parallel training related with ASGD
intargvector m_numMBsToASGDPushAndPull; // decide how many minibatchs should ASGD to a pull&push to parameter server.
intargvector m_numMiniBatchesToPushAndPullforASGD; // decide how many minibatchs should ASGD to a pull&push to parameter server.
// note that, this will override m_nFramesBetweenASGDSync when set.
intargvector m_nFramesBetweenASGDSync;
bool m_isPipeline;
bool m_isSimulateMA;
AdjustLearningRateatBeginning m_adjustlearningrateatbeginning;
double m_adjustcoefficient;
size_t m_adjustnbminibatch;
AdjustLearningRateAtBeginning m_adjustLearningRateAtBeginning;
double m_adjustCoefficient;
size_t m_adjustPerMinibatches;

//sequence training
// sequence training
double m_hSmoothingWeight;
double m_frameDropThresh;
bool m_doReferenceAlign;
Expand All @@ -316,6 +316,7 @@ struct SGDParams : public ScriptableObjects::Object
template <class ElemType>
class IDistGradAggregator;

// MultiversoHelper is used for parallel training using DataParallelASGD
template <class ElemType>
class MultiversoHelper;

Expand Down Expand Up @@ -578,7 +579,7 @@ class SGD : public SGDParams

private:
void MarkDropoutNodesEvalTimeStampAsOutdated(const ComputationNetworkPtr& net, const ComputationNodeBasePtr& criterionNode);
MultiversoHelper<ElemType>* m_pMultiversoHelper;
shared_ptr<MultiversoHelper<ElemType>> m_pMultiversoHelper;
bool m_pMultiversoHelperBarrier;

bool UsingGradientAggregation(size_t epochNumber) const
Expand Down
2 changes: 1 addition & 1 deletion Source/SGDLib/SGDLib.vcxproj
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
</PrecompiledHeader>
<PreprocessorDefinitions>WIN32;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<PreprocessorDefinitions Condition="'$(CNTK_ENABLE_1BitSGD)'=='true'">QUANTIZED_GRADIENT_AGGREGATION;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<PreprocessorDefinitions Condition="'$(CNTK_ENABLE_ASGD)'=='true'">MULTIVERSO_SUPPORT;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<PreprocessorDefinitions Condition="'$(CNTK_ENABLE_ASGD)'=='true'">ASGD_PARALLEL_SUPPORT;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<DisableSpecificWarnings>4819</DisableSpecificWarnings>
</ClCompile>
<Link>
Expand Down
2 changes: 1 addition & 1 deletion configure
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ do
fi
;;

--asgd*)
--asgd*)
if test x$optarg = xyes || test x$optarg = xno
then
enable_asgd=$optarg
Expand Down

0 comments on commit 6b1e2cd

Please sign in to comment.