Skip to content

Commit 5241462

Browse files
authored
FFM XML Doc And Add One Missing Sample File (#3374)
1 parent e57e572 commit 5241462

File tree

5 files changed

+223
-10
lines changed

5 files changed

+223
-10
lines changed
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using Microsoft.ML;
5+
using Microsoft.ML.Data;
6+
7+
namespace Samples.Dynamic.Trainers.BinaryClassification
8+
{
9+
public static class FactorizationMachine
10+
{
11+
public static void Example()
12+
{
13+
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
14+
// as a catalog of available operations and as the source of randomness.
15+
// Setting the seed to a fixed number in this example to make outputs deterministic.
16+
var mlContext = new MLContext(seed: 0);
17+
18+
// Create a list of training data points.
19+
var dataPoints = GenerateRandomDataPoints(1000);
20+
21+
// Convert the list of data points to an IDataView object, which is consumable by ML.NET API.
22+
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
23+
24+
// ML.NET doesn't cache data set by default. Therefore, if one reads a data set from a file and accesses it many times,
25+
// it can be slow due to expensive featurization and disk operations. When the considered data can fit into memory,
26+
// a solution is to cache the data in memory. Caching is especially helpful when working with iterative algorithms
27+
// which needs many data passes.
28+
trainingData = mlContext.Data.Cache(trainingData);
29+
30+
// Define the trainer.
31+
var pipeline = mlContext.BinaryClassification.Trainers.FieldAwareFactorizationMachine();
32+
33+
// Train the model.
34+
var model = pipeline.Fit(trainingData);
35+
36+
// Create testing data. Use different random seed to make it different from training data.
37+
var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(500, seed:123));
38+
39+
// Run the model on test data set.
40+
var transformedTestData = model.Transform(testData);
41+
42+
// Convert IDataView object to a list.
43+
var predictions = mlContext.Data.CreateEnumerable<Prediction>(transformedTestData, reuseRowObject: false).ToList();
44+
45+
// Print 5 predictions.
46+
foreach (var p in predictions.Take(5))
47+
Console.WriteLine($"Label: {p.Label}, Prediction: {p.PredictedLabel}");
48+
49+
// Expected output:
50+
// Label: True, Prediction: False
51+
// Label: False, Prediction: False
52+
// Label: True, Prediction: False
53+
// Label: True, Prediction: False
54+
// Label: False, Prediction: False
55+
56+
// Evaluate the overall metrics.
57+
var metrics = mlContext.BinaryClassification.Evaluate(transformedTestData);
58+
PrintMetrics(metrics);
59+
60+
// Expected output:
61+
// Accuracy: 0.55
62+
// AUC: 0.54
63+
// F1 Score: 0.23
64+
// Negative Precision: 0.54
65+
// Negative Recall: 0.92
66+
// Positive Precision: 0.62
67+
// Positive Recall: 0.14
68+
}
69+
70+
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count, int seed=0)
71+
{
72+
var random = new Random(seed);
73+
float randomFloat() => (float)random.NextDouble();
74+
for (int i = 0; i < count; i++)
75+
{
76+
var label = randomFloat() > 0.5f;
77+
yield return new DataPoint
78+
{
79+
Label = label,
80+
// Create random features that are correlated with the label.
81+
// For data points with false label, the feature values are slightly increased by adding a constant.
82+
Features = Enumerable.Repeat(label, 50).Select(x => x ? randomFloat() : randomFloat() + 0.1f).ToArray()
83+
};
84+
}
85+
}
86+
87+
// Example with label and 50 feature values. A data set is a collection of such examples.
88+
private class DataPoint
89+
{
90+
public bool Label { get; set; }
91+
[VectorType(50)]
92+
public float[] Features { get; set; }
93+
}
94+
95+
// Class used to capture predictions.
96+
private class Prediction
97+
{
98+
// Original label.
99+
public bool Label { get; set; }
100+
// Predicted label from the trainer.
101+
public bool PredictedLabel { get; set; }
102+
}
103+
104+
// Pretty-print BinaryClassificationMetrics objects.
105+
private static void PrintMetrics(BinaryClassificationMetrics metrics)
106+
{
107+
Console.WriteLine($"Accuracy: {metrics.Accuracy:F2}");
108+
Console.WriteLine($"AUC: {metrics.AreaUnderRocCurve:F2}");
109+
Console.WriteLine($"F1 Score: {metrics.F1Score:F2}");
110+
Console.WriteLine($"Negative Precision: {metrics.NegativePrecision:F2}");
111+
Console.WriteLine($"Negative Recall: {metrics.NegativeRecall:F2}");
112+
Console.WriteLine($"Positive Precision: {metrics.PositivePrecision:F2}");
113+
Console.WriteLine($"Positive Recall: {metrics.PositiveRecall:F2}");
114+
}
115+
}
116+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
<#@ include file="BinaryClassification.ttinclude"#>
2+
<#+
3+
string ClassName="FactorizationMachine";
4+
string Trainer = "FieldAwareFactorizationMachine";
5+
string TrainerOptions = null;
6+
bool IsCalibrated = true;
7+
bool CacheData = true;
8+
9+
string LabelThreshold = "0.5f";
10+
string DataSepValue = "0.1f";
11+
string OptionsInclude = "";
12+
string Comments= "";
13+
14+
string ExpectedOutputPerInstance= @"// Expected output:
15+
// Label: True, Prediction: False
16+
// Label: False, Prediction: False
17+
// Label: True, Prediction: False
18+
// Label: True, Prediction: False
19+
// Label: False, Prediction: False";
20+
21+
string ExpectedOutput = @"// Expected output:
22+
// Accuracy: 0.55
23+
// AUC: 0.54
24+
// F1 Score: 0.23
25+
// Negative Precision: 0.54
26+
// Negative Recall: 0.92
27+
// Positive Precision: 0.62
28+
// Positive Recall: 0.14";
29+
#>

