Skip to content

Commit fd30559

Browse files
authored
Modify API for advanced settings (RandomizedPcaTrainer) (#2390)
* RandomizedPcaTrainer constructor made internal * MLCOntext for PCA * update test example * added evaluation metrics for anomaly detection * make tests work. it seems adding a catalog to MLContext changes some seeds? * also updating baseline file for Release builds * review comments * taking care of review comments
1 parent fc92774 commit fd30559

File tree

10 files changed

+268
-22
lines changed

10 files changed

+268
-22
lines changed

src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
using Microsoft.ML;
1111
using Microsoft.ML.CommandLine;
1212
using Microsoft.ML.Data;
13+
using Microsoft.ML.Data.Evaluators.Metrics;
1314
using Microsoft.ML.EntryPoints;
1415
using Microsoft.ML.Internal.Utilities;
1516
using Microsoft.ML.Transforms;
@@ -576,6 +577,44 @@ public void Finish()
576577
FinishOtherMetrics();
577578
}
578579
}
580+
581+
/// <summary>
582+
/// Evaluates scored anomaly detection data.
583+
/// </summary>
584+
/// <param name="data">The scored data.</param>
585+
/// <param name="label">The name of the label column in <paramref name="data"/>.</param>
586+
/// <param name="score">The name of the score column in <paramref name="data"/>.</param>
587+
/// <param name="predictedLabel">The name of the predicted label column in <paramref name="data"/>.</param>
588+
/// <returns>The evaluation results for these outputs.</returns>
589+
internal AnomalyDetectionMetrics Evaluate(IDataView data, string label = DefaultColumnNames.Label, string score = DefaultColumnNames.Score,
590+
string predictedLabel = DefaultColumnNames.PredictedLabel)
591+
{
592+
Host.CheckValue(data, nameof(data));
593+
Host.CheckNonEmpty(label, nameof(label));
594+
Host.CheckNonEmpty(score, nameof(score));
595+
Host.CheckNonEmpty(predictedLabel, nameof(predictedLabel));
596+
597+
var roles = new RoleMappedData(data, opt: false,
598+
RoleMappedSchema.ColumnRole.Label.Bind(label),
599+
RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, score),
600+
RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.PredictedLabel, predictedLabel));
601+
602+
var resultDict = ((IEvaluator)this).Evaluate(roles);
603+
Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
604+
var overall = resultDict[MetricKinds.OverallMetrics];
605+
606+
AnomalyDetectionMetrics result;
607+
using (var cursor = overall.GetRowCursorForAllColumns())
608+
{
609+
var moved = cursor.MoveNext();
610+
Host.Assert(moved);
611+
result = new AnomalyDetectionMetrics(Host, cursor);
612+
moved = cursor.MoveNext();
613+
Host.Assert(!moved);
614+
}
615+
return result;
616+
}
617+
579618
}
580619

581620
[BestFriend]
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using Microsoft.Data.DataView;
7+
8+
namespace Microsoft.ML.Data.Evaluators.Metrics
9+
{
10+
/// <summary>
11+
/// Evaluation results for anomaly detection.
12+
/// </summary>
13+
public sealed class AnomalyDetectionMetrics
14+
{
15+
/// <summary>
16+
/// Gets the area under the ROC curve.
17+
/// </summary>
18+
/// <remarks>
19+
/// The area under the ROC curve is equal to the probability that the algorithm ranks
20+
/// a randomly chosen positive instance higher than a randomly chosen negative one
21+
/// (assuming 'positive' ranks higher than 'negative').
22+
/// </remarks>
23+
public double Auc { get; }
24+
25+
/// <summary>
26+
/// Detection rate at K false positives.
27+
/// </summary>
28+
/// <remarks>
29+
/// This is computed as follows:
30+
/// 1.Sort the test examples by the output of the anomaly detector in descending order of scores.
31+
/// 2.Among the top K False Positives, compute ratio : (True Positive @ K) / (Total anomalies in test data)
32+
/// Example confusion matrix for anomaly detection:
33+
/// Anomalies (in test data) | Non-Anomalies (in test data)
34+
/// Predicted Anomalies : TP | FP
35+
/// Predicted Non-Anomalies : FN | TN
36+
/// </remarks>
37+
public double DrAtK { get; }
38+
39+
internal AnomalyDetectionMetrics(IExceptionContext ectx, Row overallResult)
40+
{
41+
double FetchDouble(string name) => RowCursorUtils.Fetch<double>(ectx, overallResult, name);
42+
Auc = FetchDouble(BinaryClassifierEvaluator.Auc);
43+
DrAtK = FetchDouble(AnomalyDetectionEvaluator.OverallMetrics.DrAtK);
44+
}
45+
}
46+
}

