Skip to content

Commit

Permalink
Fixed Averaged Perceptron default value (#5586)
Browse files Browse the repository at this point in the history
* fixed missed averaged perceptron default value

* fixed extension api

* fixed test baselines
  • Loading branch information
michaelgsharp authored Jan 21, 2021
1 parent 927a61a commit c9ed772
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using Microsoft.ML.Numeric;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers;
using static Microsoft.ML.Trainers.AveragedLinearOptions;

[assembly: LoadableClass(AveragedPerceptronTrainer.Summary, typeof(AveragedPerceptronTrainer), typeof(AveragedPerceptronTrainer.Options),
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
Expand Down Expand Up @@ -76,6 +77,11 @@ public sealed class AveragedPerceptronTrainer : AveragedLinearTrainer<BinaryPred

private readonly Options _args;

internal class AveragedPerceptronDefault : AveragedDefault
{
public new const int NumberOfIterations = 10;
}

/// <summary>
/// Options for the <see cref="AveragedPerceptronTrainer"/> as used in
/// <see cref="Microsoft.ML.StandardTrainersCatalog.AveragedPerceptron(BinaryClassificationCatalog.BinaryClassificationTrainers, Options)"/>.
Expand All @@ -84,7 +90,7 @@ public sealed class Options : AveragedLinearOptions
{
public Options()
{
NumberOfIterations = 10;
NumberOfIterations = AveragedPerceptronDefault.NumberOfIterations;
}

/// <summary>
Expand Down Expand Up @@ -166,10 +172,10 @@ internal AveragedPerceptronTrainer(IHostEnvironment env,
string labelColumnName = DefaultColumnNames.Label,
string featureColumnName = DefaultColumnNames.Features,
IClassificationLoss lossFunction = null,
float learningRate = Options.AveragedDefault.LearningRate,
bool decreaseLearningRate = Options.AveragedDefault.DecreaseLearningRate,
float l2Regularization = Options.AveragedDefault.L2Regularization,
int numberOfIterations = Options.AveragedDefault.NumberOfIterations)
float learningRate = AveragedPerceptronDefault.LearningRate,
bool decreaseLearningRate = AveragedPerceptronDefault.DecreaseLearningRate,
float l2Regularization = AveragedPerceptronDefault.L2Regularization,
int numberOfIterations = AveragedPerceptronDefault.NumberOfIterations)
: this(env, new Options
{
LabelColumnName = labelColumnName,
Expand Down
9 changes: 5 additions & 4 deletions src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

namespace Microsoft.ML
{
using static Microsoft.ML.Trainers.AveragedPerceptronTrainer;
using LROptions = LbfgsLogisticRegressionBinaryTrainer.Options;

/// <summary>
Expand Down Expand Up @@ -417,10 +418,10 @@ public static AveragedPerceptronTrainer AveragedPerceptron(
string labelColumnName = DefaultColumnNames.Label,
string featureColumnName = DefaultColumnNames.Features,
IClassificationLoss lossFunction = null,
float learningRate = AveragedLinearOptions.AveragedDefault.LearningRate,
bool decreaseLearningRate = AveragedLinearOptions.AveragedDefault.DecreaseLearningRate,
float l2Regularization = AveragedLinearOptions.AveragedDefault.L2Regularization,
int numberOfIterations = AveragedLinearOptions.AveragedDefault.NumberOfIterations)
float learningRate = AveragedPerceptronDefault.LearningRate,
bool decreaseLearningRate = AveragedPerceptronDefault.DecreaseLearningRate,
float l2Regularization = AveragedPerceptronDefault.L2Regularization,
int numberOfIterations = AveragedPerceptronDefault.NumberOfIterations)
{
Contracts.CheckValue(catalog, nameof(catalog));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#@ col=FeatureContributions:R4:32-37
#@ col=FeatureContributions:R4:38-43
#@ }
950 757 692 720 297 7515 1 1 950 757 692 720 297 7515 0.956696868 0.760804 0.7872582 0.754716933 0.297893673 0.7578661 0 0.1527809 0 0 0 1 -0.6583012 0 -1 -0.517060339 0 0 12 2:-0.13028869 5:1 8:-0.370813 11:2.84608746
459 961 0 659 274 2147 0 0 459 961 0 659 274 2147 0.462235659 0.965829134 0 0.690775633 0.27482447 0.21651876 0 0.6788808 0 0 0 0.99999994 -0.6720779 0 0 -1 -0.870772958 0 12 3:-0.215823054 5:0.99999994 9:-0.175488681 11:0.8131137
672 275 0 65 195 9818 1 1 672 275 0 65 195 9818 0.6767372 0.2763819 0 0.06813417 0.195586756 0.990116954 0 0.04248268 0 0 0 1 -1 0 0 -0.100242466 -0.6298147 0 12 0:-0.04643902 5:1 6:-0.172673345 11:3.71828127
186 301 0 681 526 1456 0 0 186 301 0 681 526 1456 0.187311172 0.302512556 0 0.713836432 0.527582765 0.1468334 0 0.313550383 0 0 0 1 -0.162922 0 0 -0.6181894 -1 0 12 4:-0.5319963 5:1 10:-0.293352127 11:0.5514176
950 757 692 720 297 7515 1 1 950 757 692 720 297 7515 0.956696868 0.760804 0.7872582 0.754716933 0.297893673 0.7578661 0.094661206 0.00312626758 0 0 0 1 0 0 -0.704976737 -0.99999994 -0.322129458 0 12 3:-0.0539601371 5:1 9:-0.48089987 11:8.912132
459 961 0 659 274 2147 0 0 459 961 0 659 274 2147 0.462235659 0.965829134 0 0.690775633 0.27482447 0.21651876 0.160087749 0.013891547 0 0 0 1 18 3:-1 4:-0.32469207 9:-0.1728713 11:1 15:-0.440156966 17:2.546154
672 275 0 65 195 9818 1 1 672 275 0 65 195 9818 0.6767372 0.2763819 0 0.06813417 0.195586756 0.990116954 0.05125352 0.0008692985 0 0 0 1 18 3:-0.426846981 4:-1 10:-0.008735497 11:1 16:-0.10170991 17:11.6432886
186 301 0 681 526 1456 0 0 186 301 0 681 526 1456 0.187311172 0.302512556 0 0.713836432 0.527582765 0.1468334 0.0956596956 0.006416 0 0 0 1 18 3:-1 4:-0.603177547 9:-0.263423949 11:1 15:-0.454851121 17:1.72668862

0 comments on commit c9ed772

Please sign in to comment.