Open
Description
I created the model using the Visual Studio graphical assistant. The studio generated a project and code for me. Everything suits me, except that at the input I have to give the path to the file as a string. Can I transfer an image to the neural network in Bitmap format?
I find many examples, but they are all different from my code, like I have a different version. I am trying to adapt the code I found, but I come across various errors.
Please explain how to do this with Microsoft.ML 1.5 at the moment? How to adapt the generated code below for the use of Bitmap images (not path input)?
My ModelInput.cs
// This file was auto-generated by ML.NET Model Builder.
using Microsoft.ML.Data;
namespace MLTestAppML.Model
{
public class ModelInput
{
[ColumnName("Label"), LoadColumn(0)]
public string Label { get; set; }
[ColumnName("ImageSource"), LoadColumn(1)]
public string ImageSource { get; set; }
}
}
My ModelBuilder.cs
// This file was auto-generated by ML.NET Model Builder.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using MLTestAppML.Model;
using Microsoft.ML.Vision;
namespace MLTestAppML.ConsoleApp
{
public static class ModelBuilder
{
private static string TRAIN_DATA_FILEPATH = @"C:\Users\aaa\AppData\Local\Temp\e43005d1-d83d-4f35-ab8d-7dbc3e693583.tsv";
private static string MODEL_FILEPATH = @"C:\Users\aaa\AppData\Local\Temp\MLVSTools\MLTestAppML\MLTestAppML.Model\MLModel.zip";
// Create MLContext to be shared across the model creation workflow objects
// Set a random seed for repeatable/deterministic results across multiple trainings.
private static MLContext mlContext = new MLContext(seed: 1);
public static void CreateModel()
{
// Load Data
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<ModelInput>(
path: TRAIN_DATA_FILEPATH,
hasHeader: true,
separatorChar: '\t',
allowQuoting: true,
allowSparse: false);
// Build training pipeline
IEstimator<ITransformer> trainingPipeline = BuildTrainingPipeline(mlContext);
// Train Model
ITransformer mlModel = TrainModel(mlContext, trainingDataView, trainingPipeline);
// Evaluate quality of Model
Evaluate(mlContext, trainingDataView, trainingPipeline);
// Save model
SaveModel(mlContext, mlModel, MODEL_FILEPATH, trainingDataView.Schema);
}
public static IEstimator<ITransformer> BuildTrainingPipeline(MLContext mlContext)
{
// Data process configuration with pipeline data transformations
var dataProcessPipeline = mlContext.Transforms.Conversion.MapValueToKey("Label", "Label")
.Append(mlContext.Transforms.LoadRawImageBytes("ImageSource_featurized", null, "ImageSource"))
.Append(mlContext.Transforms.CopyColumns("Features", "ImageSource_featurized"));
// Set the training algorithm
var trainer = mlContext.MulticlassClassification.Trainers.ImageClassification(new ImageClassificationTrainer.Options() { LabelColumnName = "Label", FeatureColumnName = "Features" })
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel", "PredictedLabel"));
var trainingPipeline = dataProcessPipeline.Append(trainer);
return trainingPipeline;
}
public static ITransformer TrainModel(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline)
{
Console.WriteLine("=============== Training model ===============");
ITransformer model = trainingPipeline.Fit(trainingDataView);
Console.WriteLine("=============== End of training process ===============");
return model;
}
private static void Evaluate(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline)
{
// Cross-Validate with single dataset (since we don't have two datasets, one for training and for evaluate)
// in order to evaluate and get the model's accuracy metrics
Console.WriteLine("=============== Cross-validating to get model's accuracy metrics ===============");
var crossValidationResults = mlContext.MulticlassClassification.CrossValidate(trainingDataView, trainingPipeline, numberOfFolds: 5, labelColumnName: "Label");
PrintMulticlassClassificationFoldsAverageMetrics(crossValidationResults);
}
private static void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath, DataViewSchema modelInputSchema)
{
// Save/persist the trained model to a .ZIP file
Console.WriteLine($"=============== Saving the model ===============");
mlContext.Model.Save(mlModel, modelInputSchema, GetAbsolutePath(modelRelativePath));
Console.WriteLine("The model is saved to {0}", GetAbsolutePath(modelRelativePath));
}
public static string GetAbsolutePath(string relativePath)
{
FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location);
string assemblyFolderPath = _dataRoot.Directory.FullName;
string fullPath = Path.Combine(assemblyFolderPath, relativePath);
return fullPath;
}
public static void PrintMulticlassClassificationMetrics(MulticlassClassificationMetrics metrics)
{
Console.WriteLine($"************************************************************");
Console.WriteLine($"* Metrics for multi-class classification model ");
Console.WriteLine($"*-----------------------------------------------------------");
Console.WriteLine($" MacroAccuracy = {metrics.MacroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
Console.WriteLine($" MicroAccuracy = {metrics.MicroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
Console.WriteLine($" LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better");
for (int i = 0; i < metrics.PerClassLogLoss.Count; i++)
{
Console.WriteLine($" LogLoss for class {i + 1} = {metrics.PerClassLogLoss[i]:0.####}, the closer to 0, the better");
}
Console.WriteLine($"************************************************************");
}
public static void PrintMulticlassClassificationFoldsAverageMetrics(IEnumerable<TrainCatalogBase.CrossValidationResult<MulticlassClassificationMetrics>> crossValResults)
{
var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics);
var microAccuracyValues = metricsInMultipleFolds.Select(m => m.MicroAccuracy);
var microAccuracyAverage = microAccuracyValues.Average();
var microAccuraciesStdDeviation = CalculateStandardDeviation(microAccuracyValues);
var microAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(microAccuracyValues);
var macroAccuracyValues = metricsInMultipleFolds.Select(m => m.MacroAccuracy);
var macroAccuracyAverage = macroAccuracyValues.Average();
var macroAccuraciesStdDeviation = CalculateStandardDeviation(macroAccuracyValues);
var macroAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(macroAccuracyValues);
var logLossValues = metricsInMultipleFolds.Select(m => m.LogLoss);
var logLossAverage = logLossValues.Average();
var logLossStdDeviation = CalculateStandardDeviation(logLossValues);
var logLossConfidenceInterval95 = CalculateConfidenceInterval95(logLossValues);
var logLossReductionValues = metricsInMultipleFolds.Select(m => m.LogLossReduction);
var logLossReductionAverage = logLossReductionValues.Average();
var logLossReductionStdDeviation = CalculateStandardDeviation(logLossReductionValues);
var logLossReductionConfidenceInterval95 = CalculateConfidenceInterval95(logLossReductionValues);
Console.WriteLine($"*************************************************************************************************************");
Console.WriteLine($"* Metrics for Multi-class Classification model ");
Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
Console.WriteLine($"* Average MicroAccuracy: {microAccuracyAverage:0.###} - Standard deviation: ({microAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({microAccuraciesConfidenceInterval95:#.###})");
Console.WriteLine($"* Average MacroAccuracy: {macroAccuracyAverage:0.###} - Standard deviation: ({macroAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({macroAccuraciesConfidenceInterval95:#.###})");
Console.WriteLine($"* Average LogLoss: {logLossAverage:#.###} - Standard deviation: ({logLossStdDeviation:#.###}) - Confidence Interval 95%: ({logLossConfidenceInterval95:#.###})");
Console.WriteLine($"* Average LogLossReduction: {logLossReductionAverage:#.###} - Standard deviation: ({logLossReductionStdDeviation:#.###}) - Confidence Interval 95%: ({logLossReductionConfidenceInterval95:#.###})");
Console.WriteLine($"*************************************************************************************************************");
}
public static double CalculateStandardDeviation(IEnumerable<double> values)
{
double average = values.Average();
double sumOfSquaresOfDifferences = values.Select(val => (val - average) * (val - average)).Sum();
double standardDeviation = Math.Sqrt(sumOfSquaresOfDifferences / (values.Count() - 1));
return standardDeviation;
}
public static double CalculateConfidenceInterval95(IEnumerable<double> values)
{
double confidenceInterval95 = 1.96 * CalculateStandardDeviation(values) / Math.Sqrt((values.Count() - 1));
return confidenceInterval95;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment