-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
TextTransform.cs
710 lines (614 loc) · 31.6 KB
/
TextTransform.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Data.StaticPipe.Runtime;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Data.IO;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.TextAnalytics;
[assembly: LoadableClass(TextTransform.Summary, typeof(IDataTransform), typeof(TextTransform), typeof(TextTransform.Arguments), typeof(SignatureDataTransform),
TextTransform.UserName, "TextTransform", TextTransform.LoaderSignature)]
[assembly: LoadableClass(TextTransform.Summary, typeof(ITransformer), typeof(TextTransform), null, typeof(SignatureLoadModel),
TextTransform.UserName, "TextTransform", TextTransform.LoaderSignature)]
namespace Microsoft.ML.Runtime.Data
{
using StopWordsArgs = StopWordsRemoverTransform.Arguments;
using TextNormalizerArgs = TextNormalizerTransform.Arguments;
using StopWordsCol = StopWordsRemoverTransform.Column;
using TextNormalizerCol = TextNormalizerTransform.Column;
using StopWordsLang = StopWordsRemoverTransform.Language;
using CaseNormalizationMode = TextNormalizerTransform.CaseNormalizationMode;
// A transform that turns a collection of text documents into numerical feature vectors. The feature vectors are counts
// of (word or character) ngrams in a given text. It offers ngram hashing (finding the ngram token string name to feature
// integer index mapping through hashing) as an option.
/// <include file='doc.xml' path='doc/members/member[@name="TextTransform"]/*' />
public sealed class TextTransform : IEstimator<ITransformer>
{
/// <summary>
/// Text language. This enumeration is serialized.
/// </summary>
public enum Language
{
English = 1,
French = 2,
German = 3,
Dutch = 4,
Italian = 5,
Spanish = 6,
Japanese = 7
}
/// <summary>
/// Text vector normalizer kind.
/// </summary>
public enum TextNormKind
{
None = 0,
L1 = 1,
L2 = 2,
LInf = 3
}
public sealed class Column : ManyToOneColumn
{
public static Column Parse(string str)
{
var res = new Column();
if (res.TryParse(str))
return res;
return null;
}
public bool TryUnparse(StringBuilder sb)
{
Contracts.AssertValue(sb);
return TryUnparseCore(sb);
}
}
/// <summary>
/// This class exposes <see cref="NgramExtractorTransform"/>/<see cref="NgramHashExtractorTransform"/> arguments.
/// </summary>
public sealed class Arguments : TransformInputBase
{
[Argument(ArgumentType.Required, HelpText = "New column definition (optional form: name:srcs).", ShortName = "col", SortOrder = 1)]
public Column Column;
[Argument(ArgumentType.AtMostOnce, HelpText = "Dataset language or 'AutoDetect' to detect language per row.", ShortName = "lang", SortOrder = 3)]
public Language Language = DefaultLanguage;
[Argument(ArgumentType.Multiple, HelpText = "Stopwords remover.", ShortName = "remover", NullName = "<None>", SortOrder = 4)]
public IStopWordsRemoverFactory StopWordsRemover;
[Argument(ArgumentType.AtMostOnce, HelpText = "Casing text using the rules of the invariant culture.", ShortName = "case", SortOrder = 5)]
public CaseNormalizationMode TextCase = CaseNormalizationMode.Lower;
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to keep diacritical marks or remove them.", ShortName = "diac", SortOrder = 6)]
public bool KeepDiacritics;
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to keep punctuation marks or remove them.", ShortName = "punc", SortOrder = 7)]
public bool KeepPunctuations = true;
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to keep numbers or remove them.", ShortName = "num", SortOrder = 8)]
public bool KeepNumbers = true;
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to output the transformed text tokens as an additional column.", ShortName = "tokens,showtext,showTransformedText", SortOrder = 9)]
public bool OutputTokens;
[Argument(ArgumentType.Multiple, HelpText = "A dictionary of whitelisted terms.", ShortName = "dict", NullName = "<None>", SortOrder = 10, Hide = true)]
public TermLoaderArguments Dictionary;
[TGUI(Label = "Word Gram Extractor")]
[Argument(ArgumentType.Multiple, HelpText = "Ngram feature extractor to use for words (WordBag/WordHashBag).", ShortName = "wordExtractor", NullName = "<None>", SortOrder = 11)]
public INgramExtractorFactoryFactory WordFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments();
[TGUI(Label = "Char Gram Extractor")]
[Argument(ArgumentType.Multiple, HelpText = "Ngram feature extractor to use for characters (WordBag/WordHashBag).", ShortName = "charExtractor", NullName = "<None>", SortOrder = 12)]
public INgramExtractorFactoryFactory CharFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments() { NgramLength = 3, AllLengths = false };
[Argument(ArgumentType.AtMostOnce, HelpText = "Normalize vectors (rows) individually by rescaling them to unit norm.", ShortName = "norm", SortOrder = 13)]
public TextNormKind VectorNormalizer = TextNormKind.L2;
}
public sealed class Settings
{
#pragma warning disable MSML_NoInstanceInitializers // No initializers on instance fields or properties
public Language TextLanguage { get; set; } = DefaultLanguage;
public CaseNormalizationMode TextCase { get; set; } = CaseNormalizationMode.Lower;
public bool KeepDiacritics { get; set; } = false;
public bool KeepPunctuations { get; set; } = true;
public bool KeepNumbers { get; set; } = true;
public bool OutputTokens { get; set; } = false;
public TextNormKind VectorNormalizer { get; set; } = TextNormKind.L2;
#pragma warning restore MSML_NoInstanceInitializers // No initializers on instance fields or properties
}
public readonly string OutputColumn;
private readonly string[] _inputColumns;
public IReadOnlyCollection<string> InputColumns => _inputColumns.AsReadOnly();
public Settings AdvancedSettings { get; }
// These parameters are hardcoded for now.
// REVIEW: expose them once sub-transforms are estimators.
private IStopWordsRemoverFactory _stopWordsRemover;
private TermLoaderArguments _dictionary;
private INgramExtractorFactoryFactory _wordFeatureExtractor;
private INgramExtractorFactoryFactory _charFeatureExtractor;
private readonly IHost _host;
/// <summary>
/// A distilled version of the TextTransform Arguments, with all fields marked readonly and
/// only the exact set of information needed to construct the transforms preserved.
/// </summary>
private sealed class TransformApplierParams
{
public readonly INgramExtractorFactory WordExtractorFactory;
public readonly INgramExtractorFactory CharExtractorFactory;
public readonly TextNormKind VectorNormalizer;
public readonly Language Language;
public readonly IStopWordsRemoverFactory StopWordsRemover;
public readonly CaseNormalizationMode TextCase;
public readonly bool KeepDiacritics;
public readonly bool KeepPunctuations;
public readonly bool KeepNumbers;
public readonly bool OutputTextTokens;
public readonly TermLoaderArguments Dictionary;
public StopWordsRemoverTransform.Language StopwordsLanguage
{
get
{
return (StopWordsRemoverTransform.Language)
Enum.Parse(typeof(StopWordsRemoverTransform.Language), Language.ToString());
}
}
public LpNormNormalizerTransform.NormalizerKind LpNormalizerKind
{
get
{
switch (VectorNormalizer)
{
case TextNormKind.L1:
return LpNormNormalizerTransform.NormalizerKind.L1Norm;
case TextNormKind.L2:
return LpNormNormalizerTransform.NormalizerKind.L2Norm;
case TextNormKind.LInf:
return LpNormNormalizerTransform.NormalizerKind.LInf;
default:
Contracts.Assert(false, "Unexpected normalizer type");
return LpNormNormalizerTransform.NormalizerKind.L2Norm;
}
}
}
// These properties encode the logic needed to determine which transforms to apply.
#region NeededTransforms
public bool NeedsWordTokenizationTransform { get { return WordExtractorFactory != null || NeedsRemoveStopwordsTransform || OutputTextTokens; } }
public bool NeedsRemoveStopwordsTransform { get { return StopWordsRemover != null; } }
public bool NeedsNormalizeTransform
{
get
{
return
TextCase != CaseNormalizationMode.None ||
!KeepDiacritics ||
!KeepPunctuations ||
!KeepNumbers;
}
}
private bool UsesHashExtractors
{
get
{
return
(WordExtractorFactory == null ? true : WordExtractorFactory.UseHashingTrick) &&
(CharExtractorFactory == null ? true : CharExtractorFactory.UseHashingTrick);
}
}
// If we're performing language auto detection, or either of our extractors aren't hashing then
// we need all the input text concatenated into a single Vect<DvText>, for the LanguageDetectionTransform
// to operate on the entire text vector, and for the Dictionary feature extractor to build its bound dictionary
// correctly.
public bool NeedInitialSourceColumnConcatTransform
{
get
{
return !UsesHashExtractors;
}
}
#endregion
public TransformApplierParams(TextTransform parent)
{
var host = parent._host;
host.Check(Enum.IsDefined(typeof(Language), parent.AdvancedSettings.TextLanguage));
host.Check(Enum.IsDefined(typeof(CaseNormalizationMode), parent.AdvancedSettings.TextCase));
WordExtractorFactory = parent._wordFeatureExtractor?.CreateComponent(host, parent._dictionary);
CharExtractorFactory = parent._charFeatureExtractor?.CreateComponent(host, parent._dictionary);
VectorNormalizer = parent.AdvancedSettings.VectorNormalizer;
Language = parent.AdvancedSettings.TextLanguage;
StopWordsRemover = parent._stopWordsRemover;
TextCase = parent.AdvancedSettings.TextCase;
KeepDiacritics = parent.AdvancedSettings.KeepDiacritics;
KeepPunctuations = parent.AdvancedSettings.KeepPunctuations;
KeepNumbers = parent.AdvancedSettings.KeepNumbers;
OutputTextTokens = parent.AdvancedSettings.OutputTokens;
Dictionary = parent._dictionary;
}
}
internal const string Summary = "A transform that turns a collection of text documents into numerical feature vectors. " +
"The feature vectors are normalized counts of (word and/or character) ngrams in a given tokenized text.";
internal const string UserName = "Text Transform";
internal const string LoaderSignature = "Text";
public const Language DefaultLanguage = Language.English;
private const string TransformedTextColFormat = "{0}_TransformedText";
public TextTransform(IHostEnvironment env, string inputColumn, string outputColumn = null,
Action<Settings> advancedSettings = null)
: this(env, new[] { inputColumn }, outputColumn ?? inputColumn, advancedSettings)
{
}
public TextTransform(IHostEnvironment env, IEnumerable<string> inputColumns, string outputColumn,
Action<Settings> advancedSettings = null)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(TextTransform));
_host.CheckValue(inputColumns, nameof(inputColumns));
_host.CheckParam(inputColumns.Any(), nameof(inputColumns));
_host.CheckParam(!inputColumns.Any(string.IsNullOrWhiteSpace), nameof(inputColumns));
_host.CheckNonEmpty(outputColumn, nameof(outputColumn));
_host.CheckValueOrNull(advancedSettings);
_inputColumns = inputColumns.ToArray();
OutputColumn = outputColumn;
AdvancedSettings = new Settings();
advancedSettings?.Invoke(AdvancedSettings);
_stopWordsRemover = null;
_dictionary = null;
_wordFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments();
_charFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments() { NgramLength = 3, AllLengths = false };
}
public ITransformer Fit(IDataView input)
{
var h = _host;
h.CheckValue(input, nameof(input));
var tparams = new TransformApplierParams(this);
string[] textCols = _inputColumns;
string[] wordTokCols = null;
string[] charTokCols = null;
string wordFeatureCol = null;
string charFeatureCol = null;
List<string> tempCols = new List<string>();
IDataView view = input;
if (tparams.NeedInitialSourceColumnConcatTransform && textCols.Length > 1)
{
var xfCols = new ConcatTransform.Column[] { new ConcatTransform.Column() };
xfCols[0].Source = textCols;
textCols = new[] { GenerateColumnName(input.Schema, OutputColumn, "InitialConcat") };
xfCols[0].Name = textCols[0];
tempCols.Add(textCols[0]);
view = new ConcatTransform(h, new ConcatTransform.Arguments() { Column = xfCols }, view);
}
if (tparams.NeedsNormalizeTransform)
{
var xfCols = new TextNormalizerCol[textCols.Length];
string[] dstCols = new string[textCols.Length];
for (int i = 0; i < textCols.Length; i++)
{
dstCols[i] = GenerateColumnName(view.Schema, textCols[i], "TextNormalizer");
tempCols.Add(dstCols[i]);
xfCols[i] = new TextNormalizerCol() { Source = textCols[i], Name = dstCols[i] };
}
view = new TextNormalizerTransform(h,
new TextNormalizerArgs()
{
Column = xfCols,
KeepDiacritics = tparams.KeepDiacritics,
KeepNumbers = tparams.KeepNumbers,
KeepPunctuations = tparams.KeepPunctuations,
TextCase = tparams.TextCase
}, view);
textCols = dstCols;
}
if (tparams.NeedsWordTokenizationTransform)
{
var xfCols = new DelimitedTokenizeTransform.Column[textCols.Length];
wordTokCols = new string[textCols.Length];
for (int i = 0; i < textCols.Length; i++)
{
var col = new DelimitedTokenizeTransform.Column();
col.Source = textCols[i];
col.Name = GenerateColumnName(view.Schema, textCols[i], "WordTokenizer");
xfCols[i] = col;
wordTokCols[i] = col.Name;
tempCols.Add(col.Name);
}
view = new DelimitedTokenizeTransform(h, new DelimitedTokenizeTransform.Arguments() { Column = xfCols }, view);
}
if (tparams.NeedsRemoveStopwordsTransform)
{
Contracts.Assert(wordTokCols != null, "StopWords transform requires that word tokenization has been applied to the input text.");
var xfCols = new StopWordsCol[wordTokCols.Length];
var dstCols = new string[wordTokCols.Length];
for (int i = 0; i < wordTokCols.Length; i++)
{
var col = new StopWordsCol();
col.Source = wordTokCols[i];
col.Name = GenerateColumnName(view.Schema, wordTokCols[i], "StopWordsRemoverTransform");
dstCols[i] = col.Name;
tempCols.Add(col.Name);
col.Language = tparams.StopwordsLanguage;
xfCols[i] = col;
}
view = tparams.StopWordsRemover.CreateComponent(h, view, xfCols);
wordTokCols = dstCols;
}
if (tparams.WordExtractorFactory != null)
{
var dstCol = GenerateColumnName(view.Schema, OutputColumn, "WordExtractor");
tempCols.Add(dstCol);
view = tparams.WordExtractorFactory.Create(h, view, new[] {
new ExtractorColumn()
{
Name = dstCol,
Source = wordTokCols,
FriendlyNames = _inputColumns
}});
wordFeatureCol = dstCol;
}
if (tparams.OutputTextTokens)
{
string[] srcCols = wordTokCols ?? textCols;
view = new ConcatTransform(h,
new ConcatTransform.Arguments()
{
Column = new[] { new ConcatTransform.Column()
{
Name = string.Format(TransformedTextColFormat, OutputColumn),
Source = srcCols
}}
}, view);
}
if (tparams.CharExtractorFactory != null)
{
{
var srcCols = tparams.NeedsRemoveStopwordsTransform ? wordTokCols : textCols;
charTokCols = new string[srcCols.Length];
var xfCols = new CharTokenizeTransform.Column[srcCols.Length];
for (int i = 0; i < srcCols.Length; i++)
{
var col = new CharTokenizeTransform.Column();
col.Source = srcCols[i];
col.Name = GenerateColumnName(view.Schema, srcCols[i], "CharTokenizer");
tempCols.Add(col.Name);
charTokCols[i] = col.Name;
xfCols[i] = col;
}
view = new CharTokenizeTransform(h, new CharTokenizeTransform.Arguments() { Column = xfCols }, view);
}
{
charFeatureCol = GenerateColumnName(view.Schema, OutputColumn, "CharExtractor");
tempCols.Add(charFeatureCol);
view = tparams.CharExtractorFactory.Create(h, view, new[] {
new ExtractorColumn()
{
Source = charTokCols,
FriendlyNames = _inputColumns,
Name = charFeatureCol
}});
}
}
if (tparams.VectorNormalizer != TextNormKind.None)
{
var xfCols = new List<LpNormNormalizerTransform.Column>(2);
if (charFeatureCol != null)
{
var dstCol = GenerateColumnName(view.Schema, charFeatureCol, "LpCharNorm");
tempCols.Add(dstCol);
xfCols.Add(new LpNormNormalizerTransform.Column()
{
Source = charFeatureCol,
Name = dstCol
});
charFeatureCol = dstCol;
}
if (wordFeatureCol != null)
{
var dstCol = GenerateColumnName(view.Schema, wordFeatureCol, "LpWordNorm");
tempCols.Add(dstCol);
xfCols.Add(new LpNormNormalizerTransform.Column()
{
Source = wordFeatureCol,
Name = dstCol
});
wordFeatureCol = dstCol;
}
if (xfCols.Count > 0)
view = new LpNormNormalizerTransform(h, new LpNormNormalizerTransform.Arguments()
{
NormKind = tparams.LpNormalizerKind,
Column = xfCols.ToArray()
}, view);
}
{
var srcTaggedCols = new List<KeyValuePair<string, string>>(2);
if (charFeatureCol != null && wordFeatureCol != null)
{
// If we're producing both char and word grams, then we need to disambiguate
// between them (e.g. the word 'a' vs. the char gram 'a').
srcTaggedCols.Add(new KeyValuePair<string, string>("Char", charFeatureCol));
srcTaggedCols.Add(new KeyValuePair<string, string>("Word", wordFeatureCol));
}
else
{
// Otherwise, simply use the slot names, omitting the original source column names
// entirely. For the Concat transform setting the Key == Value of the TaggedColumn
// KVP signals this intent.
Contracts.Assert(charFeatureCol != null || wordFeatureCol != null || tparams.OutputTextTokens);
if (charFeatureCol != null)
srcTaggedCols.Add(new KeyValuePair<string, string>(charFeatureCol, charFeatureCol));
else if (wordFeatureCol != null)
srcTaggedCols.Add(new KeyValuePair<string, string>(wordFeatureCol, wordFeatureCol));
}
if (srcTaggedCols.Count > 0)
view = new ConcatTransform(h, new ConcatTransform.TaggedArguments()
{
Column = new[] { new ConcatTransform.TaggedColumn() {
Name = OutputColumn,
Source = srcTaggedCols.ToArray()
}}
}, view);
}
view = new DropColumnsTransform(h,
new DropColumnsTransform.Arguments() { Column = tempCols.ToArray() }, view);
return new Transformer(_host, input, view);
}
public static ITransformer Create(IHostEnvironment env, ModelLoadContext ctx)
=> new Transformer(env, ctx);
private static string GenerateColumnName(ISchema schema, string srcName, string xfTag)
{
return schema.GetTempColumnName(string.Format("{0}_{1}", srcName, xfTag));
}
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
var result = inputSchema.Columns.ToDictionary(x => x.Name);
foreach (var srcName in _inputColumns)
{
var col = inputSchema.FindColumn(srcName);
if (col == null)
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", srcName);
if (!col.ItemType.IsText)
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", srcName, "scalar or vector of text", col.GetTypeString());
}
var metadata = new List<string> { MetadataUtils.Kinds.SlotNames };
if (AdvancedSettings.VectorNormalizer != TextNormKind.None)
metadata.Add(MetadataUtils.Kinds.IsNormalized);
result[OutputColumn] = new SchemaShape.Column(OutputColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false,
metadata.ToArray());
if (AdvancedSettings.OutputTokens)
{
string name = string.Format(TransformedTextColFormat, OutputColumn);
result[name] = new SchemaShape.Column(name, SchemaShape.Column.VectorKind.VariableVector, TextType.Instance, false);
}
return new SchemaShape(result.Values);
}
public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView data)
{
Action<Settings> settings = s =>
{
s.TextLanguage = args.Language;
s.TextCase = args.TextCase;
s.KeepDiacritics = args.KeepDiacritics;
s.KeepPunctuations = args.KeepPunctuations;
s.KeepNumbers = args.KeepNumbers;
s.OutputTokens = args.OutputTokens;
s.VectorNormalizer = args.VectorNormalizer;
};
var estimator = new TextTransform(env, args.Column.Source ?? new[] { args.Column.Name }, args.Column.Name, settings);
estimator._stopWordsRemover = args.StopWordsRemover;
estimator._dictionary = args.Dictionary;
estimator._wordFeatureExtractor = args.WordFeatureExtractor;
estimator._charFeatureExtractor = args.CharFeatureExtractor;
return estimator.Fit(data).Transform(data) as IDataTransform;
}
private sealed class Transformer : ITransformer, ICanSaveModel
{
private const string TransformDirTemplate = "Step_{0:000}";
private readonly IHost _host;
private readonly IDataView _xf;
public Transformer(IHostEnvironment env, IDataView input, IDataView view)
{
_host = env.Register(nameof(Transformer));
_xf = ApplyTransformUtils.ApplyAllTransformsToData(_host, view, new EmptyDataView(_host, input.Schema), input);
}
public ISchema GetOutputSchema(ISchema inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
return Transform(new EmptyDataView(_host, inputSchema)).Schema;
}
public IDataView Transform(IDataView input)
{
_host.CheckValue(input, nameof(input));
return ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input);
}
public void Save(ModelSaveContext ctx)
{
_host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
var dataPipe = _xf;
var transforms = new List<IDataTransform>();
while (dataPipe is IDataTransform xf)
{
transforms.Add(xf);
dataPipe = xf.Source;
Contracts.AssertValue(dataPipe);
}
transforms.Reverse();
ctx.SaveSubModel("Loader", c => BinaryLoader.SaveInstance(_host, c, dataPipe.Schema));
ctx.Writer.Write(transforms.Count);
for (int i = 0; i < transforms.Count; i++)
{
var dirName = string.Format(TransformDirTemplate, i);
ctx.SaveModel(transforms[i], dirName);
}
}
public Transformer(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(Transformer));
_host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
int n = ctx.Reader.ReadInt32();
ctx.LoadModel<IDataLoader, SignatureLoadDataLoader>(env, out var loader, "Loader", new MultiFileSource(null));
IDataView data = loader;
for (int i = 0; i < n; i++)
{
var dirName = string.Format(TransformDirTemplate, i);
ctx.LoadModel<IDataTransform, SignatureLoadDataTransform>(env, out var xf, dirName, data);
data = xf;
}
_xf = data;
}
private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "TEXT XFR",
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature);
}
}
internal sealed class OutPipelineColumn : Vector<float>
{
public readonly Scalar<string>[] Inputs;
public OutPipelineColumn(IEnumerable<Scalar<string>> inputs, Action<Settings> advancedSettings)
: base(new Reconciler(advancedSettings), inputs.ToArray())
{
Inputs = inputs.ToArray();
}
}
private sealed class Reconciler : EstimatorReconciler
{
private readonly Action<Settings> _settings;
public Reconciler(Action<Settings> advancedSettings)
{
_settings = advancedSettings;
}
public override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
PipelineColumn[] toOutput,
IReadOnlyDictionary<PipelineColumn, string> inputNames,
IReadOnlyDictionary<PipelineColumn, string> outputNames,
IReadOnlyCollection<string> usedNames)
{
Contracts.Assert(toOutput.Length == 1);
var outCol = (OutPipelineColumn)toOutput[0];
var inputs = outCol.Inputs.Select(x => inputNames[x]);
return new TextTransform(env, inputs, outputNames[outCol], _settings);
}
}
}
/// <summary>
/// Extension methods for the static-pipeline over <see cref="PipelineColumn"/> objects.
/// </summary>
public static class TextFeaturizerStaticPipe
{
public static Vector<float> FeaturizeText(this Scalar<string> input, params Scalar<string>[] otherInputs)
=> input.FeaturizeText(otherInputs, null);
public static Vector<float> FeaturizeText(this Scalar<string> input, Scalar<string>[] otherInputs = null, Action<TextTransform.Settings> advancedSettings = null)
{
Contracts.CheckValue(input, nameof(input));
Contracts.CheckValueOrNull(otherInputs);
otherInputs = otherInputs ?? new Scalar<string>[0];
return new TextTransform.OutPipelineColumn(new[] { input }.Concat(otherInputs), advancedSettings);
}
}
}