Skip to content

Commit 0c60374

Browse files
authored
Adding training statistics for LR in the HAL learners package. (#1392)
* Creating two separate methods to compute the matrix of standartDeviations, one the old MKl way in the HAL Learners package, and the other making use of Math.Numerics
1 parent c917428 commit 0c60374

File tree

16 files changed

+431
-46
lines changed

16 files changed

+431
-46
lines changed

build/Dependencies.props

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
<SystemReflectionEmitLightweightPackageVersion>4.3.0</SystemReflectionEmitLightweightPackageVersion>
1010
<SystemThreadingTasksDataflowPackageVersion>4.8.0</SystemThreadingTasksDataflowPackageVersion>
1111
<SystemComponentModelCompositionVersion>4.5.0</SystemComponentModelCompositionVersion>
12+
<MathNumericPackageVersion>4.6.0</MathNumericPackageVersion>
1213
</PropertyGroup>
1314

1415
<!-- Other/Non-Core Product Dependencies -->

pkg/Microsoft.ML/Microsoft.ML.nupkgproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
<ItemGroup>
99
<ProjectReference Include="../Microsoft.ML.CpuMath/Microsoft.ML.CpuMath.nupkgproj" />
1010

11+
<PackageReference Include="MathNet.Numerics.Signed" Version="$(MathNumericPackageVersion)" />
1112
<PackageReference Include="Newtonsoft.Json" Version="$(NewtonsoftJsonPackageVersion)" />
1213
<PackageReference Include="System.Reflection.Emit.Lightweight" Version="$(SystemReflectionEmitLightweightPackageVersion)" />
1314
<PackageReference Include="System.Threading.Tasks.Dataflow" Version="$(SystemThreadingTasksDataflowPackageVersion)" />
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime.Data;
6+
using Microsoft.ML.Runtime.Internal.Utilities;
7+
using Microsoft.ML.Trainers.HalLearners;
8+
using System;
9+
10+
namespace Microsoft.ML.Runtime.Learners
11+
{
12+
using Mkl = OlsLinearRegressionTrainer.Mkl;
13+
14+
public sealed class ComputeLRTrainingStdThroughHal : ComputeLRTrainingStd
15+
{
16+
/// <summary>
17+
/// Computes the standart deviation matrix of each of the non-zero training weights, needed to calculate further the standart deviation,
18+
/// p-value and z-Score.
19+
/// If you need faster calculations, use the ComputeStd method from the Microsoft.ML.HALLearners package, which makes use of hardware acceleration.
20+
/// Due to the existence of regularization, an approximation is used to compute the variances of the trained linear coefficients.
21+
/// </summary>
22+
/// <param name="hessian"></param>
23+
/// <param name="weightIndices"></param>
24+
/// <param name="numSelectedParams"></param>
25+
/// <param name="currentWeightsCount"></param>
26+
/// <param name="ch">The <see cref="IChannel"/> used for messaging.</param>
27+
/// <param name="l2Weight">The L2Weight used for training. (Supply the same one that got used during training.)</param>
28+
public override VBuffer<float> ComputeStd(double[] hessian, int[] weightIndices, int numSelectedParams, int currentWeightsCount, IChannel ch, float l2Weight)
29+
{
30+
Contracts.AssertValue(ch);
31+
Contracts.AssertValue(hessian, nameof(hessian));
32+
Contracts.Assert(numSelectedParams > 0);
33+
Contracts.Assert(currentWeightsCount > 0);
34+
Contracts.Assert(l2Weight > 0);
35+
36+
// Apply Cholesky Decomposition to find the inverse of the Hessian.
37+
Double[] invHessian = null;
38+
try
39+
{
40+
// First, find the Cholesky decomposition LL' of the Hessian.
41+
Mkl.Pptrf(Mkl.Layout.RowMajor, Mkl.UpLo.Lo, numSelectedParams, hessian);
42+
// Note that hessian is already modified at this point. It is no longer the original Hessian,
43+
// but instead represents the Cholesky decomposition L.
44+
// Also note that the following routine is supposed to consume the Cholesky decomposition L instead
45+
// of the original information matrix.
46+
Mkl.Pptri(Mkl.Layout.RowMajor, Mkl.UpLo.Lo, numSelectedParams, hessian);
47+
// At this point, hessian should contain the inverse of the original Hessian matrix.
48+
// Swap hessian with invHessian to avoid confusion in the following context.
49+
Utils.Swap(ref hessian, ref invHessian);
50+
Contracts.Assert(hessian == null);
51+
}
52+
catch (DllNotFoundException)
53+
{
54+
throw ch.ExceptNotSupp("The MKL library (MklImports.dll) or one of its dependencies is missing.");
55+
}
56+
57+
float[] stdErrorValues = new float[numSelectedParams];
58+
stdErrorValues[0] = (float)Math.Sqrt(invHessian[0]);
59+
60+
for (int i = 1; i < numSelectedParams; i++)
61+
{
62+
// Initialize with inverse Hessian.
63+
stdErrorValues[i] = (float)invHessian[i * (i + 1) / 2 + i];
64+
}
65+
66+
if (l2Weight > 0)
67+
{
68+
// Iterate through all entries of inverse Hessian to make adjustment to variance.
69+
// A discussion on ridge regularized LR coefficient covariance matrix can be found here:
70+
// http://www.aloki.hu/pdf/0402_171179.pdf (Equations 11 and 25)
71+
// http://www.inf.unibz.it/dis/teaching/DWDM/project2010/LogisticRegression.pdf (Section "Significance testing in ridge logistic regression")
72+
int ioffset = 1;
73+
for (int iRow = 1; iRow < numSelectedParams; iRow++)
74+
{
75+
for (int iCol = 0; iCol <= iRow; iCol++)
76+
{
77+
var entry = (float)invHessian[ioffset++];
78+
AdjustVariance(entry, iRow, iCol, l2Weight, stdErrorValues);
79+
}
80+
}
81+
82+
Contracts.Assert(ioffset == invHessian.Length);
83+
}
84+
85+
for (int i = 1; i < numSelectedParams; i++)
86+
stdErrorValues[i] = (float)Math.Sqrt(stdErrorValues[i]);
87+
88+
// currentWeights vector size is Weights2 + the bias
89+
return new VBuffer<float>(currentWeightsCount, numSelectedParams, stdErrorValues, weightIndices);
90+
}
91+
}
92+
}

src/Microsoft.ML.StandardLearners/AssemblyInfo.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@
66
using Microsoft.ML;
77

88
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Legacy" + PublicKey.Value)]
9+
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.HalLearners" + PublicKey.Value)]
910

1011
[assembly: WantsToBeBestFriends]

src/Microsoft.ML.StandardLearners/Microsoft.ML.StandardLearners.csproj

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1-
<Project Sdk="Microsoft.NET.Sdk">
1+
<Project Sdk="Microsoft.NET.Sdk">
22

33
<PropertyGroup>
44
<TargetFramework>netstandard2.0</TargetFramework>
55
<IncludeInPackage>Microsoft.ML</IncludeInPackage>
66
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
77
</PropertyGroup>
88

9+
<ItemGroup>
10+
<PackageReference Include="MathNet.Numerics.Signed" Version="$(MathNumericPackageVersion)" />
11+
</ItemGroup>
12+
913
<ItemGroup>
1014
<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
1115
<ProjectReference Include="..\Microsoft.ML.CpuMath\Microsoft.ML.CpuMath.csproj" />

src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs

Lines changed: 152 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using System.Collections.Generic;
7+
using MathNet.Numerics.LinearAlgebra;
78
using Microsoft.ML.Core.Data;
89
using Microsoft.ML.Runtime;
910
using Microsoft.ML.Runtime.CommandLine;
@@ -40,11 +41,27 @@ public sealed partial class LogisticRegression : LbfgsTrainerBase<LogisticRegres
4041

4142
public sealed class Arguments : ArgumentsBase
4243
{
44+
/// <summary>
45+
/// If set to <value>true</value>training statistics will be generated at the end of training.
46+
/// If you have a large number of learned training parameters(more than 500),
47+
/// generating the training statistics might take a few seconds.
48+
/// More than 1000 weights might take a few minutes. For those cases consider using the instance of <see cref="ComputeLRTrainingStd"/>
49+
/// present in the Microsoft.ML.HalLearners package. That computes the statistics using hardware acceleration.
50+
/// </summary>
4351
[Argument(ArgumentType.AtMostOnce, HelpText = "Show statistics of training examples.", ShortName = "stat", SortOrder = 50)]
4452
public bool ShowTrainingStats = false;
53+
54+
/// <summary>
55+
/// The instance of <see cref="ComputeLRTrainingStd"/> that computes the training statistics at the end of training.
56+
/// If you have a large number of learned training parameters(more than 500),
57+
/// generating the training statistics might take a few seconds.
58+
/// More than 1000 weights might take a few minutes. For those cases consider using the instance of <see cref="ComputeLRTrainingStd"/>
59+
/// present in the Microsoft.ML.HalLearners package. That computes the statistics using hardware acceleration.
60+
/// </summary>
61+
public ComputeLRTrainingStd StdComputer;
4562
}
4663

47-
private Double _posWeight;
64+
private double _posWeight;
4865
private LinearModelStatistics _stats;
4966

5067
/// <summary>
@@ -78,6 +95,9 @@ public LogisticRegression(IHostEnvironment env,
7895

7996
_posWeight = 0;
8097
ShowTrainingStats = Args.ShowTrainingStats;
98+
99+
if (ShowTrainingStats && Args.StdComputer == null)
100+
Args.StdComputer = new ComputeLRTrainingStdImpl();
81101
}
82102

83103
/// <summary>
@@ -88,6 +108,9 @@ internal LogisticRegression(IHostEnvironment env, Arguments args)
88108
{
89109
_posWeight = 0;
90110
ShowTrainingStats = Args.ShowTrainingStats;
111+
112+
if (ShowTrainingStats && Args.StdComputer == null)
113+
Args.StdComputer = new ComputeLRTrainingStdImpl();
91114
}
92115

93116
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
@@ -330,7 +353,13 @@ protected override void ComputeTrainingStatistics(IChannel ch, FloatLabelCursor.
330353
}
331354
}
332355