docs/samples/Microsoft.ML.Samples/Microsoft.ML.Samples.csproj

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@
108108
<Generator>TextTemplatingFileGenerator</Generator>
109109
<LastGenOutput>FieldAwareFactorizationMachine.cs</LastGenOutput>
110110
</None>
111+
<None Update="Dynamic\Trainers\BinaryClassification\FactorizationMachine.tt">
112+
<Generator>TextTemplatingFileGenerator</Generator>
113+
<LastGenOutput>FactorizationMachine.cs</LastGenOutput>
114+
</None>
111115
<None Update="Dynamic\Trainers\BinaryClassification\FieldAwareFactorizationMachineWithOptions.tt">
112116
<Generator>TextTemplatingFileGenerator</Generator>
113117
<LastGenOutput>FieldAwareFactorizationMachineWithOptions.cs</LastGenOutput>
@@ -466,6 +470,11 @@
466470
<AutoGen>True</AutoGen>
467471
<DependentUpon>FieldAwareFactorizationMachine.tt</DependentUpon>
468472
</Compile>
473+
<Compile Update="Dynamic\Trainers\BinaryClassification\FactorizationMachine.cs">
474+
<DesignTime>True</DesignTime>
475+
<AutoGen>True</AutoGen>
476+
<DependentUpon>FactorizationMachine.tt</DependentUpon>
477+
</Compile>
469478
<Compile Update="Dynamic\Trainers\BinaryClassification\FieldAwareFactorizationMachineWithOptions.cs">
470479
<DesignTime>True</DesignTime>
471480
<AutoGen>True</AutoGen>

src/Microsoft.ML.StandardTrainers/FactorizationMachine/FactorizationMachineCatalog.cs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,19 @@ namespace Microsoft.ML
1414
public static class FactorizationMachineExtensions
1515
{
1616
/// <summary>
17-
/// Predict a target using a field-aware factorization machine algorithm.
17+
/// Create <see cref="FieldAwareFactorizationMachineTrainer"/>, which predicts a target using a field-aware factorization machine trained over boolean label data.
1818
/// </summary>
1919
/// <remarks>
2020
/// Note that because there is only one feature column, the underlying model is equivalent to standard factorization machine.
2121
/// </remarks>
2222
/// <param name="catalog">The binary classification catalog trainer object.</param>
23-
/// <param name="featureColumnName">The name of the feature column.</param>
24-
/// <param name="labelColumnName">The name of the label column.</param>
23+
/// <param name="labelColumnName">The name of the label column. The column data must be <see cref="System.Boolean"/>.</param>
24+
/// <param name="featureColumnName">The name of the feature column. The column data must be a known-sized vector of <see cref="System.Single"/>.</param>
2525
/// <param name="exampleWeightColumnName">The name of the example weight column (optional).</param>
2626
/// <example>
2727
/// <format type="text/markdown">
2828
/// <![CDATA[
29-
/// [!code-csharp[FieldAwareFactorizationMachineWithoutArguments](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FieldAwareFactorizationMachineWithoutArguments.cs)]
29+
/// [!code-csharp[FieldAwareFactorizationMachineWithoutArguments](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FactorizationMachine.cs)]
3030
/// ]]></format>
3131
/// </example>
3232
public static FieldAwareFactorizationMachineTrainer FieldAwareFactorizationMachine(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
@@ -40,11 +40,11 @@ public static FieldAwareFactorizationMachineTrainer FieldAwareFactorizationMachi
4040
}
4141

4242
/// <summary>
43-
/// Predict a target using a field-aware factorization machine algorithm.
43+
/// Create <see cref="FieldAwareFactorizationMachineTrainer"/>, which predicts a target using a field-aware factorization machine trained over boolean label data.
4444
/// </summary>
4545
/// <param name="catalog">The binary classification catalog trainer object.</param>
46-
/// <param name="featureColumnNames">The name(s) of the feature columns.</param>
47-
/// <param name="labelColumnName">The name of the label column.</param>
46+
/// <param name="labelColumnName">The name of the label column. The column data must be <see cref="System.Boolean"/>.</param>
47+
/// <param name="featureColumnNames">The names of the feature columns. The column data must be a known-sized vector of <see cref="System.Single"/>.</param>
4848
/// <param name="exampleWeightColumnName">The name of the example weight column (optional).</param>
4949
/// <example>
5050
/// <format type="text/markdown">
@@ -63,10 +63,10 @@ public static FieldAwareFactorizationMachineTrainer FieldAwareFactorizationMachi
6363
}
6464

