Skip to content

Commit

Permalink
Adding input to stream for testing and cv config
Browse files Browse the repository at this point in the history
  • Loading branch information
eldakms committed May 23, 2017
1 parent 85ed877 commit 18de8f0
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 27 deletions.
14 changes: 10 additions & 4 deletions Source/CNTKv2LibraryDll/API/CNTKLibrary.h
Original file line number Diff line number Diff line change
Expand Up @@ -5405,15 +5405,16 @@ namespace CNTK
CNTK_API CrossValidationConfig(const MinibatchSourcePtr& crossValidationSource,
const MinibatchSizeSchedule& crossValidationSchedule = MinibatchSizeSchedule(64),
size_t crossValidationFrequencyInSamples = std::numeric_limits<size_t>::max(),
size_t maxSamples = std::numeric_limits<size_t>::max()
);
size_t maxSamples = std::numeric_limits<size_t>::max(),
const std::unordered_map<Variable, StreamInformation>& inputVarToStream = {});

private:
friend class TrainingSession;
const MinibatchSourcePtr m_source;
const MinibatchSizeSchedule m_mbSize;
const size_t m_frequency;
const size_t m_maxSamples;
const std::unordered_map<Variable, StreamInformation> m_varToStream;
};

///
Expand Down Expand Up @@ -5454,12 +5455,14 @@ namespace CNTK
/// schedule : a minibatch size schedule
///
CNTK_API TestConfig(const MinibatchSourcePtr& source,
const MinibatchSizeSchedule& schedule = MinibatchSizeSchedule(64));
const MinibatchSizeSchedule& schedule = MinibatchSizeSchedule(64),
const std::unordered_map<Variable, StreamInformation>& inputVarToStream = {});

private:
friend class TrainingSession;
const MinibatchSourcePtr m_source;
const MinibatchSizeSchedule m_mbSize;
const std::unordered_map<Variable, StreamInformation> m_varToStream;
};

///
Expand Down Expand Up @@ -5565,7 +5568,10 @@ namespace CNTK
TrainingSession(const TrainingSession&) = delete; TrainingSession& operator=(const TrainingSession&) = delete; TrainingSession& operator=(TrainingSession&&) = delete; TrainingSession(TrainingSession&&) = delete;

// Auxilary functions.
void GetNextMinibatch(const MinibatchSourcePtr& source, std::unordered_map<Variable, ValuePtr>& minibatch, size_t maxMbSize, size_t workerRank, size_t numberOfWorkers, const DeviceDescriptor& computeDevice);
void GetNextMinibatch(const MinibatchSourcePtr& source,
std::unordered_map<Variable, ValuePtr>& minibatch,
const std::unordered_map<Variable, StreamInformation>& inputVarToStream,
size_t maxMbSize, size_t workerRank, size_t numberOfWorkers, const DeviceDescriptor& computeDevice);
void GetTrainingMinibatch(std::unordered_map<Variable, ValuePtr>& minibatch, size_t maxMbSize, const DeviceDescriptor& computeDevice);
void GetCrossValidationMinibatch(std::unordered_map<Variable, ValuePtr>& minibatch, size_t maxMbSize, const DeviceDescriptor& computeDevice);

Expand Down
31 changes: 21 additions & 10 deletions Source/CNTKv2LibraryDll/TrainingSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,23 @@ namespace CNTK
const MinibatchSourcePtr& crossValidationSource,
const MinibatchSizeSchedule& crossValidationSchedule,
size_t crossValidationFrequencyInSamples,
size_t maxSamples
):
size_t maxSamples,
const std::unordered_map<Variable, StreamInformation>& inputVarToStream):
m_source(crossValidationSource),
m_mbSize(crossValidationSchedule),
m_frequency(crossValidationFrequencyInSamples),
m_maxSamples(maxSamples)
m_maxSamples(maxSamples),
m_varToStream(inputVarToStream)
{
}

TestConfig::TestConfig(
const MinibatchSourcePtr& source,
const MinibatchSizeSchedule& schedule) :
const MinibatchSizeSchedule& schedule,
const std::unordered_map<Variable, StreamInformation>& inputVarToStream) :
m_source(source),
m_mbSize(schedule)
m_mbSize(schedule),
m_varToStream(inputVarToStream)
{
}

Expand Down Expand Up @@ -281,7 +284,8 @@ namespace CNTK
std::pair<ValuePtr, size_t> errorAndCount;
while (shouldTest)
{
GetNextMinibatch(m_test.m_source, minibatch, m_test.m_mbSize[totalNumberOfSamples], m_workerRank, m_numberOfWorkers, computeDevice);
GetNextMinibatch(m_test.m_source, minibatch, m_test.m_varToStream.empty() ? m_varToStream : m_test.m_varToStream,
m_test.m_mbSize[totalNumberOfSamples], m_workerRank, m_numberOfWorkers, computeDevice);
shouldTest = m_trainer->TestMinibatch(minibatch, errorAndCount, computeDevice, m_numberOfWorkers != 1);
totalNumberOfSamples += errorAndCount.second;
}
Expand All @@ -307,15 +311,22 @@ namespace CNTK

