Skip to content

Commit

Permalink
documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
yaeldMS committed Dec 27, 2019
1 parent 9bd7707 commit ccb6f56
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 25 deletions.
15 changes: 15 additions & 0 deletions src/Microsoft.ML.Transforms/Dracula/CountTableBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ internal interface ICountTableBuilderFactory : IComponentFactory<CountTableBuild
{
}

/// <summary>
/// Builds a table that provides counts to the <see cref="CountTargetEncodingTransformer"/>
/// by going over the training data.
/// </summary>
public abstract class CountTableBuilderBase
{
private protected CountTableBuilderBase()
Expand All @@ -20,9 +24,20 @@ private protected CountTableBuilderBase()

internal abstract InternalCountTableBuilderBase GetInternalBuilder(long labelCardinality);

/// <summary>
/// Create a builder that creates the count table using the count-min sketch structure, which has a smaller memory footprint,
/// at the expense of some possible overcounting due to collisions.
/// </summary>
/// <param name="depth">The depth of the count-min sketch table.</param>
/// <param name="width">The width of the count-min sketch table.</param>
public static CountTableBuilderBase CreateCMCountTableBuilder(int depth = 4, int width = 1 << 23)
=> new CMCountTableBuilder(depth, width);

/// <summary>
/// Create a builder that creates the count table by building a dictionary containing the exact count of each
/// categorical feature value.
/// </summary>
/// <param name="garbageThreshold">The garbage threshold (counts below or equal to the threshold are assigned to the garbage bin).</param>
public static CountTableBuilderBase CreateDictionaryCountTableBuilder(float garbageThreshold = 0)
=> new DictCountTableBuilder(garbageThreshold);
}
Expand Down
12 changes: 6 additions & 6 deletions src/Microsoft.ML.Transforms/Dracula/CountTableTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ public CountTableTransformer Fit(IDataView input)

var multiCountTable = multiBuilder.CreateMultiCountTable();

var featurizer = new DraculaFeaturizer(_host, _columns.Select(col => col.PriorCoefficient).ToArray(), _columns.Select(col => col.LaplaceScale).ToArray(), labelCardinality, multiCountTable);
var featurizer = new CountTargetEncodingFeaturizer(_host, _columns.Select(col => col.PriorCoefficient).ToArray(), _columns.Select(col => col.LaplaceScale).ToArray(), labelCardinality, multiCountTable);

return new CountTableTransformer(_host, featurizer, labelClassNames,
_columns.Select(col => col.Seed).ToArray(), _columns.Select(col => (col.Name, col.InputColumnName)).ToArray());
Expand Down Expand Up @@ -417,13 +417,13 @@ internal static class Defaults
public const bool SharedTable = false;
}

internal readonly DraculaFeaturizer Featurizer;
internal readonly CountTargetEncodingFeaturizer Featurizer;
private readonly string[] _labelClassNames;

internal int[] Seeds { get; }

internal const string Summary = "Transforms the categorical column into the set of features: count of each label class, "
+ "log-odds for each label class, back-off indicator. The input columns must be keys. This is a part of the Dracula transform.";
+ "log-odds for each label class, back-off indicator. The input columns must be keys.";

internal const string LoaderSignature = "CountTableTransform";
internal const string UserName = "Count Table Transform";
Expand All @@ -438,7 +438,7 @@ private static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(CountTableTransformer).Assembly.FullName);
}

internal CountTableTransformer(IHostEnvironment env, DraculaFeaturizer featurizer, string[] labelClassNames,
internal CountTableTransformer(IHostEnvironment env, CountTargetEncodingFeaturizer featurizer, string[] labelClassNames,
int[] seeds, (string outputColumnName, string inputColumnName)[] columns)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(CountTableTransformer)), columns)
{
Expand Down Expand Up @@ -551,7 +551,7 @@ private CountTableTransformer(IHost host, ModelLoadContext ctx)
}

Seeds = ctx.Reader.ReadIntArray(ColumnPairs.Length);
ctx.LoadModel<DraculaFeaturizer, SignatureLoadModel>(host, out Featurizer, "DraculaFeaturizer");
ctx.LoadModel<CountTargetEncodingFeaturizer, SignatureLoadModel>(host, out Featurizer, "Featurizer");
}

