Skip to content

KMeans and Implicit weight cleanup #2158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/Microsoft.ML.Data/EntryPoints/InputBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ public abstract class LearnerInputBaseWithWeight : LearnerInputBaseWithLabel
[TlcModule.EntryPointKind(typeof(CommonInputs.IUnsupervisedTrainerWithWeight))]
public abstract class UnsupervisedLearnerInputBaseWithWeight : LearnerInputBase
{
/// <summary>
/// Column to use for example weight.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public Optional<string> WeightColumn = Optional<string>.Implicit(DefaultColumnNames.Weight);
}
Expand Down
7 changes: 7 additions & 0 deletions src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ SchemaShape IEstimator<CalibratorTransformer<TICalibrator>>.GetOutputSchema(Sche
return new SchemaShape(outColumns.Values);
}

/// <summary>
/// Fits the scored <see cref="IDataView"/> creating a <see cref="CalibratorTransformer{TICalibrator}"/> that can transform the data by adding a
/// <see cref="DefaultColumnNames.Probability"/> column containing the calibrated <see cref="DefaultColumnNames.Score"/>.
/// </summary>
/// <param name="input"></param>
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Jan 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[](start = 12, length = 28)

with "Label" and "Features" columns.
From what I see, we actually don't use Score column at all, we just run predictor on top of feature column to produce "Score" column. So i'm not sure how necessary is to have that Score role in roleMappedData. #Pending

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see line 262. We do use the score column.


In reply to: 248115187 [](ancestors = 248115187)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During transform stage. During Fit we train calibrator on combination of Predictor + FeatureColumn and LabelColumn.
But it's just an observation. You don't have to do anything about that.


In reply to: 248152774 [](ancestors = 248152774,248115187)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But you can still put something into empty param, tho


In reply to: 248410187 [](ancestors = 248410187,248152774,248115187)

/// <returns>A trained <see cref="CalibratorTransformer{TICalibrator}"/> that will transform the data by adding the
/// <see cref="DefaultColumnNames.Probability"/> column.</returns>
public CalibratorTransformer<TICalibrator> Fit(IDataView input)
{
TICalibrator calibrator = null;
Expand Down
17 changes: 14 additions & 3 deletions src/Microsoft.ML.Data/Training/TrainerUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Collections.Generic;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Jan 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't look right. #Closed

Copy link
Member Author

@sfilipi sfilipi Jan 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This namespace is in Microsoft.ML.Core. There is no circular dependancy


In reply to: 248115282 [](ancestors = 248115282)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It just strange what we have EntryPoints as dependency. But I guess it's related to fact what Arguments class shared with entrypoints.


In reply to: 248152842 [](ancestors = 248152842,248115282)

using Microsoft.ML.Internal.Utilities;

namespace Microsoft.ML.Training
Expand Down Expand Up @@ -382,10 +383,20 @@ public static SchemaShape.Column MakeR4VecFeature(string featureColumn)
/// The <see cref="SchemaShape.Column"/> for the weight column.
/// </summary>
/// <param name="weightColumn">name of the weight column</param>
/// <param name="isExplicit">whether the column is implicitly, or explicitly defined</param>
public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn, bool isExplicit = true)
public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn)
{
if (weightColumn == null || !isExplicit)
if (weightColumn == null)
return default;
return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
}

/// <summary>
/// The <see cref="SchemaShape.Column"/> for the weight column.
/// </summary>
/// <param name="weightColumn">name of the weight column</param>
public static SchemaShape.Column MakeR4ScalarWeightColumn(Optional<string> weightColumn)
{
if (weightColumn == null || weightColumn.Value == null || !weightColumn.IsExplicit)
return default;
return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
}
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/FastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env,
/// Constructor that is used when invoking the classes deriving from this, through maml.
/// </summary>
private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit))
Copy link
Contributor

@justinormont justinormont Jan 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is/was args.WeightColumn.IsExplicit used for? #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is something that got introduced with EntryPoints. If the user doesn't specify a name for the optional columns, an Implict argument get created for them... not sure of the original thoughts around why not just leave it null. Maybe to avoid dealing with its serialization across languages.. idk.


In reply to: 248127604 [](ancestors = 248127604)

