@@ -5036,6 +5036,9 @@ bool CompareTensorShapeProtoEqual(const ::onnx::TensorShapeProto& shape0, const
5036
5036
return true ;
5037
5037
}
5038
5038
5039
+ // forward declaration
5040
+ static std::string SerializeDictionaryToString (const Dictionary& dict);
5041
+
5039
5042
// process scan loops. also check if the caller (CreateNode) shall continue node creating process with the input src.
5040
5043
// caller shall not continue if:
5041
5044
// - we are still creating a scan op and src is not part of the scan body.
@@ -5163,6 +5166,7 @@ bool CNTKToONNXHelper::ProcessLoopsAndCheckCNTKNodeContinueCreate(const Function
5163
5166
}
5164
5167
5165
5168
int inputIndex = 0 ;
5169
+ std::string futureValueCustomAttrStr = " " ;
5166
5170
for (auto &scanLoopState : scanLoops[loopIndex].scanLoopStates )
5167
5171
{
5168
5172
// IMPORTANT TRICK: initial state is usually a scalar. State initializer tensor is prepared
@@ -5277,6 +5281,29 @@ bool CNTKToONNXHelper::ProcessLoopsAndCheckCNTKNodeContinueCreate(const Function
5277
5281
graph->AddInitializedTensor (scanLoopState.m_initialStateTensor );
5278
5282
}
5279
5283
// else initializer is input.
5284
+
5285
+ // FutureValue's custom attrubute will be serialized into the Scan node's description
5286
+ auto fvOp = scanLoopState.m_stateOutput .Owner ();
5287
+ if (fvOp && fvOp->OpName () == L" FutureValue" )
5288
+ {
5289
+ const Dictionary& dict = fvOp->GetCustomAttributes ();
5290
+ if (dict.Size () > 0 )
5291
+ {
5292
+ if (futureValueCustomAttrStr == " " )
5293
+ {
5294
+ futureValueCustomAttrStr = " {custom_attributes:" + SerializeDictionaryToString (dict) + " }" ;
5295
+ }
5296
+ else
5297
+ {
5298
+ std::string attrStr = " {custom_attributes:" + SerializeDictionaryToString (dict) + " }" ;
5299
+ if (attrStr != futureValueCustomAttrStr)
5300
+ {
5301
+ CNTK::LogicError (" Scan node has multiple FutureValue custom attributes from state: %s" ,
5302
+ scanInitialStateNodeArgName.c_str ());
5303
+ }
5304
+ }
5305
+ }
5306
+ }
5280
5307
}
5281
5308
5282
5309
for (auto &scanInput : scanLoops[loopIndex].m_scanInputs )
@@ -5326,7 +5353,7 @@ bool CNTKToONNXHelper::ProcessLoopsAndCheckCNTKNodeContinueCreate(const Function
5326
5353
5327
5354
scanGraph.SetInputOrder (scanSubgraphOrderedInputs);
5328
5355
scanGraph.SetOutputOrder (scanSubgraphOrderedOutputs);
5329
- Node *scanNode = &graph->AddNode (scanNodeName, " Scan" , " " , input_args, output_args);
5356
+ Node *scanNode = &graph->AddNode (scanNodeName, " Scan" , /* description */ futureValueCustomAttrStr , input_args, output_args);
5330
5357
5331
5358
ResolveGraphAndSaveModel (scanSubModel.get ());
5332
5359
0 commit comments