Skip to content

Commit

Permalink
Make accessor of linear coefficients unique to the public
Browse files Browse the repository at this point in the history
1. Internalize GetFeatureWeights(ref VBuffer<float> weights)
2. Internalize IHaveFeatureWeights
  • Loading branch information
wschin committed Mar 1, 2019
1 parent 6e9023f commit e972912
Show file tree
Hide file tree
Showing 10 changed files with 16 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ public static void Example()
var outData = featureContributionCalculator.Fit(scoredData).Transform(scoredData);

// Let's extract the weights from the linear model to use as a comparison
var weights = new VBuffer<float>();
model.Model.GetFeatureWeights(ref weights);
var weights = model.Model.Weights;

// Let's now walk through the first ten records and see which feature drove the values the most
// Get prediction scores and contributions
Expand All @@ -63,7 +62,7 @@ public static void Example()
var value = row.Features[featureOfInterest];
var contribution = row.FeatureContributions[featureOfInterest];
var name = data.Schema[featureOfInterest + 1].Name;
var weight = weights.GetValues()[featureOfInterest];
var weight = weights[featureOfInterest];

Console.WriteLine("{0:0.00}\t{1:0.00}\t{2}\t{3:0.00}\t{4:0.00}\t{5:0.00}",
row.MedianHomeValue,
Expand Down
8 changes: 3 additions & 5 deletions docs/samples/Microsoft.ML.Samples/Static/SDCARegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,10 @@ public static void SdcaRegression()
var model = learningPipeline.Fit(trainData);

// Check the weights that the model learned
VBuffer<float> weights = default;
pred.GetFeatureWeights(ref weights);
var weights = pred.Weights;

var weightsValues = weights.GetValues();
Console.WriteLine($"weight 0 - {weightsValues[0]}");
Console.WriteLine($"weight 1 - {weightsValues[1]}");
Console.WriteLine($"weight 0 - {weights[0]}");
Console.WriteLine($"weight 1 - {weights[1]}");

// Evaluate how the model is doing on the test data
var dataWithPredictions = model.Transform(testData);
Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ internal interface ICanSaveInSourceCode
/// <summary>
/// Interface implemented by components that can assign weights to features.
/// </summary>
public interface IHaveFeatureWeights
[BestFriend]
internal interface IHaveFeatureWeights
{
/// <summary>
/// Returns the weights for the features.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ private TPredictor TrainCore(IChannel ch, RoleMappedData data, LinearModelParame
float bias = 0.0f;
if (predictor != null)
{
predictor.GetFeatureWeights(ref weights);
((IHaveFeatureWeights)predictor).GetFeatureWeights(ref weights);
VBufferUtils.Densify(ref weights);
bias = predictor.Bias;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
using Microsoft.ML;
using Microsoft.ML.Calibrators;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Model.OnnxConverter;
Expand Down Expand Up @@ -384,7 +383,7 @@ private protected virtual DataViewRow GetSummaryIRowOrNull(RoleMappedSchema sche

void ICanSaveInIniFormat.SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator) => SaveAsIni(writer, schema, calibrator);

public void GetFeatureWeights(ref VBuffer<float> weights)
void IHaveFeatureWeights.GetFeatureWeights(ref VBuffer<float> weights)
{
Weight.CopyTo(ref weights);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ internal DataViewSchema.Annotations MakeStatisticsMetadata(LinearBinaryModelPara
builder.AddPrimitiveValue("BiasPValue", NumberDataViewType.Single, biasPValue);

var weights = default(VBuffer<float>);
parent.GetFeatureWeights(ref weights);
((IHaveFeatureWeights)parent).GetFeatureWeights(ref weights);
var estimate = default(VBuffer<float>);
var stdErr = default(VBuffer<float>);
var zScore = default(VBuffer<float>);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Numeric;

namespace Microsoft.ML.Trainers
Expand Down Expand Up @@ -130,7 +131,7 @@ protected TrainStateBase(IChannel ch, int numFeatures, LinearModelParameters pre
// unless we have a lot of features.
if (predictor != null)
{
predictor.GetFeatureWeights(ref Weights);
((IHaveFeatureWeights)parent).GetFeatureWeights(ref Weights);
VBufferUtils.Densify(ref Weights);
Bias = predictor.Bias;
}
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1937,7 +1937,7 @@ private protected override TModel TrainCore(IChannel ch, RoleMappedData data, Li
float bias = 0.0f;
if (predictor != null)
{
predictor.GetFeatureWeights(ref weights);
((IHaveFeatureWeights)predictor).GetFeatureWeights(ref weights);
VBufferUtils.Densify(ref weights);
bias = predictor.Bias;
}
Expand Down
8 changes: 2 additions & 6 deletions test/Microsoft.ML.StaticPipelineTesting/Training.cs
Original file line number Diff line number Diff line change
Expand Up @@ -627,9 +627,7 @@ public void PoissonRegression()
var model = pipe.Fit(dataSource);
Assert.NotNull(pred);
// 11 input features, so we ought to have 11 weights.
VBuffer<float> weights = new VBuffer<float>();
pred.GetFeatureWeights(ref weights);
Assert.Equal(11, weights.Length);
Assert.Equal(11, pred.Weights.Count);

var data = model.Load(dataSource);

Expand Down Expand Up @@ -751,9 +749,7 @@ public void OnlineGradientDescent()
var model = pipe.Fit(dataSource);
Assert.NotNull(pred);
// 11 input features, so we ought to have 11 weights.
VBuffer<float> weights = new VBuffer<float>();
pred.GetFeatureWeights(ref weights);
Assert.Equal(11, weights.Length);
Assert.Equal(11, pred.Weights.Count);

var data = model.Load(dataSource);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ public void IntrospectiveTraining()
var model = pipeline.Fit(data);

// Get feature weights.
VBuffer<float> weights = default;
model.LastTransformer.Model.GetFeatureWeights(ref weights);
var weights = model.LastTransformer.Model.Weights;
}

[Fact]
Expand Down

0 comments on commit e972912

Please sign in to comment.