forked from microsoft/CNTK
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNetworkFactory.cpp
196 lines (179 loc) · 10.2 KB
/
NetworkFactory.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
// NetworkFactory.cpp -- CNTK network creation related functions
//
#include "stdafx.h"
#include "Actions.h"
#include "SimpleNetworkBuilder.h"
#include "NDLNetworkBuilder.h"
#include "ScriptableObjects.h"
#include "BrainScriptEvaluator.h"
#include "BrainScriptParser.h"
function<ComputationNetworkPtr(DEVICEID_TYPE)> GetCreateNetworkFn(const ScriptableObjects::IConfigRecord& config)
{
// createNetwork() is a BrainScript lambda that creates the model
// We create a C++ wrapper around it, which we then pass to Train().
auto createNetworkConfigLambda = config[L"createNetwork"].AsPtr<ScriptableObjects::ConfigLambda>();
return [createNetworkConfigLambda](DEVICEID_TYPE /*deviceId*/)
{
// execute the lambda
vector<ScriptableObjects::ConfigValuePtr> args; // this lambda has no arguments
ScriptableObjects::ConfigLambda::NamedParams namedArgs;
let netValue = createNetworkConfigLambda->Apply(move(args), move(namedArgs), L"BuildNetworkFromDescription");
// typecast the result to the desired type
return netValue.AsPtr<ComputationNetwork>();
};
}
function<ComputationNetworkPtr(DEVICEID_TYPE)> GetCreateNetworkFn(const ConfigParameters&)
{
NOT_IMPLEMENTED;
} // old CNTK config does not support lambdas
template <class ConfigRecordType, typename ElemType>
bool TryGetNetworkFactory(const ConfigRecordType& config, function<ComputationNetworkPtr(DEVICEID_TYPE)>& createNetworkFn)
{
DEVICEID_TYPE deviceId = DeviceFromConfig(config);
int traceLevel = config(L"traceLevel", 0);
if (config.Exists(L"createNetwork"))
{
createNetworkFn = GetCreateNetworkFn(config); // (we need a separate function needed due to template code)
return true;
}
else if (config.Exists(L"SimpleNetworkBuilder"))
{
const ConfigRecordType& simpleNetworkBuilderConfig(config(L"SimpleNetworkBuilder"));
auto netBuilder = make_shared<SimpleNetworkBuilder<ElemType>>(simpleNetworkBuilderConfig); // parses the configuration and stores it in the SimpleNetworkBuilder object
createNetworkFn = [netBuilder, traceLevel](DEVICEID_TYPE deviceId)
{
auto net = shared_ptr<ComputationNetwork>(netBuilder->BuildNetworkFromDescription()); // this operates based on the configuration saved above
net->SetTraceLevel(traceLevel);
return net;
};
return true;
}
// legacy NDL
else if (config.Exists(L"NDLNetworkBuilder"))
{
const ConfigRecordType& ndlNetworkBuilderConfig(config(L"NDLNetworkBuilder"));
shared_ptr<NDLBuilder<ElemType>> netBuilder = make_shared<NDLBuilder<ElemType>>(ndlNetworkBuilderConfig);
createNetworkFn = [netBuilder, traceLevel](DEVICEID_TYPE deviceId)
{
auto net = shared_ptr<ComputationNetwork>(netBuilder->BuildNetworkFromDescription());
net->SetTraceLevel(traceLevel);
return net;
};
return true;
}
// legacy test mode for BrainScript. Will go away once we fully integrate with BS.
else if (config.Exists(L"BrainScriptNetworkBuilder") || config.Exists(L"ExperimentalNetworkBuilder" /*legacy name*/))
{
// We interface with outer old CNTK config by taking the inner part, which we get as a string, as BrainScript.
// We prepend a few standard definitions, and also definition of deviceId and precision, which all objects will pull out again when they are being constructed.
// BUGBUG: We are not getting TextLocations right in this way! Do we need to inject location markers into the source? Moot once we fully switch to BS
wstring sourceOfNetwork = config.Exists(L"BrainScriptNetworkBuilder") ? config(L"BrainScriptNetworkBuilder") : config(L"ExperimentalNetworkBuilder");
if (sourceOfNetwork.find_first_of(L"([{") != 0)
InvalidArgument("BrainScript network description must be either a BS expression in ( ) or a config record in { }");
// set the include paths to all paths that configs were read from; no additional configurable include paths are supported by BrainScriptNetworkBuilder
auto includePaths = ConfigParameters::GetBrainScriptNetworkBuilderIncludePaths();
// inject additional items into the source code
// We support two ways of specifying the network in BrainScript:
// - BrainScriptNetworkBuilder = ( any BS expression that evaluates to a ComputationNetwork )
// - BrainScriptNetworkBuilder = { constructor parameters for a ComputationNetwork }
// For back-compat, [ ] is allowed and means the same as { }
if (sourceOfNetwork[0] == '{' || sourceOfNetwork[0] == '[') // if { } form then we turn it into ComputationNetwork by constructing a ComputationNetwork from it
sourceOfNetwork = L"new ComputationNetwork " + sourceOfNetwork;
let sourceOfBS = msra::strfun::wstrprintf(L"include \'cntk.core.bs\'\n" // include our core lib. Note: Using lowercase here to match the Linux name of the CNTK exe.
L"deviceId = %d\n" // deviceId as passed in
L"traceLevel = %d\n"
L"precision = '%ls'\n" // 'float' or 'double'
L"network = %ls", // source code of expression that evaluates to a ComputationNetwork
(int)deviceId, traceLevel, ElemTypeName<ElemType>(), sourceOfNetwork.c_str());
let expr = BS::ParseConfigDictFromString(sourceOfBS, L"BrainScriptNetworkBuilder", move(includePaths));
// the rest is done in a lambda that is only evaluated when a virgin network is needed
// Note that evaluating the BrainScript *is* instantiating the network, so the evaluate call must be inside the lambda.
createNetworkFn = [expr](DEVICEID_TYPE /*deviceId*/)
{
// evaluate the parse tree, particularly the top-level field 'network'
// Evaluating it will create the network.
let object = EvaluateField(expr, L"network"); // this comes back as a BS::Object
let network = dynamic_pointer_cast<ComputationNetwork>(object); // cast it
if (!network)
LogicError("BuildNetworkFromDescription: ComputationNetwork not what it was meant to be");
return network;
};
return true;
}
else
return false;
}
template <class ConfigRecordType, typename ElemType>
function<ComputationNetworkPtr(DEVICEID_TYPE)> GetNetworkFactory(const ConfigRecordType& config)
{
function<ComputationNetworkPtr(DEVICEID_TYPE)> createNetworkFn;
bool gotIt = TryGetNetworkFactory<ConfigRecordType, ElemType>(config, createNetworkFn);
if (!gotIt)
RuntimeError("No network builder found in the config file. NDLNetworkBuilder, SimpleNetworkBuilder, or BrainScriptNetworkBuilder must be specified");
else
return createNetworkFn;
}
// helper to remove all existing Output nodes and replace them by a new given set
static void PatchOutputNodes(const ComputationNetworkPtr& net, const ConfigArray& outputNodeNames, vector<wstring>& outputNodeNamesVector)
{
// clear out current list of outputNodes
while (!net->OutputNodes().empty())
net->RemoveFromNodeGroup(L"output", net->OutputNodes().back());
// and insert the desired nodes instead
for (wstring name : outputNodeNames)
{
if (!net->NodeNameExists(name))
{
fprintf(stderr, "PatchOutputNodes: No node named '%ls'; skipping\n", name.c_str());
continue;
}
outputNodeNamesVector.push_back (name);
let& node = net->GetNodeFromName(name);
net->AddToNodeGroup(L"output", node);
}
}
template <class ConfigRecordType, typename ElemType>
ComputationNetworkPtr GetModelFromConfig(const ConfigRecordType& config, const wstring& outputNodeNamesConfig, vector<wstring>& outputNodeNamesVector)
{
DEVICEID_TYPE deviceId = DeviceFromConfig(config);
ConfigArray outputNodeNames = config(outputNodeNamesConfig.c_str(), ConfigArray(""));
ComputationNetworkPtr net;
// first try if a NetworkBuilder is present
function<ComputationNetworkPtr(DEVICEID_TYPE)> createNetworkFn;
bool gotIt = TryGetNetworkFactory<ConfigRecordType, ElemType>(config, createNetworkFn);
if (gotIt)
{
// We have several ways to create a network.
net = createNetworkFn(deviceId);
if (outputNodeNames.size() > 0)
{
net->InvalidateCompiledNetwork();
PatchOutputNodes(net, outputNodeNames, outputNodeNamesVector);
net->CompileNetwork();
// BUGBUG: This will generate double Validation output in the log
}
}
else // no NetworkBuilder given: load from 'modelPath'
{
wstring modelPath = config(L"modelPath");
// We don't use CreateFromFile() here since the user might specify OutputNodeNames in the config.
// By not compiling the network before patching, we avoid double log output for validation.
net = make_shared<ComputationNetwork>(deviceId);
net->SetTraceLevel(config(L"traceLevel", 0));
net->Read<ElemType>(modelPath);
if (outputNodeNames.size() > 0)
PatchOutputNodes(net, outputNodeNames, outputNodeNamesVector);
net->CompileNetwork();
}
return net;
}
template function<ComputationNetworkPtr(DEVICEID_TYPE)> GetNetworkFactory<ScriptableObjects::IConfigRecord, float>(const ScriptableObjects::IConfigRecord& config);
template function<ComputationNetworkPtr(DEVICEID_TYPE)> GetNetworkFactory<ScriptableObjects::IConfigRecord, double>(const ScriptableObjects::IConfigRecord& config);
template function<ComputationNetworkPtr(DEVICEID_TYPE)> GetNetworkFactory<ConfigParameters, float>(const ConfigParameters& config);
template function<ComputationNetworkPtr(DEVICEID_TYPE)> GetNetworkFactory<ConfigParameters, double>(const ConfigParameters& config);
template ComputationNetworkPtr GetModelFromConfig<ConfigParameters, float> (const ConfigParameters& config, const wstring&, vector<wstring>& outputNodeNamesVector);
template ComputationNetworkPtr GetModelFromConfig<ConfigParameters, double>(const ConfigParameters& config, const wstring&, vector<wstring>& outputNodeNamesVector);