forked from microsoft/CNTK
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathBlockFunction.h
180 lines (151 loc) · 8.61 KB
/
BlockFunction.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#pragma once
#include "stdafx.h"
#include "CNTKLibrary.h"
#include "PrimitiveFunction.h"
#include "Utils.h"
#include "Variable.h"
namespace CNTK
{
class BlockFunction final : public PrimitiveFunction
{
public:
BlockFunction(FunctionPtr&& composite, const std::vector<std::pair<Variable, Variable>>& argumentsMap, const std::wstring& blockOpName, Dictionary&& attributes, const std::wstring& blockName = L"", const std::wstring& uid = GenerateUid(PrimitiveOpType::Block))
: PrimitiveFunction(PrimitiveOpType::Block, DetermineInputs(composite, argumentsMap, blockName), std::move(attributes), blockName, uid),
m_composite(composite), m_blockOpName(blockOpName)
{
}
virtual const std::wstring& OpName() const override { return m_blockOpName; }
const FunctionPtr& Composite() const { return m_composite; }
// Mapping from each argument of the composite underlying the block to the corresponding Variable it is mapped to
std::vector<std::pair<Variable, Variable>> CompositeArgumentsMap() const
{
std::vector<std::pair<Variable, Variable>> argumentsMap;
auto arguments = m_composite->Arguments();
for (auto argument : arguments)
{
if (argument.BlockFunctionVariableMapping() == Variable())
LogicError("BlockFunction '%S' with OpName '%S' does not have a mapping for argument '%S'.", AsString().c_str(), OpName().c_str(), argument.AsString().c_str());
argumentsMap.push_back({ argument, argument.BlockFunctionVariableMapping() });
}
// Now sort the mapping by the order of occurence of the argument mapping in the block's inputs
auto blockInputs = Inputs();
std::unordered_map<Variable, size_t> inputIndices;
for (size_t i = 0; i < blockInputs.size(); ++i)
inputIndices.insert({ blockInputs[i], i });
std::stable_sort(argumentsMap.begin(), argumentsMap.end(), [&inputIndices](const std::pair<Variable, Variable>& first, const std::pair<Variable, Variable>& second) {
return inputIndices.at(first.second) < inputIndices.at(second.second);
});
return argumentsMap;
}
// Mapping from each output of the block to the corresponding output of underlying composite
std::unordered_map<Variable, Variable> CompositeOutputsMap() const
{
std::unordered_map<Variable, Variable> outputsMap;
auto outputs = RawOutputs();
for (auto output : outputs)
{
if (output.BlockFunctionVariableMapping() == Variable())
LogicError("BlockFunction '%S' with OpName '%S' does not have a mapping for output '%S'", AsString().c_str(), OpName().c_str(), output.AsString().c_str());
outputsMap[output] = output.BlockFunctionVariableMapping();
}
return outputsMap;
}
protected:
virtual void OnPlaceholdersReplaced(const std::unordered_map<Variable, Variable>& placeholderReplacements,
std::unordered_set<Variable>& replacedPlaceholders) override
{
// Substitute any placeholder replacements in the arguments map
auto arguments = m_composite->Arguments();
std::unordered_map<Variable, Variable> blockCompositePlaceholderReplacements;
for (auto argument : arguments)
{
if (replacedPlaceholders.find(argument.BlockFunctionVariableMapping()) != replacedPlaceholders.end())
{
auto replacement = placeholderReplacements.at(argument.BlockFunctionVariableMapping());
if (IsArgument(replacement))
argument.m_dataFields->m_blockFunctionVariableMapping = replacement;
else
blockCompositePlaceholderReplacements.insert({ argument, replacement });
}
}
m_composite->ReplacePlaceholders(blockCompositePlaceholderReplacements);
}
private:
/*static*/ std::vector<Variable> DetermineInputs(const FunctionPtr& composite, const std::vector<std::pair<Variable, Variable>>& argumentsMap, const std::wstring& blockName) const
{
std::unordered_map<Variable, Variable> argumentsMappingAsMap;
for (auto argumentMapping : argumentsMap)
{
auto wasInserted = argumentsMappingAsMap.insert(argumentMapping).second;
if (!wasInserted)
InvalidArgument("Multiple mappings provided for argument '%S' of the Block composite '%S'", argumentMapping.first.AsString().c_str(), composite->AsString().c_str());
}
std::vector<Variable> blockFunctionInputs;
auto compositeInputs = composite->Inputs();
std::vector<Variable> unmappedArguments;
for (auto compositeInput : compositeInputs)
{
assert(!compositeInput.IsOutput());
if (compositeInput.IsConstant() || compositeInput.IsParameter())
blockFunctionInputs.push_back(compositeInput);
else
{
if (!compositeInput.IsPlaceholder())
{
InvalidArgument("The composite implementing Block '%S' has an argument '%S' which is not a placeholder. "
"All arguments of the composite underlying a Block must be placeholders",
blockName.c_str(), compositeInput.AsString().c_str());
}
// Verify that a mapping was provided for each argument of the composite
if (argumentsMappingAsMap.find(compositeInput) == argumentsMappingAsMap.end())
unmappedArguments.push_back(compositeInput);
}
}
if (!unmappedArguments.empty())
{
InvalidArgument("%zu of the arguments '%S' of the underlying composite Function of Block '%S' have not been mapped when encapsulating the composite as a Block.",
unmappedArguments.size(), NamedListString(unmappedArguments).c_str(), blockName.c_str());
}
// We now append the mapped arguments of the composite to the block inputs in the order of the map
// instead of the original order they appear in the composite itself
for (auto argumentMapping : argumentsMap)
{
argumentMapping.first.m_dataFields->m_blockFunctionVariableMapping = argumentMapping.second;
blockFunctionInputs.push_back(argumentMapping.second);
}
return blockFunctionInputs;
}
void InferOutputs(std::vector<Variable>& outputs) override
{
// We determine the outputs by replacing the arguments of the composite with new placeholders with updated
// shape etc. information matching the corresponding mapped input
auto currentArguments = m_composite->Arguments();
std::unordered_map<Variable, Variable> replacementMap;
for (auto currentArgument : currentArguments)
{
auto currentArgumentMapping = currentArgument.BlockFunctionVariableMapping();
auto newArgument = PlaceholderLike(currentArgumentMapping);
newArgument.m_dataFields->m_blockFunctionVariableMapping = currentArgumentMapping;
replacementMap.insert({ currentArgument, newArgument });
}
m_composite->ReplacePlaceholders(replacementMap);
auto compositeOutputs = m_composite->RawOutputs();
for (auto compositeOutput : compositeOutputs)
{
auto output = OutputVariable(compositeOutput.Shape(), compositeOutput.GetDataType(), compositeOutput.DynamicAxes(), compositeOutput.NeedsGradient(), Name());
output.m_dataFields->m_blockFunctionVariableMapping = compositeOutput;
outputs.push_back(output);
}
}
private:
FunctionPtr m_composite;
std::wstring m_blockOpName;
// Increasing s_serializationVersion every time we add more ops allows us to print
// a more meaningful message when trying to load a new model with a stale binary.
static const size_t s_serializationVersion = 1;
};
}