size_t mbSize = GetMinibatchSize();
mbSize = std::min(mbSize, maxMbSize);
GetNextMinibatch(m_source, minibatch, mbSize, workerRank, numberOfWorkers, computeDevice);
GetNextMinibatch(m_source, minibatch, m_varToStream, mbSize, workerRank, numberOfWorkers, computeDevice);
}

void TrainingSession::GetCrossValidationMinibatch(std::unordered_map<Variable, ValuePtr>& minibatch, size_t maxMbSize, const DeviceDescriptor& computeDevice)
{
GetNextMinibatch(m_cv.m_source, minibatch, maxMbSize, m_workerRank, m_numberOfWorkers, computeDevice);
GetNextMinibatch(m_cv.m_source, minibatch, m_cv.m_varToStream.empty() ? m_varToStream : m_cv.m_varToStream, maxMbSize, m_workerRank, m_numberOfWorkers, computeDevice);
}

void TrainingSession::GetNextMinibatch(const MinibatchSourcePtr& source, std::unordered_map<Variable, ValuePtr>& minibatch, size_t mbSize, size_t workerRank, size_t numberOfWorkers, const DeviceDescriptor& computeDevice)
void TrainingSession::GetNextMinibatch(
const MinibatchSourcePtr& source,
std::unordered_map<Variable, ValuePtr>& minibatch,
const std::unordered_map<Variable, StreamInformation>& inputVarToStream,
size_t mbSize,
size_t workerRank,
size_t numberOfWorkers,
const DeviceDescriptor& computeDevice)
{
minibatch.clear();

Expand All @@ -327,7 +338,7 @@ namespace CNTK
if (minibatchData.empty())
return;

for (auto v : m_varToStream)
for (auto v : inputVarToStream)
minibatch.insert({ v.first, minibatchData[v.second].data });
}

