Skip to content

Commit ccb6f56

Browse files
committed
documentation
1 parent 9bd7707 commit ccb6f56

File tree

7 files changed

+65
-25
lines changed

7 files changed

+65
-25
lines changed

src/Microsoft.ML.Transforms/Dracula/CountTableBuilder.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ internal interface ICountTableBuilderFactory : IComponentFactory<CountTableBuild
1212
{
1313
}
1414

15+
/// <summary>
16+
/// Builds a table that provides counts to the <see cref="CountTargetEncodingTransformer"/>
17+
/// by going over the training data.
18+
/// </summary>
1519
public abstract class CountTableBuilderBase
1620
{
1721
private protected CountTableBuilderBase()
@@ -20,9 +24,20 @@ private protected CountTableBuilderBase()
2024

2125
internal abstract InternalCountTableBuilderBase GetInternalBuilder(long labelCardinality);
2226

27+
/// <summary>
28+
/// Create a builder that creates the count table using the count-min sketch structure, which has a smaller memory footprint,
29+
/// at the expense of some possible overcounting due to collisions.
30+
/// </summary>
31+
/// <param name="depth">The depth of the count-min sketch table.</param>
32+
/// <param name="width">The width of the count-min sketch table.</param>
2333
public static CountTableBuilderBase CreateCMCountTableBuilder(int depth = 4, int width = 1 << 23)
2434
=> new CMCountTableBuilder(depth, width);
2535

36+
/// <summary>
37+
/// Create a builder that creates the count table by building a dictionary containing the exact count of each
38+
/// categorical feature value.
39+
/// </summary>
40+
/// <param name="garbageThreshold">The garbage threshold (counts below or equal to the threshold are assigned to the garbage bin).</param>
2641
public static CountTableBuilderBase CreateDictionaryCountTableBuilder(float garbageThreshold = 0)
2742
=> new DictCountTableBuilder(garbageThreshold);
2843
}

src/Microsoft.ML.Transforms/Dracula/CountTableTransformer.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ public CountTableTransformer Fit(IDataView input)
169169

170170
var multiCountTable = multiBuilder.CreateMultiCountTable();
171171

172-
var featurizer = new DraculaFeaturizer(_host, _columns.Select(col => col.PriorCoefficient).ToArray(), _columns.Select(col => col.LaplaceScale).ToArray(), labelCardinality, multiCountTable);
172+
var featurizer = new CountTargetEncodingFeaturizer(_host, _columns.Select(col => col.PriorCoefficient).ToArray(), _columns.Select(col => col.LaplaceScale).ToArray(), labelCardinality, multiCountTable);
173173

174174
return new CountTableTransformer(_host, featurizer, labelClassNames,
175175
_columns.Select(col => col.Seed).ToArray(), _columns.Select(col => (col.Name, col.InputColumnName)).ToArray());
@@ -417,13 +417,13 @@ internal static class Defaults
417417
public const bool SharedTable = false;
418418
}
419419

420-
internal readonly DraculaFeaturizer Featurizer;
420+
internal readonly CountTargetEncodingFeaturizer Featurizer;
421421
private readonly string[] _labelClassNames;
422422

423423
internal int[] Seeds { get; }
424424

425425
internal const string Summary = "Transforms the categorical column into the set of features: count of each label class, "
426-
+ "log-odds for each label class, back-off indicator. The input columns must be keys. This is a part of the Dracula transform.";
426+
+ "log-odds for each label class, back-off indicator. The input columns must be keys.";
427427

428428
internal const string LoaderSignature = "CountTableTransform";
429429
internal const string UserName = "Count Table Transform";
@@ -438,7 +438,7 @@ private static VersionInfo GetVersionInfo()
438438
loaderAssemblyName: typeof(CountTableTransformer).Assembly.FullName);
439439
}
440440

441-
internal CountTableTransformer(IHostEnvironment env, DraculaFeaturizer featurizer, string[] labelClassNames,
441+
internal CountTableTransformer(IHostEnvironment env, CountTargetEncodingFeaturizer featurizer, string[] labelClassNames,
442442
int[] seeds, (string outputColumnName, string inputColumnName)[] columns)
443443
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(CountTableTransformer)), columns)
444444
{
@@ -551,7 +551,7 @@ private CountTableTransformer(IHost host, ModelLoadContext ctx)
551551
}
552552

553553
Seeds = ctx.Reader.ReadIntArray(ColumnPairs.Length);
554-
ctx.LoadModel<DraculaFeaturizer, SignatureLoadModel>(host, out Featurizer, "DraculaFeaturizer");
554+
ctx.LoadModel<CountTargetEncodingFeaturizer, SignatureLoadModel>(host, out Featurizer, "Featurizer");
555555
}
556556

557557
private protected override void SaveModel(ModelSaveContext ctx)
@@ -582,7 +582,7 @@ private protected override void SaveModel(ModelSaveContext ctx)
582582
}
583583