333-
_stats = new LinearModelStatistics(Host, NumGoodRows, numParams, deviance, nullDeviance);
356+
if (Args.StdComputer == null)
357+
_stats = new LinearModelStatistics(Host, NumGoodRows, numParams, deviance, nullDeviance);
358+
else
359+
{
360+
var std = Args.StdComputer.ComputeStd(hessian, weightIndices, numParams, CurrentWeights.Length, ch, L2Weight);
361+
_stats = new LinearModelStatistics(Host, NumGoodRows, numParams, deviance, nullDeviance, std);
362+
}
334363
}
335364

336365
protected override void ProcessPriorDistribution(float label, float weight)
@@ -397,4 +426,125 @@ public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironm
397426
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn));
398427
}
399428
}
429+
430+
/// <summary>
431+
/// Computes the standard deviation matrix of each of the non-zero training weights, needed to calculate further the standard deviation,
432+
/// p-value and z-Score.
433+
/// If you need fast calculations, use the <see cref="ComputeLRTrainingStd"/> implementation in the Microsoft.ML.HALLearners package,
434+
/// which makes use of hardware acceleration.
435+
/// Due to the existence of regularization, an approximation is used to compute the variances of the trained linear coefficients.
436+
/// </summary>
437+
public abstract class ComputeLRTrainingStd
438+
{
439+
/// <summary>
440+
/// Computes the standard deviation matrix of each of the non-zero training weights, needed to calculate further the standard deviation,
441+
/// p-value and z-Score.
442+
/// If you need fast calculations, use the ComputeStd method from the Microsoft.ML.HALLearners package, which makes use of hardware acceleration.
443+
/// Due to the existence of regularization, an approximation is used to compute the variances of the trained linear coefficients.
444+
/// </summary>
445+
public abstract VBuffer<float> ComputeStd(double[] hessian, int[] weightIndices, int parametersCount, int currentWeightsCount, IChannel ch, float l2Weight);
446+
447+
/// <summary>
448+
/// Adjust the variance for regularized cases.
449+
/// </summary>
450+
[BestFriend]
451+
internal void AdjustVariance(float inverseEntry, int iRow, int iCol, float l2Weight, float[] stdErrorValues2)
452+
{
453+
var adjustment = l2Weight * inverseEntry * inverseEntry;
454+
stdErrorValues2[iRow] -= adjustment;
455+
456+
if (0 < iCol && iCol < iRow)
457+
stdErrorValues2[iCol] -= adjustment;
458+
}
459+
}
460+
461+
/// <summary>
462+
/// Extends the <see cref="ComputeLRTrainingStd"/> implementing <see cref="ComputeLRTrainingStd.ComputeStd(double[], int[], int, int, IChannel, float)"/> making use of Math.Net numeric
463+
/// If you need faster calculations(have non-sparse weight vectors of more than 300 features), use the instance of ComputeLRTrainingStd from the Microsoft.ML.HALLearners package, which makes use of hardware acceleration
464+
/// for those computations.
465+
/// </summary>
466+
public sealed class ComputeLRTrainingStdImpl : ComputeLRTrainingStd
467+
{
468+
/// <summary>
469+
/// Computes the standard deviation matrix of each of the non-zero training weights, needed to calculate further the standard deviation,
470+
/// p-value and z-Score.
471+
/// If you need faster calculations, use the ComputeStd method from the Microsoft.ML.HALLearners package, which makes use of hardware acceleration.
472+
/// Due to the existence of regularization, an approximation is used to compute the variances of the trained linear coefficients.
473+
/// </summary>
474+
/// <param name="hessian"></param>
475+
/// <param name="weightIndices"></param>
476+
/// <param name="numSelectedParams"></param>
477+
/// <param name="currentWeightsCount"></param>
478+
/// <param name="ch">The <see cref="IChannel"/> used for messaging.</param>
479+
/// <param name="l2Weight">The L2Weight used for training. (Supply the same one that got used during training.)</param>
480+
public override VBuffer<float> ComputeStd(double[] hessian, int[] weightIndices, int numSelectedParams, int currentWeightsCount, IChannel ch, float l2Weight)
481+
{
482+
Contracts.AssertValue(ch);
483+
Contracts.AssertValue(hessian, nameof(hessian));
484+
Contracts.Assert(numSelectedParams > 0);
485+
Contracts.Assert(currentWeightsCount > 0);
486+
Contracts.Assert(l2Weight > 0);
487+
488+
double[,] matrixHessian = new double[numSelectedParams, numSelectedParams];
489+
490+
int hessianLength = 0;
491+
int dimension = numSelectedParams - 1;
492+
493+
for (int row = dimension; row >= 0; row--)
494+
{
495+
for (int col = 0; col <= dimension; col++)
496+
{
497+
if ((row + col) <= dimension)
498+
{
499+
if ((row + col) == dimension)
500+
{
501+
matrixHessian[row, col] = hessian[hessianLength];
502+
}
503+
else
504+
{
505+
matrixHessian[row, col] = hessian[hessianLength];
506+
matrixHessian[dimension - col, dimension - row] = hessian[hessianLength];
507+
}
508+
hessianLength++;
509+
}
510+
else
511+
continue;
512+
}
513+
}
514+
515+
var h = Matrix<double>.Build.DenseOfArray(matrixHessian);
516+
var invers = h.Inverse();
517+
518+
float[] stdErrorValues = new float[numSelectedParams];
519+
stdErrorValues[0] = (float)Math.Sqrt(invers[0, numSelectedParams - 1]);
520+
521+
for (int i = 1; i < numSelectedParams; i++)
522+
{
523+
// Initialize with inverse Hessian.
524+
// The diagonal of the inverse Hessian.
525+
stdErrorValues[i] = (float)invers[i, numSelectedParams - i - 1];
526+
}
527+
528+
if (l2Weight > 0)
529+
{
530+
// Iterate through all entries of inverse Hessian to make adjustment to variance.
531+
// A discussion on ridge regularized LR coefficient covariance matrix can be found here:
532+
// http://www.aloki.hu/pdf/0402_171179.pdf (Equations 11 and 25)
533+
// http://www.inf.unibz.it/dis/teaching/DWDM/project2010/LogisticRegression.pdf (Section "Significance testing in ridge logistic regression")
534+
for (int iRow = 1; iRow < numSelectedParams; iRow++)
535+
{
536+
for (int iCol = 0; iCol <= iRow; iCol++)
537+
{
538+
float entry = (float)invers[iRow, numSelectedParams - iCol - 1];
539+
AdjustVariance(entry, iRow, iCol, l2Weight, stdErrorValues);
540+
}
541+
}
542+
}
543+
544+
for (int i = 1; i < numSelectedParams; i++)
545+
stdErrorValues[i] = (float)Math.Sqrt(stdErrorValues[i]);
546+
547+
return new VBuffer<float>(currentWeightsCount, numSelectedParams, stdErrorValues, weightIndices);
548+
}
549+
}
400550
}

0 commit comments

Comments
 (0)