Skip to content

Updated xml docs for tree-based trainers. #2970

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 0 additions & 40 deletions docs/samples/Microsoft.ML.Samples/Dynamic/FastTreeRegression.cs

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Data;

namespace Microsoft.ML.Samples.Dynamic.Trainers.Regression
{
public static class FastTree
{
Copy link
Member

@singlis singlis Mar 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{ [](start = 4, length = 1)

So do we need a Binary and Ranking examples for FastTree? #Resolved

Copy link
Author

@shmoradims shmoradims Mar 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all the samples will come in my next PR. I added this one as template for discussion. please see my email about in-memory samples.


In reply to: 265818226 [](ancestors = 265818226)

// This example requires installation of additional NuGet package
// <a href="https://www.nuget.org/packages/Microsoft.ML.FastTree/">Microsoft.ML.FastTree</a>.
public static void Example()
{
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
// as a catalog of available operations and as the source of randomness.
// Setting the seed to a fixed number in this example to make outputs deterministic.
var mlContext = new MLContext(seed: 0);

// Create a list of training examples.
var examples = GenerateRandomDataPoints(1000);

// Convert the examples list to an IDataView object, which is consumable by ML.NET API.
var data = mlContext.Data.LoadFromEnumerable(examples);

// Define the trainer.
var pipeline = mlContext.BinaryClassification.Trainers.FastTree();

// Train the model.
var model = pipeline.Fit(data);
Copy link
Member

@wschin wschin Mar 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is minimal version without prediction. I personally like to see in-memory prediction, which is what will happen immediately in production.

Why in-memory prediction is important?
(1) User have no idea about the IDataView produced by the train model. If we don't tell them how to extract data into C# data structure, they will have to look for tutorials of IDataVIew, ITransformer, IDataView-C# bridge.
(2) Prediction format varies from different models and are ML.NET-specific, so it's also hard to figure out which one should be used.
(3) Prediction is how the trained model will be used. One might think scikit-learn doesn't do so, so we shouldn't. My suggestion is we should! Here is my reason ---- scikit-learn produces numpy data structures and everyone know how to manipulate them (by Googling for Numpy), but IDataView is not at that stage yet.

}

private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count)
{
var random = new Random(0);
float randomFloat() => (float)random.NextDouble();
for (int i = 0; i < count; i++)
{
var label = randomFloat();
yield return new DataPoint
{
Label = label,
// Create random features that are correlated with label.
Features = Enumerable.Repeat(label, 50).Select(x => x + randomFloat()).ToArray()
};
}
}

private class DataPoint
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really good! An example with 50 features!

{
public float Label { get; set; }
[VectorType(50)]
public float[] Features { get; set; }
}
}
}
13 changes: 8 additions & 5 deletions src/Microsoft.ML.Data/Training/TrainerInputBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,17 @@ private protected TrainerInputBase() { }
internal NormalizeOption NormalizeFeatures = NormalizeOption.Auto;

/// <summary>
/// Whether learner should cache input training data. Used only in entry-points, since the intended API mechanism
/// Whether trainer should cache input training data. Used only in entry-points, since the intended API mechanism
/// is that the user will use the <see cref="DataOperationsCatalog.Cache(IDataView, string[])"/> or other method
/// like <see cref="EstimatorChain{TLastTransformer}.AppendCacheCheckpoint(IHostEnvironment)"/>.
/// </summary>
[BestFriend]
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Whether learner should cache input training data", ShortName = "cache", SortOrder = 6, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Whether trainer should cache input training data", ShortName = "cache", SortOrder = 6, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
internal CachingOptions Caching = CachingOptions.Auto;
}

/// <summary>
/// The base class for all learner inputs that support a Label column.
/// The base class for all trainer inputs that support a Label column.
/// </summary>
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithLabel))]
public abstract class TrainerInputBaseWithLabel : TrainerInputBase
Expand All @@ -67,7 +67,7 @@ private protected TrainerInputBaseWithLabel() { }

