Skip to content

Commit c3d7440

Browse files
committed
add FutureValue Op's custom attributes into Scan node's description
1 parent 91b78a7 commit c3d7440

File tree

1 file changed

+28
-1
lines changed

1 file changed

+28
-1
lines changed

Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5036,6 +5036,9 @@ bool CompareTensorShapeProtoEqual(const ::onnx::TensorShapeProto& shape0, const
50365036
return true;
50375037
}
50385038

5039+
// forward declaration
5040+
static std::string SerializeDictionaryToString(const Dictionary& dict);
5041+
50395042
// process scan loops. also check if the caller (CreateNode) shall continue node creating process with the input src.
50405043
// caller shall not continue if:
50415044
// - we are still creating a scan op and src is not part of the scan body.
@@ -5163,6 +5166,7 @@ bool CNTKToONNXHelper::ProcessLoopsAndCheckCNTKNodeContinueCreate(const Function
51635166
}
51645167

51655168
int inputIndex = 0;
5169+
std::string futureValueCustomAttrStr = "";
51665170
for (auto &scanLoopState : scanLoops[loopIndex].scanLoopStates)
51675171
{
51685172
// IMPORTANT TRICK: initial state is usually a scalar. State initializer tensor is prepared
@@ -5277,6 +5281,29 @@ bool CNTKToONNXHelper::ProcessLoopsAndCheckCNTKNodeContinueCreate(const Function
52775281
graph->AddInitializedTensor(scanLoopState.m_initialStateTensor);
52785282
}
52795283
// else initializer is input.
5284+
5285+
// FutureValue's custom attrubute will be serialized into the Scan node's description
5286+
auto fvOp = scanLoopState.m_stateOutput.Owner();
5287+
if (fvOp && fvOp->OpName() == L"FutureValue")
5288+
{
5289+
const Dictionary& dict = fvOp->GetCustomAttributes();
5290+
if (dict.Size() > 0)
5291+
{
5292+
if (futureValueCustomAttrStr == "")
5293+
{
5294+
futureValueCustomAttrStr = "{custom_attributes:" + SerializeDictionaryToString(dict) + "}";
5295+
}
5296+
else
5297+
{
5298+
std::string attrStr = "{custom_attributes:" + SerializeDictionaryToString(dict) + "}";
5299+
if (attrStr != futureValueCustomAttrStr)
5300+
{
5301+
CNTK::LogicError("Scan node has multiple FutureValue custom attributes from state: %s",
5302+
scanInitialStateNodeArgName.c_str());
5303+
}
5304+
}
5305+
}
5306+
}
52805307
}
52815308

52825309
for (auto &scanInput : scanLoops[loopIndex].m_scanInputs)
@@ -5326,7 +5353,7 @@ bool CNTKToONNXHelper::ProcessLoopsAndCheckCNTKNodeContinueCreate(const Function
53265353

53275354
scanGraph.SetInputOrder(scanSubgraphOrderedInputs);
53285355
scanGraph.SetOutputOrder(scanSubgraphOrderedOutputs);
5329-
Node *scanNode = &graph->AddNode(scanNodeName, "Scan", "", input_args, output_args);
5356+
Node *scanNode = &graph->AddNode(scanNodeName, "Scan", /*description*/ futureValueCustomAttrStr, input_args, output_args);
53305357

53315358
ResolveGraphAndSaveModel(scanSubModel.get());
53325359

0 commit comments

Comments
 (0)