584584
ctx.Writer.WriteIntsNoCount(Seeds);
585-
ctx.SaveModel(Featurizer, "DraculaFeaturizer");
585+
ctx.SaveModel(Featurizer, "Featurizer");
586586
}
587587

588588
private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema);

src/Microsoft.ML.Transforms/Dracula/DraculaTransform.cs renamed to src/Microsoft.ML.Transforms/Dracula/CountTargetEncodingTransformer.cs

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,30 @@
2424

2525
namespace Microsoft.ML.Transforms
2626
{
27+
/// <summary>
28+
/// Transforms a categorical column into a set of features that includes the count of each label class,
29+
/// the log-odds for each label class and the back-off indicator.
30+
/// </summary>
31+
/// <remarks>
32+
/// <format type="text/markdown"><![CDATA[
33+
/// ### Estimator Characteristics
34+
/// | | |
35+
/// | -- | -- |
36+
/// | Does this estimator need to look at the data to train its parameters? | Yes |
37+
/// | Input column data type | Any |
38+
/// | Output column data type | Vector of <xref:System.Single>. |
39+
///
40+
/// The resulting <xref:Microsoft.ML.Transforms.CountTargetEncodingTransformer> creates a new column, named as specified in the output column name parameters,
41+
/// containing three parts: the count of each label class, the log-odds for each label class and the back-off indicator.
42+
///
43+
/// Check the See Also section for links to usage examples.
44+
/// ]]>
45+
/// </format>
46+
/// </remarks>
47+
/// <seealso cref="CountTargetEncodingCatalog.CountTargetEncode(TransformsCatalog, InputOutputColumnPair[], CountTargetEncodingTransformer, string)" />
48+
/// <seealso cref="CountTargetEncodingCatalog.CountTargetEncode(TransformsCatalog, InputOutputColumnPair[], string, CountTableBuilderBase, float, float, bool, int, bool, uint)" />
49+
/// <seealso cref="CountTargetEncodingCatalog.CountTargetEncode(TransformsCatalog, string, CountTargetEncodingTransformer, string, string)"/>
50+
/// <seealso cref="CountTargetEncodingCatalog.CountTargetEncode(TransformsCatalog, string, string, string, CountTableBuilderBase, float, float, int, bool, uint)"/>
2751
public class CountTargetEncodingEstimator : IEstimator<CountTargetEncodingTransformer>
2852
{
2953
/// <summary>
@@ -344,6 +368,9 @@ private SchemaShape CreateHashJoinOutputSchema(SchemaShape inputSchema)
344368
}
345369
}
346370

