Skip to content

Commit

Permalink
CNTK v2 library: Fix a bug with replacing multiple parameters of a Bl…
Browse files Browse the repository at this point in the history
…ock Function
  • Loading branch information
amitaga committed Mar 21, 2017
1 parent 637524c commit e2044d2
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 24 deletions.
13 changes: 11 additions & 2 deletions Source/CNTKv2LibraryDll/BlockFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,20 @@ namespace CNTK
{
// Substitute any placeholder replacements in the arguments map
auto arguments = m_composite->Arguments();
std::unordered_map<Variable, Variable> blockCompositePlaceholderReplacements;
for (auto argument : arguments)
{
if (replacedPlaceholders.find(argument.BlockFunctionVariableMapping()) != replacedPlaceholders.end())
argument.m_dataFields->m_blockFunctionVariableMapping = placeholderReplacements.at(argument.BlockFunctionVariableMapping());
{
auto replacement = placeholderReplacements.at(argument.BlockFunctionVariableMapping());
if (IsArgument(replacement))
argument.m_dataFields->m_blockFunctionVariableMapping = replacement;
else
blockCompositePlaceholderReplacements.insert({ argument, replacement });
}
}

m_composite->ReplacePlaceholders(blockCompositePlaceholderReplacements);
}

private:
Expand Down Expand Up @@ -140,7 +149,7 @@ namespace CNTK
for (auto currentArgument : currentArguments)
{
auto currentArgumentMapping = currentArgument.BlockFunctionVariableMapping();
auto newArgument = PlaceholderVariable(currentArgumentMapping.Shape(), currentArgumentMapping.GetDataType(), currentArgumentMapping.Name(), currentArgumentMapping.DynamicAxes());
auto newArgument = PlaceholderLike(currentArgumentMapping);
newArgument.m_dataFields->m_blockFunctionVariableMapping = currentArgumentMapping;

replacementMap.insert({ currentArgument, newArgument });
Expand Down
44 changes: 27 additions & 17 deletions Source/CNTKv2LibraryDll/Function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ namespace CNTK
Variable clonedInput;
if (replacements.find(cloneeInput) != replacements.end())
{
clonedInput = PlaceholderVariable(cloneeInput.Shape(), cloneeInput.DynamicAxes());
clonedInput = PlaceholderLike(cloneeInput);
placeholderReplacements[clonedInput] = replacements.at(cloneeInput);
}
else
Expand Down Expand Up @@ -658,7 +658,7 @@ namespace CNTK

if (existingPlaceholderReplacement == placeholderReplacements.end())
{
clonedInput = PlaceholderVariable(cloneeInput.Shape(), cloneeInput.DynamicAxes());
clonedInput = PlaceholderLike(cloneeInput);
placeholderReplacements[clonedInput] = cloneeInput;
}
else
Expand Down Expand Up @@ -689,29 +689,33 @@ namespace CNTK
std::vector<std::pair<Variable, Variable>> clonedBlockCompositeArgumentsMap;

// When cloning the block, we need to replace any Parameter/Constants inside the block with
// the correspondind replacements if any were specified
for (size_t i = 0; i < cloneeCompositeInputs.size(); ++i)
// the correspondind replacements if any
for (size_t i = 0; i < inputs.size(); ++i)
{
auto cloneeCompositeInput = cloneeCompositeInputs[i];
if (replacements.find(cloneeCompositeInput) != replacements.end())
auto cloneeInput = cloneeInputs[i];
auto clonedInput = inputs[i];
if ((cloneeInput != clonedInput) && (cloneeInput.IsParameter() || cloneeInput.IsConstant()))
{
if (IsArgument(cloneeCompositeInput))
{
InvalidArgument("Function '%S': Illegal to replace internal variable '%S' of nested Block Function '%S'.",
clonee->AsString().c_str(),
cloneeCompositeInput.AsString().c_str(),
blockFunction->AsString().c_str());
}
else
auto iter = std::find(cloneeCompositeInputs.begin(), cloneeCompositeInputs.end(), cloneeInput);
if (iter != cloneeCompositeInputs.end())
{
auto replacement = PlaceholderVariable(cloneeCompositeInput.Shape(), cloneeCompositeInput.DynamicAxes());
auto cloneeCompositeInput = *iter;
Variable replacement = clonedInput;
if (IsArgument(replacement))
{
replacement = PlaceholderLike(cloneeCompositeInput);
clonedBlockCompositeArgumentsMap.push_back({ replacement, inputs[i] });
}

cloneeCompositeReplacements.insert({ cloneeCompositeInput, replacement });
clonedBlockCompositeArgumentsMap.push_back({ replacement, inputs[i]});
}
}
}

auto clonedComposite = cloneeComposite->Clone(parameterCloneMethod, cloneeCompositeReplacements);
// We will not have the block's internal composite create new clones of Parameters/Constants since
// in the case we want to really clone, they have been cloned as part of cloning the inputs of the
// block and will be handled through the replacements
auto clonedComposite = cloneeComposite->Clone(ParameterCloningMethod::Share, cloneeCompositeReplacements);

auto clonedCompositeInputs = clonedComposite->Inputs();
std::unordered_map<Variable, Variable> cloneeToClonedBlockCompositeArgumentsMap;
Expand All @@ -726,6 +730,12 @@ namespace CNTK
clonedBlockCompositeArgumentsMap.push_back({ cloneeToClonedBlockCompositeArgumentsMap.at(cloneeArgumentMapping.first), cloneeToClonedInputMap.at(cloneeArgumentMapping.second) });

clonedFunction = MakeSharedObject<BlockFunction>(std::move(clonedComposite), clonedBlockCompositeArgumentsMap, blockFunction->OpName(), Dictionary(blockFunction->Attributes()), blockFunction->Name());
auto clonedFunctionInputs = clonedFunction->Inputs();
if (clonedFunctionInputs != inputs)
LogicError("Block Function '%S': Inputs '%S' of the new clone do not match the cloned inputs '%S' of the clonee Block Function.",
clonedFunction->AsString().c_str(),
NamedListString(clonedFunctionInputs).c_str(),
NamedListString(inputs).c_str());
}
else
clonedFunction = clonee->Clone(inputs);
Expand Down
5 changes: 5 additions & 0 deletions Source/CNTKv2LibraryDll/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,11 @@ namespace CNTK
return var.IsConstant() && (var.Shape().TotalSize() == 1);
}