: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
{
Host.CheckValue(args, nameof(args));
Args = args;
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/GamTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ private protected GamTrainerBase(IHostEnvironment env,

private protected GamTrainerBase(IHostEnvironment env, TArgs args, string name, SchemaShape.Column label)
: base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn),
label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit))
label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
{
Contracts.CheckValue(env, nameof(env));
Host.CheckValue(args, nameof(args));
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.HalLearners/OlsLinearRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public OlsLinearRegressionTrainer(IHostEnvironment env,
/// </summary>
internal OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args)
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(args.FeatureColumn),
TrainerUtils.MakeR4ScalarColumn(args.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit))
TrainerUtils.MakeR4ScalarColumn(args.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
{
Host.CheckValue(args, nameof(args));
Host.CheckUserArg(args.L2Weight >= 0, nameof(args.L2Weight), "L2 regularization term cannot be negative");
Expand All @@ -106,7 +106,7 @@ private static Arguments ArgsInit(string featureColumn,
advancedSettings?.Invoke(args);
args.FeatureColumn = featureColumn;
args.LabelColumn = labelColumn;
args.WeightColumn = weightColumn;
args.WeightColumn = weightColumn != null ? Optional<string>.Explicit(weightColumn) : Optional<string>.Implicit(DefaultColumnNames.Weight);
Copy link
Member

@abgoswam abgoswam Jan 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weightColumn != null [](start = 32, length = 21)

maybe we should have a separate PR just to fix the weights bug across the entire codebase ?

that way we isolate the API changes from the bug related to Weight. Also, less dependency between PRs related to learner API changes

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

than we'd need a separate PR to fix the ctor in KMeans .. and this PR does take care of both..


In reply to: 248117752 [](ancestors = 248117752)

Copy link
Member

@abgoswam abgoswam Jan 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my concern is by coupling the weights changes with API changes, we are making ourselves susceptible to inconsistency across PRs etc.

Would be much more convenient if we have consistent story for weights, get that reviewed + checked in and then move on with API fixing .

So I am not sure at this point if I need to make similar changes for the weights bug in my existing PRs or not ? I see we are adding overloads for MakeR4ScalarWeightColumn.


In reply to: 248158333 [](ancestors = 248158333,248117752)

return args;
}

Expand Down
40 changes: 33 additions & 7 deletions src/Microsoft.ML.KMeansClustering/KMeansCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Trainers.KMeans;

namespace Microsoft.ML
Expand All @@ -16,20 +17,45 @@ public static class KMeansClusteringExtensions
/// <summary>
/// Train a KMeans++ clustering algorithm.
/// </summary>
/// <param name="ctx">The regression context trainer object.</param>
/// <param name="features">The features, or independent variables.</param>
/// <param name="ctx">The clustering context trainer object.</param>
/// <param name="featureColumn">The features, or independent variables.</param>
Copy link
Member

@wschin wschin Jan 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe featureColumnName? In ML.NET, we mix the meanings of column name, column index, and column itself. I feel our naming can be more specific and therefore less ambiguous. #Pending

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will leave it to featureColumn, because we have had a loooonnnggg discussion about it, and a PR to standartize those.

#1524 (comment)


In reply to: 248383325 [](ancestors = 248383325)

/// <param name="weights">The optional example weights.</param>
/// <param name="clustersCount">The number of clusters to use for KMeans.</param>
/// <param name="advancedSettings">Algorithm advanced settings.</param>
/// <example>
/// <format type="text/markdown">
/// <![CDATA[
/// [!code-csharp[KMeans](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/KMeans.cs)]
/// ]]></format>
/// </example>
public static KMeansPlusPlusTrainer KMeans(this ClusteringContext.ClusteringTrainers ctx,
string features,
string featureColumn = DefaultColumnNames.Features,
string weights = null,
int clustersCount = KMeansPlusPlusTrainer.Defaults.K,
Action<KMeansPlusPlusTrainer.Arguments> advancedSettings = null)
int clustersCount = KMeansPlusPlusTrainer.Defaults.ClustersCount)
{
Contracts.CheckValue(ctx, nameof(ctx));
var env = CatalogUtils.GetEnvironment(ctx);
return new KMeansPlusPlusTrainer(env, features, clustersCount, weights, advancedSettings);

var options = new KMeansPlusPlusTrainer.Options
{
FeatureColumn = featureColumn,
WeightColumn = weights != null ? Optional<string>.Explicit(weights) : Optional<string>.Implicit(DefaultColumnNames.Weight),
ClustersCount = clustersCount
};
return new KMeansPlusPlusTrainer(env, options);
}

/// <summary>
/// Train a KMeans++ clustering algorithm.
/// </summary>
/// <param name="ctx">The clustering context trainer object.</param>
/// <param name="options">Algorithm advanced options.</param>
public static KMeansPlusPlusTrainer KMeans(this ClusteringContext.ClusteringTrainers ctx, KMeansPlusPlusTrainer.Options options)
{
Contracts.CheckValue(ctx, nameof(ctx));
Contracts.CheckValue(options, nameof(options));

var env = CatalogUtils.GetEnvironment(ctx);
return new KMeansPlusPlusTrainer(env, options);
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Jan 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

options [](start = 50, length = 7)

Host.CheckValue(options, nameof(options));
You have this in internal KMeansPlusPlusTrainer(IHostEnvironment env, Options options)
if no one will specify options that check will throw.
Is it works? Do we somehow magically create Options somewhere? #Closed

}
}
}
111 changes: 52 additions & 59 deletions src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
using Microsoft.ML.Trainers.KMeans;
using Microsoft.ML.Training;

[assembly: LoadableClass(KMeansPlusPlusTrainer.Summary, typeof(KMeansPlusPlusTrainer), typeof(KMeansPlusPlusTrainer.Arguments),
[assembly: LoadableClass(KMeansPlusPlusTrainer.Summary, typeof(KMeansPlusPlusTrainer), typeof(KMeansPlusPlusTrainer.Options),
new[] { typeof(SignatureClusteringTrainer), typeof(SignatureTrainer) },
KMeansPlusPlusTrainer.UserNameValue,
KMeansPlusPlusTrainer.LoadNameValue,
Expand All @@ -30,7 +30,7 @@ namespace Microsoft.ML.Trainers.KMeans
/// <include file='./doc.xml' path='doc/members/member[@name="KMeans++"]/*' />
public class KMeansPlusPlusTrainer : TrainerEstimatorBase<ClusteringPredictionTransformer<KMeansModelParameters>, KMeansModelParameters>
{
public const string LoadNameValue = "KMeansPlusPlus";
internal const string LoadNameValue = "KMeansPlusPlus";
internal const string UserNameValue = "KMeans++ Clustering";
internal const string ShortName = "KM";
internal const string Summary = "K-means is a popular clustering algorithm. With K-means, the data is clustered into a specified "
Expand All @@ -45,34 +45,54 @@ public enum InitAlgorithm
}

[BestFriend]
internal static class Defaults{
internal static class Defaults
{
/// <value>The number of clusters.</value>
public const int K = 5;
public const int ClustersCount = 5;
}

public class Arguments : UnsupervisedLearnerInputBaseWithWeight
public class Options : UnsupervisedLearnerInputBaseWithWeight
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Options [](start = 21, length = 7)

Just an observation.
I found it's funny what we diverge during ITransformer/IEstimator conversion and for what used to be transforms we come up with ColumnInfo and we no longer rely on arguments. (We still have conversion from arguments to columnInfo for EntryPoints)

Eventually half of our entrypoints would become options, and other half remain arguments.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point... and a bummer. Does it make sense to have a plan to reconcile?


In reply to: 248412000 [](ancestors = 248412000)

{
[Argument(ArgumentType.AtMostOnce, HelpText = "The number of clusters", SortOrder = 50)]
/// <summary>
/// The number of clusters.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "The number of clusters", SortOrder = 50, Name = "K")]
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Jan 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Name [](start = 100, length = 4)

ShortName maybe? it's a "K" is quite short, right? #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Name is for backwards compatibility. This is what maml uses.


In reply to: 248405583 [](ancestors = 248405583)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ShortNames for same thing, and as @najeeb-kazmi discover yesterday, Name is not always work, but ShortName do (or maybe combination of ShortName + Name).


In reply to: 248405874 [](ancestors = 248405874,248405583)

[TGUI(SuggestedSweeps = "5,10,20,40")]
[TlcModule.SweepableDiscreteParam("K", new object[] { 5, 10, 20, 40 })]
public int K = Defaults.K;
public int ClustersCount = Defaults.ClustersCount;

/// <summary>
/// Cluster initialization algorithm.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Cluster initialization algorithm", ShortName = "init")]
public InitAlgorithm InitAlgorithm = InitAlgorithm.KMeansParallel;

/// <summary>
/// Tolerance parameter for trainer convergence. Low = slower, more accurate.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Tolerance parameter for trainer convergence. Low = slower, more accurate",
ShortName = "ot")]
Name = "OptTol", ShortName = "ot")]
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Jan 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"ot" [](start = 45, length = 4)

you can enumerate ShortNames via comma "opttol,ot" #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apparently name is for maml, and i set it to the previous name for backwards compatibility.


In reply to: 248405346 [](ancestors = 248405346)

[TGUI(Label = "Optimization Tolerance", Description = "Threshold for trainer convergence")]
public float OptTol = (float)1e-7;
public float OptimizationTolerance = (float)1e-7;

/// <summary>
/// Maximum number of iterations.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Maximum number of iterations.", ShortName = "maxiter")]
[TGUI(Label = "Max Number of Iterations")]
public int MaxIterations = 1000;

[Argument(ArgumentType.AtMostOnce, HelpText = "Memory budget (in MBs) to use for KMeans acceleration", ShortName = "accelMemBudgetMb")]
/// <summary>
/// Memory budget (in MBs) to use for KMeans acceleration.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Memory budget (in MBs) to use for KMeans acceleration",
Name = "AccelMemBudgetMb", ShortName = "accelMemBudgetMb")]
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Jan 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Name = "AccelMemBudgetMb" [](start = 16, length = 25)

I think we lowercase shortnames and name anyway, so I don't think you need to add this Name parameter #Pending

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep it to clearly documents the previous name? Najeeb has been doing the same on his PR of pluralizing some args names.


In reply to: 248405947 [](ancestors = 248405947)

[TGUI(Label = "Memory Budget (in MBs) for KMeans Acceleration")]
public int AccelMemBudgetMb = 4 * 1024; // by default, use at most 4 GB
public int AccelerationMemoryBudgetMb = 4 * 1024; // by default, use at most 4 GB

/// <summary>
/// Degree of lock-free parallelism. Defaults to automatic. Determinism not guaranteed.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Degree of lock-free parallelism. Defaults to automatic. Determinism not guaranteed.", ShortName = "nt,t,threads", SortOrder = 50)]
[TGUI(Label = "Number of threads")]
public int? NumThreads;
Expand All @@ -95,58 +115,31 @@ public class Arguments : UnsupervisedLearnerInputBaseWithWeight
/// Initializes a new instance of <see cref="KMeansPlusPlusTrainer"/>
/// </summary>
/// <param name="env">The <see cref="IHostEnvironment"/> to use.</param>
/// <param name="featureColumn">The name of the feature column.</param>
/// <param name="weights">The name for the optional column containing the example weights.</param>
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
/// <param name="clustersCount">The number of clusters.</param>
public KMeansPlusPlusTrainer(IHostEnvironment env,
string featureColumn = DefaultColumnNames.Features,
int clustersCount = Defaults.K,
string weights = null,
Action<Arguments> advancedSettings = null)
: this(env, new Arguments
{
FeatureColumn = featureColumn,
WeightColumn = weights,
K = clustersCount
}, advancedSettings)
{
}

internal KMeansPlusPlusTrainer(IHostEnvironment env, Arguments args)
: this(env, args, null)
/// <param name="options">The advanced options of the algorithm.</param>
internal KMeansPlusPlusTrainer(IHostEnvironment env, Options options)
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(options.FeatureColumn), default, TrainerUtils.MakeR4ScalarWeightColumn(options.WeightColumn))
{
Host.CheckValue(options, nameof(options));
Host.CheckUserArg(options.ClustersCount > 0, nameof(options.ClustersCount), "Must be positive");

}

private KMeansPlusPlusTrainer(IHostEnvironment env, Arguments args, Action<Arguments> advancedSettings = null)
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), default, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
{
Host.CheckValue(args, nameof(args));

// override with the advanced settings.
advancedSettings?.Invoke(args);

Host.CheckUserArg(args.K > 0, nameof(args.K), "Must be positive");

_featureColumn = args.FeatureColumn;
_featureColumn = options.FeatureColumn;

_k = args.K;
_k = options.ClustersCount;

Host.CheckUserArg(args.MaxIterations > 0, nameof(args.MaxIterations), "Must be positive");
_maxIterations = args.MaxIterations;
Host.CheckUserArg(options.MaxIterations > 0, nameof(options.MaxIterations), "Must be positive");
_maxIterations = options.MaxIterations;

Host.CheckUserArg(args.OptTol > 0, nameof(args.OptTol), "Tolerance must be positive");
_convergenceThreshold = args.OptTol;
Host.CheckUserArg(options.OptimizationTolerance > 0, nameof(options.OptimizationTolerance), "Tolerance must be positive");
_convergenceThreshold = options.OptimizationTolerance;

Host.CheckUserArg(args.AccelMemBudgetMb > 0, nameof(args.AccelMemBudgetMb), "Must be positive");
_accelMemBudgetMb = args.AccelMemBudgetMb;
Host.CheckUserArg(options.AccelerationMemoryBudgetMb > 0, nameof(options.AccelerationMemoryBudgetMb), "Must be positive");
_accelMemBudgetMb = options.AccelerationMemoryBudgetMb;

_initAlgorithm = args.InitAlgorithm;
_initAlgorithm = options.InitAlgorithm;

Host.CheckUserArg(!args.NumThreads.HasValue || args.NumThreads > 0, nameof(args.NumThreads),
Host.CheckUserArg(!options.NumThreads.HasValue || options.NumThreads > 0, nameof(options.NumThreads),
"Must be either null or a positive integer.");
_numThreads = ComputeNumThreads(Host, args.NumThreads);
_numThreads = ComputeNumThreads(Host, options.NumThreads);
Info = new TrainerInfo();
}

Expand Down Expand Up @@ -247,14 +240,14 @@ private static int ComputeNumThreads(IHost host, int? argNumThreads)
ShortName = ShortName,
XmlInclude = new[] { @"<include file='../Microsoft.ML.KMeansClustering/doc.xml' path='doc/members/member[@name=""KMeans++""]/*' />",
@"<include file='../Microsoft.ML.KMeansClustering/doc.xml' path='doc/members/example[@name=""KMeans++""]/*' />"})]
public static CommonOutputs.ClusteringOutput TrainKMeans(IHostEnvironment env, Arguments input)
public static CommonOutputs.ClusteringOutput TrainKMeans(IHostEnvironment env, Options input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("TrainKMeans");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);

return LearnerEntryPointsUtils.Train<Arguments, CommonOutputs.ClusteringOutput>(host, input,
return LearnerEntryPointsUtils.Train<Options, CommonOutputs.ClusteringOutput>(host, input,
() => new KMeansPlusPlusTrainer(host, input),
getWeight: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn));
}
Expand Down Expand Up @@ -749,10 +742,10 @@ public static void Initialize(IHost host, int numThreads, IChannel ch, FeatureFl
host.CheckValue(ch, nameof(ch));
ch.CheckValue(cursorFactory, nameof(cursorFactory));
ch.CheckValue(centroids, nameof(centroids));
ch.CheckUserArg(numThreads > 0, nameof(KMeansPlusPlusTrainer.Arguments.NumThreads), "Must be positive");
ch.CheckUserArg(k > 0, nameof(KMeansPlusPlusTrainer.Arguments.K), "Must be positive");
ch.CheckUserArg(numThreads > 0, nameof(KMeansPlusPlusTrainer.Options.NumThreads), "Must be positive");
ch.CheckUserArg(k > 0, nameof(KMeansPlusPlusTrainer.Options.ClustersCount), "Must be positive");
ch.CheckParam(dimensionality > 0, nameof(dimensionality), "Must be positive");
ch.CheckUserArg(accelMemBudgetMb >= 0, nameof(KMeansPlusPlusTrainer.Arguments.AccelMemBudgetMb), "Must be non-negative");
ch.CheckUserArg(accelMemBudgetMb >= 0, nameof(KMeansPlusPlusTrainer.Options.AccelerationMemoryBudgetMb), "Must be non-negative");

int numRounds;
int numSamplesPerRound;
Expand Down
Loading