Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
<ItemGroup>
<PackageReference Include="System.CodeDom" Version="$(SystemCodeDomPackageVersion)" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Workspaces" Version="$(MicrosoftCodeAnalysisCSharpVersion)" />
<PackageReference Include="System.Collections.Specialized" Version="4.3.0" />
</ItemGroup>

<ItemGroup>
Expand Down
72 changes: 52 additions & 20 deletions src/Microsoft.ML.CodeGenerator/Utils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.Specialized;
using System.Globalization;
using System.IO;
using System.Linq;
Expand Down Expand Up @@ -49,23 +51,35 @@ internal static IDictionary<string, string> GenerateSampleData(string inputFile,

internal static IDictionary<string, string> GenerateSampleData(IDataView dataView, ColumnInferenceResults columnInference)
{
var featureColumns = dataView.Schema.AsEnumerable().Where(col => col.Name != columnInference.ColumnInformation.LabelColumnName && !columnInference.ColumnInformation.IgnoredColumnNames.Contains(col.Name));
var featureColumns = dataView.Schema.ToList().FindAll(
col => col.Name != columnInference.ColumnInformation.LabelColumnName &&
!columnInference.ColumnInformation.IgnoredColumnNames.Contains(col.Name));
var rowCursor = dataView.GetRowCursor(featureColumns);

var sampleData = featureColumns.Select(column => new { key = Utils.Normalize(column.Name), val = "null" }).ToDictionary(x => x.key, x => x.val);
OrderedDictionary sampleData = new OrderedDictionary();
// Get normalized and unique column names. If there are duplicate column names, the
// differentiator suffix '_col_x' will be added to each column name, where 'x' is
// the load order for a given column.
List<string> normalizedColumnNames= GenerateColumnNames(featureColumns.Select(column => column.Name).ToList());
foreach (string columnName in normalizedColumnNames)
sampleData[columnName] = null;
if (rowCursor.MoveNext())
{
var getGetGetterMethod = typeof(Utils).GetMethod(nameof(Utils.GetValueFromColumn), BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic);

foreach (var column in featureColumns)
// Access each feature column name through its index in featureColumns
// as there may exist duplicate column names. In this case, sampleData
// column names may have the differentiator suffix of '_col_x' added,
// which requires access to each column name in through its index.
for(int i = 0; i < featureColumns.Count(); i++)
{
var getGeneraicGetGetterMethod = getGetGetterMethod.MakeGenericMethod(column.Type.RawType);
string val = getGeneraicGetGetterMethod.Invoke(null, new object[] { rowCursor, column }) as string;
sampleData[Utils.Normalize(column.Name)] = val;
var getGenericGetGetterMethod = getGetGetterMethod.MakeGenericMethod(featureColumns[i].Type.RawType);
string val = getGenericGetGetterMethod.Invoke(null, new object[] { rowCursor, featureColumns[i] }) as string;
sampleData[i] = val;
}
}

return sampleData;
return sampleData.Cast<DictionaryEntry>().ToDictionary(k => (string)k.Key, v => (string)v.Value);
}

internal static string GetValueFromColumn<T>(DataViewRowCursor rowCursor, DataViewSchema.Column column)
Expand Down Expand Up @@ -247,8 +261,7 @@ internal static int CreateSolutionFile(string solutionFile, string outputPath)
internal static IList<string> GenerateClassLabels(ColumnInferenceResults columnInferenceResults, IDictionary<string, CodeGeneratorSettings.ColumnMapping> columnMapping = default)
{
IList<string> result = new List<string>();
List<string> normalizedColumnNames = new List<string>();
bool duplicateColumnNamesExist = false;
List<string> columnNames = new List<string>();
foreach (var column in columnInferenceResults.TextLoaderOptions.Columns)
{
StringBuilder sb = new StringBuilder();
Expand Down Expand Up @@ -284,28 +297,47 @@ internal static IList<string> GenerateClassLabels(ColumnInferenceResults columnI
result.Add($"[ColumnName(\"{columnName}\"), LoadColumn({column.Source[0].Min})]");
}
sb.Append(" ");
string normalizedColumnName = Utils.Normalize(column.Name);
// Put placeholder for normalized and unique version of column name
if (!duplicateColumnNamesExist && normalizedColumnNames.Contains(normalizedColumnName))
duplicateColumnNamesExist = true;
normalizedColumnNames.Add(normalizedColumnName);
columnNames.Add(column.Name);
result.Add(sb.ToString());
result.Add("\r\n");
}
// Get normalized and unique column names. If there are duplicate column names, the
// differentiator suffix '_col_x' will be added to each column name, where 'x' is
// the load order for a given column.
List<string> normalizedColumnNames = GenerateColumnNames(columnNames);
for (int i = 1; i < result.Count; i+=3)
{
// Get normalized column name for correctly typed class property name
// If duplicate column names exist, the only way to ensure all generated column names are unique is to add
// a differentiator depending on the column load order from dataset.
if (duplicateColumnNamesExist)
result[i] += normalizedColumnNames[i/3] + $"_col_{i/3}";
else
result[i] += normalizedColumnNames[i/3];
result[i] += normalizedColumnNames[i/3];
result[i] += "{get; set;}";
}
return result;
}

/// <summary>
/// Take a list of column names that may not be normalized to fit property name standards
/// and contain duplicate column names. Return unique and normalized column names.
/// </summary>
/// <param name="columnNames">Column names to normalize.</param>
/// <returns>A list of strings that contain normalized and unique column names.</returns>
internal static List<string> GenerateColumnNames(List<string> columnNames)
{
for (int i = 0; i < columnNames.Count; i++)
columnNames[i] = Utils.Normalize(columnNames[i]);
// Check if there are any duplicates in columnNames by obtaining its set
// and seeing whether or not they are the same size.
HashSet<String> columnNamesSet = new HashSet<String>(columnNames);
// If there are duplicates, add the differentiator suffix '_col_x'
// to each normalized column name, where 'x' is the load
// order for a given column from dataset.
if (columnNamesSet.Count != columnNames.Count)
{
for (int i = 0; i < columnNames.Count; i++)
columnNames[i] += String.Concat("_col_", i);
}
return columnNames;
}

internal static string GetSymbolOfDataKind(DataKind dataKind)
{
switch (dataKind)
Expand Down
89 changes: 89 additions & 0 deletions test/Microsoft.ML.CodeGenerator.Tests/UtilTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,57 @@ class TestClass
public bool T { get; set; }
}

class TestClassContainsDuplicates
{
[LoadColumn(0)]
public string Label_col_0 { get; set; }

[LoadColumn(1)]
public string STR_col_1 { get; set; }

[LoadColumn(2)]
public string STR_col_2 { get; set; }

[LoadColumn(3)]
public string PATH_col_3 { get; set; }

[LoadColumn(4)]
public int INT_col_4 { get; set; }

[LoadColumn(5)]
public Double DOUBLE_col_5 { get; set; }

[LoadColumn(6)]
public float FLOAT_col_6 { get; set; }

[LoadColumn(7)]
public float FLOAT_col_7 { get; set; }

[LoadColumn(8)]
public string TrickySTR_col_8 { get; set; }

[LoadColumn(9)]
public float SingleNan_col_9 { get; set; }

[LoadColumn(10)]
public float SinglePositiveInfinity_col_10 { get; set; }

[LoadColumn(11)]
public float SingleNegativeInfinity_col_11 { get; set; }

[LoadColumn(12)]
public float SingleNegativeInfinity_col_12 { get; set; }

[LoadColumn(13)]
public string EmptyString_col_13 { get; set; }

[LoadColumn(14)]
public bool One_col_14 { get; set; }

[LoadColumn(15)]
public bool T_col_15 { get; set; }
}

public class UtilTest : BaseTestClass
{
public UtilTest(ITestOutputHelper output) : base(output)
Expand Down Expand Up @@ -97,6 +148,44 @@ public async Task TestGenerateSampleDataAsync()
}
}