src/Microsoft.ML.Data/MLContext.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ public sealed class MLContext : IHostEnvironment
4040
/// </summary>
4141
public RankingCatalog Ranking { get; }
4242

43+
/// <summary>
44+
/// Trainers and tasks specific to anomaly detection problems.
45+
/// </summary>
46+
public AnomalyDetectionCatalog AnomalyDetection { get; }
47+
4348
/// <summary>
4449
/// Data processing operations.
4550
/// </summary>
@@ -83,6 +88,7 @@ public MLContext(int? seed = null, int conc = 0)
8388
Regression = new RegressionCatalog(_env);
8489
Clustering = new ClusteringCatalog(_env);
8590
Ranking = new RankingCatalog(_env);
91+
AnomalyDetection = new AnomalyDetectionCatalog(_env);
8692
Transforms = new TransformsCatalog(_env);
8793
Model = new ModelOperationsCatalog(_env);
8894
Data = new DataOperationsCatalog(_env);

src/Microsoft.ML.Data/TrainCatalog.cs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.Linq;
77
using Microsoft.Data.DataView;
88
using Microsoft.ML.Data;
9+
using Microsoft.ML.Data.Evaluators.Metrics;
910
using Microsoft.ML.Transforms;
1011
using Microsoft.ML.Transforms.Conversions;
1112

@@ -646,4 +647,53 @@ public RankerMetrics Evaluate(IDataView data, string label, string groupId, stri
646647
return eval.Evaluate(data, label, groupId, score);
647648
}
648649
}
650+
651+
/// <summary>
652+
/// The central catalog for anomaly detection tasks and trainers.
653+
/// </summary>
654+
public sealed class AnomalyDetectionCatalog : TrainCatalogBase
655+
{
656+
/// <summary>
657+
/// The list of trainers for anomaly detection.
658+
/// </summary>
659+
public AnomalyDetectionTrainers Trainers { get; }
660+
661+
internal AnomalyDetectionCatalog(IHostEnvironment env)
662+
: base(env, nameof(AnomalyDetectionCatalog))
663+
{
664+
Trainers = new AnomalyDetectionTrainers(this);
665+
}
666+
667+
public sealed class AnomalyDetectionTrainers : CatalogInstantiatorBase
668+
{
669+
internal AnomalyDetectionTrainers(AnomalyDetectionCatalog catalog)
670+
: base(catalog)
671+
{
672+
}
673+
}
674+
675+
/// <summary>
676+
/// Evaluates scored anomaly detection data.
677+
/// </summary>
678+
/// <param name="data">The scored data.</param>
679+
/// <param name="label">The name of the label column in <paramref name="data"/>.</param>
680+
/// <param name="score">The name of the score column in <paramref name="data"/>.</param>
681+
/// <param name="predictedLabel">The name of the predicted label column in <paramref name="data"/>.</param>
682+
/// <param name="k">The number of false positives to compute the <see cref="AnomalyDetectionMetrics.DrAtK"/> metric. </param>
683+
/// <returns>Evaluation results.</returns>
684+
public AnomalyDetectionMetrics Evaluate(IDataView data, string label = DefaultColumnNames.Label, string score = DefaultColumnNames.Score,
685+
string predictedLabel = DefaultColumnNames.PredictedLabel, int k = 10)
686+
{
687+
Environment.CheckValue(data, nameof(data));
688+
Environment.CheckNonEmpty(label, nameof(label));
689+
Environment.CheckNonEmpty(score, nameof(score));
690+
Environment.CheckNonEmpty(predictedLabel, nameof(predictedLabel));
691+
692+
var args = new AnomalyDetectionEvaluator.Arguments();
693+
args.K = k;
694+
695+
var eval = new AnomalyDetectionEvaluator(Environment, args);
696+
return eval.Evaluate(data, label, score, predictedLabel);
697+
}
698+
}
649699
}

