|
11 | 11 | using Microsoft.ML.CommandLine;
|
12 | 12 | using Microsoft.ML.Data;
|
13 | 13 | using Microsoft.ML.Internal.Utilities;
|
| 14 | +using Microsoft.ML.Model.OnnxConverter; |
14 | 15 | using Microsoft.ML.Model.Pfa;
|
15 | 16 | using Microsoft.ML.Runtime;
|
16 | 17 | using Microsoft.ML.Transforms;
|
@@ -152,7 +153,7 @@ private protected override void SaveModel(ModelSaveContext ctx)
|
152 | 153 |
|
153 | 154 | private protected override IRowMapper MakeRowMapper(DataViewSchema inputSchema) => new Mapper(this, inputSchema);
|
154 | 155 |
|
155 |
| - private sealed class Mapper : OneToOneMapperBase, ISaveAsPfa |
| 156 | + private sealed class Mapper : OneToOneMapperBase, ISaveAsPfa, ISaveAsOnnx |
156 | 157 | {
|
157 | 158 | private readonly KeyToValueMappingTransformer _parent;
|
158 | 159 | private readonly DataViewType[] _types;
|
@@ -298,6 +299,8 @@ protected KeyToValueMap(Mapper mapper, PrimitiveDataViewType typeVal, int iinfo)
|
298 | 299 | public abstract Delegate GetMappingGetter(DataViewRow input);
|
299 | 300 |
|
300 | 301 | public abstract JToken SavePfa(BoundPfaContext ctx, JToken srcToken);
|
| 302 | + |
| 303 | + public abstract bool SaveOnnx(OnnxContext ctx, string srcVariableName, string dstVariableName); |
301 | 304 | }
|
302 | 305 |
|
303 | 306 | private class KeyToValueMap<TKey, TValue> : KeyToValueMap
|
@@ -494,8 +497,65 @@ public override JToken SavePfa(BoundPfaContext ctx, JToken srcToken)
|
494 | 497 | }
|
495 | 498 | return PfaUtils.If(PfaUtils.Call("<", srcToken, 0), defaultToken, PfaUtils.Index(cellRef, srcToken));
|
496 | 499 | }
|
| 500 | + |
| 501 | + public override bool SaveOnnx(OnnxContext ctx, string srcVariableName, string dstVariableName) |
| 502 | + { |
| 503 | + string opType; |
| 504 | + |
| 505 | + // Onnx expects the input keys to be int64s. But the input data can come from an ML.NET node that |
| 506 | + // may output a uint32. So cast it here to ensure that the data is treated correctly |
| 507 | + opType = "Cast"; |
| 508 | + var castNodeOutput = ctx.AddIntermediateVariable(TypeOutput, "CastNodeOutput", true); |
| 509 | + var castNode = ctx.CreateNode(opType, srcVariableName, castNodeOutput, ctx.GetNodeName(opType), ""); |
| 510 | + var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Int64).ToType(); |
| 511 | + castNode.AddAttribute("to", t); |
| 512 | + |
| 513 | + opType = "LabelEncoder"; |
| 514 | + var node = ctx.CreateNode(opType, castNodeOutput, dstVariableName, ctx.GetNodeName(opType)); |
| 515 | + var keys = Array.ConvertAll<int, long>(Enumerable.Range(1, _values.Length).ToArray(), item => Convert.ToInt64(item)); |
| 516 | + node.AddAttribute("keys_int64s", keys); |
| 517 | + |
| 518 | + if (TypeOutput == NumberDataViewType.Int64) |
| 519 | + { |
| 520 | + long[] values = Array.ConvertAll<TValue, long>(_values.GetValues().ToArray(), item => Convert.ToInt64(item)); |
| 521 | + node.AddAttribute("values_int64s", values); |
| 522 | + } |
| 523 | + else if (TypeOutput == NumberDataViewType.Single) |
| 524 | + { |
| 525 | + float[] values = Array.ConvertAll<TValue, float>(_values.GetValues().ToArray(), item => Convert.ToSingle(item)); |
| 526 | + node.AddAttribute("values_floats", values); |
| 527 | + } |
| 528 | + else if (TypeOutput == TextDataViewType.Instance) |
| 529 | + { |
| 530 | + string[] values = Array.ConvertAll<TValue, string>(_values.GetValues().ToArray(), item => Convert.ToString(item)); |
| 531 | + node.AddAttribute("values_strings", values); |
| 532 | + } |
| 533 | + else |
| 534 | + return false; |
| 535 | + |
| 536 | + return true; |
| 537 | + } |
497 | 538 | }
|
498 | 539 |
|
| 540 | + public bool CanSaveOnnx(OnnxContext ctx) => true; |
| 541 | + |
| 542 | + public void SaveAsOnnx(OnnxContext ctx) |
| 543 | + { |
| 544 | + for (int iinfo = 0; iinfo < _parent.ColumnPairs.Length; ++iinfo) |
| 545 | + { |
| 546 | + var info = _parent.ColumnPairs[iinfo]; |
| 547 | + var inputColumnName = info.inputColumnName; |
| 548 | + |
| 549 | + if (!ctx.ContainsColumn(inputColumnName)) |
| 550 | + continue; |
| 551 | + |
| 552 | + var dstVariableName = ctx.AddIntermediateVariable(_types[iinfo], info.outputColumnName, true); |
| 553 | + if (!_kvMaps[iinfo].SaveOnnx(ctx, inputColumnName, dstVariableName)) |
| 554 | + { |
| 555 | + ctx.RemoveColumn(inputColumnName, true); |
| 556 | + } |
| 557 | + } |
| 558 | + } |
499 | 559 | }
|
500 | 560 | }
|
501 | 561 |
|
|
0 commit comments