diff --git a/src/Microsoft.ML.Api/ComponentCreation.cs b/src/Microsoft.ML.Api/ComponentCreation.cs index 73cfbf91a5..0d164b6124 100644 --- a/src/Microsoft.ML.Api/ComponentCreation.cs +++ b/src/Microsoft.ML.Api/ComponentCreation.cs @@ -52,7 +52,7 @@ public static RoleMappedData CreateExamples(this IHostEnvironment env, IDataView env.CheckValueOrNull(weight); env.CheckValueOrNull(custom); - return TrainUtils.CreateExamples(data, label, features, group, weight, name: null, custom: custom); + return new RoleMappedData(data, label, features, group, weight, name: null, custom: custom); } /// diff --git a/src/Microsoft.ML.Api/GenerateCodeCommand.cs b/src/Microsoft.ML.Api/GenerateCodeCommand.cs index 0b45bcc4bb..0bca5edfb3 100644 --- a/src/Microsoft.ML.Api/GenerateCodeCommand.cs +++ b/src/Microsoft.ML.Api/GenerateCodeCommand.cs @@ -108,8 +108,8 @@ public void Run() { var roles = ModelFileUtils.LoadRoleMappingsOrNull(_host, fs); scorer = roles != null - ? _host.CreateDefaultScorer(RoleMappedData.CreateOpt(transformPipe, roles), pred) - : _host.CreateDefaultScorer(_host.CreateExamples(transformPipe, "Features"), pred); + ? _host.CreateDefaultScorer(new RoleMappedData(transformPipe, roles, opt: true), pred) + : _host.CreateDefaultScorer(new RoleMappedData(transformPipe, label: null, "Features"), pred); } var nonScoreSb = new StringBuilder(); diff --git a/src/Microsoft.ML.Api/PredictionEngine.cs b/src/Microsoft.ML.Api/PredictionEngine.cs index eacf8d2218..14e2498c93 100644 --- a/src/Microsoft.ML.Api/PredictionEngine.cs +++ b/src/Microsoft.ML.Api/PredictionEngine.cs @@ -49,8 +49,8 @@ internal BatchPredictionEngine(IHostEnvironment env, Stream modelStream, bool ig { var roles = ModelFileUtils.LoadRoleMappingsOrNull(env, modelStream); pipe = roles != null - ? env.CreateDefaultScorer(RoleMappedData.CreateOpt(pipe, roles), predictor) - : env.CreateDefaultScorer(env.CreateExamples(pipe, "Features"), predictor); + ? env.CreateDefaultScorer(new RoleMappedData(pipe, roles, opt: true), predictor) + : env.CreateDefaultScorer(new RoleMappedData(pipe, label: null, "Features"), predictor); } _pipeEngine = new PipeEngine(env, pipe, ignoreMissingColumns, outputSchemaDefinition); diff --git a/src/Microsoft.ML.Core/Data/MetadataUtils.cs b/src/Microsoft.ML.Core/Data/MetadataUtils.cs index ca13d03ab3..b0a18f6d18 100644 --- a/src/Microsoft.ML.Core/Data/MetadataUtils.cs +++ b/src/Microsoft.ML.Core/Data/MetadataUtils.cs @@ -312,7 +312,6 @@ public static bool HasSlotNames(this ISchema schema, int col, int vectorSize) public static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.ColumnRole role, int vectorSize, ref VBuffer slotNames) { Contracts.CheckValueOrNull(schema); - Contracts.CheckValue(role.Value, nameof(role)); Contracts.CheckParam(vectorSize >= 0, nameof(vectorSize)); IReadOnlyList list; diff --git a/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs b/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs index 302d369489..2dab48fc58 100644 --- a/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs +++ b/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using System.Collections.Generic; using Microsoft.ML.Runtime.Internal.Utilities; @@ -34,9 +33,8 @@ private ColumnInfo(string name, int index, ColumnType type) /// public static ColumnInfo CreateFromName(ISchema schema, string name, string descName) { - ColumnInfo colInfo; - if (!TryCreateFromName(schema, name, out colInfo)) - throw Contracts.ExceptParam(nameof(name), "{0} column '{1}' not found", descName, name); + if (!TryCreateFromName(schema, name, out var colInfo)) + throw Contracts.ExceptParam(nameof(name), $"{descName} column '{name}' not found"); return colInfo; } @@ -51,8 +49,7 @@ public static bool TryCreateFromName(ISchema schema, string name, out ColumnInfo Contracts.CheckNonEmpty(name, nameof(name)); colInfo = null; - int index; - if (!schema.TryGetColumnIndex(name, out index)) + if (!schema.TryGetColumnIndex(name, out int index)) return false; colInfo = new ColumnInfo(name, index, schema.GetColumnType(index)); @@ -83,12 +80,25 @@ public static ColumnInfo CreateFromIndex(ISchema schema, int index) /// multiple features columns to consume that information. /// /// This class has convenience fields for several common column roles (se.g., , ), but can hold an arbitrary set of column infos. The convenience fields are non-null iff there is - /// a unique column with the corresponding role. When there are no such columns or more than one such column, the - /// field is null. The , , and methods provide - /// some cardinality information. Note that all columns assigned roles are guaranteed to be non-hidden in this - /// schema. + /// cref="Label"/>), but can hold an arbitrary set of column infos. The convenience fields are non-null if and only + /// if there is a unique column with the corresponding role. When there are no such columns or more than one such + /// column, the field is null. The , , and + /// methods provide some cardinality information. Note that all columns assigned roles are guaranteed to be non-hidden + /// in this schema. /// + /// + /// Note that instances of this class are, like instances of , immutable. + /// + /// It is often the case that one wishes to bundle the actual data with the role mappings, not just the schema. For + /// that case, please use the class. + /// + /// Note that there is no need for components consuming a or + /// to make use of every defined mapping. Consuming components are also expected to ignore any + /// they do not handle. They may very well however complain if a mapping they wanted to see is not present, or the column(s) + /// mapped from the role are not of the form they require. + /// + /// + /// public sealed class RoleMappedSchema { private const string FeatureString = "Feature"; @@ -98,17 +108,57 @@ public sealed class RoleMappedSchema private const string NameString = "Name"; private const string FeatureContributionsString = "FeatureContributions"; + /// + /// Instances of this are the keys of a . This class also holds some important + /// commonly used pre-defined instances available (e.g., , ) that should + /// be used when possible for consistency reasons. However, practitioners should not be afraid to declare custom + /// roles if approppriate for their task. + /// public struct ColumnRole { + /// + /// Role for features. Commonly used as the independent variables given to trainers, and scorers. + /// public static ColumnRole Feature => FeatureString; + + /// + /// Role for labels. Commonly used as the dependent variables given to trainers, and evaluators. + /// public static ColumnRole Label => LabelString; + + /// + /// Role for group ID. Commonly used in ranking applications, for defining query boundaries, or + /// sequence classification, for defining the boundaries of an utterance. + /// public static ColumnRole Group => GroupString; + + /// + /// Role for sample weights. Commonly used to point to a number to make trainers give more weight + /// to a particular example. + /// public static ColumnRole Weight => WeightString; + + /// + /// Role for sample names. Useful for informational and tracking purposes when scoring, but typically + /// without affecting results. + /// public static ColumnRole Name => NameString; + + // REVIEW: Does this really belong here? + /// + /// Role for feature contributions. Useful for specific diagnostic functionality. + /// public static ColumnRole FeatureContributions => FeatureContributionsString; + /// + /// The string value for the role. Guaranteed to be non-empty. + /// public readonly string Value; + /// + /// Constructor for the column role. + /// + /// The value for the role. Must be non-empty. public ColumnRole(string value) { Contracts.CheckNonEmpty(value, nameof(value)); @@ -116,50 +166,51 @@ public ColumnRole(string value) } public static implicit operator ColumnRole(string value) - { - return new ColumnRole(value); - } - + => new ColumnRole(value); + + /// + /// Convenience method for creating a mapping pair from a role to a column name + /// for giving to constructors of and . + /// + /// The column name to map to. Can be null, in which case when used + /// to construct a role mapping structure this pair will be ignored + /// A key-value pair with this instance as the key and as the value public KeyValuePair Bind(string name) - { - return new KeyValuePair(this, name); - } + => new KeyValuePair(this, name); } public static KeyValuePair CreatePair(ColumnRole role, string name) - { - return new KeyValuePair(role, name); - } + => new KeyValuePair(role, name); /// - /// The source ISchema. + /// The source . /// - public readonly ISchema Schema; + public ISchema Schema { get; } /// - /// The Feature column, when there is exactly one (null otherwise). + /// The column, when there is exactly one (null otherwise). /// - public readonly ColumnInfo Feature; + public ColumnInfo Feature { get; } /// - /// The Label column, when there is exactly one (null otherwise). + /// The column, when there is exactly one (null otherwise). /// - public readonly ColumnInfo Label; + public ColumnInfo Label { get; } /// - /// The Group column, when there is exactly one (null otherwise). + /// The column, when there is exactly one (null otherwise). /// - public readonly ColumnInfo Group; + public ColumnInfo Group { get; } /// - /// The Weight column, when there is exactly one (null otherwise). + /// The column, when there is exactly one (null otherwise). /// - public readonly ColumnInfo Weight; + public ColumnInfo Weight { get; } /// - /// The Name column, when there is exactly one (null otherwise). + /// The column, when there is exactly one (null otherwise). /// - public readonly ColumnInfo Name; + public ColumnInfo Name { get; } // Maps from role to the associated column infos. private readonly Dictionary> _map; @@ -183,21 +234,21 @@ private RoleMappedSchema(ISchema schema, Dictionary> map, ColumnRole rol Contracts.AssertNonEmpty(role.Value); Contracts.AssertValue(info); - List list; - if (!map.TryGetValue(role.Value, out list)) + if (!map.TryGetValue(role.Value, out var list)) { list = new List(); map.Add(role.Value, list); @@ -223,7 +273,7 @@ private static void Add(Dictionary> map, ColumnRole rol list.Add(info); } - private static Dictionary> MapFromNames(ISchema schema, IEnumerable> roles) + private static Dictionary> MapFromNames(ISchema schema, IEnumerable> roles, bool opt = false) { Contracts.AssertValue(schema); Contracts.AssertValue(roles); @@ -231,28 +281,13 @@ private static Dictionary> MapFromNames(ISchema schema, var map = new Dictionary>(); foreach (var kvp in roles) { - Contracts.CheckNonEmpty(kvp.Key.Value, nameof(roles), "Bad column role"); - if (string.IsNullOrEmpty(kvp.Value)) - continue; - var info = ColumnInfo.CreateFromName(schema, kvp.Value, kvp.Key.Value); - Add(map, kvp.Key.Value, info); - } - return map; - } - - private static Dictionary> MapFromNamesOpt(ISchema schema, IEnumerable> roles) - { - Contracts.AssertValue(schema); - Contracts.AssertValue(roles); - - var map = new Dictionary>(); - foreach (var kvp in roles) - { - Contracts.CheckNonEmpty(kvp.Key.Value, nameof(roles), "Bad column role"); + Contracts.AssertNonEmpty(kvp.Key.Value); if (string.IsNullOrEmpty(kvp.Value)) continue; ColumnInfo info; - if (!ColumnInfo.TryCreateFromName(schema, kvp.Value, out info)) + if (!opt) + info = ColumnInfo.CreateFromName(schema, kvp.Value, kvp.Key.Value); + else if (!ColumnInfo.TryCreateFromName(schema, kvp.Value, out info)) continue; Add(map, kvp.Key.Value, info); } @@ -263,39 +298,26 @@ private static Dictionary> MapFromNamesOpt(ISchema sche /// Returns whether there are any columns with the given column role. /// public bool Has(ColumnRole role) - { - return role.Value != null && _map.ContainsKey(role.Value); - } + => _map.ContainsKey(role.Value); /// /// Returns whether there is exactly one column of the given role. /// public bool HasUnique(ColumnRole role) - { - IReadOnlyList cols; - return role.Value != null && _map.TryGetValue(role.Value, out cols) && cols.Count == 1; - } + => _map.TryGetValue(role.Value, out var cols) && cols.Count == 1; /// /// Returns whether there are two or more columns of the given role. /// public bool HasMultiple(ColumnRole role) - { - IReadOnlyList cols; - return role.Value != null && _map.TryGetValue(role.Value, out cols) && cols.Count > 1; - } + => _map.TryGetValue(role.Value, out var cols) && cols.Count > 1; /// /// If there are columns of the given role, this returns the infos as a readonly list. Otherwise, /// it returns null. /// public IReadOnlyList GetColumns(ColumnRole role) - { - IReadOnlyList list; - if (role.Value != null && _map.TryGetValue(role.Value, out list)) - return list; - return null; - } + => _map.TryGetValue(role.Value, out var list) ? list : null; /// /// An enumerable over all role-column associations within this object. @@ -327,8 +349,7 @@ public IEnumerable> GetColumnRoleNames() /// public IEnumerable> GetColumnRoleNames(ColumnRole role) { - IReadOnlyList list; - if (role.Value != null && _map.TryGetValue(role.Value, out list)) + if (_map.TryGetValue(role.Value, out var list)) { foreach (var info in list) yield return new KeyValuePair(role, info.Name); @@ -363,45 +384,82 @@ private static Dictionary> Copy(Dictionary - /// Creates a RoleMappedSchema from the given schema with no column role assignments. + /// Constructor given a schema, and mapping pairs of roles to columns in the schema. + /// This skips null or empty column-names. It will also skip column-names that are not + /// found in the schema if is true. /// - public static RoleMappedSchema Create(ISchema schema) + /// The schema over which roles are defined + /// Whether to consider the column names specified "optional" or not. If false then any non-empty + /// values for the column names that does not appear in will result in an exception being thrown, + /// but if true such values will be ignored + /// The column role to column name mappings + public RoleMappedSchema(ISchema schema, bool opt = false, params KeyValuePair[] roles) + : this(Contracts.CheckRef(schema, nameof(schema)), Contracts.CheckRef(roles, nameof(roles)), opt) { - Contracts.CheckValue(schema, nameof(schema)); - return new RoleMappedSchema(schema, new Dictionary>()); } /// - /// Creates a RoleMappedSchema from the given schema and role/column-name pairs. - /// This skips null or empty column-names. + /// Constructor given a schema, and mapping pairs of roles to columns in the schema. + /// This skips null or empty column names. It will also skip column-names that are not + /// found in the schema if is true. /// - public static RoleMappedSchema Create(ISchema schema, params KeyValuePair[] roles) + /// The schema over which roles are defined + /// The column role to column name mappings + /// Whether to consider the column names specified "optional" or not. If false then any non-empty + /// values for the column names that does not appear in will result in an exception being thrown, + /// but if true such values will be ignored + public RoleMappedSchema(ISchema schema, IEnumerable> roles, bool opt = false) + : this(Contracts.CheckRef(schema, nameof(schema)), + MapFromNames(schema, Contracts.CheckRef(roles, nameof(roles)), opt)) { - Contracts.CheckValue(schema, nameof(schema)); - Contracts.CheckValue(roles, nameof(roles)); - return new RoleMappedSchema(schema, MapFromNames(schema, roles)); } - /// - /// Creates a RoleMappedSchema from the given schema and role/column-name pairs. - /// This skips null or empty column-names. - /// - public static RoleMappedSchema Create(ISchema schema, IEnumerable> roles) + private static IEnumerable> PredefinedRolesHelper( + string label, string feature, string group, string weight, string name, + IEnumerable> custom = null) { - Contracts.CheckValue(schema, nameof(schema)); - Contracts.CheckValue(roles, nameof(roles)); - return new RoleMappedSchema(schema, MapFromNames(schema, roles)); + if (!string.IsNullOrWhiteSpace(label)) + yield return ColumnRole.Label.Bind(label); + if (!string.IsNullOrWhiteSpace(feature)) + yield return ColumnRole.Feature.Bind(feature); + if (!string.IsNullOrWhiteSpace(group)) + yield return ColumnRole.Group.Bind(group); + if (!string.IsNullOrWhiteSpace(weight)) + yield return ColumnRole.Weight.Bind(weight); + if (!string.IsNullOrWhiteSpace(name)) + yield return ColumnRole.Name.Bind(name); + if (custom != null) + { + foreach (var role in custom) + yield return role; + } } /// - /// Creates a RoleMappedSchema from the given schema and role/column-name pairs. - /// This skips null or empty column-names, or column-names that are not found in the schema. + /// Convenience constructor for role-mappings over the commonly used roles. Note that if any column name specified + /// is null or whitespace, it is ignored. /// - public static RoleMappedSchema CreateOpt(ISchema schema, IEnumerable> roles) + /// The schema over which roles are defined + /// The column name that will be mapped to the role + /// The column name that will be mapped to the role + /// The column name that will be mapped to the role + /// The column name that will be mapped to the role + /// The column name that will be mapped to the role + /// Any additional desired custom column role mappings + /// Whether to consider the column names specified "optional" or not. If false then any non-empty + /// values for the column names that does not appear in will result in an exception being thrown, + /// but if true such values will be ignored + public RoleMappedSchema(ISchema schema, string label, string feature, + string group = null, string weight = null, string name = null, + IEnumerable> custom = null, bool opt = false) + : this(Contracts.CheckRef(schema, nameof(schema)), PredefinedRolesHelper(label, feature, group, weight, name, custom), opt) { - Contracts.CheckValue(schema, nameof(schema)); - Contracts.CheckValue(roles, nameof(roles)); - return new RoleMappedSchema(schema, MapFromNamesOpt(schema, roles)); + Contracts.CheckValueOrNull(label); + Contracts.CheckValueOrNull(feature); + Contracts.CheckValueOrNull(group); + Contracts.CheckValueOrNull(weight); + Contracts.CheckValueOrNull(name); + Contracts.CheckValueOrNull(custom); } } @@ -415,12 +473,13 @@ public sealed class RoleMappedData /// /// The data. /// - public readonly IDataView Data; + public IDataView Data { get; } /// - /// The role mapped schema. Note that Schema.Schema is guaranteed to be the same as Data.Schema. + /// The role mapped schema. Note that 's is + /// guaranteed to be the same as 's . /// - public readonly RoleMappedSchema Schema; + public RoleMappedSchema Schema { get; } private RoleMappedData(IDataView data, RoleMappedSchema schema) { @@ -432,45 +491,61 @@ private RoleMappedData(IDataView data, RoleMappedSchema schema) } /// - /// Creates a RoleMappedData from the given data with no column role assignments. - /// - public static RoleMappedData Create(IDataView data) - { - Contracts.CheckValue(data, nameof(data)); - return new RoleMappedData(data, RoleMappedSchema.Create(data.Schema)); - } - - /// - /// Creates a RoleMappedData from the given schema and role/column-name pairs. - /// This skips null or empty column-names. + /// Constructor given a data view, and mapping pairs of roles to columns in the data view's schema. + /// This skips null or empty column-names. It will also skip column-names that are not + /// found in the schema if is true. /// - public static RoleMappedData Create(IDataView data, params KeyValuePair[] roles) + /// The data over which roles are defined + /// Whether to consider the column names specified "optional" or not. If false then any non-empty + /// values for the column names that does not appear in 's schema will result in an exception being thrown, + /// but if true such values will be ignored + /// The column role to column name mappings + public RoleMappedData(IDataView data, bool opt = false, params KeyValuePair[] roles) + : this(Contracts.CheckRef(data, nameof(data)), new RoleMappedSchema(data.Schema, Contracts.CheckRef(roles, nameof(roles)), opt)) { - Contracts.CheckValue(data, nameof(data)); - Contracts.CheckValue(roles, nameof(roles)); - return new RoleMappedData(data, RoleMappedSchema.Create(data.Schema, roles)); } /// - /// Creates a RoleMappedData from the given schema and role/column-name pairs. - /// This skips null or empty column-names. + /// Constructor given a data view, and mapping pairs of roles to columns in the data view's schema. + /// This skips null or empty column-names. It will also skip column-names that are not + /// found in the schema if is true. /// - public static RoleMappedData Create(IDataView data, IEnumerable> roles) + /// The schema over which roles are defined + /// The column role to column name mappings + /// Whether to consider the column names specified "optional" or not. If false then any non-empty + /// values for the column names that does not appear in 's schema will result in an exception being thrown, + /// but if true such values will be ignored + public RoleMappedData(IDataView data, IEnumerable> roles, bool opt = false) + : this(Contracts.CheckRef(data, nameof(data)), new RoleMappedSchema(data.Schema, Contracts.CheckRef(roles, nameof(roles)), opt)) { - Contracts.CheckValue(data, nameof(data)); - Contracts.CheckValue(roles, nameof(roles)); - return new RoleMappedData(data, RoleMappedSchema.Create(data.Schema, roles)); } /// - /// Creates a RoleMappedData from the given schema and role/column-name pairs. - /// This skips null or empty column-names, or column-names that are not found in the schema. + /// Convenience constructor for role-mappings over the commonly used roles. Note that if any column name specified + /// is null or whitespace, it is ignored. /// - public static RoleMappedData CreateOpt(IDataView data, IEnumerable> roles) + /// The data over which roles are defined + /// The column name that will be mapped to the role + /// The column name that will be mapped to the role + /// The column name that will be mapped to the role + /// The column name that will be mapped to the role + /// The column name that will be mapped to the role + /// Any additional desired custom column role mappings + /// Whether to consider the column names specified "optional" or not. If false then any non-empty + /// values for the column names that does not appear in 's schema will result in an exception being thrown, + /// but if true such values will be ignored + public RoleMappedData(IDataView data, string label, string feature, + string group = null, string weight = null, string name = null, + IEnumerable> custom = null, bool opt = false) + : this(Contracts.CheckRef(data, nameof(data)), + new RoleMappedSchema(data.Schema, label, feature, group, weight, name, custom, opt)) { - Contracts.CheckValue(data, nameof(data)); - Contracts.CheckValue(roles, nameof(roles)); - return new RoleMappedData(data, RoleMappedSchema.CreateOpt(data.Schema, roles)); + Contracts.CheckValueOrNull(label); + Contracts.CheckValueOrNull(feature); + Contracts.CheckValueOrNull(group); + Contracts.CheckValueOrNull(weight); + Contracts.CheckValueOrNull(name); + Contracts.CheckValueOrNull(custom); } } } \ No newline at end of file diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs index fc78e72c53..23b1c601c9 100644 --- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs +++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs @@ -254,7 +254,7 @@ private RoleMappedData ApplyAllTransformsToData(IHostEnvironment env, IChannel c RoleMappedData srcData, IDataView marker) { var pipe = ApplyTransformUtils.ApplyAllTransformsToData(env, srcData.Data, dstData, marker); - return RoleMappedData.Create(pipe, srcData.Schema.GetColumnRoleNames()); + return new RoleMappedData(pipe, srcData.Schema.GetColumnRoleNames()); } /// @@ -277,7 +277,7 @@ private RoleMappedData CreateRoleMappedData(IHostEnvironment env, IChannel ch, I // Training pipe and examples. var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn); - return TrainUtils.CreateExamples(data, label, features, group, weight, name, customCols); + return new RoleMappedData(data, label, features, group, weight, name, customCols); } private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output) @@ -568,7 +568,7 @@ private FoldResult RunFold(int fold) { using (var file = host.CreateOutputFile(modelFileName)) { - var rmd = RoleMappedData.Create( + var rmd = new RoleMappedData( CompositeDataLoader.ApplyTransform(host, _loader, null, null, (e, newSource) => ApplyTransformUtils.ApplyAllTransformsToData(e, trainData.Data, newSource)), trainData.Schema.GetColumnRoleNames()); @@ -581,17 +581,17 @@ private FoldResult RunFold(int fold) if (!evalComp.IsGood()) evalComp = EvaluateUtils.GetEvaluatorType(ch, scorePipe.Schema); var eval = evalComp.CreateInstance(host); - // Note that this doesn't require the provided columns to exist (because of "Opt"). + // Note that this doesn't require the provided columns to exist (because of the "opt" parameter). // We don't normally expect the scorer to drop columns, but if it does, we should not require // all the columns in the test pipeline to still be present. - var dataEval = RoleMappedData.CreateOpt(scorePipe, testData.Schema.GetColumnRoleNames()); + var dataEval = new RoleMappedData(scorePipe, testData.Schema.GetColumnRoleNames(), opt: true); var dict = eval.Evaluate(dataEval); RoleMappedData perInstance = null; if (_savePerInstance) { var perInst = eval.GetPerInstanceMetrics(dataEval); - perInstance = RoleMappedData.CreateOpt(perInst, dataEval.Schema.GetColumnRoleNames()); + perInstance = new RoleMappedData(perInst, dataEval.Schema.GetColumnRoleNames(), opt: true); } ch.Done(); return new FoldResult(dict, dataEval.Schema.Schema, perInstance, trainData.Schema); diff --git a/src/Microsoft.ML.Data/Commands/DataCommand.cs b/src/Microsoft.ML.Data/Commands/DataCommand.cs index 435c25bf5b..2a62d78901 100644 --- a/src/Microsoft.ML.Data/Commands/DataCommand.cs +++ b/src/Microsoft.ML.Data/Commands/DataCommand.cs @@ -305,7 +305,7 @@ protected void LoadModelObjects( // can be loaded with no data at all, to get their schemas. if (trainPipe == null) trainPipe = ModelFileUtils.LoadLoader(Host, rep, new MultiFileSource(null), loadTransforms: true); - trainSchema = RoleMappedSchema.Create(trainPipe.Schema, trainRoleMappings); + trainSchema = new RoleMappedSchema(trainPipe.Schema, trainRoleMappings); } // If the role mappings are null, an alternative would be to fail. However the idea // is that the scorer should always still succeed, although perhaps with reduced diff --git a/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs b/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs index d0e066d789..cd2eb464af 100644 --- a/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs +++ b/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs @@ -158,7 +158,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV evalComp = EvaluateUtils.GetEvaluatorType(ch, input.Schema); var eval = evalComp.CreateInstance(env); - var data = TrainUtils.CreateExamples(input, label, null, group, weight, null, customCols); + var data = new RoleMappedData(input, label, null, group, weight, null, customCols); return eval.GetPerInstanceMetrics(data); } } @@ -236,7 +236,7 @@ private void RunCore(IChannel ch) if (!evalComp.IsGood()) evalComp = EvaluateUtils.GetEvaluatorType(ch, view.Schema); var evaluator = evalComp.CreateInstance(Host); - var data = TrainUtils.CreateExamples(view, label, null, group, weight, name, customCols); + var data = new RoleMappedData(view, label, null, group, weight, name, customCols); var metrics = evaluator.Evaluate(data); MetricWriter.PrintWarnings(ch, metrics); evaluator.PrintFoldResults(ch, metrics); @@ -248,7 +248,7 @@ private void RunCore(IChannel ch) if (!string.IsNullOrWhiteSpace(Args.OutputDataFile)) { var perInst = evaluator.GetPerInstanceMetrics(data); - var perInstData = TrainUtils.CreateExamples(perInst, label, null, group, weight, name, customCols); + var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols); var idv = evaluator.GetPerInstanceDataViewToSave(perInstData); MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, idv); } diff --git a/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs b/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs index 505d3a28e6..e1057d18b4 100644 --- a/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs +++ b/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs @@ -219,7 +219,7 @@ public static void LoadModel(IHostEnvironment env, Stream modelStream, bool load if (roles != null) { var emptyView = ModelFileUtils.LoadPipeline(env, rep, new MultiFileSource(null)); - schema = RoleMappedSchema.CreateOpt(emptyView.Schema, roles); + schema = new RoleMappedSchema(emptyView.Schema, roles, opt: true); } else { diff --git a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs index 02d655b48b..c353e4a4ec 100644 --- a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs +++ b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs @@ -97,10 +97,7 @@ private void RunCore(IChannel ch) ch.Trace("Creating loader"); - IPredictor predictor; - IDataLoader loader; - RoleMappedSchema trainSchema; - LoadModelObjects(ch, true, out predictor, true, out trainSchema, out loader); + LoadModelObjects(ch, true, out var predictor, true, out var trainSchema, out var loader); ch.AssertValue(predictor); ch.AssertValueOrNull(trainSchema); ch.AssertValue(loader); @@ -116,7 +113,7 @@ private void RunCore(IChannel ch) string group = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, nameof(Args.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId); var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn); - var schema = TrainUtils.CreateRoleMappedSchemaOpt(loader.Schema, feat, group, customCols); + var schema = new RoleMappedSchema(loader.Schema, label: null, feature: feat, group: group, custom: customCols, opt: true); var mapper = bindable.Bind(Host, schema); if (!scorer.IsGood()) @@ -153,22 +150,20 @@ private void RunCore(IChannel ch) Args.OutputAllColumns == true || Utils.Size(Args.OutputColumn) == 0; if (Args.OutputAllColumns == true && Utils.Size(Args.OutputColumn) != 0) - ch.Warning("outputAllColumns=+ always writes all columns irrespective of outputColumn specified."); + ch.Warning(nameof(Args.OutputAllColumns) + "=+ always writes all columns irrespective of " + nameof(Args.OutputColumn) + " specified."); if (!outputAllColumns && Utils.Size(Args.OutputColumn) != 0) { foreach (var outCol in Args.OutputColumn) { - int dummyColIndex; - if (!loader.Schema.TryGetColumnIndex(outCol, out dummyColIndex)) + if (!loader.Schema.TryGetColumnIndex(outCol, out int dummyColIndex)) throw ch.ExceptUserArg(nameof(Arguments.OutputColumn), "Column '{0}' not found.", outCol); } } - int colMax; uint maxScoreId = 0; if (!outputAllColumns) - maxScoreId = loader.Schema.GetMaxMetadataKind(out colMax, MetadataUtils.Kinds.ScoreColumnSetId); + maxScoreId = loader.Schema.GetMaxMetadataKind(out int colMax, MetadataUtils.Kinds.ScoreColumnSetId); ch.Assert(outputAllColumns || maxScoreId > 0); // score set IDs are one-based var cols = new List(); for (int i = 0; i < loader.Schema.ColumnCount; i++) @@ -211,12 +206,12 @@ private bool ShouldAddColumn(ISchema schema, int i, uint scoreSet, bool outputNa { switch (schema.GetColumnName(i)) { - case "Label": - case "Name": - case "Names": - return true; - default: - break; + case "Label": + case "Name": + case "Names": + return true; + default: + break; } } if (Args.OutputColumn != null && Array.FindIndex(Args.OutputColumn, schema.GetColumnName(i).Equals) >= 0) @@ -229,8 +224,7 @@ public static class ScoreUtils { public static IDataScorerTransform GetScorer(IPredictor predictor, RoleMappedData data, IHostEnvironment env, RoleMappedSchema trainSchema) { - ISchemaBoundMapper mapper; - var sc = GetScorerComponentAndMapper(predictor, null, data.Schema, env, out mapper); + var sc = GetScorerComponentAndMapper(predictor, null, data.Schema, env, out var mapper); return sc.CreateInstance(env, data.Data, mapper, trainSchema); } @@ -247,9 +241,8 @@ public static IDataScorerTransform GetScorer(SubComponent GetScorerC Contracts.AssertValue(mapper); string loadName = null; - DvText scoreKind = default(DvText); + DvText scoreKind = default; if (mapper.OutputSchema.ColumnCount > 0 && mapper.OutputSchema.TryGetMetadata(TextType.Instance, MetadataUtils.Kinds.ScoreColumnKind, 0, ref scoreKind) && scoreKind.HasChars) @@ -311,10 +304,8 @@ public static ISchemaBindableMapper GetSchemaBindableMapper(IHostEnvironment env env.CheckValue(predictor, nameof(predictor)); env.CheckValueOrNull(scorerSettings); - ISchemaBindableMapper bindable; - // See if we can instantiate a mapper using scorer arguments. - if (scorerSettings.IsGood() && TryCreateBindableFromScorer(env, predictor, scorerSettings, out bindable)) + if (scorerSettings.IsGood() && TryCreateBindableFromScorer(env, predictor, scorerSettings, out var bindable)) return bindable; // The easy case is that the predictor implements the interface. diff --git a/src/Microsoft.ML.Data/Commands/TestCommand.cs b/src/Microsoft.ML.Data/Commands/TestCommand.cs index 79e7bd5458..d0ebbd5a05 100644 --- a/src/Microsoft.ML.Data/Commands/TestCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TestCommand.cs @@ -114,7 +114,7 @@ private void RunCore(IChannel ch) if (!evalComp.IsGood()) evalComp = EvaluateUtils.GetEvaluatorType(ch, scorePipe.Schema); var evaluator = evalComp.CreateInstance(Host); - var data = TrainUtils.CreateExamples(scorePipe, label, null, group, weight, name, customCols); + var data = new RoleMappedData(scorePipe, label, null, group, weight, name, customCols); var metrics = evaluator.Evaluate(data); MetricWriter.PrintWarnings(ch, metrics); evaluator.PrintFoldResults(ch, metrics); @@ -128,7 +128,7 @@ private void RunCore(IChannel ch) if (!string.IsNullOrWhiteSpace(Args.OutputDataFile)) { var perInst = evaluator.GetPerInstanceMetrics(data); - var perInstData = TrainUtils.CreateExamples(perInst, label, null, group, weight, name, customCols); + var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols); var idv = evaluator.GetPerInstanceDataViewToSave(perInstData); MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, idv); } diff --git a/src/Microsoft.ML.Data/Commands/TrainCommand.cs b/src/Microsoft.ML.Data/Commands/TrainCommand.cs index b5b23f9020..b5e3964157 100644 --- a/src/Microsoft.ML.Data/Commands/TrainCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TrainCommand.cs @@ -157,7 +157,7 @@ private void RunCore(IChannel ch, string cmd) ch.Trace("Binding columns"); var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn); - var data = TrainUtils.CreateExamples(view, label, feature, group, weight, name, customCols); + var data = new RoleMappedData(view, label, feature, group, weight, name, customCols); // REVIEW: Unify the code that creates validation examples in Train, TrainTest and CV commands. RoleMappedData validData = null; @@ -172,7 +172,7 @@ private void RunCore(IChannel ch, string cmd) ch.Trace("Constructing the validation pipeline"); IDataView validPipe = CreateRawLoader(dataFile: Args.ValidationFile); validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, view, validPipe); - validData = RoleMappedData.Create(validPipe, data.Schema.GetColumnRoleNames()); + validData = new RoleMappedData(validPipe, data.Schema.GetColumnRoleNames()); } } @@ -550,7 +550,7 @@ private static bool AddCacheIfWanted(IHostEnvironment env, IChannel ch, ITrainer var prefetch = data.Schema.GetColumnRoles().Select(kc => kc.Value.Index).ToArray(); var cacheView = new CacheDataView(env, data.Data, prefetch); // Because the prefetching worked, we know that these are valid columns. - data = RoleMappedData.Create(cacheView, data.Schema.GetColumnRoleNames()); + data = new RoleMappedData(cacheView, data.Schema.GetColumnRoleNames()); } else ch.Trace("Not caching"); @@ -571,97 +571,5 @@ public static IEnumerable> CheckAndGenerateCust } return customColumnArg.Select(kindName => new ColumnRole(kindName.Key).Bind(kindName.Value)); } - - /// - /// Given a schema and a bunch of column names, create the BoundSchema object. Any or all of the column - /// names may be null or whitespace, in which case they are ignored. Any columns that are specified but not - /// valid columns of the schema are also ignored. - /// - public static RoleMappedSchema CreateRoleMappedSchemaOpt(ISchema schema, string feature, string group, IEnumerable> custom = null) - { - Contracts.CheckValueOrNull(feature); - Contracts.CheckValueOrNull(custom); - - var list = new List>(); - if (!string.IsNullOrWhiteSpace(feature)) - list.Add(ColumnRole.Feature.Bind(feature)); - if (!string.IsNullOrWhiteSpace(group)) - list.Add(ColumnRole.Group.Bind(group)); - if (custom != null) - list.AddRange(custom); - - return RoleMappedSchema.CreateOpt(schema, list); - } - - /// - /// Given a view and a bunch of column names, create the RoleMappedData object. Any or all of the column - /// names may be null or whitespace, in which case they are ignored. Any columns that are specified must - /// be valid columns of the schema. - /// - public static RoleMappedData CreateExamples(IDataView view, string label, string feature, - string group = null, string weight = null, string name = null, - IEnumerable> custom = null) - { - Contracts.CheckValueOrNull(label); - Contracts.CheckValueOrNull(feature); - Contracts.CheckValueOrNull(group); - Contracts.CheckValueOrNull(weight); - Contracts.CheckValueOrNull(name); - Contracts.CheckValueOrNull(custom); - - var list = new List>(); - if (!string.IsNullOrWhiteSpace(label)) - list.Add(ColumnRole.Label.Bind(label)); - if (!string.IsNullOrWhiteSpace(feature)) - list.Add(ColumnRole.Feature.Bind(feature)); - if (!string.IsNullOrWhiteSpace(group)) - list.Add(ColumnRole.Group.Bind(group)); - if (!string.IsNullOrWhiteSpace(weight)) - list.Add(ColumnRole.Weight.Bind(weight)); - if (!string.IsNullOrWhiteSpace(name)) - list.Add(ColumnRole.Name.Bind(name)); - if (custom != null) - list.AddRange(custom); - - return RoleMappedData.Create(view, list); - } - - /// - /// Given a view and a bunch of column names, create the RoleMappedData object. Any or all of the column - /// names may be null or whitespace, in which case they are ignored. Any columns that are specified but not - /// valid columns of the schema are also ignored. - /// - public static RoleMappedData CreateExamplesOpt(IDataView view, string label, string feature, - string group = null, string weight = null, string name = null, - IEnumerable> custom = null) - { - Contracts.CheckValueOrNull(label); - Contracts.CheckValueOrNull(feature); - Contracts.CheckValueOrNull(group); - Contracts.CheckValueOrNull(weight); - Contracts.CheckValueOrNull(name); - Contracts.CheckValueOrNull(custom); - - var list = new List>(); - if (!string.IsNullOrWhiteSpace(label)) - list.Add(ColumnRole.Label.Bind(label)); - if (!string.IsNullOrWhiteSpace(feature)) - list.Add(ColumnRole.Feature.Bind(feature)); - if (!string.IsNullOrWhiteSpace(group)) - list.Add(ColumnRole.Group.Bind(group)); - if (!string.IsNullOrWhiteSpace(weight)) - list.Add(ColumnRole.Weight.Bind(weight)); - if (!string.IsNullOrWhiteSpace(name)) - list.Add(ColumnRole.Name.Bind(name)); - if (custom != null) - list.AddRange(custom); - - return RoleMappedData.CreateOpt(view, list); - } - - private static KeyValuePair Pair(ColumnRole kind, T value) - { - return new KeyValuePair(kind, value); - } } } diff --git a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs index f6ffa772f9..7c4249c6ee 100644 --- a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs @@ -147,7 +147,7 @@ private void RunCore(IChannel ch, string cmd) ch.Trace("Binding columns"); var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn); - var data = TrainUtils.CreateExamples(trainPipe, label, features, group, weight, name, customCols); + var data = new RoleMappedData(trainPipe, label, features, group, weight, name, customCols); RoleMappedData validData = null; if (!string.IsNullOrWhiteSpace(Args.ValidationFile)) @@ -161,7 +161,7 @@ private void RunCore(IChannel ch, string cmd) ch.Trace("Constructing the validation pipeline"); IDataView validPipe = CreateRawLoader(dataFile: Args.ValidationFile); validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, validPipe); - validData = RoleMappedData.Create(validPipe, data.Schema.GetColumnRoleNames()); + validData = new RoleMappedData(validPipe, data.Schema.GetColumnRoleNames()); } } @@ -189,8 +189,8 @@ private void RunCore(IChannel ch, string cmd) if (!evalComp.IsGood()) evalComp = EvaluateUtils.GetEvaluatorType(ch, scorePipe.Schema); var evaluator = evalComp.CreateInstance(Host); - var dataEval = TrainUtils.CreateExamplesOpt(scorePipe, label, features, - group, weight, name, customCols); + var dataEval = new RoleMappedData(scorePipe, label, features, + group, weight, name, customCols, opt: true); var metrics = evaluator.Evaluate(dataEval); MetricWriter.PrintWarnings(ch, metrics); evaluator.PrintFoldResults(ch, metrics); @@ -204,7 +204,7 @@ private void RunCore(IChannel ch, string cmd) if (!string.IsNullOrWhiteSpace(Args.OutputDataFile)) { var perInst = evaluator.GetPerInstanceMetrics(dataEval); - var perInstData = TrainUtils.CreateExamples(perInst, label, null, group, weight, name, customCols); + var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols); var idv = evaluator.GetPerInstanceDataViewToSave(perInstData); MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, idv); } diff --git a/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs b/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs index 3dd16f141b..f08d52fe85 100644 --- a/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs +++ b/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs @@ -364,9 +364,7 @@ private sealed class Dense : FeatureNameCollection private readonly int _count; private readonly string[] _names; - private readonly RoleMappedSchema _schema; - - public override RoleMappedSchema Schema => _schema; + public override RoleMappedSchema Schema { get; } public Dense(int count, string[] names) { @@ -379,8 +377,9 @@ public Dense(int count, string[] names) if (size > 0) Array.Copy(names, _names, size); - _schema = RoleMappedSchema.Create(new FeatureNameCollectionSchema(this), - RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Feature, RoleMappedSchema.ColumnRole.Feature.Value)); + // REVIEW: This seems wrong. The default feature column name is "Features" yet the role is named "Feature". + Schema = new RoleMappedSchema(new FeatureNameCollectionSchema(this), + roles: RoleMappedSchema.ColumnRole.Feature.Bind(RoleMappedSchema.ColumnRole.Feature.Value)); } public override int Count => _count; @@ -470,8 +469,9 @@ public Sparse(int count, string[] names, int cnn) } Contracts.Assert(cv == cnn); - _schema = RoleMappedSchema.Create(new FeatureNameCollectionSchema(this), - RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Feature, RoleMappedSchema.ColumnRole.Feature.Value)); + // REVIEW: This seems wrong. The default feature column name is "Features" yet the role is named "Feature". + _schema = new RoleMappedSchema(new FeatureNameCollectionSchema(this), + roles: RoleMappedSchema.ColumnRole.Feature.Bind(RoleMappedSchema.ColumnRole.Feature.Value)); } /// diff --git a/src/Microsoft.ML.Data/EntryPoints/InputBase.cs b/src/Microsoft.ML.Data/EntryPoints/InputBase.cs index 5583c66df0..bc45a929d4 100644 --- a/src/Microsoft.ML.Data/EntryPoints/InputBase.cs +++ b/src/Microsoft.ML.Data/EntryPoints/InputBase.cs @@ -146,7 +146,7 @@ public static TOut Train(IHost host, TArg input, TrainUtils.AddNormalizerIfNeeded(host, ch, trainer, ref view, feature, input.NormalizeFeatures); ch.Trace("Binding columns"); - var roleMappedData = TrainUtils.CreateExamples(view, label, feature, group, weight, name, custom); + var roleMappedData = new RoleMappedData(view, label, feature, group, weight, name, custom); RoleMappedData cachedRoleMappedData = roleMappedData; Cache.CachingType? cachingType = null; @@ -184,7 +184,7 @@ public static TOut Train(IHost host, TArg input, Data = roleMappedData.Data, Caching = cachingType.Value }).OutputData; - cachedRoleMappedData = RoleMappedData.Create(cacheView, roleMappedData.Schema.GetColumnRoleNames()); + cachedRoleMappedData = new RoleMappedData(cacheView, roleMappedData.Schema.GetColumnRoleNames()); } var predictor = TrainUtils.Train(host, ch, cachedRoleMappedData, trainer, "Train", calibrator, maxCalibrationExamples); diff --git a/src/Microsoft.ML.Data/EntryPoints/PredictorModel.cs b/src/Microsoft.ML.Data/EntryPoints/PredictorModel.cs index af726fa758..4b474f847c 100644 --- a/src/Microsoft.ML.Data/EntryPoints/PredictorModel.cs +++ b/src/Microsoft.ML.Data/EntryPoints/PredictorModel.cs @@ -80,7 +80,7 @@ public void Save(IHostEnvironment env, Stream stream) // Create the chain of transforms for saving. IDataView data = new EmptyDataView(env, _transformModel.InputSchema); data = _transformModel.Apply(env, data); - var roleMappedData = RoleMappedData.CreateOpt(data, _roleMappings); + var roleMappedData = new RoleMappedData(data, _roleMappings, opt: true); TrainUtils.SaveModel(env, ch, stream, _predictor, roleMappedData); ch.Done(); @@ -102,7 +102,7 @@ public void PrepareData(IHostEnvironment env, IDataView input, out RoleMappedDat env.CheckValue(input, nameof(input)); input = _transformModel.Apply(env, input); - roleMappedData = RoleMappedData.CreateOpt(input, _roleMappings); + roleMappedData = new RoleMappedData(input, _roleMappings, opt: true); predictor = _predictor; } @@ -141,7 +141,7 @@ public RoleMappedSchema GetTrainingSchema(IHostEnvironment env) { Contracts.CheckValue(env, nameof(env)); var predInput = _transformModel.Apply(env, new EmptyDataView(env, _transformModel.InputSchema)); - var trainRms = RoleMappedSchema.CreateOpt(predInput.Schema, _roleMappings); + var trainRms = new RoleMappedSchema(predInput.Schema, _roleMappings, opt: true); return trainRms; } } diff --git a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs index 39a5f31c38..74f7ca0068 100644 --- a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs @@ -796,7 +796,7 @@ public static CommonOutputs.CommonEvaluateOutput AnomalyDetection(IHostEnvironme string name; MatchColumns(host, input, out label, out weight, out name); var evaluator = new AnomalyDetectionMamlEvaluator(host, input); - var data = TrainUtils.CreateExamples(input.Data, label, null, null, weight, name); + var data = new RoleMappedData(input.Data, label, null, null, weight, name); var metrics = evaluator.Evaluate(data); var warnings = ExtractWarnings(host, metrics); diff --git a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs index 90078da9ee..7161f66439 100644 --- a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs @@ -1455,7 +1455,7 @@ public static CommonOutputs.ClassificationEvaluateOutput Binary(IHostEnvironment string name; MatchColumns(host, input, out label, out weight, out name); var evaluator = new BinaryClassifierMamlEvaluator(host, input); - var data = TrainUtils.CreateExamples(input.Data, label, null, null, weight, name); + var data = new RoleMappedData(input.Data, label, null, null, weight, name); var metrics = evaluator.Evaluate(data); var warnings = ExtractWarnings(host, metrics); diff --git a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs index 907760649f..bec1ac144a 100644 --- a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs @@ -776,7 +776,7 @@ public ClusteringMamlEvaluator(IHostEnvironment env, Arguments args) string feat = EvaluateUtils.GetColName(_featureCol, schema.Feature, DefaultColumnNames.Features); if (!schema.Schema.TryGetColumnIndex(feat, out int featCol)) throw Host.ExceptUserArg(nameof(Arguments.FeatureColumn), "Features column '{0}' not found", feat); - yield return RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Feature, feat); + yield return RoleMappedSchema.ColumnRole.Feature.Bind(feat); } } @@ -867,7 +867,7 @@ public static CommonOutputs.CommonEvaluateOutput Clustering(IHostEnvironment env nameof(ClusteringMamlEvaluator.Arguments.FeatureColumn), input.FeatureColumn, DefaultColumnNames.Features); var evaluator = new ClusteringMamlEvaluator(host, input); - var data = TrainUtils.CreateExamples(input.Data, label, features, null, weight, name); + var data = new RoleMappedData(input.Data, label, features, null, weight, name); var metrics = evaluator.Evaluate(data); var warnings = ExtractWarnings(host, metrics); diff --git a/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs index bbb53ba631..1a5a2177f3 100644 --- a/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs @@ -95,7 +95,7 @@ protected MamlEvaluatorBase(ArgumentsBase args, IHostEnvironment env, string sco public Dictionary Evaluate(RoleMappedData data) { - data = RoleMappedData.Create(data.Data, GetInputColumnRoles(data.Schema, needStrat: true)); + data = new RoleMappedData(data.Data, GetInputColumnRoles(data.Schema, needStrat: true)); return Evaluator.Evaluate(data); } @@ -108,7 +108,7 @@ public Dictionary Evaluate(RoleMappedData data) : StratCols.Select(col => RoleMappedSchema.CreatePair(Strat, col)); if (needName && schema.Name != null) - roles = roles.Prepend(RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Name, schema.Name.Name)); + roles = roles.Prepend(RoleMappedSchema.ColumnRole.Name.Bind(schema.Name.Name)); return roles.Concat(GetInputColumnRolesCore(schema)); } @@ -126,12 +126,12 @@ public Dictionary Evaluate(RoleMappedData data) yield return RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreInfo.Name); // Get the label column information. - string lab = EvaluateUtils.GetColName(LabelCol, schema.Label, DefaultColumnNames.Label); - yield return RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, lab); + string label = EvaluateUtils.GetColName(LabelCol, schema.Label, DefaultColumnNames.Label); + yield return RoleMappedSchema.ColumnRole.Label.Bind(label); - var weight = EvaluateUtils.GetColName(WeightCol, schema.Weight, null); + string weight = EvaluateUtils.GetColName(WeightCol, schema.Weight, null); if (!string.IsNullOrEmpty(weight)) - yield return RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Weight, weight); + yield return RoleMappedSchema.ColumnRole.Weight.Bind(weight); } public virtual IEnumerable GetOverallMetricColumns() @@ -203,7 +203,7 @@ public IDataTransform GetPerInstanceMetrics(RoleMappedData scoredData) Host.AssertValue(scoredData); var schema = scoredData.Schema; - var dataEval = RoleMappedData.Create(scoredData.Data, GetInputColumnRoles(schema)); + var dataEval = new RoleMappedData(scoredData.Data, GetInputColumnRoles(schema)); return Evaluator.GetPerInstanceMetrics(dataEval); } @@ -260,7 +260,7 @@ protected virtual IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMap public IDataView GetPerInstanceDataViewToSave(RoleMappedData perInstance) { Host.CheckValue(perInstance, nameof(perInstance)); - var data = RoleMappedData.Create(perInstance.Data, GetInputColumnRoles(perInstance.Schema, needName: true)); + var data = new RoleMappedData(perInstance.Data, GetInputColumnRoles(perInstance.Schema, needName: true)); return WrapPerInstance(data); } diff --git a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs index 3b5e5fc910..d57835b168 100644 --- a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs @@ -784,7 +784,7 @@ public static CommonOutputs.CommonEvaluateOutput MultiOutputRegression(IHostEnvi string name; MatchColumns(host, input, out label, out weight, out name); var evaluator = new MultiOutputRegressionMamlEvaluator(host, input); - var data = TrainUtils.CreateExamples(input.Data, label, null, null, weight, name); + var data = new RoleMappedData(input.Data, label, null, null, weight, name); var metrics = evaluator.Evaluate(data); var warnings = ExtractWarnings(host, metrics); diff --git a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs index 94c920c3c6..fd23e7c3b0 100644 --- a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs @@ -1070,7 +1070,7 @@ public static CommonOutputs.ClassificationEvaluateOutput MultiClass(IHostEnviron MatchColumns(host, input, out string label, out string weight, out string name); var evaluator = new MultiClassMamlEvaluator(host, input); - var data = TrainUtils.CreateExamples(input.Data, label, null, null, weight, name); + var data = new RoleMappedData(input.Data, label, null, null, weight, name); var metrics = evaluator.Evaluate(data); var warnings = ExtractWarnings(host, metrics); diff --git a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs index 6d61f6b965..fb8d9c1249 100644 --- a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs @@ -556,7 +556,7 @@ public static CommonOutputs.CommonEvaluateOutput QuantileRegression(IHostEnviron string name; MatchColumns(host, input, out label, out weight, out name); var evaluator = new QuantileRegressionMamlEvaluator(host, input); - var data = TrainUtils.CreateExamples(input.Data, label, null, null, weight, name); + var data = new RoleMappedData(input.Data, label, null, null, weight, name); var metrics = evaluator.Evaluate(data); var warnings = ExtractWarnings(host, metrics); diff --git a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs index ae9c2a8594..cdf3f9c57e 100644 --- a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs @@ -851,7 +851,7 @@ public RankerMamlEvaluator(IHostEnvironment env, Arguments args) { var cols = base.GetInputColumnRolesCore(schema); var groupIdCol = EvaluateUtils.GetColName(_groupIdCol, schema.Group, DefaultColumnNames.GroupId); - return cols.Prepend(RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Group, groupIdCol)); + return cols.Prepend(RoleMappedSchema.ColumnRole.Group.Bind(groupIdCol)); } protected override void PrintAdditionalMetricsCore(IChannel ch, Dictionary[] metrics) @@ -1039,7 +1039,7 @@ public static CommonOutputs.CommonEvaluateOutput Ranking(IHostEnvironment env, R nameof(RankerMamlEvaluator.Arguments.GroupIdColumn), input.GroupIdColumn, DefaultColumnNames.GroupId); var evaluator = new RankerMamlEvaluator(host, input); - var data = TrainUtils.CreateExamples(input.Data, label, null, groupId, weight, name); + var data = new RoleMappedData(input.Data, label, null, groupId, weight, name); var metrics = evaluator.Evaluate(data); var warnings = ExtractWarnings(host, metrics); diff --git a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs index 4292e13b8d..1804ce429f 100644 --- a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs @@ -354,7 +354,7 @@ public static CommonOutputs.CommonEvaluateOutput Regression(IHostEnvironment env string name; MatchColumns(host, input, out label, out weight, out name); var evaluator = new RegressionMamlEvaluator(host, input); - var data = TrainUtils.CreateExamples(input.Data, label, null, null, weight, name); + var data = new RoleMappedData(input.Data, label, null, null, weight, name); var metrics = evaluator.Evaluate(data); var warnings = ExtractWarnings(host, metrics); diff --git a/src/Microsoft.ML.Data/Model/Pfa/SavePfaCommand.cs b/src/Microsoft.ML.Data/Model/Pfa/SavePfaCommand.cs index 6c28dac997..dfec0913ca 100644 --- a/src/Microsoft.ML.Data/Model/Pfa/SavePfaCommand.cs +++ b/src/Microsoft.ML.Data/Model/Pfa/SavePfaCommand.cs @@ -147,13 +147,13 @@ private void Run(IChannel ch) { RoleMappedData data; if (trainSchema != null) - data = RoleMappedData.Create(end, trainSchema.GetColumnRoleNames()); + data = new RoleMappedData(end, trainSchema.GetColumnRoleNames()); else { // We had a predictor, but no roles stored in the model. Just suppose // default column names are OK, if present. - data = TrainUtils.CreateExamplesOpt(end, DefaultColumnNames.Label, - DefaultColumnNames.Features, DefaultColumnNames.GroupId, DefaultColumnNames.Weight, DefaultColumnNames.Name); + data = new RoleMappedData(end, DefaultColumnNames.Label, + DefaultColumnNames.Features, DefaultColumnNames.GroupId, DefaultColumnNames.Weight, DefaultColumnNames.Name, opt: true); } var scorePipe = ScoreUtils.GetScorer(rawPred, data, Host, trainSchema); diff --git a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs index 91c84a0734..a407873bac 100644 --- a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs @@ -70,7 +70,7 @@ private static Bindings Create(IHostEnvironment env, ISchemaBindableMapper binda Contracts.AssertValue(roles); Contracts.AssertValueOrNull(suffix); - var mapper = bindable.Bind(env, RoleMappedSchema.Create(input, roles)); + var mapper = bindable.Bind(env, new RoleMappedSchema(input, roles)); // We don't actually depend on this invariant, but if this assert fires it means the bindable // did the wrong thing. Contracts.Assert(mapper.InputSchema.Schema == input); diff --git a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs index fe69585b78..2fd039897a 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs @@ -117,7 +117,7 @@ public BindingsImpl ApplyToSchema(ISchema input, ISchemaBindableMapper bindable, env.AssertValue(bindable); string scoreCol = RowMapper.OutputSchema.GetColumnName(ScoreColumnIndex); - var schema = RoleMappedSchema.Create(input, RowMapper.GetInputColumnRoles()); + var schema = new RoleMappedSchema(input, RowMapper.GetInputColumnRoles()); // Checks compatibility of the predictor input types. var mapper = bindable.Bind(env, schema); @@ -148,7 +148,7 @@ public static BindingsImpl Create(ModelLoadContext ctx, ISchema input, string scoreKind = ctx.LoadNonEmptyString(); string scoreCol = ctx.LoadNonEmptyString(); - var mapper = bindable.Bind(env, RoleMappedSchema.Create(input, roles)); + var mapper = bindable.Bind(env, new RoleMappedSchema(input, roles)); var rowMapper = mapper as ISchemaBoundRowMapper; env.CheckParam(rowMapper != null, nameof(bindable), "Bindable expected to be an " + nameof(ISchemaBindableMapper) + "!"); diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs b/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs index 22fa9686ed..f3c560c8dc 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs @@ -225,7 +225,7 @@ public static bool CreateIfNeeded(IHostEnvironment env, ref RoleMappedData data, env.AssertValue(featInfo); // Should be defined, if FEaturesAreNormalized returned a definite value. var view = CreateMinMaxNormalizer(env, data.Data, name: featInfo.Name); - data = RoleMappedData.Create(view, data.Schema.GetColumnRoleNames()); + data = new RoleMappedData(view, data.Schema.GetColumnRoleNames()); return true; } diff --git a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs index bdf1a36d41..bba2f58256 100644 --- a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs @@ -181,7 +181,7 @@ private static RoleMappedData CreateDataFromArgs(IExceptionContext var name = TrainUtils.MatchNameOrDefaultOrNull(ectx, schema, nameof(args.NameColumn), args.NameColumn, DefaultColumnNames.Name); var customCols = TrainUtils.CheckAndGenerateCustomColumns(ectx, args.CustomColumn); - return TrainUtils.CreateExamples(input, label, feat, group, weight, name, customCols); + return new RoleMappedData(input, label, feat, group, weight, name, customCols); } } } diff --git a/src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs b/src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs index 3e39a0008a..900778cd31 100644 --- a/src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs +++ b/src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs @@ -338,7 +338,7 @@ public static RoleMappedSchema LoadRoleMappedSchemaOrNull(IHostEnvironment env, if (roleMappings == null) return null; var pipe = ModelFileUtils.LoadLoader(h, rep, new MultiFileSource(null), loadTransforms: true); - return RoleMappedSchema.Create(pipe.Schema, roleMappings); + return new RoleMappedSchema(pipe.Schema, roleMappings); } /// diff --git a/src/Microsoft.ML.Ensemble/EnsembleUtils.cs b/src/Microsoft.ML.Ensemble/EnsembleUtils.cs index d5134e48c9..6366321c48 100644 --- a/src/Microsoft.ML.Ensemble/EnsembleUtils.cs +++ b/src/Microsoft.ML.Ensemble/EnsembleUtils.cs @@ -33,7 +33,7 @@ public static RoleMappedData SelectFeatures(IHost host, RoleMappedData data, Bit host, "FeatureSelector", data.Data, name, name, type, type, (ref VBuffer src, ref VBuffer dst) => SelectFeatures(ref src, features, card, ref dst)); - var res = RoleMappedData.Create(view, data.Schema.GetColumnRoleNames()); + var res = new RoleMappedData(view, data.Schema.GetColumnRoleNames()); return res; } diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs b/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs index e8ea19684e..f512114bbb 100644 --- a/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs +++ b/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs @@ -307,7 +307,7 @@ private static TOut CreatePipelineEnsemble(IHostEnvironment env, IPredicto var dv = new EmptyDataView(env, inputSchema); // The role mappings are specific to the individual predictors. - var rmd = RoleMappedData.Create(dv); + var rmd = new RoleMappedData(dv); var predictorModel = new PredictorModel(env, rmd, dv, ensemble); var output = new TOut { PredictorModel = predictorModel }; diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs index 7435958e62..f49e3af81c 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs @@ -181,11 +181,11 @@ public void Train(List>> models, var bldr = new ArrayDataViewBuilder(host); Array.Resize(ref labels, count); Array.Resize(ref features, count); - bldr.AddColumn("Label", NumberType.Float, labels); - bldr.AddColumn("Features", NumberType.Float, features); + bldr.AddColumn(DefaultColumnNames.Label, NumberType.Float, labels); + bldr.AddColumn(DefaultColumnNames.Features, NumberType.Float, features); var view = bldr.GetDataView(); - var rmd = RoleMappedData.Create(view, ColumnRole.Label.Bind("Label"), ColumnRole.Feature.Bind("Features")); + var rmd = new RoleMappedData(view, DefaultColumnNames.Label, DefaultColumnNames.Features); var trainer = BasePredictorType.CreateInstance(host); if (trainer is ITrainerEx ex && ex.NeedNormalization) diff --git a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs index 43956c608f..3cf30a3211 100644 --- a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs +++ b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs @@ -318,7 +318,7 @@ public IPredictor Calibrate(IChannel ch, IDataView data, ICalibratorTrainer cali if (caliTrainer.NeedsTraining) { - var bound = new Bound(this, RoleMappedSchema.Create(data.Schema)); + var bound = new Bound(this, new RoleMappedSchema(data.Schema)); using (var curs = data.GetRowCursor(col => true)) { var scoreGetter = (ValueGetter)bound.CreateScoreGetter(curs, col => true, out Action disposer); diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs index 2976214cfe..8465e1a5e8 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs @@ -84,7 +84,7 @@ public virtual void CalculateMetrics(FeatureSubsetModel GetBatches(IRandom rand) var view = new GenerateNumberTransform(Host, args, Data.Data); var viewTest = new RangeFilter(Host, new RangeFilter.Arguments() { Column = name, Max = ValidationDatasetProportion }, view); var viewTrain = new RangeFilter(Host, new RangeFilter.Arguments() { Column = name, Max = ValidationDatasetProportion, Complement = true }, view); - dataTest = RoleMappedData.Create(viewTest, Data.Schema.GetColumnRoleNames()); - dataTrain = RoleMappedData.Create(viewTrain, Data.Schema.GetColumnRoleNames()); + dataTest = new RoleMappedData(viewTest, Data.Schema.GetColumnRoleNames()); + dataTrain = new RoleMappedData(viewTrain, Data.Schema.GetColumnRoleNames()); } if (BatchSize > 0) diff --git a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BootstrapSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BootstrapSelector.cs index 25e7f8d64d..97dda8aeba 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BootstrapSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BootstrapSelector.cs @@ -46,7 +46,7 @@ public override IEnumerable GetSubsets(Batch batch, IRandom rand) { // REVIEW: Consider ways to reintroduce "balanced" samples. var viewTrain = new BootstrapSampleTransform(Host, new BootstrapSampleTransform.Arguments(), Data.Data); - var dataTrain = RoleMappedData.Create(viewTrain, Data.Schema.GetColumnRoleNames()); + var dataTrain = new RoleMappedData(viewTrain, Data.Schema.GetColumnRoleNames()); yield return FeatureSelector.SelectFeatures(dataTrain, rand); } } diff --git a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/RandomPartitionSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/RandomPartitionSelector.cs index 322a2133dc..0fca7ac55b 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/RandomPartitionSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/RandomPartitionSelector.cs @@ -45,7 +45,7 @@ public override IEnumerable GetSubsets(Batch batch, IRandom rand) for (int i = 0; i < Size; i++) { var viewTrain = new RangeFilter(Host, new RangeFilter.Arguments() { Column = name, Min = (Double)i / Size, Max = (Double)(i + 1) / Size }, view); - var dataTrain = RoleMappedData.Create(viewTrain, Data.Schema.GetColumnRoleNames()); + var dataTrain = new RoleMappedData(viewTrain, Data.Schema.GetColumnRoleNames()); yield return FeatureSelector.SelectFeatures(dataTrain, rand); } } diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 01668eb3f2..22a4e145b8 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -1386,7 +1386,7 @@ private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxB // Since we've passed it through a few transforms, reconstitute the mapping on the // newly transformed data. - examples = RoleMappedData.Create(data, examples.Schema.GetColumnRoleNames()); + examples = new RoleMappedData(data, examples.Schema.GetColumnRoleNames()); // Get the index of the columns in the transposed view, while we're at it composing // the list of the columns we want to transpose. diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs index bc16a07338..931dd2335b 100644 --- a/src/Microsoft.ML.FastTree/GamTrainer.cs +++ b/src/Microsoft.ML.FastTree/GamTrainer.cs @@ -990,10 +990,11 @@ public Context(IChannel ch, GamPredictorBase pred, RoleMappedData data, IEvaluat { _eval = eval; var builder = new ArrayDataViewBuilder(pred.Host); - builder.AddColumn("Label", NumberType.Float, _labels); - builder.AddColumn("Score", NumberType.Float, _scores); - _dataForEvaluator = RoleMappedData.Create(builder.GetDataView(), RoleMappedSchema.ColumnRole.Label.Bind("Label"), - RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, "Score")); + builder.AddColumn(DefaultColumnNames.Label, NumberType.Float, _labels); + builder.AddColumn(DefaultColumnNames.Score, NumberType.Float, _scores); + _dataForEvaluator = new RoleMappedData(builder.GetDataView(), opt: false, + RoleMappedSchema.ColumnRole.Label.Bind(DefaultColumnNames.Label), + new RoleMappedSchema.ColumnRole(MetadataUtils.Const.ScoreValueKind.Score).Bind(DefaultColumnNames.Score)); } _data.Schema.Schema.TryGetColumnIndex(DefaultColumnNames.Features, out int featureIndex); @@ -1196,7 +1197,7 @@ private Context Init(IChannel ch) } var pred = rawPred as GamPredictorBase; ch.CheckUserArg(pred != null, nameof(Args.InputModelFile), "Predictor was not a " + nameof(GamPredictorBase)); - var data = RoleMappedData.CreateOpt(loader, schema.GetColumnRoleNames()); + var data = new RoleMappedData(loader, schema.GetColumnRoleNames(), opt: true); if (hadCalibrator && !string.IsNullOrWhiteSpace(Args.OutputModelFile)) ch.Warning("If you save the GAM model, only the GAM model, not the wrapping calibrator, will be saved."); diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs index 253d76d654..a7044d2502 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs @@ -393,8 +393,7 @@ private void EnsureCachedPosition() public IEnumerable> GetInputColumnRoles() { - yield return new KeyValuePair( - RoleMappedSchema.ColumnRole.Feature, _inputSchema.Feature.Name); + yield return RoleMappedSchema.ColumnRole.Feature.Bind(_inputSchema.Feature.Name); } public Func GetDependencies(Func predicate) diff --git a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs index df386d0dc1..ff663a8d6e 100644 --- a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs +++ b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs @@ -97,54 +97,37 @@ public class Arguments : UnsupervisedLearnerInputBaseWithWeight public KMeansPlusPlusTrainer(IHostEnvironment env, Arguments args) : base(env, LoadNameValue) { - Contracts.CheckValue(args, nameof(args)); - Contracts.CheckUserArg(args.K > 0, nameof(args.K), "Number of means must be positive"); + Host.CheckValue(args, nameof(args)); + Host.CheckUserArg(args.K > 0, nameof(args.K), "Must be positive"); _k = args.K; - Contracts.CheckUserArg(args.MaxIterations > 0, nameof(args.MaxIterations), "Number of iterations must be positive"); + Host.CheckUserArg(args.MaxIterations > 0, nameof(args.MaxIterations), "Must be positive"); _maxIterations = args.MaxIterations; - Contracts.CheckUserArg(args.OptTol > 0, nameof(args.OptTol), "Tolerance must be positive"); + Host.CheckUserArg(args.OptTol > 0, nameof(args.OptTol), "Tolerance must be positive"); _convergenceThreshold = args.OptTol; _centroids = new VBuffer[_k]; - Contracts.CheckUserArg(args.AccelMemBudgetMb > 0, nameof(args.AccelMemBudgetMb), "Memory budget must be positive"); + Host.CheckUserArg(args.AccelMemBudgetMb > 0, nameof(args.AccelMemBudgetMb), "Must be positive"); _accelMemBudgetMb = args.AccelMemBudgetMb; _initAlgorithm = args.InitAlgorithm; - if (args.NumThreads.HasValue) - { - Contracts.CheckUserArg(args.NumThreads.Value > 0, nameof(args.NumThreads), "The number of threads must be either null or a positive integer."); - } + Host.CheckUserArg(!args.NumThreads.HasValue || args.NumThreads > 0, nameof(args.NumThreads), + "Must be either null or a positive integer."); _numThreads = ComputeNumThreads(Host, args.NumThreads); } - public override bool NeedNormalization - { - get { return true; } - } - - public override bool NeedCalibration - { - get { return false; } - } - - public override bool WantCaching - { - get { return true; } - } - - public override PredictionKind PredictionKind - { - get { return PredictionKind.Clustering; } - } + public override bool NeedNormalization => true; + public override bool NeedCalibration => false; + public override bool WantCaching => true; + public override PredictionKind PredictionKind => PredictionKind.Clustering; public override void Train(RoleMappedData data) { - Contracts.CheckValue(data, nameof(data)); + Host.CheckValue(data, nameof(data)); data.CheckFeatureFloatVector(out _dimensionality); Contracts.Assert(_dimensionality > 0); @@ -311,7 +294,7 @@ public static void Initialize( // This check is only performed once, at the first pass of initialization if (dimensionality != cursor.Features.Length) { - throw Contracts.Except( + throw ch.Except( "Dimensionality doesn't match, expected {0}, got {1}", dimensionality, cursor.Features.Length); @@ -330,7 +313,7 @@ public static void Initialize( probabilityWeight = Math.Min(probabilityWeight, distance); } - Contracts.Assert(FloatUtils.IsFinite(probabilityWeight)); + ch.Assert(FloatUtils.IsFinite(probabilityWeight)); } if (probabilityWeight > 0) @@ -362,7 +345,7 @@ public static void Initialize( // persist the candidate as a new centroid if (!haveCandidate) { - throw Contracts.Except( + throw ch.Except( "Not enough distinct instances to populate {0} clusters (only found {1} distinct instances)", k, i); } @@ -716,13 +699,13 @@ public static void Initialize(IHost host, int numThreads, IChannel ch, FeatureFl out long missingFeatureCount, out long totalTrainingInstances) { Contracts.CheckValue(host, nameof(host)); - host.CheckValue(cursorFactory, nameof(cursorFactory)); host.CheckValue(ch, nameof(ch)); - host.CheckValue(centroids, nameof(centroids)); - host.CheckUserArg(numThreads > 0, nameof(KMeansPlusPlusTrainer.Arguments.NumThreads), "Must be positive"); - host.CheckUserArg(k > 0, nameof(KMeansPlusPlusTrainer.Arguments.K), "Must be positive"); - host.CheckParam(dimensionality > 0, nameof(dimensionality), "Must be positive"); - host.CheckUserArg(accelMemBudgetMb >= 0, nameof(KMeansPlusPlusTrainer.Arguments.AccelMemBudgetMb), "Must be non-negative"); + ch.CheckValue(cursorFactory, nameof(cursorFactory)); + ch.CheckValue(centroids, nameof(centroids)); + ch.CheckUserArg(numThreads > 0, nameof(KMeansPlusPlusTrainer.Arguments.NumThreads), "Must be positive"); + ch.CheckUserArg(k > 0, nameof(KMeansPlusPlusTrainer.Arguments.K), "Must be positive"); + ch.CheckParam(dimensionality > 0, nameof(dimensionality), "Must be positive"); + ch.CheckUserArg(accelMemBudgetMb >= 0, nameof(KMeansPlusPlusTrainer.Arguments.AccelMemBudgetMb), "Must be non-negative"); int numRounds; int numSamplesPerRound; @@ -787,7 +770,7 @@ public static void Initialize(IHost host, int numThreads, IChannel ch, FeatureFl VBufferUtils.Densify(ref clusters[clusterCount]); clustersL2s[clusterCount] = VectorUtils.NormSquared(clusters[clusterCount]); clusterPrevCount = clusterCount; - Contracts.Assert(clusterCount - clusterPrevCount <= numSamplesPerRound); + ch.Assert(clusterCount - clusterPrevCount <= numSamplesPerRound); clusterCount++; logicalExternalRounds++; pCh.Checkpoint(logicalExternalRounds, numRounds + 2); @@ -828,11 +811,11 @@ public static void Initialize(IHost host, int numThreads, IChannel ch, FeatureFl clusterCount++; } - Contracts.Assert(clusterCount - clusterPrevCount <= numSamplesPerRound); + ch.Assert(clusterCount - clusterPrevCount <= numSamplesPerRound); logicalExternalRounds++; pCh.Checkpoint(logicalExternalRounds, numRounds + 2); } - Contracts.Assert(clusterCount == clusters.Length); + ch.Assert(clusterCount == clusters.Length); } // Finally, we do one last pass through the dataset, finding for @@ -851,7 +834,7 @@ public static void Initialize(IHost host, int numThreads, IChannel ch, FeatureFl clustersL2s, false, false, out discardBestWeight, out bestCluster); #if DEBUG int debugBestCluster = KMeansUtils.FindBestCluster(ref point, clusters, clustersL2s); - Contracts.Assert(bestCluster == debugBestCluster); + ch.Assert(bestCluster == debugBestCluster); #endif weights[bestCluster]++; }, @@ -881,9 +864,9 @@ public static void Initialize(IHost host, int numThreads, IChannel ch, FeatureFl ref debugWeightBuffer, ref debugTotalWeights); for (int i = 0; i < totalWeights.Length; i++) - Contracts.Assert(totalWeights[i] == debugTotalWeights[i]); + ch.Assert(totalWeights[i] == debugTotalWeights[i]); #endif - Contracts.Assert(totalWeights.Length == clusters.Length); + ch.Assert(totalWeights.Length == clusters.Length); logicalExternalRounds++; // If we sampled exactly the right number of points then we can @@ -899,10 +882,10 @@ public static void Initialize(IHost host, int numThreads, IChannel ch, FeatureFl else { ArrayDataViewBuilder arrDv = new ArrayDataViewBuilder(host); - arrDv.AddColumn("Features", PrimitiveType.FromKind(DataKind.R4), clusters); - arrDv.AddColumn("Weights", PrimitiveType.FromKind(DataKind.R4), totalWeights); + arrDv.AddColumn(DefaultColumnNames.Features, PrimitiveType.FromKind(DataKind.R4), clusters); + arrDv.AddColumn(DefaultColumnNames.Weight, PrimitiveType.FromKind(DataKind.R4), totalWeights); var subDataViewCursorFactory = new FeatureFloatVectorCursor.Factory( - TrainUtils.CreateExamples(arrDv.GetDataView(), null, "Features", weight: "Weights"), CursOpt.Weight | CursOpt.Features); + new RoleMappedData(arrDv.GetDataView(), null, DefaultColumnNames.Features, weight: DefaultColumnNames.Weight), CursOpt.Weight | CursOpt.Features); long discard1; long discard2; KMeansPlusPlusInit.Initialize(host, numThreads, ch, subDataViewCursorFactory, k, dimensionality, centroids, out discard1, out discard2, false); @@ -1198,13 +1181,13 @@ public int GetBestCluster(int idx) public SharedState(FeatureFloatVectorCursor.Factory factory, IChannel ch, long baseMaxInstancesToAccelerate, int k, bool isParallel, long totalTrainingInstances) { - Contracts.AssertValue(factory); Contracts.AssertValue(ch); - Contracts.Assert(k > 0); - Contracts.Assert(totalTrainingInstances > 0); + ch.AssertValue(factory); + ch.Assert(k > 0); + ch.Assert(totalTrainingInstances > 0); _acceleratedRowMap = new KMeansAcceleratedRowMap(factory, ch, baseMaxInstancesToAccelerate, totalTrainingInstances, isParallel); - Contracts.Assert(MaxInstancesToAccelerate >= 0, + ch.Assert(MaxInstancesToAccelerate >= 0, "MaxInstancesToAccelerate cannot be negative as KMeansAcceleratedRowMap sets it to 0 when baseMaxInstancesToAccelerate is negative"); if (MaxInstancesToAccelerate > 0) @@ -1576,12 +1559,12 @@ public static RowStats ParallelWeightedReservoirSample( }, (Heap[] heaps, IRandom rand, ref Heap finalHeap) => { - Contracts.Assert(finalHeap == null); + host.Assert(finalHeap == null); finalHeap = new Heap((x, y) => x.Weight > y.Weight, numSamples); for (int i = 0; i < heaps.Length; i++) { - Contracts.AssertValue(heaps[i]); - Contracts.Assert(heaps[i].Count <= numSamples, "heaps[i].Count must not be greater than numSamples"); + host.AssertValue(heaps[i]); + host.Assert(heaps[i].Count <= numSamples, "heaps[i].Count must not be greater than numSamples"); while (heaps[i].Count > 0) { var row = heaps[i].Pop(); @@ -1597,7 +1580,7 @@ public static RowStats ParallelWeightedReservoirSample( }, ref buffer, ref outHeap); if (outHeap.Count != numSamples) - throw Contracts.Except("Failed to initialize clusters: too few examples"); + throw host.Except("Failed to initialize clusters: too few examples"); // Keep in mind that the distribution of samples in dst will not be random. It will // have the residual minHeap ordering. diff --git a/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs b/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs index 6b72d64af5..9c40e61ca2 100644 --- a/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs +++ b/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs @@ -172,13 +172,13 @@ private void Run(IChannel ch) { RoleMappedData data; if (trainSchema != null) - data = RoleMappedData.Create(end, trainSchema.GetColumnRoleNames()); + data = new RoleMappedData(end, trainSchema.GetColumnRoleNames()); else { // We had a predictor, but no roles stored in the model. Just suppose // default column names are OK, if present. - data = TrainUtils.CreateExamplesOpt(end, DefaultColumnNames.Label, - DefaultColumnNames.Features, DefaultColumnNames.GroupId, DefaultColumnNames.Weight, DefaultColumnNames.Name); + data = new RoleMappedData(end, DefaultColumnNames.Label, + DefaultColumnNames.Features, DefaultColumnNames.GroupId, DefaultColumnNames.Weight, DefaultColumnNames.Name, opt: true); } var scorePipe = ScoreUtils.GetScorer(rawPred, data, Host, trainSchema); diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs index 73b3f1051f..67c53223d7 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs @@ -82,7 +82,7 @@ public FieldAwareFactorizationMachineScalarRowMapper(IHostEnvironment env, RoleM _pred = pred; var inputFeatureColumns = _columns.Select(c => new KeyValuePair(RoleMappedSchema.ColumnRole.Feature, c.Name)).ToList(); - InputSchema = RoleMappedSchema.Create(schema.Schema, inputFeatureColumns); + InputSchema = new RoleMappedSchema(schema.Schema, inputFeatureColumns); OutputSchema = outputSchema; _inputColumnIndexes = new List(); diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs index 738c0197cb..a48a53ae65 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs @@ -118,7 +118,7 @@ protected RoleMappedData PrepareDataFromTrainingExamples(IChannel ch, RoleMapped ch.Assert(idvToFeedTrain.CanShuffle); var roles = examples.Schema.GetColumnRoleNames(); - var examplesToFeedTrain = RoleMappedData.Create(idvToFeedTrain, roles); + var examplesToFeedTrain = new RoleMappedData(idvToFeedTrain, roles); ch.Assert(examplesToFeedTrain.Schema.Label != null); ch.Assert(examplesToFeedTrain.Schema.Feature != null); diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs index 8aa6e4b6e0..7b5fcc8a93 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs @@ -79,7 +79,7 @@ private TScalarPredictor TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappe var roles = data.Schema.GetColumnRoleNames() .Where(kvp => kvp.Key.Value != CR.Label.Value) .Prepend(CR.Label.Bind(dstName)); - var td = RoleMappedData.Create(view, roles); + var td = new RoleMappedData(view, roles); trainer.Train(td); @@ -214,7 +214,7 @@ public static ModelOperations.PredictorModelOutput CombineOvaModels(IHostEnviron input.FeatureColumn, DefaultColumnNames.Features); var weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(input.WeightColumn), input.WeightColumn, DefaultColumnNames.Weight); - var data = TrainUtils.CreateExamples(normalizedView, label, feature, null, weight); + var data = new RoleMappedData(normalizedView, label, feature, null, weight); return new ModelOperations.PredictorModelOutput { diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs index c710b3149d..cf1e7c062b 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs @@ -76,7 +76,7 @@ private TDistPredictor TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedD var roles = data.Schema.GetColumnRoleNames() .Where(kvp => kvp.Key.Value != CR.Label.Value) .Prepend(CR.Label.Bind(dstName)); - var td = RoleMappedData.Create(view, roles); + var td = new RoleMappedData(view, roles); trainer.Train(td); diff --git a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs index 87420a265e..1d2f466ad5 100644 --- a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs +++ b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs @@ -126,12 +126,12 @@ private FastForestRegressionPredictor FitModel(IEnumerable previousR } ArrayDataViewBuilder dvBuilder = new ArrayDataViewBuilder(_host); - dvBuilder.AddColumn("Label", NumberType.Float, targets); - dvBuilder.AddColumn("Features", NumberType.Float, features); + dvBuilder.AddColumn(DefaultColumnNames.Label, NumberType.Float, targets); + dvBuilder.AddColumn(DefaultColumnNames.Features, NumberType.Float, features); IDataView view = dvBuilder.GetDataView(); _host.Assert(view.GetRowCount() == targets.Length, "This data view will have as many rows as there have been evaluations"); - RoleMappedData data = TrainUtils.CreateExamples(view, "Label", "Features"); + RoleMappedData data = new RoleMappedData(view, DefaultColumnNames.Label, DefaultColumnNames.Features); using (IChannel ch = _host.Start("Single training")) { diff --git a/src/Microsoft.ML.Transforms/LearnerFeatureSelection.cs b/src/Microsoft.ML.Transforms/LearnerFeatureSelection.cs index 8448a406db..637d75250b 100644 --- a/src/Microsoft.ML.Transforms/LearnerFeatureSelection.cs +++ b/src/Microsoft.ML.Transforms/LearnerFeatureSelection.cs @@ -299,7 +299,7 @@ private static void TrainCore(IHost host, IDataView input, Arguments args, ref V ch.Trace("Binding columns"); var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, args.CustomColumn); - var data = TrainUtils.CreateExamples(view, label, feature, group, weight, name, customCols); + var data = new RoleMappedData(view, label, feature, group, weight, name, customCols); var predictor = TrainUtils.Train(host, ch, data, trainer, args.Filter.Kind, null, null, 0, args.CacheData); diff --git a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs index f6c99d061c..631b7b6946 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs @@ -433,13 +433,11 @@ public static CombinedOutput CombineMetrics(IHostEnvironment env, CombineMetrics var eval = GetEvaluator(env, input.Kind); var perInst = EvaluateUtils.ConcatenatePerInstanceDataViews(env, eval, true, true, input.PerInstanceMetrics.Select( - idv => RoleMappedData.CreateOpt(idv, new[] - { - RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, input.LabelColumn), - RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Weight, input.WeightColumn.Value), - RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Group, input.GroupColumn.Value), - RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Name, input.NameColumn.Value) - })).ToArray(), + idv => new RoleMappedData(idv, opt: true, + RoleMappedSchema.ColumnRole.Label.Bind(input.LabelColumn), + RoleMappedSchema.ColumnRole.Weight.Bind(input.WeightColumn.Value), + RoleMappedSchema.ColumnRole.Group.Bind(input.GroupColumn), + RoleMappedSchema.ColumnRole.Name.Bind(input.NameColumn.Value))).ToArray(), out var variableSizeVectorColumnNames); var warnings = input.Warnings != null ? new List(input.Warnings) : new List(); diff --git a/src/Microsoft.ML/Runtime/EntryPoints/FeatureCombiner.cs b/src/Microsoft.ML/Runtime/EntryPoints/FeatureCombiner.cs index b26960c561..6502fc2afa 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/FeatureCombiner.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/FeatureCombiner.cs @@ -22,7 +22,7 @@ public sealed class FeatureCombinerInput : TransformInputBase [Argument(ArgumentType.Multiple, HelpText = "Features", SortOrder = 2)] public string[] Features; - public IEnumerable> GetRoles() + internal IEnumerable> GetRoles() { if (Utils.Size(Features) > 0) { @@ -49,7 +49,7 @@ public static CommonOutputs.TransformOutput PrepareFeatures(IHostEnvironment env using (var ch = host.Start(featureCombiner)) { var viewTrain = input.Data; - var rms = RoleMappedSchema.Create(viewTrain.Schema, input.GetRoles()); + var rms = new RoleMappedSchema(viewTrain.Schema, input.GetRoles()); var feats = rms.GetColumns(RoleMappedSchema.ColumnRole.Feature); if (Utils.Size(feats) == 0) throw ch.Except("No feature columns specified"); diff --git a/src/Microsoft.ML/Runtime/EntryPoints/OneVersusAllMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/OneVersusAllMacro.cs index c95f384ad6..e4a54de040 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/OneVersusAllMacro.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/OneVersusAllMacro.cs @@ -128,7 +128,7 @@ private static int GetNumberOfClasses(IHostEnvironment env, Arguments input, out input.WeightColumn, DefaultColumnNames.Weight); // Get number of classes - var data = TrainUtils.CreateExamples(input.TrainingData, label, feature, null, weight); + var data = new RoleMappedData(input.TrainingData, label, feature, null, weight); data.CheckMultiClassLabel(out var numClasses); return numClasses; } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index 43e208ad3c..ec8c0e12d7 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -684,9 +684,7 @@ public void EntryPointCalibrate() // This tests that the SchemaBindableCalibratedPredictor doesn't get confused if its sub-predictor is already calibrated. var fastForest = new FastForestClassification(Env, new FastForestClassification.Arguments()); - var rmd = RoleMappedData.Create(splitOutput.TrainData[0], - RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Feature, "Features"), - RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, "Label")); + var rmd = new RoleMappedData(splitOutput.TrainData[0], "Label", "Features"); fastForest.Train(rmd); var ffModel = new PredictorModel(Env, rmd, splitOutput.TrainData[0], fastForest.CreatePredictor()); var calibratedFfModel = Calibrate.Platt(Env, @@ -1220,9 +1218,7 @@ public void EntryPointMulticlassPipelineEnsemble() }, data); var mlr = new MulticlassLogisticRegression(Env, new MulticlassLogisticRegression.Arguments()); - RoleMappedData rmd = RoleMappedData.Create(data, - RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Feature, "Features"), - RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, "Label")); + var rmd = new RoleMappedData(data, "Label", "Features"); mlr.Train(rmd); predictorModels[i] = new PredictorModel(Env, rmd, data, mlr.CreatePredictor()); diff --git a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs index 928f740b57..f9b08ef72a 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs @@ -599,7 +599,7 @@ private void CombineAndTestTreeEnsembles(IDataView idv, IPredictorModel[] fastTr var fastTree = combiner.CombineModels(fastTrees.Select(pm => pm.Predictor as IPredictorProducing)); - var data = RoleMappedData.Create(idv, RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Feature, "Features")); + var data = new RoleMappedData(idv, label: null, feature: "Features"); var scored = ScoreModel.Score(Env, new ScoreModel.Input() { Data = idv, PredictorModel = new PredictorModel(Env, data, idv, fastTree) }).ScoredData; Assert.True(scored.Schema.TryGetColumnIndex("Score", out int scoreCol)); Assert.True(scored.Schema.TryGetColumnIndex("Probability", out int probCol)); @@ -1537,10 +1537,10 @@ public void CompareSvmPredictorResultsToLibSvm() Column = new[] { new NormalizeTransform.AffineColumn() { Name = "Features", Source = "Features" } } }, trainView); - var trainData = TrainUtils.CreateExamples(trainView, "Label", "Features"); + var trainData = new RoleMappedData(trainView, "Label", "Features"); IDataView testView = new TextLoader(env, new TextLoader.Arguments(), new MultiFileSource(GetDataPath(TestDatasets.mnistOneClass.testFilename))); ApplyTransformUtils.ApplyAllTransformsToData(env, trainView, testView); - var testData = TrainUtils.CreateExamples(testView, "Label", "Features"); + var testData = new RoleMappedData(testView, "Label", "Features"); CompareSvmToLibSvmCore("linear kernel", "LinearKernel", env, trainData, testData); CompareSvmToLibSvmCore("polynomial kernel", "PolynomialKernel{d=2}", env, trainData, testData); diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs index c1291cd52e..7c371082de 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs @@ -74,7 +74,7 @@ public void TrainAndPredictIrisModelUsingDirectInstantiationTest() // Explicity adding CacheDataView since caching is not working though trainer has 'Caching' On/Auto var cached = new CacheDataView(env, trans, prefetch: null); - var trainRoles = TrainUtils.CreateExamples(cached, label: "Label", feature: "Features"); + var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features"); trainer.Train(trainRoles); // Get scorer and evaluate the predictions from test data @@ -176,7 +176,7 @@ private void CompareMatrics(ClassificationMetrics metrics) private ClassificationMetrics Evaluate(IHostEnvironment env, IDataView scoredData) { - var dataEval = TrainUtils.CreateExamplesOpt(scoredData, label: "Label", feature: "Features"); + var dataEval = new RoleMappedData(scoredData, label: "Label", feature: "Features", opt: true); // Evaluate. // It does not work. It throws error "Failed to find 'Score' column" when Evaluate is called @@ -193,7 +193,7 @@ private IDataScorerTransform GetScorer(IHostEnvironment env, IDataView transform using (var ch = env.Start("Saving model")) using (var memoryStream = new MemoryStream()) { - var trainRoles = TrainUtils.CreateExamples(transforms, label: "Label", feature: "Features"); + var trainRoles = new RoleMappedData(transforms, label: "Label", feature: "Features"); // Model cannot be saved with CacheDataView TrainUtils.SaveModel(env, ch, memoryStream, pred, trainRoles); @@ -201,7 +201,7 @@ private IDataScorerTransform GetScorer(IHostEnvironment env, IDataView transform using (var rep = RepositoryReader.Open(memoryStream, ch)) { IDataLoader testPipe = ModelFileUtils.LoadLoader(env, rep, new MultiFileSource(testDataPath), true); - RoleMappedData testRoles = TrainUtils.CreateExamples(testPipe, label: "Label", feature: "Features"); + RoleMappedData testRoles = new RoleMappedData(testPipe, label: "Label", feature: "Features"); return ScoreUtils.GetScorer(pred, testRoles, env, testRoles.Schema); } } diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs index 2266f4d1f0..da208cb3f0 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs @@ -78,7 +78,7 @@ public void TrainAndPredictSentimentModelWithDirectionInstantiationTest() MinDocumentsInLeafs = 2 }); - var trainRoles = TrainUtils.CreateExamples(trans, label: "Label", feature: "Features"); + var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features"); trainer.Train(trainRoles); // Get scorer and evaluate the predictions from test data @@ -103,7 +103,7 @@ public void TrainAndPredictSentimentModelWithDirectionInstantiationTest() private BinaryClassificationMetrics EvaluateBinary(IHostEnvironment env, IDataView scoredData) { - var dataEval = TrainUtils.CreateExamplesOpt(scoredData, label: "Label", feature: "Features"); + var dataEval = new RoleMappedData(scoredData, label: "Label", feature: "Features", opt: true); // Evaluate. // It does not work. It throws error "Failed to find 'Score' column" when Evaluate is called