[Fact]
public async Task TestGenerateSampleDataAsyncDuplicateColumnNames()
{
var filePath = "sample2.txt";
using (var file = new StreamWriter(filePath))
{
await file.WriteLineAsync("Label,STR,STR,PATH,INT,DOUBLE,FLOAT,FLOAT,TrickySTR,SingleNan,SinglePositiveInfinity,SingleNegativeInfinity,SingleNegativeInfinity,EmptyString,One,T");
await file.WriteLineAsync("label1,feature1,feature2,/path/to/file,2,1.2,1.223E+10,1.223E+11,ab\"\';@#$%^&-++==,NaN,Infinity,-Infinity,-Infinity,,1,T");
await file.FlushAsync();
file.Close();
var context = new MLContext();
var dataView = context.Data.LoadFromTextFile<TestClassContainsDuplicates>(filePath, separatorChar: ',', hasHeader: true);
var columnInference = new ColumnInferenceResults()
{
ColumnInformation = new ColumnInformation()
{
LabelColumnName = "Label_col_0"
}
};
var sampleData = Utils.GenerateSampleData(dataView, columnInference);
Assert.Equal("@\"feature1\"", sampleData["STR_col_1"]);
Assert.Equal("@\"feature2\"", sampleData["STR_col_2"]);
Assert.Equal("@\"/path/to/file\"", sampleData["PATH_col_3"]);
Assert.Equal("2", sampleData["INT_col_4"]);
Assert.Equal("1.2", sampleData["DOUBLE_col_5"]);
Assert.Equal("1.223E+10F", sampleData["FLOAT_col_6"]);
Assert.Equal("1.223E+11F", sampleData["FLOAT_col_7"]);
Assert.Equal("@\"ab\\\"\';@#$%^&-++==\"", sampleData["TrickySTR_col_8"]);
Assert.Equal($"Single.NaN", sampleData["SingleNan_col_9"]);
Assert.Equal($"Single.PositiveInfinity", sampleData["SinglePositiveInfinity_col_10"]);
Assert.Equal($"Single.NegativeInfinity", sampleData["SingleNegativeInfinity_col_11"]);
Assert.Equal($"Single.NegativeInfinity", sampleData["SingleNegativeInfinity_col_12"]);
Assert.Equal("@\"\"", sampleData["EmptyString_col_13"]);
Assert.Equal($"true", sampleData["One_col_14"]);
Assert.Equal($"true", sampleData["T_col_15"]);
}
}

[Fact]
public void NormalizeTest()
{
Expand Down