// REVIEW: This is a known antipattern, but the solution involves the decorator pattern which can't be used in this case.
/// <summary>
/// The base class for all learner inputs that support a weight column.
/// The base class for all trainer inputs that support a weight column.
/// </summary>
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithWeight))]
public abstract class TrainerInputBaseWithWeight : TrainerInputBaseWithLabel
Expand All @@ -82,7 +82,7 @@ private protected TrainerInputBaseWithWeight() { }
}

/// <summary>
/// The base class for all unsupervised learner inputs that support a weight column.
/// The base class for all unsupervised trainer inputs that support a weight column.
/// </summary>
[TlcModule.EntryPointKind(typeof(CommonInputs.IUnsupervisedTrainerWithWeight))]
public abstract class UnsupervisedTrainerInputBaseWithWeight : TrainerInputBase
Expand All @@ -96,6 +96,9 @@ private protected UnsupervisedTrainerInputBaseWithWeight() { }
public string ExampleWeightColumnName = null;
}

/// <summary>
/// The base class for all trainer inputs that support a group column.
/// </summary>
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithGroupId))]
public abstract class TrainerInputBaseWithGroupId : TrainerInputBaseWithWeight
{
Expand Down
77 changes: 66 additions & 11 deletions src/Microsoft.ML.FastTree/FastTreeArguments.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,18 @@ public enum EarlyStoppingRankingMetric
NdcgAt3 = 3
}

/// <include file='doc.xml' path='doc/members/member[@name="FastTree"]/*' />
// XML docs are provided in the other part of this partial class. No need to duplicate the content here.
public sealed partial class FastTreeBinaryClassificationTrainer
{
/// <summary>
/// Options for the <see cref="FastTreeBinaryClassificationTrainer"/>.
/// </summary>
[TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)]
public sealed class Options : BoostedTreeOptions, IFastTreeTrainerFactory
{

/// <summary>
/// Option for using derivatives optimized for unbalanced sets.
/// Whether to use derivatives optimized for unbalanced training data.
/// </summary>
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Option for using derivatives optimized for unbalanced sets", ShortName = "us")]
[TGUI(Label = "Optimize for unbalanced")]
Expand Down Expand Up @@ -90,6 +93,9 @@ public EarlyStoppingMetric EarlyStoppingMetric
}
}

/// <summary>
/// Create a new <see cref="Options"/> object with default values.
/// </summary>
public Options()
{
// Use L1 by default.
Expand All @@ -100,8 +106,12 @@ public Options()
}
}

// XML docs are provided in the other part of this partial class. No need to duplicate the content here.
public sealed partial class FastTreeRegressionTrainer
{
/// <summary>
/// Options for the <see cref="FastTreeRegressionTrainer"/>.
/// </summary>
[TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)]
public sealed class Options : BoostedTreeOptions, IFastTreeTrainerFactory
{
Expand All @@ -127,6 +137,9 @@ public EarlyStoppingMetric EarlyStoppingMetric
}
}

/// <summary>
/// Create a new <see cref="Options"/> object with default values.
/// </summary>
public Options()
{
EarlyStoppingMetric = EarlyStoppingMetric.L1Norm; // Use L1 by default.
Expand All @@ -136,14 +149,22 @@ public Options()
}
}

// XML docs are provided in the other part of this partial class. No need to duplicate the content here.
public sealed partial class FastTreeTweedieTrainer
{
/// <summary>
/// Options for the <see cref="FastTreeTweedieTrainer"/>.
/// </summary>
[TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)]
public sealed class Options : BoostedTreeOptions, IFastTreeTrainerFactory
{
// REVIEW: It is possible to estimate this index parameter from the distribution of data, using
// a combination of univariate optimization and grid search, following section 4.2 of the paper. However
// it is probably not worth doing unless and until explicitly asked for.
/// <summary>
/// The index parameter for the Tweedie distribution, in the range [1, 2]. 1 is Poisson loss, 2 is gamma loss,
/// and intermediate values are compound Poisson loss.
/// </summary>
[Argument(ArgumentType.LastOccurenceWins, HelpText =
"Index parameter for the Tweedie distribution, in the range [1, 2]. 1 is Poisson loss, 2 is gamma loss, " +
"and intermediate values are compound Poisson loss.")]
Expand Down Expand Up @@ -174,6 +195,9 @@ public EarlyStoppingMetric EarlyStoppingMetric
}
}

