Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dedicated column attributes added to input type fields #224

Closed
wants to merge 11 commits into from
58 changes: 57 additions & 1 deletion src/Microsoft.ML.Api/SchemaDefinition.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ public VectorTypeAttribute(params int[] dims)
/// column encapsulates.
/// </summary>
[AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)]
public sealed class ColumnAttribute : Attribute
public class ColumnAttribute : Attribute
{

public ColumnAttribute(string ordinal, string name = null)
{
Name = name;
Expand All @@ -93,6 +94,61 @@ public ColumnAttribute(string ordinal, string name = null)
public string Ordinal { get; }
}

/// <summary>
/// Describes 'Label' column with indicies.
/// </summary>
public sealed class LabelColumnAttribute : ColumnAttribute
{
public LabelColumnAttribute(string ordinal):
base(ordinal, "Label")
{
}
}

/// <summary>
/// Describes 'Features' column with indicies.
/// </summary>
public sealed class FeaturesColumnAttribute : ColumnAttribute
{
public FeaturesColumnAttribute(string ordinal) :
base(ordinal, "Features")
{
}
}

/// <summary>
/// Describes 'Weight' column with indicies.
/// </summary>
public sealed class WeightColumnAttribute : ColumnAttribute
{
public WeightColumnAttribute(string ordinal) :
base(ordinal, "Weight")
{
}
}

/// <summary>
/// Describes 'GroupId' column with indicies.
/// </summary>
public sealed class GroupColumnAttribute : ColumnAttribute
{
public GroupColumnAttribute(string ordinal) :
base(ordinal, "GroupId")
{
}
}

/// <summary>
/// Describes 'Name' column with indicies.
/// </summary>
public sealed class NameColumnAttribute : ColumnAttribute
{
public NameColumnAttribute(string ordinal) :
base(ordinal, "Name")
{
}
}

/// <summary>
/// Allows a member to specify its column name directly, as opposed to the default
/// behavior of using the member name as the column name.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public class HousePriceData
[Column(ordinal: "1")]
public string Date;

[Column(ordinal: "2", name: "Label")]
[LabelColumn(ordinal: "2")]
public float Price;

[Column(ordinal: "3")]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Data;
using Microsoft.ML.Models;
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;
using Xunit;

namespace Microsoft.ML.Scenarios
{
public partial class ScenariosTests
{
[Fact]
public void TrainAndPredictIrisModelWithFeatureVectorTest()
{
string dataPath = GetDataPath("iris.data");

var pipeline = new LearningPipeline();

pipeline.Add(new TextLoader(dataPath).CreateFrom<IrisDataWithFeatureVector>(useHeader: false, separator: ','));

pipeline.Add(new Dictionarizer("Label")); // "IrisPlantType" is used as "Label" because of column attribute name on the field.

pipeline.Add(new StochasticDualCoordinateAscentClassifier());

PredictionModel<IrisDataWithFeatureVector, IrisPrediction> model = pipeline.Train<IrisDataWithFeatureVector, IrisPrediction>();

IrisPrediction prediction = model.Predict(new IrisDataWithFeatureVector()
{
Feat = new float[] { 5.1f, 3.3f, 1.6f, 0.2f }
});

Assert.Equal(1, prediction.PredictedLabels[0], 2);
Assert.Equal(0, prediction.PredictedLabels[1], 2);
Assert.Equal(0, prediction.PredictedLabels[2], 2);

prediction = model.Predict(new IrisDataWithFeatureVector()
{
Feat = new float[] { 6.4f, 3.1f, 5.5f, 2.2f }
});

Assert.Equal(0, prediction.PredictedLabels[0], 2);
Assert.Equal(0, prediction.PredictedLabels[1], 2);
Assert.Equal(1, prediction.PredictedLabels[2], 2);

prediction = model.Predict(new IrisDataWithFeatureVector()
{
Feat = new float[] { 4.4f, 3.1f, 2.5f, 1.2f }
});

Assert.Equal(.2, prediction.PredictedLabels[0], 1);
Assert.Equal(.8, prediction.PredictedLabels[1], 1);
Assert.Equal(0, prediction.PredictedLabels[2], 2);

// Note: Testing against the same data set as a simple way to test evaluation.
// This isn't appropriate in real-world scenarios.
string testDataPath = GetDataPath("iris.data");
var testData = new TextLoader(testDataPath).CreateFrom<IrisDataWithFeatureVector>(useHeader: false, separator: ',');

var evaluator = new ClassificationEvaluator();
evaluator.OutputTopKAcc = 3;
ClassificationMetrics metrics = evaluator.Evaluate(model, testData);

Assert.Equal(.98, metrics.AccuracyMacro);
Assert.Equal(.98, metrics.AccuracyMicro, 2);
Assert.Equal(.06, metrics.LogLoss, 2);
Assert.InRange(metrics.LogLossReduction, 94, 96);
Assert.Equal(1, metrics.TopKAccuracy);

Assert.Equal(3, metrics.PerClassLogLoss.Length);
Assert.Equal(0, metrics.PerClassLogLoss[0], 1);
Assert.Equal(.1, metrics.PerClassLogLoss[1], 1);
Assert.Equal(.1, metrics.PerClassLogLoss[2], 1);

ConfusionMatrix matrix = metrics.ConfusionMatrix;
Assert.Equal(3, matrix.Order);
Assert.Equal(3, matrix.ClassNames.Count);
Assert.Equal("Iris-setosa", matrix.ClassNames[0]);
Assert.Equal("Iris-versicolor", matrix.ClassNames[1]);
Assert.Equal("Iris-virginica", matrix.ClassNames[2]);

Assert.Equal(50, matrix[0, 0]);
Assert.Equal(50, matrix["Iris-setosa", "Iris-setosa"]);
Assert.Equal(0, matrix[0, 1]);
Assert.Equal(0, matrix["Iris-setosa", "Iris-versicolor"]);
Assert.Equal(0, matrix[0, 2]);
Assert.Equal(0, matrix["Iris-setosa", "Iris-virginica"]);

Assert.Equal(0, matrix[1, 0]);
Assert.Equal(0, matrix["Iris-versicolor", "Iris-setosa"]);
Assert.Equal(48, matrix[1, 1]);
Assert.Equal(48, matrix["Iris-versicolor", "Iris-versicolor"]);
Assert.Equal(2, matrix[1, 2]);
Assert.Equal(2, matrix["Iris-versicolor", "Iris-virginica"]);

Assert.Equal(0, matrix[2, 0]);
Assert.Equal(0, matrix["Iris-virginica", "Iris-setosa"]);
Assert.Equal(1, matrix[2, 1]);
Assert.Equal(1, matrix["Iris-virginica", "Iris-versicolor"]);
Assert.Equal(49, matrix[2, 2]);
Assert.Equal(49, matrix["Iris-virginica", "Iris-virginica"]);
}

public class IrisDataWithFeatureVector
{
[FeaturesColumn("0-3")]
[VectorType(4)]
public float[] Feat;

[LabelColumn("4")]
public string IrisPlantType;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ public class IrisDataWithStringLabel
[Column("3")]
public float PetalLength;

[Column("4", name: "Label")]
[LabelColumn("4")]
public string IrisPlantType;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ public void TrainAndPredictSentimentModelTest()

public class SentimentData
{
[Column(ordinal: "0", name: "Label")]
[LabelColumn(ordinal: "0")]
public float Sentiment;
[Column(ordinal: "1")]
public string SentimentText;
Expand Down
91 changes: 91 additions & 0 deletions test/Microsoft.ML.Tests/TextLoaderTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,79 @@ public void ThrowsExceptionWithPropertyName()
Assert.StartsWith("String1 is missing ColumnAttribute", ex.Message);
}


[Fact]
public void CanSuccessfullyNamedColumns()
{
string dataPath = GetDataPath("SparseData.txt");
var loader = new Data.TextLoader(dataPath).CreateFrom<SparseInputWithNamedColumns>(useHeader: true, allowQuotedStrings: false, supportSparse: true);

using (var environment = new TlcEnvironment())
{
Experiment experiment = environment.CreateExperiment();
ILearningPipelineDataStep output = loader.ApplyStep(null, experiment) as ILearningPipelineDataStep;

experiment.Compile();
loader.SetInput(environment, experiment);
experiment.Run();

IDataLoader data = experiment.GetOutput(output.Data) as IDataLoader;
Assert.NotNull(data);

Assert.Equal(5, data.Schema.ColumnCount);
Assert.Equal("Name", data.Schema.GetColumnName(0));
Assert.Equal("GroupId", data.Schema.GetColumnName(1));
Assert.Equal("Weight", data.Schema.GetColumnName(2));
Assert.Equal("Features", data.Schema.GetColumnName(3));
Assert.Equal("Label", data.Schema.GetColumnName(4));

using (var cursor = data.GetRowCursor((a => true)))
{
var getters = new ValueGetter<float>[]{
cursor.GetGetter<float>(0),
cursor.GetGetter<float>(1),
cursor.GetGetter<float>(2),
cursor.GetGetter<float>(3),
cursor.GetGetter<float>(4)
};


Assert.True(cursor.MoveNext());

float[] targets = new float[] { 1, 2, 3, 4, 5 };
for (int i = 0; i < getters.Length; i++)
{
float value = 0;
getters[i](ref value);
Assert.Equal(targets[i], value);
}

Assert.True(cursor.MoveNext());

targets = new float[] { 0, 0, 0, 4, 5 };
for (int i = 0; i < getters.Length; i++)
{
float value = 0;
getters[i](ref value);
Assert.Equal(targets[i], value);
}

Assert.True(cursor.MoveNext());

targets = new float[] { 0, 2, 0, 0, 0 };
for (int i = 0; i < getters.Length; i++)
{
float value = 0;
getters[i](ref value);
Assert.Equal(targets[i], value);
}

Assert.False(cursor.MoveNext());
}
}

}

public class QuoteInput
{
[Column("0")]
Expand Down Expand Up @@ -268,5 +341,23 @@ public class ModelWithoutColumnAttribute
{
public string String1;
}

public class SparseInputWithNamedColumns
{
[NameColumn("0")]
public float C1;

[GroupColumn("1")]
public float C2;

[WeightColumn("2")]
public float C3;

[FeaturesColumn("3")]
public float C4;

[LabelColumn("4")]
public float C5;
}
}
}