Skip to content

Commit 5db5127

Browse files
Use inline training data in generated Console Project file. (#4907)
* add GenerateSampleData in util * add test * fix some bugs * fix bugs, and add more tests * fix test
1 parent 5449b91 commit 5db5127

16 files changed

+329
-293
lines changed

src/Microsoft.ML.AutoML/TrainerExtensions/RecommendationTrainerExtensions.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ public ITrainerEsitmator CreateInstance(MLContext mlContext, IEnumerable<Sweepab
1818
options.LabelColumnName = columnInfo.LabelColumnName;
1919
options.MatrixColumnIndexColumnName = columnInfo.UserIdColumnName;
2020
options.MatrixRowIndexColumnName = columnInfo.ItemIdColumnName;
21+
options.Quiet = true;
2122
return mlContext.Recommendation().Trainers.MatrixFactorization(options);
2223
}
2324

src/Microsoft.ML.CodeGenerator/CodeGenerator/CSharp/AzureCodeGenerator/AzureAttachConsoleAppCodeGenerator.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
using Microsoft.ML.CodeGenerator.CSharp;
1515
using Microsoft.ML.CodeGenerator.Templates.Azure.Console;
1616
using Microsoft.ML.CodeGenerator.Templates.Console;
17+
using Microsoft.ML.CodeGenerator.Utilities;
1718
using Microsoft.ML.Transforms;
19+
using Tensorflow.Operations.Losses;
1820

1921
namespace Microsoft.ML.CodeGenerator.CodeGenerator.CSharp
2022
{
@@ -79,21 +81,21 @@ public AzureAttachConsoleAppCodeGenerator(Pipeline pipeline, ColumnInferenceResu
7981

8082
var columns = _columnInferenceResult.TextLoaderOptions.Columns;
8183
var featuresList = columns.Where((str) => str.Name != _settings.LabelName).Select((str) => str.Name).ToList();
84+
var sampleResult = Utils.GenerateSampleData(_settings.TrainDataset, _columnInferenceResult);
8285
PredictProgram = new CSharpCodeFile()
8386
{
8487
File = new PredictProgram()
8588
{
8689
TaskType = _settings.MlTask.ToString(),
8790
LabelName = _settings.LabelName,
8891
Namespace = _nameSpaceValue,
89-
TestDataPath = _settings.TestDataset,
90-
TrainDataPath = _settings.TrainDataset,
9192
AllowQuoting = _columnInferenceResult.TextLoaderOptions.AllowQuoting,
9293
AllowSparse = _columnInferenceResult.TextLoaderOptions.AllowSparse,
9394
HasHeader = _columnInferenceResult.TextLoaderOptions.HasHeader,
9495
Separator = _columnInferenceResult.TextLoaderOptions.Separators.FirstOrDefault(),
9596
Target = _settings.Target,
9697
Features = featuresList,
98+
SampleData = sampleResult,
9799
}.TransformText(),
98100
Name = "Program.cs",
99101
};

src/Microsoft.ML.CodeGenerator/CodeGenerator/CSharp/CodeGenerator.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,19 +434,19 @@ private string GeneratePredictProgramCSFileContent(string namespaceValue)
434434
{
435435
var columns = _columnInferenceResult.TextLoaderOptions.Columns;
436436
var featuresList = columns.Where((str) => str.Name != _settings.LabelName).Select((str) => str.Name).ToList();
437+
var sampleData = Utils.GenerateSampleData(_settings.TrainDataset, _columnInferenceResult);
437438
PredictProgram predictProgram = new PredictProgram()
438439
{
439440
TaskType = _settings.MlTask.ToString(),
440441
LabelName = _settings.LabelName,
441442
Namespace = namespaceValue,
442-
TestDataPath = _settings.TestDataset,
443-
TrainDataPath = _settings.TrainDataset,
444443
HasHeader = _columnInferenceResult.TextLoaderOptions.HasHeader,
445444
Separator = _columnInferenceResult.TextLoaderOptions.Separators.FirstOrDefault(),
446445
AllowQuoting = _columnInferenceResult.TextLoaderOptions.AllowQuoting,
447446
AllowSparse = _columnInferenceResult.TextLoaderOptions.AllowSparse,
448447
Features = featuresList,
449448
Target = _settings.Target,
449+
SampleData = sampleData,
450450
};
451451
return predictProgram.TransformText();
452452
}

src/Microsoft.ML.CodeGenerator/Templates/Console/PredictProgram.cs

Lines changed: 26 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -32,32 +32,29 @@ public virtual string TransformText()
3232
} else if(Target == CSharp.GenerateTarget.ModelBuilder){
3333
MB_Annotation();
3434
}
35-
this.Write("\r\nusing System;\r\nusing System.IO;\r\nusing System.Linq;\r\nusing Microsoft.ML;\r\nusing" +
36-
" ");
35+
this.Write("\r\nusing System;\r\nusing ");
3736
this.Write(this.ToStringHelper.ToStringWithCulture(Namespace));
3837
this.Write(".Model;\r\n\r\nnamespace ");
3938
this.Write(this.ToStringHelper.ToStringWithCulture(Namespace));
40-
this.Write(".ConsoleApp\r\n{\r\n class Program\r\n {\r\n //Dataset to use for prediction" +
41-
"s \r\n");
42-
if(string.IsNullOrEmpty(TestDataPath)){
43-
this.Write(" private const string DATA_FILEPATH = @\"");
44-
this.Write(this.ToStringHelper.ToStringWithCulture(TrainDataPath));
45-
this.Write("\";\r\n");
46-
} else{
47-
this.Write(" private const string DATA_FILEPATH = @\"");
48-
this.Write(this.ToStringHelper.ToStringWithCulture(TestDataPath));
49-
this.Write("\";\r\n");
50-
}
51-
this.Write(@"
52-
static void Main(string[] args)
53-
{
54-
// Create single instance of sample data from first line of dataset for model input
55-
ModelInput sampleData = CreateSingleDataSample(DATA_FILEPATH);
56-
57-
// Make a single prediction on the sample data and print results
58-
var predictionResult = ConsumeModel.Predict(sampleData);
59-
60-
Console.WriteLine(""Using model to make single prediction -- Comparing actual ");
39+
this.Write(".ConsoleApp\r\n{\r\n class Program\r\n {\r\n static void Main(string[] args)" +
40+
"\r\n {\r\n // Create single instance of sample data from first lin" +
41+
"e of dataset for model input\r\n");
42+
if(SampleData != null) {
43+
this.Write(" ModelInput sampleData = new ModelInput()\r\n {\r\n");
44+
foreach(var kv in SampleData){
45+
this.Write(" ");
46+
this.Write(this.ToStringHelper.ToStringWithCulture(kv.Key));
47+
this.Write("=");
48+
this.Write(this.ToStringHelper.ToStringWithCulture(kv.Value));
49+
this.Write(",\r\n");
50+
}
51+
this.Write(" };\r\n");
52+
}else{
53+
this.Write(" ModelInput sampleData = new ModelInput();\r\n");
54+
}
55+
this.Write("\r\n\t\t\t// Make a single prediction on the sample data and print results\r\n\t\t\tvar pre" +
56+
"dictionResult = ConsumeModel.Predict(sampleData);\r\n\r\n\t\t\tConsole.WriteLine(\"Using" +
57+
" model to make single prediction -- Comparing actual ");
6158
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.Normalize(LabelName)));
6259
this.Write(" with predicted ");
6360
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.Normalize(LabelName)));
@@ -70,81 +67,35 @@ static void Main(string[] args)
7067
this.Write("}\");\r\n");
7168
}
7269
if("BinaryClassification".Equals(TaskType) ){
73-
this.Write("\t\t\tConsole.WriteLine($\"\\n\\nActual ");
74-
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.Normalize(LabelName)));
75-
this.Write(": {sampleData.");
76-
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.Normalize(LabelName)));
77-
this.Write("} \\nPredicted ");
70+
this.Write("\t\t\tConsole.WriteLine($\"\\n\\nPredicted ");
7871
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.Normalize(LabelName)));
7972
this.Write(": {predictionResult.Prediction}\\n\\n\");\r\n");
8073
} else if("Regression".Equals(TaskType) || "Recommendation".Equals(TaskType)){
81-
this.Write("\t\t\tConsole.WriteLine($\"\\n\\nActual ");
82-
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.Normalize(LabelName)));
83-
this.Write(": {sampleData.");
84-
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.Normalize(LabelName)));
85-
this.Write("} \\nPredicted ");
74+
this.Write("\t\t\tConsole.WriteLine($\"\\n\\nPredicted ");
8675
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.Normalize(LabelName)));
8776
this.Write(": {predictionResult.Score}\\n\\n\");\r\n");
8877
} else if("MulticlassClassification".Equals(TaskType)){
89-
this.Write("\t\t\tConsole.WriteLine($\"\\n\\nActual ");
90-
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.Normalize(LabelName)));
91-
this.Write(": {sampleData.");
92-
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.Normalize(LabelName)));
93-
this.Write("} \\nPredicted ");
78+
this.Write("\t\t\tConsole.WriteLine($\"\\n\\nPredicted ");
9479
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.Normalize(LabelName)));
9580
this.Write(" value {predictionResult.Prediction} \\nPredicted ");
9681
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.Normalize(LabelName)));
9782
this.Write(" scores: [{String.Join(\",\", predictionResult.Score)}]\\n\\n\");\r\n");
9883
}
99-
this.Write(@" Console.WriteLine(""=============== End of process, hit any key to finish ==============="");
100-
Console.ReadKey();
101-
}
102-
103-
// Change this code to create your own sample data
104-
#region CreateSingleDataSample
105-
// Method to load single row of dataset to try a single prediction
106-
private static ModelInput CreateSingleDataSample(string dataFilePath)
107-
{
108-
// Create MLContext
109-
MLContext mlContext = new MLContext();
110-
111-
// Load dataset
112-
IDataView dataView = mlContext.Data.LoadFromTextFile<ModelInput>(
113-
path: dataFilePath,
114-
hasHeader : ");
115-
this.Write(this.ToStringHelper.ToStringWithCulture(HasHeader.ToString().ToLowerInvariant()));
116-
this.Write(",\r\n separatorChar : \'");
117-
this.Write(this.ToStringHelper.ToStringWithCulture(Regex.Escape(Separator.ToString())));
118-
this.Write("\',\r\n allowQuoting : ");
119-
this.Write(this.ToStringHelper.ToStringWithCulture(AllowQuoting.ToString().ToLowerInvariant()));
120-
this.Write(",\r\n allowSparse: ");
121-
this.Write(this.ToStringHelper.ToStringWithCulture(AllowSparse.ToString().ToLowerInvariant()));
122-
this.Write(@");
123-
124-
// Use first line of dataset as model input
125-
// You can replace this with new test data (hardcoded or from end-user application)
126-
ModelInput sampleForPrediction = mlContext.Data.CreateEnumerable<ModelInput>(dataView, false)
127-
.First();
128-
return sampleForPrediction;
129-
}
130-
#endregion
131-
}
132-
}
133-
");
84+
this.Write(" Console.WriteLine(\"=============== End of process, hit any key to fin" +
85+
"ish ===============\");\r\n Console.ReadKey();\r\n }\r\n }\r\n}\r\n");
13486
return this.GenerationEnvironment.ToString();
13587
}
13688