src/Microsoft.ML.PCA/PCACatalog.cs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
// See the LICENSE file in the project root for more information.
44

55
using Microsoft.ML.Data;
6+
using Microsoft.ML.Trainers.PCA;
67
using Microsoft.ML.Transforms.Projections;
8+
using static Microsoft.ML.Trainers.PCA.RandomizedPcaTrainer;
79

810
namespace Microsoft.ML
911
{
1012
public static class PcaCatalog
1113
{
12-
1314
/// <summary>Initializes a new instance of <see cref="PrincipalComponentAnalysisEstimator"/>.</summary>
1415
/// <param name="catalog">The transform's catalog.</param>
1516
/// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
@@ -35,5 +36,40 @@ public static PrincipalComponentAnalysisEstimator ProjectToPrincipalComponents(t
3536
/// <param name="columns">Input columns to apply PrincipalComponentAnalysis on.</param>
3637
public static PrincipalComponentAnalysisEstimator ProjectToPrincipalComponents(this TransformsCatalog.ProjectionTransforms catalog, params PrincipalComponentAnalysisEstimator.ColumnInfo[] columns)
3738
=> new PrincipalComponentAnalysisEstimator(CatalogUtils.GetEnvironment(catalog), columns);
39+
40+
/// <summary>
41+
/// Trains an approximate PCA using Randomized SVD algorithm.
42+
/// </summary>
43+
/// <param name="catalog">The anomaly detection catalog trainer object.</param>
44+
/// <param name="featureColumn">The features, or independent variables.</param>
45+
/// <param name="weights">The optional example weights.</param>
46+
/// <param name="rank">The number of components in the PCA.</param>
47+
/// <param name="oversampling">Oversampling parameter for randomized PCA training.</param>
48+
/// <param name="center">If enabled, data is centered to be zero mean.</param>
49+
/// <param name="seed">The seed for random number generation.</param>
50+
public static RandomizedPcaTrainer RandomizedPca(this AnomalyDetectionCatalog.AnomalyDetectionTrainers catalog,
51+
string featureColumn = DefaultColumnNames.Features,
52+
string weights = null,
53+
int rank = Options.Defaults.NumComponents,
54+
int oversampling = Options.Defaults.OversamplingParameters,
55+
bool center = Options.Defaults.IsCenteredZeroMean,
56+
int? seed = null)
57+
{
58+
Contracts.CheckValue(catalog, nameof(catalog));
59+
var env = CatalogUtils.GetEnvironment(catalog);
60+
return new RandomizedPcaTrainer(env, featureColumn, weights, rank, oversampling, center, seed);
61+
}
62+
63+
/// <summary>
64+
/// Trains an approximate PCA using Randomized SVD algorithm.
65+
/// </summary>
66+
/// <param name="catalog">The anomaly detection catalog trainer object.</param>
67+
/// <param name="options">Advanced options to the algorithm.</param>
68+
public static RandomizedPcaTrainer RandomizedPca(this AnomalyDetectionCatalog.AnomalyDetectionTrainers catalog, Options options)
69+
{
70+
Contracts.CheckValue(catalog, nameof(catalog));
71+
var env = CatalogUtils.GetEnvironment(catalog);
72+
return new RandomizedPcaTrainer(env, options);
73+
}
3874
}
3975
}

src/Microsoft.ML.PCA/PcaTrainer.cs

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
using Microsoft.ML.Trainers.PCA;
1919
using Microsoft.ML.Training;
2020

21-
[assembly: LoadableClass(RandomizedPcaTrainer.Summary, typeof(RandomizedPcaTrainer), typeof(RandomizedPcaTrainer.Arguments),
21+
[assembly: LoadableClass(RandomizedPcaTrainer.Summary, typeof(RandomizedPcaTrainer), typeof(RandomizedPcaTrainer.Options),
2222
new[] { typeof(SignatureAnomalyDetectorTrainer), typeof(SignatureTrainer) },
2323
RandomizedPcaTrainer.UserNameValue,
2424
RandomizedPcaTrainer.LoadNameValue,
@@ -48,24 +48,31 @@ public sealed class RandomizedPcaTrainer : TrainerEstimatorBase<AnomalyPredictio
4848
internal const string Summary = "This algorithm trains an approximate PCA using Randomized SVD algorithm. "
4949
+ "This PCA can be made into Kernel PCA by using Random Fourier Features transform.";
5050

51-
public class Arguments : UnsupervisedLearnerInputBaseWithWeight
51+
public class Options : UnsupervisedLearnerInputBaseWithWeight
5252
{
5353
[Argument(ArgumentType.AtMostOnce, HelpText = "The number of components in the PCA", ShortName = "k", SortOrder = 50)]
5454
[TGUI(SuggestedSweeps = "10,20,40,80")]
5555
[TlcModule.SweepableDiscreteParam("Rank", new object[] { 10, 20, 40, 80 })]
56-
public int Rank = 20;
56+
public int Rank = Defaults.NumComponents;
5757

5858
[Argument(ArgumentType.AtMostOnce, HelpText = "Oversampling parameter for randomized PCA training", SortOrder = 50)]
5959
[TGUI(SuggestedSweeps = "10,20,40")]
6060
[TlcModule.SweepableDiscreteParam("Oversampling", new object[] { 10, 20, 40 })]
61-
public int Oversampling = 20;
61+
public int Oversampling = Defaults.OversamplingParameters;
6262

6363
[Argument(ArgumentType.AtMostOnce, HelpText = "If enabled, data is centered to be zero mean", ShortName = "center")]
6464
[TlcModule.SweepableDiscreteParam("Center", null, isBool: true)]
65-
public bool Center = true;
65+
public bool Center = Defaults.IsCenteredZeroMean;
6666

6767
[Argument(ArgumentType.AtMostOnce, HelpText = "The seed for random number generation", ShortName = "seed")]
6868
public int? Seed;
69+
70+
internal static class Defaults
71+
{
72+
public const int NumComponents = 20;
73+
public const int OversamplingParameters = 20;
74+
public const bool IsCenteredZeroMean = true;
75+
}
6976
}
7077

7178
private readonly int _rank;
@@ -90,35 +97,35 @@ public class Arguments : UnsupervisedLearnerInputBaseWithWeight
9097
/// <param name="oversampling">Oversampling parameter for randomized PCA training.</param>
9198
/// <param name="center">If enabled, data is centered to be zero mean.</param>
9299
/// <param name="seed">The seed for random number generation.</param>
93-
public RandomizedPcaTrainer(IHostEnvironment env,
100+
internal RandomizedPcaTrainer(IHostEnvironment env,
94101
string features,
95102
string weights = null,
96-
int rank = 20,
97-
int oversampling = 20,
98-
bool center = true,
103+
int rank = Options.Defaults.NumComponents,
104+
int oversampling = Options.Defaults.OversamplingParameters,
105+
bool center = Options.Defaults.IsCenteredZeroMean,
99106
int? seed = null)
100107
: this(env, null, features, weights, rank, oversampling, center, seed)
101108
{
102109

103110
}
104111

105-
internal RandomizedPcaTrainer(IHostEnvironment env, Arguments args)
106-
:this(env, args, args.FeatureColumn, args.WeightColumn)
112+
internal RandomizedPcaTrainer(IHostEnvironment env, Options options)
113+
:this(env, options, options.FeatureColumn, options.WeightColumn)
107114
{
108115

109116
}
110117

111-
private RandomizedPcaTrainer(IHostEnvironment env, Arguments args, string featureColumn, string weightColumn,
118+
private RandomizedPcaTrainer(IHostEnvironment env, Options options, string featureColumn, string weightColumn,
112119
int rank = 20, int oversampling = 20, bool center = true, int? seed = null)
113120
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(featureColumn), default, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
114121
{
115122
// if the args are not null, we got here from maml, and the internal ctor.
116-
if (args != null)
123+
if (options != null)
117124
{
118-
_rank = args.Rank;
119-
_center = args.Center;
120-
_oversampling = args.Oversampling;
121-
_seed = args.Seed ?? Host.Rand.Next();
125+
_rank = options.Rank;
126+
_center = options.Center;
127+
_oversampling = options.Oversampling;
128+
_seed = options.Seed ?? Host.Rand.Next();
122129
}
123130
else
124131
{
@@ -346,14 +353,14 @@ protected override AnomalyPredictionTransformer<PcaModelParameters> MakeTransfor
346353
Desc = "Train an PCA Anomaly model.",
347354
UserName = UserNameValue,
348355
ShortName = ShortName)]
349-
internal static CommonOutputs.AnomalyDetectionOutput TrainPcaAnomaly(IHostEnvironment env, Arguments input)
356+
internal static CommonOutputs.AnomalyDetectionOutput TrainPcaAnomaly(IHostEnvironment env, Options input)
350357
{
351358
Contracts.CheckValue(env, nameof(env));
352359
var host = env.Register("TrainPCAAnomaly");
353360
host.CheckValue(input, nameof(input));
354361
EntryPointUtils.CheckInputArgs(host, input);
355362

356-
return LearnerEntryPointsUtils.Train<Arguments, CommonOutputs.AnomalyDetectionOutput>(host, input,
363+
return LearnerEntryPointsUtils.Train<Options, CommonOutputs.AnomalyDetectionOutput>(host, input,
357364
() => new RandomizedPcaTrainer(host, input),
358365
getWeight: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn));
359366
}

src/Microsoft.ML.PCA/Properties/AssemblyInfo.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Runtime.CompilerServices;
66
using Microsoft.ML;
77

8+
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Tests" + PublicKey.TestValue)]
89
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.StaticPipe" + PublicKey.Value)]
910
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Core.Tests" + PublicKey.TestValue)]
1011

test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ Trainers.LogisticRegressionClassifier Logistic Regression is a method in statist
6363
Trainers.NaiveBayesClassifier Train a MultiClassNaiveBayesTrainer. Microsoft.ML.Trainers.MultiClassNaiveBayesTrainer TrainMultiClassNaiveBayesTrainer Microsoft.ML.Trainers.MultiClassNaiveBayesTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput
6464
Trainers.OnlineGradientDescentRegressor Train a Online gradient descent perceptron. Microsoft.ML.Trainers.Online.OnlineGradientDescentTrainer TrainRegression Microsoft.ML.Trainers.Online.OnlineGradientDescentTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput
6565
Trainers.OrdinaryLeastSquaresRegressor Train an OLS regression model. Microsoft.ML.Trainers.HalLearners.OlsLinearRegressionTrainer TrainRegression Microsoft.ML.Trainers.HalLearners.OlsLinearRegressionTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput
66-
Trainers.PcaAnomalyDetector Train an PCA Anomaly model. Microsoft.ML.Trainers.PCA.RandomizedPcaTrainer TrainPcaAnomaly Microsoft.ML.Trainers.PCA.RandomizedPcaTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+AnomalyDetectionOutput
66+
Trainers.PcaAnomalyDetector Train an PCA Anomaly model. Microsoft.ML.Trainers.PCA.RandomizedPcaTrainer TrainPcaAnomaly Microsoft.ML.Trainers.PCA.RandomizedPcaTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+AnomalyDetectionOutput
6767
Trainers.PoissonRegressor Train an Poisson regression model. Microsoft.ML.Trainers.PoissonRegression TrainRegression Microsoft.ML.Trainers.PoissonRegression+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput
6868
Trainers.StochasticDualCoordinateAscentBinaryClassifier Train an SDCA binary model. Microsoft.ML.Trainers.Sdca TrainBinary Microsoft.ML.Trainers.SdcaBinaryTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
6969
Trainers.StochasticDualCoordinateAscentClassifier The SDCA linear multi-class classification trainer. Microsoft.ML.Trainers.Sdca TrainMultiClass Microsoft.ML.Trainers.SdcaMultiClassTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput

test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3414,7 +3414,7 @@ public void EntryPointPcaPredictorSummary()
34143414
InputFile = inputFile,
34153415
}).Data;
34163416

3417-
var pcaInput = new RandomizedPcaTrainer.Arguments
3417+
var pcaInput = new RandomizedPcaTrainer.Options
34183418
{
34193419
TrainingData = dataView,
34203420
};

0 commit comments

Comments
 (0)