6565
/// <summary>
66-
/// Predict a target using a field-aware factorization machine algorithm.
66+
/// Create <see cref="FieldAwareFactorizationMachineTrainer"/> using advanced options, which predicts a target using a field-aware factorization machine trained over boolean label data.
6767
/// </summary>
6868
/// <param name="catalog">The binary classification catalog trainer object.</param>
69-
/// <param name="options">Advanced arguments to the algorithm.</param>
69+
/// <param name="options">Trainer options.</param>
7070
/// <example>
7171
/// <format type="text/markdown">
7272
/// <![CDATA[

src/Microsoft.ML.StandardTrainers/FactorizationMachine/FactorizationMachineTrainer.cs

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,62 @@ namespace Microsoft.ML.Trainers
3030
[2] https://www.csie.ntu.edu.tw/~cjlin/papers/ffm.pdf
3131
[3] https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf
3232
*/
33-
/// <include file='doc.xml' path='doc/members/member[@name="FieldAwareFactorizationMachineBinaryClassifier"]/*' />
33+
/// <summary>
34+
/// The <see cref="IEstimator{TTransformer}"/> to predict a target using a field-aware factorization machine model trained using a stochastic gradient method.
35+
/// </summary>
36+
/// <remarks>
37+
/// <format type="text/markdown"><![CDATA[
38+
/// [!include[io](~/../docs/samples/docs/api-reference/io-columns-binary-classification.md)]
39+
/// To create this trainer, use [FieldAwareFactorizationMachine](xref:Microsoft.ML.FactorizationMachineExtensions.FieldAwareFactorizationMachine(Microsoft.ML.BinaryClassificationCatalog.BinaryClassificationTrainers,System.String,System.String,System.String))
40+
/// [FieldAwareFactorizationMachine](xref:Microsoft.ML.FactorizationMachineExtensions.FieldAwareFactorizationMachine(Microsoft.ML.BinaryClassificationCatalog.BinaryClassificationTrainers,System.String[],System.String,System.String)),
41+
/// or [FieldAwareFactorizationMachine(Options)](xref:Microsoft.ML.FactorizationMachineExtensions.FieldAwareFactorizationMachine(Microsoft.ML.BinaryClassificationCatalog.BinaryClassificationTrainers,Microsoft.ML.Trainers.FieldAwareFactorizationMachineTrainer.Options)).
42+
///
43+
/// In contrast to other binary classifiers which can only support one feature column, field-aware factorization machine can consume multiple feature columns.
44+
/// Each column is viewed as a container of some features and such a container is called a field.
45+
/// Note that all feature columns must be float vectors but their dimensions can be different.
46+
/// The motivation of splitting features into different fields is to model features from different distributions independently.
47+
/// For example, in online game store, features created from user profile and those from game profile can be assigned to two different fields.
48+
///
49+
/// ### Trainer Characteristics
50+
/// | | |
51+
/// | -- | -- |
52+
/// | Machine learning task | Binary classification |
53+
/// | Is normalization required? | Yes |
54+
/// | Is caching required? | No |
55+
/// | Required NuGet in addition to Microsoft.ML | None |
56+
///
57+
/// ### Background
58+
/// Factorization machine family is a powerful model group for supervised learning problems.
59+
/// It was first introduced in Steffen Rendle's [Factorization Machines](http://ieeexplore.ieee.org/document/5694074/?reload=true) paper in 2010.
60+
/// Later, one of its generalized versions, field-aware factorization machine, became an important predictive module in recent recommender systems and click-through rate prediction contests.
61+
/// For examples, see winning solutions in Steffen Rendle's KDD-Cup 2012 ([Track 1](http://www.kdd.org/kdd-cup/view/kdd-cup-2012-track-1) and [Track 2](http://www.kdd.org/kdd-cup/view/kdd-cup-2012-track-2)),
62+
/// [Criteo's](https://www.kaggle.com/c/criteo-display-ad-challenge), [Avazu's](https://www.kaggle.com/c/avazu-ctr-prediction), and [Outbrain's](https://www.kaggle.com/c/outbrain-click-prediction) click prediction challenges on Kaggle.
63+
///
64+
/// Factorization machines are especially powerful when feature conjunctions are extremely correlated to the signal you want to predict.
65+
/// An example of feature pairs which can form important conjunctions is user ID and music ID in music recommendation.
66+
/// When a dataset consists of only dense numerical features, usage of factorization machine is not recommended or some featurizations should be performed.
67+
///
68+
/// ### Scoring Function
69+
/// Field-aware factorization machine is a scoring function which maps feature vectors from different fields to a scalar score.
70+
/// Assume that all $m$ feature columns are concatenated into a long feature vector $\boldsymbol{x}\in {\mathbb R}^n$ and ${\mathcal F}(j)$ denotes the $j$-th feature's field indentifier.
71+
/// The corresponding score is $\hat{y}\left(\boldsymbol{x}\right) = \left\langle \boldsymbol{w}, \boldsymbol{x} \right\rangle + \sum_{j = 1}^n \sum_{j' = j + 1}^n \left\langle \boldsymbol{v}_{j, {\mathcal F}(j')} , \boldsymbol{v}_{j', {\mathcal F}(j)} \right\rangle x_j x_{j'}$,
72+
/// where $\left\langle \cdot, \cdot \right\rangle$ is the inner product operator, $\boldsymbol{w}\in{\mathbb R}^n$ stores the linear coefficients, and $\boldsymbol{v}_{j, f}\in {\mathbb R}^k$ is the $j$-th feature's representation in the $f$-th field's latent space.
73+
/// Note that $k$ is the latent dimension specified by the user.
74+
/// The predicted label is the sign of $\hat{y}$. If $\hat{y} > 0$, this model predicts true. Otherwise, it predicts false.
75+
/// For a systematic introduction to field-aware factorization machine, please see [this paper](https://www.csie.ntu.edu.tw/~cjlin/papers/ffm.pdf)
76+
///
77+
/// ### Training Algorithm Details
78+
/// The implemented algorithm in <see cref="FieldAwareFactorizationMachineTrainer"/> is based on [a stochastic gradient method](http://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf).
79+
/// Algorithm details is described in Algorithm 3 in [a online document](https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf).
80+
/// The minimized loss function is [logistic loss](https://en.wikipedia.org/wiki/Loss_functions_for_classification), so the trained model can be viewed as a non-linear logistic regression.
81+
///
82+
/// ]]>
83+
/// </format>
84+
/// </remarks>
85+
/// <seealso cref="Microsoft.ML.FactorizationMachineExtensions.FieldAwareFactorizationMachine(BinaryClassificationCatalog.BinaryClassificationTrainers, string, string, string)"/>
86+
/// <seealso cref="Microsoft.ML.FactorizationMachineExtensions.FieldAwareFactorizationMachine(BinaryClassificationCatalog.BinaryClassificationTrainers, string[], string, string)"/>
87+
/// <seealso cref="Microsoft.ML.FactorizationMachineExtensions.FieldAwareFactorizationMachine(BinaryClassificationCatalog.BinaryClassificationTrainers, FieldAwareFactorizationMachineTrainer.Options)"/>
88+
/// <seealso cref="FieldAwareFactorizationMachineTrainer.Options"/>
3489
public sealed class FieldAwareFactorizationMachineTrainer : ITrainer<FieldAwareFactorizationMachineModelParameters>,
3590
IEstimator<FieldAwareFactorizationMachinePredictionTransformer>
3691
{
@@ -39,6 +94,10 @@ public sealed class FieldAwareFactorizationMachineTrainer : ITrainer<FieldAwareF
3994
internal const string LoadName = "FieldAwareFactorizationMachine";
4095
internal const string ShortName = "ffm";
4196

97+
/// <summary>
98+
/// <see cref="Options"/> for <see cref="FieldAwareFactorizationMachineTrainer"/> as used in
99+
/// <see cref="Microsoft.ML.FactorizationMachineExtensions.FieldAwareFactorizationMachine(BinaryClassificationCatalog.BinaryClassificationTrainers, FieldAwareFactorizationMachineTrainer.Options)"/>.
100+
/// </summary>
42101
public sealed class Options : TrainerInputBaseWithWeight
43102
{
44103
/// <summary>

0 commit comments

Comments
 (0)