/// <summary>
/// Create a new <see cref="Options"/> object with default values.
/// </summary>
public Options()
{
EarlyStoppingMetric = EarlyStoppingMetric.L1Norm; // Use L1 by default.
Expand All @@ -183,15 +207,25 @@ public Options()
}
}

// XML docs are provided in the other part of this partial class. No need to duplicate the content here.
public sealed partial class FastTreeRankingTrainer
{
/// <summary>
/// Options for the <see cref="FastTreeRankingTrainer"/>.
/// </summary>
[TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)]
public sealed class Options : BoostedTreeOptions, IFastTreeTrainerFactory
{
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Comma seperated list of gains associated to each relevance label.", ShortName = "gains")]
/// <summary>
/// Comma-separated list of gains associated with each relevance label.
/// </summary>
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Comma-separated list of gains associated to each relevance label.", ShortName = "gains")]
[TGUI(NoSweep = true)]
public double[] CustomGains = new double[] { 0, 3, 7, 15, 31 };

/// <summary>
/// Whether to train using discounted cumulative gain (DCG) instead of normalized DCG (NDCG).
/// </summary>
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Train DCG instead of NDCG", ShortName = "dcg")]
public bool UseDcg;

Expand All @@ -204,7 +238,11 @@ public sealed class Options : BoostedTreeOptions, IFastTreeTrainerFactory
[TGUI(NotGui = true)]
internal string SortingAlgorithm = "DescendingStablePessimistic";

[Argument(ArgumentType.AtMostOnce, HelpText = "max-NDCG truncation to use in the Lambda Mart algorithm", ShortName = "n", Hide = true)]
/// <summary>
/// The maximum NDCG truncation to use in the
/// <a href="https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/MSR-TR-2010-82.pdf">LambdaMAR algorithm</a>.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "max-NDCG truncation to use in the LambdaMART algorithm", ShortName = "n", Hide = true)]
[TGUI(NotGui = true)]
public int NdcgTruncationLevel = 100;

Expand Down Expand Up @@ -253,6 +291,9 @@ public EarlyStoppingRankingMetric EarlyStoppingMetric
}
}

/// <summary>
/// Create a new <see cref="Options"/> object with default values.
/// </summary>
public Options()
{
EarlyStoppingMetric = EarlyStoppingRankingMetric.NdcgAt1; // Use L1 by default.
Expand Down Expand Up @@ -295,6 +336,9 @@ internal static class Defaults
public const double LearningRate = 0.2;
}

/// <summary>
/// Options for tree trainers.
/// </summary>
public abstract class TreeOptions : TrainerInputBaseWithGroupId
{
/// <summary>
Expand Down Expand Up @@ -428,11 +472,13 @@ public abstract class TreeOptions : TrainerInputBaseWithGroupId
[Argument(ArgumentType.LastOccurenceWins, HelpText = "The feature re-use penalty (regularization) coefficient", ShortName = "frup")]
public Double FeatureReusePenalty;

/// Only consider a gain if its likelihood versus a random choice gain is above a certain value.
/// So 0.95 would mean restricting to gains that have less than a 0.05 change of being generated randomly through choice of a random split.
/// <summary>
/// Tree fitting gain confidence requirement (should be in the range [0,1) ).
/// Tree fitting gain confidence requirement. Only consider a gain if its likelihood versus a random choice gain is above this value.
/// </summary>
/// <value>
/// Value of 0.95 would mean restricting to gains that have less than a 0.05 chance of being generated randomly through choice of a random split.
/// Valid range is [0,1).
/// </value>
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Tree fitting gain confidence requirement (should be in the range [0,1) ).", ShortName = "gainconf")]
public Double GainConfidenceLevel;

Expand All @@ -458,7 +504,7 @@ public abstract class TreeOptions : TrainerInputBaseWithGroupId
public int NumberOfLeaves = Defaults.NumberOfLeaves;

/// <summary>
/// The minimal number of examples allowed in a leaf of a regression tree, out of the subsampled data.
/// The minimal number of data points required to form a new tree leaf.
/// </summary>
// REVIEW: Arrays not supported in GUI
// REVIEW: Different shortname than FastRank module. Same as the TLC FRWrapper.
Expand Down Expand Up @@ -582,6 +628,9 @@ internal virtual void Check(IExceptionContext ectx)
}
}