private protected override void SaveModel(ModelSaveContext ctx)
Expand Down Expand Up @@ -582,7 +582,7 @@ private protected override void SaveModel(ModelSaveContext ctx)
}

ctx.Writer.WriteIntsNoCount(Seeds);
ctx.SaveModel(Featurizer, "DraculaFeaturizer");
ctx.SaveModel(Featurizer, "Featurizer");
}

private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,30 @@

namespace Microsoft.ML.Transforms
{
/// <summary>
/// Transforms a categorical column into a set of features that includes the count of each label class,
/// the log-odds for each label class and the back-off indicator.
/// </summary>
/// <remarks>
/// <format type="text/markdown"><![CDATA[
/// ### Estimator Characteristics
/// | | |
/// | -- | -- |
/// | Does this estimator need to look at the data to train its parameters? | Yes |
/// | Input column data type | Any |
/// | Output column data type | Vector of <xref:System.Single>. |
///
/// The resulting <xref:Microsoft.ML.Transforms.CountTargetEncodingTransformer> creates a new column, named as specified in the output column name parameters,
/// containing three parts: the count of each label class, the log-odds for each label class and the back-off indicator.
///
/// Check the See Also section for links to usage examples.
/// ]]>
/// </format>
/// </remarks>
/// <seealso cref="CountTargetEncodingCatalog.CountTargetEncode(TransformsCatalog, InputOutputColumnPair[], CountTargetEncodingTransformer, string)" />
/// <seealso cref="CountTargetEncodingCatalog.CountTargetEncode(TransformsCatalog, InputOutputColumnPair[], string, CountTableBuilderBase, float, float, bool, int, bool, uint)" />
/// <seealso cref="CountTargetEncodingCatalog.CountTargetEncode(TransformsCatalog, string, CountTargetEncodingTransformer, string, string)"/>
/// <seealso cref="CountTargetEncodingCatalog.CountTargetEncode(TransformsCatalog, string, string, string, CountTableBuilderBase, float, float, int, bool, uint)"/>
public class CountTargetEncodingEstimator : IEstimator<CountTargetEncodingTransformer>
{
/// <summary>
Expand Down Expand Up @@ -344,6 +368,9 @@ private SchemaShape CreateHashJoinOutputSchema(SchemaShape inputSchema)
}
}

/// <summary>
/// <see cref="ITransformer"/> resulting from fitting a <see cref="LpNormNormalizingEstimator"/> or <see cref="CountTargetEncodingEstimator"/>.
/// </summary>
public sealed class CountTargetEncodingTransformer : ITransformer
{
private readonly IHost _host;
Expand All @@ -353,8 +380,8 @@ public sealed class CountTargetEncodingTransformer : ITransformer

internal const string Summary = "Transforms the categorical column into the set of features: count of each label class, "
+ "log-odds for each label class, back-off indicator. The columns can be of arbitrary type.";
internal const string LoaderSignature = "DraculaTransform";
internal const string UserName = "Dracula Transform";
internal const string LoaderSignature = "CountTargetEncode";
internal const string UserName = "Count Target Encoding Transform";

bool ITransformer.IsRowToRowMapper => true;

Expand Down
14 changes: 7 additions & 7 deletions src/Microsoft.ML.Transforms/Dracula/Featurizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;

[assembly: LoadableClass(typeof(DraculaFeaturizer), null, typeof(SignatureLoadModel),
"Dracula Featurizer", DraculaFeaturizer.RegistrationName)]
[assembly: LoadableClass(typeof(CountTargetEncodingFeaturizer), null, typeof(SignatureLoadModel),
"Count Target Encoding Featurizer", CountTargetEncodingFeaturizer.RegistrationName)]

namespace Microsoft.ML.Transforms
{
internal sealed class DraculaFeaturizer : ICanSaveModel
internal sealed class CountTargetEncodingFeaturizer : ICanSaveModel
{
private readonly IHost _host;
private readonly int _labelBinCount;
Expand All @@ -27,7 +27,7 @@ internal sealed class DraculaFeaturizer : ICanSaveModel
public float[] PriorCoef { get; }
public float[] LaplaceScale { get; }

internal const string RegistrationName = "DraculaFeaturizer";
internal const string RegistrationName = "CountTargetEncodingFeaturizer";
private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
Expand All @@ -36,14 +36,14 @@ private static VersionInfo GetVersionInfo()
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: RegistrationName,
loaderAssemblyName: typeof(DraculaFeaturizer).Assembly.FullName);
loaderAssemblyName: typeof(CountTargetEncodingFeaturizer).Assembly.FullName);
}