Expand Down
64 changes: 60 additions & 4 deletions bindings/python/cntk/train/tests/training_session_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,43 @@
10 |S0 61:1 |# A |S1 32:1 |# ~AH
'''

ctf_data2 = '''\
0 |S4 3:1 |# <s> |S5 3:1 |# <s>
0 |S4 4:1 |# A |S5 32:1 |# ~AH
0 |S4 5:1 |# B |S5 36:1 |# ~B
0 |S4 4:1 |# A |S5 31:1 |# ~AE
0 |S4 7:1 |# D |S5 38:1 |# ~D
0 |S4 12:1 |# I |S5 47:1 |# ~IY
0 |S4 1:1 |# </s> |S5 1:1 |# </s>
2 |S4 60:1 |# <s> |S5 3:1 |# <s>
2 |S4 61:1 |# A |S5 32:1 |# ~AH
3 |S4 60:1 |# <s> |S5 3:1 |# <s>
3 |S4 61:1 |# A |S5 32:1 |# ~AH
4 |S4 60:1 |# <s> |S5 3:1 |# <s>
4 |S4 61:1 |# A |S5 32:1 |# ~AH
5 |S4 60:1 |# <s> |S5 3:1 |# <s>
5 |S4 61:1 |# A |S5 32:1 |# ~AH
6 |S4 60:1 |# <s> |S5 3:1 |# <s>
6 |S4 61:1 |# A |S5 32:1 |# ~AH
7 |S4 60:1 |# <s> |S5 3:1 |# <s>
7 |S4 61:1 |# A |S5 32:1 |# ~AH
8 |S4 60:1 |# <s> |S5 3:1 |# <s>
8 |S4 61:1 |# A |S5 32:1 |# ~AH
9 |S4 60:1 |# <s> |S5 3:1 |# <s>
9 |S4 61:1 |# A |S5 32:1 |# ~AH
10 |S4 60:1 |# <s> |S5 3:1 |# <s>
10 |S4 61:1 |# A |S5 32:1 |# ~AH
'''


def mb_source(tmpdir, fileprefix, max_samples=FULL_DATA_SWEEP):
def mb_source(tmpdir, fileprefix, max_samples=FULL_DATA_SWEEP, ctf=ctf_data, streams = ['S0', 'S1']):
ctf_file = str(tmpdir / (fileprefix + '2seqtest.txt'))
with open(ctf_file, 'w') as f:
f.write(ctf_data)
f.write(ctf)

mbs = MinibatchSource(CTFDeserializer(ctf_file, StreamDefs(
features=StreamDef(field='S0', shape=input_dim, is_sparse=True),
labels=StreamDef(field='S1', shape=input_dim, is_sparse=True)
features=StreamDef(field=streams[0], shape=input_dim, is_sparse=True),
labels=StreamDef(field=streams[1], shape=input_dim, is_sparse=True)
)),
randomize=False, max_samples=max_samples)
return mbs
Expand Down Expand Up @@ -509,3 +537,31 @@ def test_session_with_test(tmpdir, device_id):
assert(t.total_number_of_samples_seen == 61)
assert(writer.test_summary_counter == 1)


def test_session_with_test_own_inputs(tmpdir, device_id):
device = cntk_device(device_id)
writer = MockProgressWriter(expected_test_summary=[[92, 25]])
t, feature, label = create_sample_model(device, writer)

mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT)
mbs1 = mb_source(tmpdir, "test", ctf=ctf_data2, streams=['S4', 'S5'])

input_map = {
feature: mbs.streams.features,
label: mbs.streams.labels
}

input_map1 = {
feature: mbs1.streams.features,
label: mbs1.streams.labels
}

C.training_session(
trainer=t, mb_source=mbs,
mb_size=4, model_inputs_to_streams=input_map,
max_samples=60,
test_config = C.TestConfig(source=mbs1, mb_size=2, model_inputs_to_streams = input_map1),
).train(device)

assert(t.total_number_of_samples_seen == 61)
assert(writer.test_summary_counter == 1)
29 changes: 20 additions & 9 deletions bindings/python/cntk/train/training_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ class CheckpointConfig(cntk_py.CheckpointConfig):
filename (str): checkpoint file name.
frequency (int): checkpoint frequency in samples. If 0, no checkpointing takes place.
If ``sys.maxsize``, a single checkpoint is taken at the end of the training.
preserve_all (bool): saves all checkpoints, using ``filename`` as prefix and checkpoint index as a suffix.
restore (bool): flag, indicating whether to restore from available checkpoint before the start of the training
preserve_all (bool): saves all checkpoints, using ``filename`` as prefix and checkpoint index as a suffix.
'''
def __init__(self, filename, frequency=None,
restore=True, preserve_all=False):
Expand All @@ -30,8 +30,8 @@ def __init__(self, filename, frequency=None,
filename (str): checkpoint file name.
frequency (int): checkpoint frequency in samples. If 0, no checkpointing takes place.
If ``sys.maxsize``, a single checkpoint is taken at the end of the training.
preserve_all (bool): saves all checkpoints, using ``filename`` as prefix and checkpoint index as a suffix.
restore (bool): flag, indicating whether to restore from available checkpoint before the start of the training
preserve_all (bool): saves all checkpoints, using ``filename`` as prefix and checkpoint index as a suffix.
Returns:
Reconfigured self.
Expand All @@ -55,17 +55,19 @@ class CrossValidationConfig(cntk_py.CrossValidationConfig):
Args:
source (:class:`~cntk.io.MinibatchSource`): minibatch source used for cross validation
mb_size(:class:`~cntk.cntk_py.minibatch_size_schedule`): minibatch schedule for cross validation
frequency (int): frequency in samples for cross validation
If None or ``sys.maxsize``, a single cross validation is performed at the end of training.
schedule (:class:`~cntk.cntk_py.minibatch_size_schedule`): minibatch schedule for cross validation
callback (func (index, average_error, cv_num_samples, cv_num_minibatches)): Callback that will
be called with frequency which can implement custom cross validation logic,
returns False if training should be stopped.
max_samples (int, default None): number of samples to perform
cross-validation on. If None, all samples are taken.
model_inputs_to_streams (dict, default None): mapping between input variables and input streams.
If None, the mapping provided to the training session constructor is used.
'''
def __init__(self, source=None, mb_size=None, frequency=None,
callback=None, max_samples=None):
callback=None, max_samples=None, model_inputs_to_streams=None):
self.callback = callback

if source is None and callback is None:
Expand All @@ -92,18 +94,24 @@ def __init__(self, source=None, mb_size=None, frequency=None,
if max_samples is None:
max_samples = sys.maxsize

super(CrossValidationConfig, self).__init__(
source, schedule, frequency, max_samples)
if model_inputs_to_streams is not None:
super(CrossValidationConfig, self).__init__(
source, schedule, frequency, max_samples, model_inputs_to_streams)
else:
super(CrossValidationConfig, self).__init__(
source, schedule, frequency, max_samples)

class TestConfig(cntk_py.TestConfig):
'''
A test configuration for the training session.
Args:
source (:class:`~cntk.io.MinibatchSource`): minibatch source used for testing
schedule (:class:`~cntk.cntk_py.minibatch_size_schedule`): minibatch schedule for testing
mb_size(:class:`~cntk.cntk_py.minibatch_size_schedule`): minibatch schedule for testing
model_inputs_to_streams (dict): mapping between input variables and input streams
If None, the mapping provided to the training session constructor is used.
'''
def __init__(self, source, mb_size=None):
def __init__(self, source, mb_size=None, model_inputs_to_streams=None):
schedule = mb_size
if isinstance(mb_size, int):
schedule = minibatch_size_schedule(mb_size)
Expand All @@ -116,7 +124,10 @@ def __init__(self, source, mb_size=None):
'it must be an output of minibatch_size_schedule() function'
% type(schedule))

super(TestConfig, self).__init__(source, schedule)
if model_inputs_to_streams is not None:
super(TestConfig, self).__init__(source, schedule, model_inputs_to_streams)
else:
super(TestConfig, self).__init__(source, schedule)

class TrainingSession(cntk_py.TrainingSession):
'''
Expand Down

0 comments on commit 18de8f0

Please sign in to comment.