12
12
using Microsoft . ML . EntryPoints ;
13
13
using Microsoft . ML . Internal . Utilities ;
14
14
using Microsoft . ML . Model . Onnx ;
15
+ using Microsoft . ML . UniversalModelFormat . Onnx ;
15
16
using Newtonsoft . Json ;
16
17
17
18
[ assembly: LoadableClass ( SaveOnnxCommand . Summary , typeof ( SaveOnnxCommand ) , typeof ( SaveOnnxCommand . Arguments ) , typeof ( SignatureCommand ) ,
@@ -113,9 +114,10 @@ public override void Run()
113
114
}
114
115
}
115
116
116
- private void GetPipe ( OnnxContextImpl ctx , IChannel ch , IDataView end , out IDataView source , out IDataView trueEnd , out LinkedList < ITransformCanSaveOnnx > transforms )
117
+ internal static void GetPipe ( OnnxContextImpl ctx , IChannel ch , IDataView end , out IDataView source , out IDataView trueEnd , out LinkedList < ITransformCanSaveOnnx > transforms )
117
118
{
118
- Host . AssertValue ( end ) ;
119
+ ch . AssertValue ( end ) ;
120
+
119
121
source = trueEnd = ( end as CompositeDataLoader ) ? . View ?? end ;
120
122
IDataTransform transform = source as IDataTransform ;
121
123
transforms = new LinkedList < ITransformCanSaveOnnx > ( ) ;
@@ -134,7 +136,53 @@ private void GetPipe(OnnxContextImpl ctx, IChannel ch, IDataView end, out IDataV
134
136
transform = ( source = transform . Source ) as IDataTransform ;
135
137
}
136
138
137
- Host . AssertValue ( source ) ;
139
+ ch . AssertValue ( source ) ;
140
+ }
141
+
142
+ internal static ModelProto ConvertTransformListToOnnxModel ( OnnxContextImpl ctx , IChannel ch , IDataView inputData , IDataView outputData ,
143
+ LinkedList < ITransformCanSaveOnnx > transforms , HashSet < string > inputColumnNamesToDrop = null , HashSet < string > outputColumnNamesToDrop = null )
144
+ {
145
+ inputColumnNamesToDrop = inputColumnNamesToDrop ?? new HashSet < string > ( ) ;
146
+ outputColumnNamesToDrop = outputColumnNamesToDrop ?? new HashSet < string > ( ) ;
147
+ HashSet < string > inputColumns = new HashSet < string > ( ) ;
148
+ // Create graph inputs.
149
+ for ( int i = 0 ; i < inputData . Schema . Count ; i ++ )
150
+ {
151
+ string colName = inputData . Schema [ i ] . Name ;
152
+ if ( inputColumnNamesToDrop . Contains ( colName ) )
153
+ continue ;
154
+
155
+ ctx . AddInputVariable ( inputData . Schema [ i ] . Type , colName ) ;
156
+ inputColumns . Add ( colName ) ;
157
+ }
158
+
159
+ // Create graph nodes, outputs and intermediate values.
160
+ foreach ( var trans in transforms )
161
+ {
162
+ ch . Assert ( trans . CanSaveOnnx ( ctx ) ) ;
163
+ trans . SaveAsOnnx ( ctx ) ;
164
+ }
165
+
166
+ // Add graph outputs.
167
+ for ( int i = 0 ; i < outputData . Schema . Count ; ++ i )
168
+ {
169
+ if ( outputData . Schema [ i ] . IsHidden )
170
+ continue ;
171
+
172
+ var idataviewColumnName = outputData . Schema [ i ] . Name ;
173
+
174
+ // Since the last IDataView also contains columns of the initial IDataView, last IDataView's columns found in
175
+ // _inputToDrop should be removed too.
176
+ if ( inputColumnNamesToDrop . Contains ( idataviewColumnName ) || outputColumnNamesToDrop . Contains ( idataviewColumnName ) )
177
+ continue ;
178
+
179
+ var variableName = ctx . TryGetVariableName ( idataviewColumnName ) ;
180
+ var trueVariableName = ctx . AddIntermediateVariable ( null , idataviewColumnName , true ) ;
181
+ ctx . CreateNode ( "Identity" , variableName , trueVariableName , ctx . GetNodeName ( "Identity" ) , "" ) ;
182
+ ctx . AddOutputVariable ( outputData . Schema [ i ] . Type , trueVariableName ) ;
183
+ }
184
+
185
+ return ctx . MakeModel ( ) ;
138
186
}
139
187
140
188
private void Run ( IChannel ch )
@@ -210,45 +258,8 @@ private void Run(IChannel ch)
210
258
nameof ( Arguments . LoadPredictor ) , "We were explicitly told to load the predictor but one was not present." ) ;
211
259
}
212
260
213
- HashSet < string > inputColumns = new HashSet < string > ( ) ;
214
- //Create graph inputs.
215
- for ( int i = 0 ; i < source . Schema . Count ; i ++ )
216
- {
217
- string colName = source . Schema [ i ] . Name ;
218
- if ( _inputsToDrop . Contains ( colName ) )
219
- continue ;
220
-
221
- ctx . AddInputVariable ( source . Schema [ i ] . Type , colName ) ;
222
- inputColumns . Add ( colName ) ;
223
- }
224
-
225
- //Create graph nodes, outputs and intermediate values.
226
- foreach ( var trans in transforms )
227
- {
228
- Host . Assert ( trans . CanSaveOnnx ( ctx ) ) ;
229
- trans . SaveAsOnnx ( ctx ) ;
230
- }
231
-
232
- //Add graph outputs.
233
- for ( int i = 0 ; i < end . Schema . Count ; ++ i )
234
- {
235
- if ( end . Schema [ i ] . IsHidden )
236
- continue ;
237
-
238
- var idataviewColumnName = end . Schema [ i ] . Name ;
239
-
240
- // Since the last IDataView also contains columns of the initial IDataView, last IDataView's columns found in
241
- // _inputToDrop should be removed too.
242
- if ( _inputsToDrop . Contains ( idataviewColumnName ) || _outputsToDrop . Contains ( idataviewColumnName ) )
243
- continue ;
244
-
245
- var variableName = ctx . TryGetVariableName ( idataviewColumnName ) ;
246
- var trueVariableName = ctx . AddIntermediateVariable ( null , idataviewColumnName , true ) ;
247
- ctx . CreateNode ( "Identity" , variableName , trueVariableName , ctx . GetNodeName ( "Identity" ) , "" ) ;
248
- ctx . AddOutputVariable ( end . Schema [ i ] . Type , trueVariableName ) ;
249
- }
261
+ var model = ConvertTransformListToOnnxModel ( ctx , ch , source , end , transforms , _inputsToDrop , _outputsToDrop ) ;
250
262
251
- var model = ctx . MakeModel ( ) ;
252
263
using ( var file = Host . CreateOutputFile ( _outputModelPath ) )
253
264
using ( var stream = file . CreateWriteStream ( ) )
254
265
model . WriteTo ( stream ) ;
0 commit comments