/// <summary>
/// Options for boosting tree trainers.
/// </summary>
public abstract class BoostedTreeOptions : TreeOptions
{
// REVIEW: TLC FR likes to call it bestStepRegressionTrees which might be more appropriate.
Expand All @@ -594,7 +643,7 @@ public abstract class BoostedTreeOptions : TreeOptions
public bool BestStepRankingRegressionTrees = false;

/// <summary>
/// Should we use line search for a step size.
/// Determines whether to use line search for a step size.
/// </summary>
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Should we use line search for a step size", ShortName = "ls")]
public bool UseLineSearch;
Expand All @@ -611,11 +660,17 @@ public abstract class BoostedTreeOptions : TreeOptions
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Minimum line search step size", ShortName = "minstep")]
public Double MinimumStepSize;

/// <summary>
/// Types of optimization algorithms.
/// </summary>
public enum OptimizationAlgorithmType { GradientDescent, AcceleratedGradientDescent, ConjugateGradientDescent };

/// <summary>
/// Optimization algorithm to be used (GradientDescent, AcceleratedGradientDescent).
/// Optimization algorithm to be used.
/// </summary>
/// <value>
/// See <see cref="OptimizationAlgorithmType"/> for available optimizers.
/// </value>
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Optimization algorithm to be used (GradientDescent, AcceleratedGradientDescent)", ShortName = "oa")]
public OptimizationAlgorithmType OptimizationAlgorithm = OptimizationAlgorithmType.GradientDescent;

Expand Down Expand Up @@ -655,7 +710,7 @@ public EarlyStoppingRuleBase EarlyStoppingRule
internal int EarlyStoppingMetrics;

/// <summary>
/// Enable post-training pruning to avoid overfitting. (a validation set is required).
/// Enable post-training tree pruning to avoid overfitting. It requires a validation set.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Enable post-training pruning to avoid overfitting. (a validation set is required)", ShortName = "pruning")]
public bool EnablePruning;
Expand Down
5 changes: 4 additions & 1 deletion src/Microsoft.ML.FastTree/FastTreeClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ private static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoad
private protected override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
}

/// <include file = 'doc.xml' path='doc/members/member[@name="FastTree"]/*' />
/// <summary>
/// The <see cref="IEstimator{TTransformer}"/> for training a decision tree binary classification model using FastTree.
/// </summary>
/// <include file='doc.xml' path='doc/members/member[@name="FastTree_remarks"]/*' />
public sealed partial class FastTreeBinaryClassificationTrainer :
BoostingFastTreeTrainerBase<FastTreeBinaryClassificationTrainer.Options,
BinaryPredictionTransformer<CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator>>,
Expand Down
5 changes: 4 additions & 1 deletion src/Microsoft.ML.FastTree/FastTreeRanking.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@

namespace Microsoft.ML.Trainers.FastTree
{
/// <include file='doc.xml' path='doc/members/member[@name="FastTree"]/*' />
/// <summary>
/// The <see cref="IEstimator{TTransformer}"/> for training a decision tree ranking model using FastTree.
/// </summary>
/// <include file='doc.xml' path='doc/members/member[@name="FastTree_remarks"]/*' />
public sealed partial class FastTreeRankingTrainer
: BoostingFastTreeTrainerBase<FastTreeRankingTrainer.Options, RankingPredictionTransformer<FastTreeRankingModelParameters>, FastTreeRankingModelParameters>
{
Expand Down
Loading