-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Samples for categorical transform estimators #3179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,10 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using Microsoft.ML; | ||
using Microsoft.ML.Data; | ||
using static Microsoft.ML.Transforms.OneHotEncodingEstimator; | ||
|
||
namespace Microsoft.ML.Samples.Dynamic | ||
namespace Samples.Dynamic | ||
{ | ||
public static class OneHotEncoding | ||
{ | ||
|
@@ -17,53 +17,39 @@ public static void Example() | |
// Get a small dataset as an IEnumerable. | ||
var samples = new List<DataPoint>() | ||
{ | ||
new DataPoint(){ Label = 0, Education = "0-5yrs" }, | ||
new DataPoint(){ Label = 1, Education = "0-5yrs" }, | ||
new DataPoint(){ Label = 45, Education = "6-11yrs" }, | ||
new DataPoint(){ Label = 50, Education = "6-11yrs" }, | ||
new DataPoint(){ Label = 50, Education = "11-15yrs" }, | ||
new DataPoint(){ Education = "0-5yrs" }, | ||
new DataPoint(){ Education = "0-5yrs" }, | ||
new DataPoint(){ Education = "6-11yrs" }, | ||
new DataPoint(){ Education = "6-11yrs" }, | ||
new DataPoint(){ Education = "11-15yrs" }, | ||
}; | ||
|
||
// Convert training data to IDataView. | ||
var trainData = mlContext.Data.LoadFromEnumerable(samples); | ||
var data = mlContext.Data.LoadFromEnumerable(samples); | ||
|
||
// A pipeline for one hot encoding the Education column. | ||
var bagPipeline = mlContext.Transforms.Categorical.OneHotEncoding("EducationOneHotEncoded", "Education", OutputKind.Bag); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I would leave it, so that it makes sense why we call it bagPipeline. #Closed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. am using the default (which is uses Indicator , not bagging) .. also renamed it to In reply to: 271811206 [](ancestors = 271811206) |
||
// Fit to data. | ||
var bagTransformer = bagPipeline.Fit(trainData); | ||
var pipeline = mlContext.Transforms.Categorical.OneHotEncoding("EducationOneHotEncoded", "Education"); | ||
|
||
// Get transformed data | ||
var bagTransformedData = bagTransformer.Transform(trainData); | ||
// Getting the data of the newly created column, so we can preview it. | ||
var bagEncodedColumn = bagTransformedData.GetColumn<float[]>("EducationOneHotEncoded"); | ||
// Fit and transform the data. | ||
var oneHotEncodedData = pipeline.Fit(data).Transform(data); | ||
|
||
PrintDataColumn(oneHotEncodedData, "EducationOneHotEncoded"); | ||
// We have 3 slots, because there are three categories in the 'Education' column. | ||
// 1 0 0 | ||
// 1 0 0 | ||
// 0 1 0 | ||
// 0 1 0 | ||
// 0 0 1 | ||
|
||
// A pipeline for one hot encoding the Education column (using keying). | ||
var keyPipeline = mlContext.Transforms.Categorical.OneHotEncoding("EducationOneHotEncoded", "Education", OutputKind.Key); | ||
// Fit to data. | ||
var keyTransformer = keyPipeline.Fit(trainData); | ||
|
||
// Get transformed data | ||
var keyTransformedData = keyTransformer.Transform(trainData); | ||
// Getting the data of the newly created column, so we can preview it. | ||
var keyEncodedColumn = keyTransformedData.GetColumn<uint>("EducationOneHotEncoded"); | ||
// Fit and Transform data. | ||
oneHotEncodedData = keyPipeline.Fit(data).Transform(data); | ||
|
||
Console.WriteLine("One Hot Encoding based on the bagging strategy."); | ||
foreach (var row in bagEncodedColumn) | ||
{ | ||
for (var i = 0; i < row.Length; i++) | ||
Console.Write($"{row[i]} "); | ||
} | ||
|
||
// data column obtained post-transformation. | ||
// Since there are only two categories in the Education column of the trainData, the output vector | ||
// for one hot will have two slots. | ||
// | ||
// 0 0 0 | ||
// 0 0 0 | ||
// 0 0 1 | ||
// 0 0 1 | ||
// 0 1 0 | ||
var keyEncodedColumn = oneHotEncodedData.GetColumn<uint>("EducationOneHotEncoded"); | ||
|
||
Console.WriteLine("One Hot Encoding with key type output."); | ||
Console.WriteLine("One Hot Encoding of single column 'Education', with key type output."); | ||
foreach (var element in keyEncodedColumn) | ||
Console.WriteLine(element); | ||
|
||
|
@@ -72,13 +58,20 @@ public static void Example() | |
// 2 | ||
// 2 | ||
// 3 | ||
|
||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make it a separate example, because the multi-output is a different API. #Closed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
private static void PrintDataColumn(IDataView transformedData, string columnName) | ||
{ | ||
var countSelectColumn = transformedData.GetColumn<float[]>(transformedData.Schema[columnName]); | ||
|
||
foreach (var row in countSelectColumn) | ||
{ | ||
for (var i = 0; i < row.Length; i++) | ||
Console.Write($"{row[i]}\t"); | ||
Console.WriteLine(); | ||
} | ||
} | ||
private class DataPoint | ||
{ | ||
public float Label { get; set; } | ||
|
||
public string Education { get; set; } | ||
} | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using Microsoft.ML; | ||
|
||
namespace Samples.Dynamic | ||
{ | ||
public static class OneHotEncodingMultiColumn | ||
{ | ||
public static void Example() | ||
{ | ||
// Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging, | ||
// as well as the source of randomness. | ||
var mlContext = new MLContext(); | ||
|
||
// Get a small dataset as an IEnumerable. | ||
var samples = new List<DataPoint>() | ||
{ | ||
new DataPoint(){ Education = "0-5yrs", ZipCode = "98005" }, | ||
new DataPoint(){ Education = "0-5yrs", ZipCode = "98052" }, | ||
new DataPoint(){ Education = "6-11yrs", ZipCode = "98005" }, | ||
new DataPoint(){ Education = "6-11yrs", ZipCode = "98052" }, | ||
new DataPoint(){ Education = "11-15yrs", ZipCode = "98005" }, | ||
}; | ||
|
||
// Convert training data to IDataView. | ||
var data = mlContext.Data.LoadFromEnumerable(samples); | ||
|
||
// Multi column example : A pipeline for one hot encoding two columns 'Education' and 'ZipCode' | ||
var multiColumnKeyPipeline = mlContext.Transforms.Categorical.OneHotEncoding( | ||
new InputOutputColumnPair[] { | ||
new InputOutputColumnPair("Education"), | ||
new InputOutputColumnPair("ZipCode"), | ||
}); | ||
|
||
// Fit and Transform data. | ||
var transformedData = multiColumnKeyPipeline.Fit(data).Transform(data); | ||
|
||
var convertedData = mlContext.Data.CreateEnumerable<TransformedData>(transformedData, true); | ||
|
||
Console.WriteLine("One Hot Encoding of two columns 'Education' and 'ZipCode'."); | ||
foreach (var item in convertedData) | ||
Console.WriteLine("{0}\t\t\t{1}", string.Join(" ", item.Education), string.Join(" ", item.ZipCode)); | ||
|
||
// 1 0 0 1 0 | ||
// 1 0 0 0 1 | ||
// 0 1 0 1 0 | ||
// 0 1 0 0 1 | ||
// 0 0 1 1 0 | ||
} | ||
|
||
private class DataPoint | ||
{ | ||
public string Education { get; set; } | ||
|
||
public string ZipCode { get; set; } | ||
} | ||
|
||
private class TransformedData | ||
{ | ||
public float[] Education { get; set; } | ||
|
||
public float[] ZipCode { get; set; } | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using Microsoft.ML; | ||
using Microsoft.ML.Data; | ||
using Microsoft.ML.Transforms; | ||
|
||
namespace Samples.Dynamic | ||
{ | ||
public static class OneHotHashEncoding | ||
{ | ||
public static void Example() | ||
{ | ||
// Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging, | ||
// as well as the source of randomness. | ||
var mlContext = new MLContext(); | ||
|
||
// Get a small dataset as an IEnumerable. | ||
var samples = new List<DataPoint>() | ||
{ | ||
new DataPoint(){ Education = "0-5yrs" }, | ||
new DataPoint(){ Education = "0-5yrs" }, | ||
new DataPoint(){ Education = "6-11yrs" }, | ||
new DataPoint(){ Education = "6-11yrs" }, | ||
new DataPoint(){ Education = "11-15yrs" }, | ||
}; | ||
|
||
// Convert training data to IDataView. | ||
var data = mlContext.Data.LoadFromEnumerable(samples); | ||
|
||
// A pipeline for one hot hash encoding the 'Education' column. | ||
var pipeline = mlContext.Transforms.Categorical.OneHotHashEncoding("EducationOneHotHashEncoded", "Education", numberOfBits: 3); | ||
|
||
// Fit and transform the data. | ||
var hashEncodedData = pipeline.Fit(data).Transform(data); | ||
|
||
PrintDataColumn(hashEncodedData, "EducationOneHotHashEncoded"); | ||
// We have 8 slots, because we used numberOfBits = 3. | ||
|
||
// 0 0 0 1 0 0 0 0 | ||
// 0 0 0 1 0 0 0 0 | ||
// 0 0 0 0 1 0 0 0 | ||
// 0 0 0 0 1 0 0 0 | ||
// 0 0 0 0 0 0 0 1 | ||
|
||
// A pipeline for one hot hash encoding the 'Education' column (using keying strategy). | ||
var keyPipeline = mlContext.Transforms.Categorical.OneHotHashEncoding("EducationOneHotHashEncoded", "Education", | ||
outputKind: OneHotEncodingEstimator.OutputKind.Key, | ||
numberOfBits: 3); | ||
|
||
// Fit and transform the data. | ||
var hashKeyEncodedData = keyPipeline.Fit(data).Transform(data); | ||
|
||
// Getting the data of the newly created column, so we can preview it. | ||
var keyEncodedColumn = hashKeyEncodedData.GetColumn<uint>("EducationOneHotHashEncoded"); | ||
|
||
Console.WriteLine("One Hot Hash Encoding of single column 'Education', with key type output."); | ||
foreach (var element in keyEncodedColumn) | ||
Console.WriteLine(element); | ||
|
||
// 4 | ||
// 4 | ||
// 5 | ||
// 5 | ||
// 8 | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. separate example. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
private static void PrintDataColumn(IDataView transformedData, string columnName) | ||
{ | ||
var countSelectColumn = transformedData.GetColumn<float[]>(transformedData.Schema[columnName]); | ||
|
||
foreach (var row in countSelectColumn) | ||
{ | ||
for (var i = 0; i < row.Length; i++) | ||
Console.Write($"{row[i]}\t"); | ||
Console.WriteLine(); | ||
} | ||
} | ||
|
||
private class DataPoint | ||
{ | ||
public string Education { get; set; } | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using Microsoft.ML; | ||
|
||
namespace Samples.Dynamic | ||
{ | ||
public static class OneHotHashEncodingMultiColumn | ||
{ | ||
public static void Example() | ||
{ | ||
// Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging, | ||
// as well as the source of randomness. | ||
var mlContext = new MLContext(); | ||
|
||
// Get a small dataset as an IEnumerable. | ||
var samples = new List<DataPoint>() | ||
{ | ||
new DataPoint(){ Education = "0-5yrs", ZipCode = "98005" }, | ||
new DataPoint(){ Education = "0-5yrs", ZipCode = "98052" }, | ||
new DataPoint(){ Education = "6-11yrs", ZipCode = "98005" }, | ||
new DataPoint(){ Education = "6-11yrs", ZipCode = "98052" }, | ||
new DataPoint(){ Education = "11-15yrs", ZipCode = "98005" }, | ||
}; | ||
|
||
// Convert training data to IDataView. | ||
var data = mlContext.Data.LoadFromEnumerable(samples); | ||
|
||
// Multi column example : A pipeline for one hot has encoding two columns 'Education' and 'ZipCode' | ||
var multiColumnKeyPipeline = mlContext.Transforms.Categorical.OneHotHashEncoding( | ||
new InputOutputColumnPair[] { new InputOutputColumnPair("Education"), new InputOutputColumnPair("ZipCode") }, | ||
numberOfBits: 3); | ||
|
||
// Fit and Transform the data. | ||
var transformedData = multiColumnKeyPipeline.Fit(data).Transform(data); | ||
|
||
var convertedData = mlContext.Data.CreateEnumerable<TransformedData>(transformedData, true); | ||
|
||
Console.WriteLine("One Hot Hash Encoding of two columns 'Education' and 'ZipCode'."); | ||
foreach (var item in convertedData) | ||
Console.WriteLine("{0}\t\t\t{1}", string.Join(" ", item.Education), string.Join(" ", item.ZipCode)); | ||
|
||
// We have 8 slots, because we used numberOfBits = 3. | ||
|
||
// 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 1 | ||
// 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 | ||
// 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 | ||
// 0 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 | ||
// 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 | ||
} | ||
|
||
private class DataPoint | ||
{ | ||
public string Education { get; set; } | ||
|
||
public string ZipCode { get; set; } | ||
} | ||
|
||
private class TransformedData | ||
{ | ||
public float[] Education { get; set; } | ||
|
||
public float[] ZipCode { get; set; } | ||
} | ||
} | ||
} |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is this for? can we remove it?
#Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is for the
OutputKind.Key
parameter that we use in the example belowIn reply to: 272292647 [](ancestors = 272292647)