public int ColCount => _countTables.ColCount;

public ReadOnlySpan<int> SlotCount => _countTables.SlotCount;

public DraculaFeaturizer(IHostEnvironment env, float[] priorCoef, float[] laplaceScale, long labelBinCount, MultiCountTableBase countTable)
public CountTargetEncodingFeaturizer(IHostEnvironment env, float[] priorCoef, float[] laplaceScale, long labelBinCount, MultiCountTableBase countTable)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(RegistrationName);
Expand All @@ -61,7 +61,7 @@ public DraculaFeaturizer(IHostEnvironment env, float[] priorCoef, float[] laplac
_countTables = countTable;
}

public DraculaFeaturizer(IHostEnvironment env, ModelLoadContext ctx)
public CountTargetEncodingFeaturizer(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(RegistrationName);
Expand Down
5 changes: 2 additions & 3 deletions test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4697,10 +4697,10 @@ public void EntryPointHashJoinCountTable()
}

[Fact]
public void EntryPointDracula()
public void EntryPointCountTargetEncoding()
{
var dataPath = GetDataPath("breast-cancer.txt");
var countsModel = DeleteOutputPath("Dracula-trained-counts.zip");
var countsModel = DeleteOutputPath("cte-trained-counts.zip");

var data = ML.Data.LoadFromTextFile(dataPath, new[]
{
Expand All @@ -4711,7 +4711,6 @@ public void EntryPointDracula()
var transformer = estimator.Fit(data);
ML.Model.Save(transformer, data.Schema, countsModel);

//var countsFile = GetDataPath(@"Dracula/ext-count-table.tsv");
TestEntryPointPipelineRoutine(dataPath, "col=Text:TX:1-9 col=OneText:TX:1 col=Label:0",
new[]
{
Expand Down
11 changes: 5 additions & 6 deletions test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1327,7 +1327,7 @@ public void SavePipeCountTableShared()
}

[Fact]
public void SavePipeDracula()
public void SavePipeCountTargetEncoding()
{
TestCore(null, false,
new[] {
Expand All @@ -1340,7 +1340,7 @@ public void SavePipeDracula()
}

[Fact]
public void SavePipeDraculaKeyLabel()
public void SavePipeCountTargetEncodingKeyLabel()
{
TestCore(null, false,
new[] {
Expand All @@ -1354,12 +1354,11 @@ public void SavePipeDraculaKeyLabel()
}

[Fact]
public void SavePipeDraculaExternalCounts()
public void SavePipeCountTargetEncodingLoadModel()
{
//var countsFile = GetDataPath("Dracula", "ext-count-table.tsv");
var inputData = GetDataPath("breast-cancer.txt");
var initialCountsModel = DeleteOutputPath("Dracula", "initialCounts.zip");
var outputData = DeleteOutputPath("Dracula", "countsData.txt");
var initialCountsModel = DeleteOutputPath("CTE", "initialCounts.zip");
var outputData = DeleteOutputPath("CTE", "countsData.txt");
var loaderArg = "loader=Text{col=Text:TX:1-9 col=OneText:TX:1 col=Label:0}";
MainForTest($"SaveData data={inputData} {loaderArg} xf=Dracula{{lab=Label col={{name=DT src=Text customSlotMap=0,1;2,3,4,5}} table = Dict}} out={initialCountsModel} dout={outputData}");

Expand Down
2 changes: 1 addition & 1 deletion test/Microsoft.ML.TestFramework/TestCommandBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1760,7 +1760,7 @@ public void CommandSaveData()
Done();
}

[TestCategory(Cat), TestCategory("Dracula")]
[TestCategory(Cat), TestCategory("CountTargetEncoding")]
[Fact(Skip = "Need CoreTLC specific baseline update")]
public void CommandDraculaInfer()
{
Expand Down

0 comments on commit ccb6f56

Please sign in to comment.