13789
public string TaskType {get;set;}
13890
public string Namespace {get;set;}
13991
public string LabelName {get;set;}
140-
public string TestDataPath {get;set;}
141-
public string TrainDataPath {get;set;}
14292
public char Separator {get;set;}
14393
public bool AllowQuoting {get;set;}
14494
public bool AllowSparse {get;set;}
14595
public bool HasHeader {get;set;}
14696
public IList<string> Features {get;set;}
14797
internal CSharp.GenerateTarget Target {get;set;}
98+
public IDictionary<string, string> SampleData {get;set;}
14899

149100

150101
void CLI_Annotation()

src/Microsoft.ML.CodeGenerator/Templates/Console/PredictProgram.tt

Lines changed: 15 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,25 @@
1212
<# } #>
1313

1414
using System;
15-
using System.IO;
16-
using System.Linq;
17-
using Microsoft.ML;
1815
using <#= Namespace #>.Model;
1916

2017
namespace <#= Namespace #>.ConsoleApp
2118
{
2219
class Program
2320
{
24-
//Dataset to use for predictions
25-
<#if(string.IsNullOrEmpty(TestDataPath)){ #>
26-
private const string DATA_FILEPATH = @"<#= TrainDataPath #>";
27-
<# } else{ #>
28-
private const string DATA_FILEPATH = @"<#= TestDataPath #>";
29-
<# } #>
30-
3121
static void Main(string[] args)
3222
{
33-
// Create single instance of sample data from first line of dataset for model input
34-
ModelInput sampleData = CreateSingleDataSample(DATA_FILEPATH);
23+
// Create single instance of sample data from first line of dataset for model input
24+
<# if(SampleData != null) {#>
25+
ModelInput sampleData = new ModelInput()
26+
{
27+
<# foreach(var kv in SampleData){ #>
28+
<#= kv.Key #>=<#= kv.Value #>,
29+
<#}#>
30+
};
31+
<#}else{#>
32+
ModelInput sampleData = new ModelInput();
33+
<#}#>
3534

3635
// Make a single prediction on the sample data and print results
3736
var predictionResult = ConsumeModel.Predict(sampleData);
@@ -41,51 +40,26 @@ namespace <#= Namespace #>.ConsoleApp
4140
Console.WriteLine($"<#= label #>: {sampleData.<#= Utils.Normalize(label) #>}");
4241
<#}#>
4342
<#if("BinaryClassification".Equals(TaskType) ){ #>
44-
Console.WriteLine($"\n\nActual <#= Utils.Normalize(LabelName) #>: {sampleData.<#= Utils.Normalize(LabelName) #>} \nPredicted <#= Utils.Normalize(LabelName) #>: {predictionResult.Prediction}\n\n");
43+
Console.WriteLine($"\n\nPredicted <#= Utils.Normalize(LabelName) #>: {predictionResult.Prediction}\n\n");
4544
<#} else if("Regression".Equals(TaskType) || "Recommendation".Equals(TaskType)){#>
46-
Console.WriteLine($"\n\nActual <#= Utils.Normalize(LabelName) #>: {sampleData.<#= Utils.Normalize(LabelName) #>} \nPredicted <#= Utils.Normalize(LabelName) #>: {predictionResult.Score}\n\n");
45+
Console.WriteLine($"\n\nPredicted <#= Utils.Normalize(LabelName) #>: {predictionResult.Score}\n\n");
4746
<#} else if("MulticlassClassification".Equals(TaskType)){#>
48-
Console.WriteLine($"\n\nActual <#= Utils.Normalize(LabelName) #>: {sampleData.<#= Utils.Normalize(LabelName) #>} \nPredicted <#= Utils.Normalize(LabelName) #> value {predictionResult.Prediction} \nPredicted <#= Utils.Normalize(LabelName) #> scores: [{String.Join(",", predictionResult.Score)}]\n\n");
47+
Console.WriteLine($"\n\nPredicted <#= Utils.Normalize(LabelName) #> value {predictionResult.Prediction} \nPredicted <#= Utils.Normalize(LabelName) #> scores: [{String.Join(",", predictionResult.Score)}]\n\n");
4948
<#} #>
5049
Console.WriteLine("=============== End of process, hit any key to finish ===============");
5150
Console.ReadKey();
5251
}
53-
54-
// Change this code to create your own sample data
55-
#region CreateSingleDataSample
56-
// Method to load single row of dataset to try a single prediction
57-
private static ModelInput CreateSingleDataSample(string dataFilePath)
58-
{
59-
// Create MLContext
60-
MLContext mlContext = new MLContext();
61-
62-
// Load dataset
63-
IDataView dataView = mlContext.Data.LoadFromTextFile<ModelInput>(
64-
path: dataFilePath,
65-
hasHeader : <#= HasHeader.ToString().ToLowerInvariant() #>,
66-
separatorChar : '<#= Regex.Escape(Separator.ToString()) #>',
67-
allowQuoting : <#= AllowQuoting.ToString().ToLowerInvariant() #>,
68-
allowSparse: <#= AllowSparse.ToString().ToLowerInvariant() #>);
69-
70-
// Use first line of dataset as model input
71-
// You can replace this with new test data (hardcoded or from end-user application)
72-
ModelInput sampleForPrediction = mlContext.Data.CreateEnumerable<ModelInput>(dataView, false)
73-
.First();
74-
return sampleForPrediction;
75-
}
76-
#endregion
7752
}
7853
}
7954
<#+
8055
public string TaskType {get;set;}
8156
public string Namespace {get;set;}
8257
public string LabelName {get;set;}
83-
public string TestDataPath {get;set;}
84-
public string TrainDataPath {get;set;}
8558
public char Separator {get;set;}
8659
public bool AllowQuoting {get;set;}
8760
public bool AllowSparse {get;set;}
8861
public bool HasHeader {get;set;}
8962
public IList<string> Features {get;set;}
9063
internal CSharp.GenerateTarget Target {get;set;}
64+
public IDictionary<string, string> SampleData {get;set;}
9165
#>

0 commit comments

Comments
 (0)