371+
/// <summary>
372+
/// <see cref="ITransformer"/> resulting from fitting a <see cref="LpNormNormalizingEstimator"/> or <see cref="CountTargetEncodingEstimator"/>.
373+
/// </summary>
347374
public sealed class CountTargetEncodingTransformer : ITransformer
348375
{
349376
private readonly IHost _host;
@@ -353,8 +380,8 @@ public sealed class CountTargetEncodingTransformer : ITransformer
353380

354381
internal const string Summary = "Transforms the categorical column into the set of features: count of each label class, "
355382
+ "log-odds for each label class, back-off indicator. The columns can be of arbitrary type.";
356-
internal const string LoaderSignature = "DraculaTransform";
357-
internal const string UserName = "Dracula Transform";
383+
internal const string LoaderSignature = "CountTargetEncode";
384+
internal const string UserName = "Count Target Encoding Transform";
358385

359386
bool ITransformer.IsRowToRowMapper => true;
360387

src/Microsoft.ML.Transforms/Dracula/Featurizer.cs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
using Microsoft.ML.Runtime;
1212
using Microsoft.ML.Transforms;
1313

14-
[assembly: LoadableClass(typeof(DraculaFeaturizer), null, typeof(SignatureLoadModel),
15-
"Dracula Featurizer", DraculaFeaturizer.RegistrationName)]
14+
[assembly: LoadableClass(typeof(CountTargetEncodingFeaturizer), null, typeof(SignatureLoadModel),
15+
"Count Target Encoding Featurizer", CountTargetEncodingFeaturizer.RegistrationName)]
1616

1717
namespace Microsoft.ML.Transforms
1818
{
19-
internal sealed class DraculaFeaturizer : ICanSaveModel
19+
internal sealed class CountTargetEncodingFeaturizer : ICanSaveModel
2020
{
2121
private readonly IHost _host;
2222
private readonly int _labelBinCount;
@@ -27,7 +27,7 @@ internal sealed class DraculaFeaturizer : ICanSaveModel
2727
public float[] PriorCoef { get; }
2828
public float[] LaplaceScale { get; }
2929

30-
internal const string RegistrationName = "DraculaFeaturizer";
30+
internal const string RegistrationName = "CountTargetEncodingFeaturizer";
3131
private static VersionInfo GetVersionInfo()
3232
{
3333
return new VersionInfo(
@@ -36,14 +36,14 @@ private static VersionInfo GetVersionInfo()
3636
verReadableCur: 0x00010001,
3737
verWeCanReadBack: 0x00010001,
3838
loaderSignature: RegistrationName,
39-
loaderAssemblyName: typeof(DraculaFeaturizer).Assembly.FullName);
39+
loaderAssemblyName: typeof(CountTargetEncodingFeaturizer).Assembly.FullName);
4040
}
4141

4242
public int ColCount => _countTables.ColCount;
4343

4444
public ReadOnlySpan<int> SlotCount => _countTables.SlotCount;
4545

46-
public DraculaFeaturizer(IHostEnvironment env, float[] priorCoef, float[] laplaceScale, long labelBinCount, MultiCountTableBase countTable)
46+
public CountTargetEncodingFeaturizer(IHostEnvironment env, float[] priorCoef, float[] laplaceScale, long labelBinCount, MultiCountTableBase countTable)
4747
{
4848
Contracts.CheckValue(env, nameof(env));
4949
_host = env.Register(RegistrationName);
@@ -61,7 +61,7 @@ public DraculaFeaturizer(IHostEnvironment env, float[] priorCoef, float[] laplac
6161
_countTables = countTable;
6262
}
6363

64-
public DraculaFeaturizer(IHostEnvironment env, ModelLoadContext ctx)
64+
public CountTargetEncodingFeaturizer(IHostEnvironment env, ModelLoadContext ctx)
6565
{
6666
Contracts.CheckValue(env, nameof(env));
6767
_host = env.Register(RegistrationName);

test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4697,10 +4697,10 @@ public void EntryPointHashJoinCountTable()
46974697
}
46984698

46994699
[Fact]
4700-
public void EntryPointDracula()
4700+
public void EntryPointCountTargetEncoding()
47014701
{
47024702
var dataPath = GetDataPath("breast-cancer.txt");
4703-
var countsModel = DeleteOutputPath("Dracula-trained-counts.zip");
4703+
var countsModel = DeleteOutputPath("cte-trained-counts.zip");
47044704

47054705
var data = ML.Data.LoadFromTextFile(dataPath, new[]
47064706
{
@@ -4711,7 +4711,6 @@ public void EntryPointDracula()
47114711
var transformer = estimator.Fit(data);
47124712
ML.Model.Save(transformer, data.Schema, countsModel);
47134713

4714-
//var countsFile = GetDataPath(@"Dracula/ext-count-table.tsv");
47154714
TestEntryPointPipelineRoutine(dataPath, "col=Text:TX:1-9 col=OneText:TX:1 col=Label:0",
47164715
new[]
47174716
{

test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,7 +1327,7 @@ public void SavePipeCountTableShared()
13271327
}
13281328

13291329
[Fact]
1330-
public void SavePipeDracula()
1330+
public void SavePipeCountTargetEncoding()
13311331
{
13321332
TestCore(null, false,
13331333
new[] {
@@ -1340,7 +1340,7 @@ public void SavePipeDracula()
13401340
}
13411341

13421342
[Fact]
1343-
public void SavePipeDraculaKeyLabel()
1343+
public void SavePipeCountTargetEncodingKeyLabel()
13441344
{
13451345
TestCore(null, false,
13461346
new[] {
@@ -1354,12 +1354,11 @@ public void SavePipeDraculaKeyLabel()
13541354
}
13551355

13561356
[Fact]
1357-
public void SavePipeDraculaExternalCounts()
1357+
public void SavePipeCountTargetEncodingLoadModel()
13581358
{
1359-
//var countsFile = GetDataPath("Dracula", "ext-count-table.tsv");
13601359
var inputData = GetDataPath("breast-cancer.txt");
1361-
var initialCountsModel = DeleteOutputPath("Dracula", "initialCounts.zip");
1362-
var outputData = DeleteOutputPath("Dracula", "countsData.txt");
1360+
var initialCountsModel = DeleteOutputPath("CTE", "initialCounts.zip");
1361+
var outputData = DeleteOutputPath("CTE", "countsData.txt");
13631362
var loaderArg = "loader=Text{col=Text:TX:1-9 col=OneText:TX:1 col=Label:0}";
13641363
MainForTest($"SaveData data={inputData} {loaderArg} xf=Dracula{{lab=Label col={{name=DT src=Text customSlotMap=0,1;2,3,4,5}} table = Dict}} out={initialCountsModel} dout={outputData}");
13651364

test/Microsoft.ML.TestFramework/TestCommandBase.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1760,7 +1760,7 @@ public void CommandSaveData()
17601760
Done();
17611761
}
17621762

1763-
[TestCategory(Cat), TestCategory("Dracula")]
1763+
[TestCategory(Cat), TestCategory("CountTargetEncoding")]
17641764
[Fact(Skip = "Need CoreTLC specific baseline update")]
17651765
public void CommandDraculaInfer()
17661766
{

0 commit comments

Comments
 (0)