Skip to content

Update ConsumeModel.cs to enhance it's performance when being called for multiple times #4913

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions src/Microsoft.ML.CodeGenerator/Templates/Console/ConsumeModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,32 +39,39 @@ public virtual string TransformText()
{
public class ConsumeModel
{
// For more info on consuming ML.NET models, visit https://aka.ms/model-builder-consume
private static Lazy<PredictionEngine<ModelInput, ModelOutput>> PredictionEngine = new Lazy<PredictionEngine<ModelInput, ModelOutput>>(CreatePredictionEngine);

// For more info on consuming ML.NET models, visit https://aka.ms/mlnet-consume
// Method for consuming model in your app
public static ModelOutput Predict(ModelInput input)
{

ModelOutput result = PredictionEngine.Value.Predict(input);
return result;
}

public static PredictionEngine<ModelInput, ModelOutput> CreatePredictionEngine()
{
// Create new MLContext
MLContext mlContext = new MLContext();
");
if(HasNormalizeMapping){
this.Write(" \r\n\t\t\t// Register NormalizeMapping\r\n mlContext.ComponentCatalog.Regist" +
"erAssembly(typeof(NormalizeMapping).Assembly);\r\n");
this.Write(" \r\n\t\t\t// Register NormalizeMapping to calculate probabilities for each Label.\r\n " +
" mlContext.ComponentCatalog.RegisterAssembly(typeof(NormalizeMapping).A" +
"ssembly);\r\n");
}
if(HasLabelMapping){
this.Write(" \r\n\t\t\t// Register LabelMapping\r\n mlContext.ComponentCatalog.RegisterAs" +
"sembly(typeof(LabelMapping).Assembly);\r\n");
this.Write(" \r\n\t\t\t// Register LabelMapping to map predicted Labels to their corresponding pro" +
"babilities (likelihood of specified Labels)\r\n mlContext.ComponentCata" +
"log.RegisterAssembly(typeof(LabelMapping).Assembly);\r\n");
}
this.Write("\r\n // Load model & create prediction engine\r\n string modelP" +
"ath = @\"");
this.Write(this.ToStringHelper.ToStringWithCulture(MLNetModelpath));
this.Write(@""";
ITransformer mlModel = mlContext.Model.Load(modelPath, out var modelInputSchema);
var predEngine = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(mlModel);

// Use model to make prediction on input data
ModelOutput result = predEngine.Predict(input);
return result;

return predEngine;
}
}
}
Expand Down
21 changes: 13 additions & 8 deletions src/Microsoft.ML.CodeGenerator/Templates/Console/ConsumeModel.tt
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,35 @@ namespace <#= Namespace #>.Model
{
public class ConsumeModel
{
// For more info on consuming ML.NET models, visit https://aka.ms/model-builder-consume
private static Lazy<PredictionEngine<ModelInput, ModelOutput>> PredictionEngine = new Lazy<PredictionEngine<ModelInput, ModelOutput>>(CreatePredictionEngine);

// For more info on consuming ML.NET models, visit https://aka.ms/mlnet-consume
// Method for consuming model in your app
public static ModelOutput Predict(ModelInput input)
{

ModelOutput result = PredictionEngine.Value.Predict(input);
return result;
}

public static PredictionEngine<ModelInput, ModelOutput> CreatePredictionEngine()
{
// Create new MLContext
MLContext mlContext = new MLContext();
<#if(HasNormalizeMapping){ #>
// Register NormalizeMapping
// Register NormalizeMapping to calculate probabilities for each Label.
mlContext.ComponentCatalog.RegisterAssembly(typeof(NormalizeMapping).Assembly);
<#} #>
<#if(HasLabelMapping){ #>
// Register LabelMapping
// Register LabelMapping to map predicted Labels to their corresponding probabilities (likelihood of specified Labels)
mlContext.ComponentCatalog.RegisterAssembly(typeof(LabelMapping).Assembly);
<#} #>

// Load model & create prediction engine
string modelPath = @"<#= MLNetModelpath #>";
ITransformer mlModel = mlContext.Model.Load(modelPath, out var modelInputSchema);
var predEngine = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(mlModel);

// Use model to make prediction on input data
ModelOutput result = predEngine.Predict(input);
return result;

return predEngine;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,30 @@ namespace Test.Model
{
public class ConsumeModel
{
// For more info on consuming ML.NET models, visit https://aka.ms/model-builder-consume
private static Lazy<PredictionEngine<ModelInput, ModelOutput>> PredictionEngine = new Lazy<PredictionEngine<ModelInput, ModelOutput>>(CreatePredictionEngine);

// For more info on consuming ML.NET models, visit https://aka.ms/mlnet-consume
// Method for consuming model in your app
public static ModelOutput Predict(ModelInput input)
{
ModelOutput result = PredictionEngine.Value.Predict(input);
return result;
}

public static PredictionEngine<ModelInput, ModelOutput> CreatePredictionEngine()
{
// Create new MLContext
MLContext mlContext = new MLContext();

// Register LabelMapping
// Register LabelMapping to map predicted Labels to their corresponding probabilities (likelihood of specified Labels)
mlContext.ComponentCatalog.RegisterAssembly(typeof(LabelMapping).Assembly);

// Load model & create prediction engine
string modelPath = @"\path\to\model";
ITransformer mlModel = mlContext.Model.Load(modelPath, out var modelInputSchema);
var predEngine = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(mlModel);

// Use model to make prediction on input data
ModelOutput result = predEngine.Predict(input);
return result;
return predEngine;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,33 @@ namespace CodeGenTest.Model
{
public class ConsumeModel
{
// For more info on consuming ML.NET models, visit https://aka.ms/model-builder-consume
private static Lazy<PredictionEngine<ModelInput, ModelOutput>> PredictionEngine = new Lazy<PredictionEngine<ModelInput, ModelOutput>>(CreatePredictionEngine);

// For more info on consuming ML.NET models, visit https://aka.ms/mlnet-consume
// Method for consuming model in your app
public static ModelOutput Predict(ModelInput input)
{
ModelOutput result = PredictionEngine.Value.Predict(input);
return result;
}

public static PredictionEngine<ModelInput, ModelOutput> CreatePredictionEngine()
{
// Create new MLContext
MLContext mlContext = new MLContext();

// Register NormalizeMapping
// Register NormalizeMapping to calculate probabilities for each Label.
mlContext.ComponentCatalog.RegisterAssembly(typeof(NormalizeMapping).Assembly);

// Register LabelMapping
// Register LabelMapping to map predicted Labels to their corresponding probabilities (likelihood of specified Labels)
mlContext.ComponentCatalog.RegisterAssembly(typeof(LabelMapping).Assembly);

// Load model & create prediction engine
string modelPath = @"/path/to/model";
ITransformer mlModel = mlContext.Model.Load(modelPath, out var modelInputSchema);
var predEngine = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(mlModel);

// Use model to make prediction on input data
ModelOutput result = predEngine.Predict(input);
return result;
return predEngine;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,18 @@ namespace TestNamespace.Model
{
public class ConsumeModel
{
// For more info on consuming ML.NET models, visit https://aka.ms/model-builder-consume
private static Lazy<PredictionEngine<ModelInput, ModelOutput>> PredictionEngine = new Lazy<PredictionEngine<ModelInput, ModelOutput>>(CreatePredictionEngine);

// For more info on consuming ML.NET models, visit https://aka.ms/mlnet-consume
// Method for consuming model in your app
public static ModelOutput Predict(ModelInput input)
{
ModelOutput result = PredictionEngine.Value.Predict(input);
return result;
}

public static PredictionEngine<ModelInput, ModelOutput> CreatePredictionEngine()
{
// Create new MLContext
MLContext mlContext = new MLContext();

Expand All @@ -28,9 +35,7 @@ namespace TestNamespace.Model
ITransformer mlModel = mlContext.Model.Load(modelPath, out var modelInputSchema);
var predEngine = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(mlModel);

// Use model to make prediction on input data
ModelOutput result = predEngine.Predict(input);
return result;
return predEngine;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,18 @@ namespace TestNamespace.Model
{
public class ConsumeModel
{
// For more info on consuming ML.NET models, visit https://aka.ms/model-builder-consume
private static Lazy<PredictionEngine<ModelInput, ModelOutput>> PredictionEngine = new Lazy<PredictionEngine<ModelInput, ModelOutput>>(CreatePredictionEngine);

// For more info on consuming ML.NET models, visit https://aka.ms/mlnet-consume
// Method for consuming model in your app
public static ModelOutput Predict(ModelInput input)
{
ModelOutput result = PredictionEngine.Value.Predict(input);
return result;
}

public static PredictionEngine<ModelInput, ModelOutput> CreatePredictionEngine()
{
// Create new MLContext
MLContext mlContext = new MLContext();

Expand All @@ -28,9 +35,7 @@ namespace TestNamespace.Model
ITransformer mlModel = mlContext.Model.Load(modelPath, out var modelInputSchema);
var predEngine = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(mlModel);

// Use model to make prediction on input data
ModelOutput result = predEngine.Predict(input);
return result;
return predEngine;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// This file was auto-generated by ML.NET Model Builder.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.ML;
using Namespace.Model;

namespace Namespace.Model
{
public class ConsumeModel
{
private static Lazy<PredictionEngine<ModelInput, ModelOutput>> PredictionEngine = new Lazy<PredictionEngine<ModelInput, ModelOutput>>(CreatePredictionEngine);

// For more info on consuming ML.NET models, visit https://aka.ms/mlnet-consume
// Method for consuming model in your app
public static ModelOutput Predict(ModelInput input)
{
ModelOutput result = PredictionEngine.Value.Predict(input);
return result;
}

public static PredictionEngine<ModelInput, ModelOutput> CreatePredictionEngine()
{
// Create new MLContext
MLContext mlContext = new MLContext();

// Register NormalizeMapping to calculate probabilities for each Label.
mlContext.ComponentCatalog.RegisterAssembly(typeof(NormalizeMapping).Assembly);

// Register LabelMapping to map predicted Labels to their corresponding probabilities (likelihood of specified Labels)
mlContext.ComponentCatalog.RegisterAssembly(typeof(LabelMapping).Assembly);

// Load model & create prediction engine
string modelPath = @"/path/to/model";
ITransformer mlModel = mlContext.Model.Load(modelPath, out var modelInputSchema);
var predEngine = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(mlModel);

return predEngine;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,21 @@ public void TestPredictProgram_WithSampleData()
};
Approvals.Verify(predictProgram.TransformText());
}

[Fact]
[UseReporter(typeof(DiffReporter))]
[MethodImpl(MethodImplOptions.NoInlining)]
public void TestConsumeModel()
{
var consumeModel = new ConsumeModel()
{
Namespace = "Namespace",
HasNormalizeMapping = true,
HasLabelMapping = true,
MLNetModelpath = @"/path/to/model",
};

Approvals.Verify(consumeModel.TransformText());
}
}
}