-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Add NaiveBayes sample & docs #3246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2726f13
8037fd1
8831b0f
7046cfd
64a8f5a
d056196
b28d900
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,6 @@ using System.Collections.Generic; | |
using System.Linq; | ||
using Microsoft.ML; | ||
using Microsoft.ML.Data; | ||
using Microsoft.ML.SamplesUtils; | ||
<# if (TrainerOptions != null) { #> | ||
<#=OptionsInclude#> | ||
<# } #> | ||
|
@@ -67,15 +66,20 @@ namespace Samples.Dynamic.Trainers.MulticlassClassification | |
|
||
// Evaluate the overall metrics | ||
var metrics = mlContext.MulticlassClassification.Evaluate(transformedTestData); | ||
ConsoleUtils.PrintMetrics(metrics); | ||
Console.WriteLine($"Micro Accuracy: {metrics.MicroAccuracy:F2}"); | ||
Console.WriteLine($"Macro Accuracy: {metrics.MacroAccuracy:F2}"); | ||
Console.WriteLine($"Log Loss: {metrics.LogLoss:F2}"); | ||
Console.WriteLine($"Log Loss Reduction: {metrics.LogLossReduction:F2}"); | ||
|
||
|
||
<#=ExpectedOutput#> | ||
} | ||
|
||
<#=DataGenerationComments#> | ||
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count, int seed=0) | ||
{ | ||
var random = new Random(seed); | ||
float randomFloat() => (float)random.NextDouble(); | ||
float randomFloat() => (float)(random.NextDouble() - 0.5); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we do this? #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @natke Its to make sure feature values are evenly distributed between -0.5 and +0.5. This gives us even number of positive and negative examples. Naive Bayes considers two types of feature values 1) greater than zero and 2) less than equal to zero and you want to have a sample with both those feature values to have sensible prediction. I believe @ganik has talked about it briefly in the doc that he has attached here. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, great. Is it worth adding a comment to the code? Also, which doc? #ByDesign There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, so yes it is spelt out in the trainer code comments. I wonder if we should add a comment to this sample code too, to be absolutely clear. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I cant add it to the sample code since this code is shared (generated from .tt which is shared) by 3 other trainers that don't have this NaiveBayes problem In reply to: 274708597 [](ancestors = 274708597) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so we have random values from -.5 to .5 range, some trainers like NB need that, others like OVA dont but will be ok with that In reply to: 274651105 [](ancestors = 274651105) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually I think I know how to do it, I ll send next iteration In reply to: 275193160 [](ancestors = 275193160,274708597) |
||
for (int i = 0; i < count; i++) | ||
{ | ||
// Generate Labels that are integers 1, 2 or 3 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using Microsoft.ML; | ||
using Microsoft.ML.Data; | ||
|
||
namespace Samples.Dynamic.Trainers.MulticlassClassification | ||
{ | ||
public static class NaiveBayes | ||
{ | ||
// Naive Bayes classifier is based on Bayes' theorem. | ||
// It assumes independence among the presence of features in a class even though they may be dependent on each other. | ||
// It is a multi-class trainer that accepts binary feature values of type float, i.e., feature values are either true or false. | ||
// Specifically a feature value greater than zero is treated as true, zero or less is treated as false. | ||
public static void Example() | ||
{ | ||
// Create a new context for ML.NET operations. It can be used for exception tracking and logging, | ||
// as a catalog of available operations and as the source of randomness. | ||
// Setting the seed to a fixed number in this example to make outputs deterministic. | ||
var mlContext = new MLContext(seed: 0); | ||
|
||
// Create a list of training data points. | ||
var dataPoints = GenerateRandomDataPoints(1000); | ||
|
||
// Convert the list of data points to an IDataView object, which is consumable by ML.NET API. | ||
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints); | ||
|
||
// Define the trainer. | ||
var pipeline = | ||
// Convert the string labels into key types. | ||
mlContext.Transforms.Conversion.MapValueToKey("Label") | ||
// Apply NaiveBayes multiclass trainer. | ||
.Append(mlContext.MulticlassClassification.Trainers.NaiveBayes()); | ||
|
||
// Train the model. | ||
var model = pipeline.Fit(trainingData); | ||
|
||
// Create testing data. Use different random seed to make it different from training data. | ||
var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(500, seed:123)); | ||
|
||
// Run the model on test data set. | ||
var transformedTestData = model.Transform(testData); | ||
|
||
// Convert IDataView object to a list. | ||
var predictions = mlContext.Data.CreateEnumerable<Prediction>(transformedTestData, reuseRowObject: false).ToList(); | ||
|
||
// Look at 5 predictions | ||
foreach (var p in predictions.Take(5)) | ||
Console.WriteLine($"Label: {p.Label}, Prediction: {p.PredictedLabel}"); | ||
|
||
// Expected output: | ||
// Label: 1, Prediction: 1 | ||
// Label: 2, Prediction: 2 | ||
// Label: 3, Prediction: 3 | ||
// Label: 2, Prediction: 2 | ||
// Label: 3, Prediction: 3 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NICE! #Resolved |
||
|
||
// Evaluate the overall metrics | ||
var metrics = mlContext.MulticlassClassification.Evaluate(transformedTestData); | ||
Console.WriteLine($"Micro Accuracy: {metrics.MicroAccuracy:F2}"); | ||
Console.WriteLine($"Macro Accuracy: {metrics.MacroAccuracy:F2}"); | ||
Console.WriteLine($"Log Loss: {metrics.LogLoss:F2}"); | ||
Console.WriteLine($"Log Loss Reduction: {metrics.LogLossReduction:F2}"); | ||
|
||
|
||
// Expected output: | ||
// Micro Accuracy: 0.88 | ||
// Macro Accuracy: 0.88 | ||
// Log Loss: 34.54 | ||
// Log Loss Reduction: -30.47 | ||
} | ||
|
||
|
||
// Generates random uniform doubles in [-0.5, 0.5) range with labels 1, 2 or 3. | ||
// For NaiveBayes values greater than zero are treated as true, zero or less are treated as false. | ||
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count, int seed=0) | ||
{ | ||
var random = new Random(seed); | ||
float randomFloat() => (float)(random.NextDouble() - 0.5); | ||
for (int i = 0; i < count; i++) | ||
{ | ||
// Generate Labels that are integers 1, 2 or 3 | ||
var label = random.Next(1, 4); | ||
yield return new DataPoint | ||
{ | ||
Label = (uint)label, | ||
// Create random features that are correlated with the label. | ||
// The feature values are slightly increased by adding a constant multiple of label. | ||
Features = Enumerable.Repeat(label, 20).Select(x => randomFloat() + label * 0.2f).ToArray() | ||
}; | ||
} | ||
} | ||
|
||
// Example with label and 20 feature values. A data set is a collection of such examples. | ||
private class DataPoint | ||
{ | ||
public uint Label { get; set; } | ||
[VectorType(20)] | ||
public float[] Features { get; set; } | ||
} | ||
|
||
// Class used to capture predictions. | ||
private class Prediction | ||
{ | ||
// Original label. | ||
public uint Label { get; set; } | ||
// Predicted label from the trainer. | ||
public uint PredictedLabel { get; set; } | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
<#@ include file="MulticlassClassification.ttinclude"#> | ||
<#+ | ||
string ClassName="NaiveBayes"; | ||
string Trainer = "NaiveBayes"; | ||
string MetaTrainer = null; | ||
string TrainerOptions = null; | ||
|
||
string OptionsInclude = ""; | ||
string Comments= @" | ||
// Naive Bayes classifier is based on Bayes' theorem. | ||
// It assumes independence among the presence of features in a class even though they may be dependent on each other. | ||
// It is a multi-class trainer that accepts binary feature values of type float, i.e., feature values are either true or false. | ||
// Specifically a feature value greater than zero is treated as true, zero or less is treated as false."; | ||
|
||
string DataGenerationComments= @" | ||
// Generates random uniform doubles in [-0.5, 0.5) range with labels 1, 2 or 3. | ||
// For NaiveBayes values greater than zero are treated as true, zero or less are treated as false."; | ||
|
||
string ExpectedOutputPerInstance= @"// Expected output: | ||
// Label: 1, Prediction: 1 | ||
// Label: 2, Prediction: 2 | ||
// Label: 3, Prediction: 3 | ||
// Label: 2, Prediction: 2 | ||
// Label: 3, Prediction: 3"; | ||
|
||
string ExpectedOutput = @"// Expected output: | ||
// Micro Accuracy: 0.88 | ||
// Macro Accuracy: 0.88 | ||
// Log Loss: 34.54 | ||
// Log Loss Reduction: -30.47"; | ||
#> |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,6 @@ | |
using System.Linq; | ||
using Microsoft.ML; | ||
using Microsoft.ML.Data; | ||
using Microsoft.ML.SamplesUtils; | ||
|
||
namespace Samples.Dynamic.Trainers.MulticlassClassification | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. extra line? #Resolved |
||
{ | ||
|
@@ -54,7 +53,11 @@ public static void Example() | |
|
||
// Evaluate the overall metrics | ||
var metrics = mlContext.MulticlassClassification.Evaluate(transformedTestData); | ||
ConsoleUtils.PrintMetrics(metrics); | ||
Console.WriteLine($"Micro Accuracy: {metrics.MicroAccuracy:F2}"); | ||
Console.WriteLine($"Macro Accuracy: {metrics.MacroAccuracy:F2}"); | ||
Console.WriteLine($"Log Loss: {metrics.LogLoss:F2}"); | ||
Console.WriteLine($"Log Loss Reduction: {metrics.LogLossReduction:F2}"); | ||
|
||
|
||
// Expected output: | ||
// Micro Accuracy: 0.90 | ||
|
@@ -63,10 +66,12 @@ public static void Example() | |
// Log Loss Reduction: 0.67 | ||
} | ||
|
||
|
||
// Generates random uniform doubles in [-0.5, 0.5) range with labels 1, 2 or 3. | ||
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count, int seed=0) | ||
{ | ||
var random = new Random(seed); | ||
float randomFloat() => (float)random.NextDouble(); | ||
float randomFloat() => (float)(random.NextDouble() - 0.5); | ||
for (int i = 0; i < count; i++) | ||
{ | ||
// Generate Labels that are integers 1, 2 or 3 | ||
|
@@ -76,7 +81,7 @@ private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count, int se | |
Label = (uint)label, | ||
// Create random features that are correlated with the label. | ||
// The feature values are slightly increased by adding a constant multiple of label. | ||
Features = Enumerable.Repeat(label, 20).Select(x => randomFloat() + label * 0.1f).ToArray() | ||
Features = Enumerable.Repeat(label, 20).Select(x => randomFloat() + label * 0.2f).ToArray() | ||
}; | ||
} | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,7 @@ | |
new[] { typeof(SignatureMulticlassClassifierTrainer), typeof(SignatureTrainer) }, | ||
NaiveBayesMulticlassTrainer.UserName, | ||
NaiveBayesMulticlassTrainer.LoadName, | ||
NaiveBayesMulticlassTrainer.ShortName, DocName = "trainer/NaiveBayes.md")] | ||
NaiveBayesMulticlassTrainer.ShortName)] | ||
|
||
[assembly: LoadableClass(typeof(NaiveBayesMulticlassModelParameters), null, typeof(SignatureLoadModel), | ||
"Multi Class Naive Bayes predictor", NaiveBayesMulticlassModelParameters.LoaderSignature)] | ||
|
@@ -26,6 +26,12 @@ | |
|
||
namespace Microsoft.ML.Trainers | ||
{ | ||
/// <summary> | ||
/// Naive Bayes classifier is based on Bayes' theorem. It assumes independence among the presence of features | ||
/// in a class even though they may be dependent on each other. It is a multi-class trainer that accepts | ||
/// binary feature values of type float, i.e., feature values are either true or false, specifically a | ||
/// feature value greater than zero is treated as true. | ||
/// </summary> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. info: this is good for the 1st pass for docs. please leave the 2nd pass empty, so that we improve this next week. #Resolved |
||
public sealed class NaiveBayesMulticlassTrainer : TrainerEstimatorBase<MulticlassPredictionTransformer<NaiveBayesMulticlassModelParameters>, NaiveBayesMulticlassModelParameters> | ||
{ | ||
internal const string LoadName = "MultiClassNaiveBayes"; | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great, but did you re-generate all the TT by running custom tool to make sure the samples are not broken? #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, only 3 tt depend on this one, they are regenerated
In reply to: 274642625 [](ancestors = 274642625)