inline Variable PlaceholderLike(const Variable& var)
{
return PlaceholderVariable(var.Shape(), var.GetDataType(), var.Name(), var.DynamicAxes());
}

std::vector<Axis> DynamicAxesFromInternalDynamicAxisName(const std::wstring& internalDynamicAxisName);

// Construct the dynamic axis name to be used internally for the CNTK InputNodes
Expand Down
9 changes: 5 additions & 4 deletions bindings/python/cntk/ops/tests/block_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,13 @@ def test_block_clone():
dense_block = as_block(block_composite, [(operand_placeholder, x)], 'dense')

w_new = parameter(shape=(1,1), init=3)
dense_block_clone = dense_block.clone('share', {w : w_new})
assert dense_block_clone.parameters[0].uid == b.uid
assert dense_block_clone.inputs[1].uid == w_new.uid
b_new = parameter(shape=(1,), init=4)
dense_block_clone = dense_block.clone('share', {w : w_new, b : b_new})
assert dense_block_clone.inputs[0].uid == w_new.uid
assert dense_block_clone.inputs[1].uid == b_new.uid

result = dense_block_clone.eval({dense_block_clone.arguments[0] : [np.asarray([2.], dtype=np.float32)]})
assert np.array_equal(result, [[[8.]]])
assert np.array_equal(result, [[[10.]]])


def test_root_block_clone():
Expand Down
2 changes: 1 addition & 1 deletion bindings/python/cntk/ops/tests/userfunction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,4 +351,4 @@ def test_udf_op_name():
p = parameter(shape=(dim,), init=10, name='p')
i = input_variable(dim, needs_gradient=True, name='i_var')
m = user_function(MyPlus(i, constant(3)))
print(m.root_function)
assert str(m.root_function) != ''

0 comments on commit e2044d2

Please sign in to comment.