diff --git a/src/Microsoft.ML.Api/ComponentCreation.cs b/src/Microsoft.ML.Api/ComponentCreation.cs
index 73cfbf91a5..0d164b6124 100644
--- a/src/Microsoft.ML.Api/ComponentCreation.cs
+++ b/src/Microsoft.ML.Api/ComponentCreation.cs
@@ -52,7 +52,7 @@ public static RoleMappedData CreateExamples(this IHostEnvironment env, IDataView
env.CheckValueOrNull(weight);
env.CheckValueOrNull(custom);
- return TrainUtils.CreateExamples(data, label, features, group, weight, name: null, custom: custom);
+ return new RoleMappedData(data, label, features, group, weight, name: null, custom: custom);
}
///
diff --git a/src/Microsoft.ML.Api/GenerateCodeCommand.cs b/src/Microsoft.ML.Api/GenerateCodeCommand.cs
index 0b45bcc4bb..0bca5edfb3 100644
--- a/src/Microsoft.ML.Api/GenerateCodeCommand.cs
+++ b/src/Microsoft.ML.Api/GenerateCodeCommand.cs
@@ -108,8 +108,8 @@ public void Run()
{
var roles = ModelFileUtils.LoadRoleMappingsOrNull(_host, fs);
scorer = roles != null
- ? _host.CreateDefaultScorer(RoleMappedData.CreateOpt(transformPipe, roles), pred)
- : _host.CreateDefaultScorer(_host.CreateExamples(transformPipe, "Features"), pred);
+ ? _host.CreateDefaultScorer(new RoleMappedData(transformPipe, roles, opt: true), pred)
+ : _host.CreateDefaultScorer(new RoleMappedData(transformPipe, label: null, "Features"), pred);
}
var nonScoreSb = new StringBuilder();
diff --git a/src/Microsoft.ML.Api/PredictionEngine.cs b/src/Microsoft.ML.Api/PredictionEngine.cs
index eacf8d2218..14e2498c93 100644
--- a/src/Microsoft.ML.Api/PredictionEngine.cs
+++ b/src/Microsoft.ML.Api/PredictionEngine.cs
@@ -49,8 +49,8 @@ internal BatchPredictionEngine(IHostEnvironment env, Stream modelStream, bool ig
{
var roles = ModelFileUtils.LoadRoleMappingsOrNull(env, modelStream);
pipe = roles != null
- ? env.CreateDefaultScorer(RoleMappedData.CreateOpt(pipe, roles), predictor)
- : env.CreateDefaultScorer(env.CreateExamples(pipe, "Features"), predictor);
+ ? env.CreateDefaultScorer(new RoleMappedData(pipe, roles, opt: true), predictor)
+ : env.CreateDefaultScorer(new RoleMappedData(pipe, label: null, "Features"), predictor);
}
_pipeEngine = new PipeEngine(env, pipe, ignoreMissingColumns, outputSchemaDefinition);
diff --git a/src/Microsoft.ML.Core/Data/MetadataUtils.cs b/src/Microsoft.ML.Core/Data/MetadataUtils.cs
index ca13d03ab3..b0a18f6d18 100644
--- a/src/Microsoft.ML.Core/Data/MetadataUtils.cs
+++ b/src/Microsoft.ML.Core/Data/MetadataUtils.cs
@@ -312,7 +312,6 @@ public static bool HasSlotNames(this ISchema schema, int col, int vectorSize)
public static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.ColumnRole role, int vectorSize, ref VBuffer slotNames)
{
Contracts.CheckValueOrNull(schema);
- Contracts.CheckValue(role.Value, nameof(role));
Contracts.CheckParam(vectorSize >= 0, nameof(vectorSize));
IReadOnlyList list;
diff --git a/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs b/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs
index 302d369489..2dab48fc58 100644
--- a/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs
+++ b/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs
@@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
using System.Collections.Generic;
using Microsoft.ML.Runtime.Internal.Utilities;
@@ -34,9 +33,8 @@ private ColumnInfo(string name, int index, ColumnType type)
///
public static ColumnInfo CreateFromName(ISchema schema, string name, string descName)
{
- ColumnInfo colInfo;
- if (!TryCreateFromName(schema, name, out colInfo))
- throw Contracts.ExceptParam(nameof(name), "{0} column '{1}' not found", descName, name);
+ if (!TryCreateFromName(schema, name, out var colInfo))
+ throw Contracts.ExceptParam(nameof(name), $"{descName} column '{name}' not found");
return colInfo;
}
@@ -51,8 +49,7 @@ public static bool TryCreateFromName(ISchema schema, string name, out ColumnInfo
Contracts.CheckNonEmpty(name, nameof(name));
colInfo = null;
- int index;
- if (!schema.TryGetColumnIndex(name, out index))
+ if (!schema.TryGetColumnIndex(name, out int index))
return false;
colInfo = new ColumnInfo(name, index, schema.GetColumnType(index));
@@ -83,12 +80,25 @@ public static ColumnInfo CreateFromIndex(ISchema schema, int index)
/// multiple features columns to consume that information.
///
/// This class has convenience fields for several common column roles (se.g., , ), but can hold an arbitrary set of column infos. The convenience fields are non-null iff there is
- /// a unique column with the corresponding role. When there are no such columns or more than one such column, the
- /// field is null. The , , and methods provide
- /// some cardinality information. Note that all columns assigned roles are guaranteed to be non-hidden in this
- /// schema.
+ /// cref="Label"/>), but can hold an arbitrary set of column infos. The convenience fields are non-null if and only
+ /// if there is a unique column with the corresponding role. When there are no such columns or more than one such
+ /// column, the field is null. The , , and
+ /// methods provide some cardinality information. Note that all columns assigned roles are guaranteed to be non-hidden
+ /// in this schema.
///
+ ///
+ /// Note that instances of this class are, like instances of , immutable.
+ ///
+ /// It is often the case that one wishes to bundle the actual data with the role mappings, not just the schema. For
+ /// that case, please use the class.
+ ///
+ /// Note that there is no need for components consuming a or
+ /// to make use of every defined mapping. Consuming components are also expected to ignore any
+ /// they do not handle. They may very well however complain if a mapping they wanted to see is not present, or the column(s)
+ /// mapped from the role are not of the form they require.
+ ///
+ ///
+ ///
public sealed class RoleMappedSchema
{
private const string FeatureString = "Feature";
@@ -98,17 +108,57 @@ public sealed class RoleMappedSchema
private const string NameString = "Name";
private const string FeatureContributionsString = "FeatureContributions";
+ ///
+ /// Instances of this are the keys of a . This class also holds some important
+ /// commonly used pre-defined instances available (e.g., , ) that should
+ /// be used when possible for consistency reasons. However, practitioners should not be afraid to declare custom
+ /// roles if approppriate for their task.
+ ///
public struct ColumnRole
{
+ ///
+ /// Role for features. Commonly used as the independent variables given to trainers, and scorers.
+ ///
public static ColumnRole Feature => FeatureString;
+
+ ///
+ /// Role for labels. Commonly used as the dependent variables given to trainers, and evaluators.
+ ///
public static ColumnRole Label => LabelString;
+
+ ///
+ /// Role for group ID. Commonly used in ranking applications, for defining query boundaries, or
+ /// sequence classification, for defining the boundaries of an utterance.
+ ///
public static ColumnRole Group => GroupString;
+
+ ///
+ /// Role for sample weights. Commonly used to point to a number to make trainers give more weight
+ /// to a particular example.
+ ///
public static ColumnRole Weight => WeightString;
+
+ ///
+ /// Role for sample names. Useful for informational and tracking purposes when scoring, but typically
+ /// without affecting results.
+ ///
public static ColumnRole Name => NameString;
+
+ // REVIEW: Does this really belong here?
+ ///
+ /// Role for feature contributions. Useful for specific diagnostic functionality.
+ ///
public static ColumnRole FeatureContributions => FeatureContributionsString;
+ ///
+ /// The string value for the role. Guaranteed to be non-empty.
+ ///
public readonly string Value;
+ ///
+ /// Constructor for the column role.
+ ///
+ /// The value for the role. Must be non-empty.
public ColumnRole(string value)
{
Contracts.CheckNonEmpty(value, nameof(value));
@@ -116,50 +166,51 @@ public ColumnRole(string value)
}
public static implicit operator ColumnRole(string value)
- {
- return new ColumnRole(value);
- }
-
+ => new ColumnRole(value);
+
+ ///
+ /// Convenience method for creating a mapping pair from a role to a column name
+ /// for giving to constructors of and .
+ ///
+ /// The column name to map to. Can be null, in which case when used
+ /// to construct a role mapping structure this pair will be ignored
+ /// A key-value pair with this instance as the key and as the value
public KeyValuePair Bind(string name)
- {
- return new KeyValuePair(this, name);
- }
+ => new KeyValuePair(this, name);
}
public static KeyValuePair CreatePair(ColumnRole role, string name)
- {
- return new KeyValuePair(role, name);
- }
+ => new KeyValuePair(role, name);
///
- /// The source ISchema.
+ /// The source .
///
- public readonly ISchema Schema;
+ public ISchema Schema { get; }
///
- /// The Feature column, when there is exactly one (null otherwise).
+ /// The column, when there is exactly one (null otherwise).
///
- public readonly ColumnInfo Feature;
+ public ColumnInfo Feature { get; }
///
- /// The Label column, when there is exactly one (null otherwise).
+ /// The column, when there is exactly one (null otherwise).
///
- public readonly ColumnInfo Label;
+ public ColumnInfo Label { get; }
///
- /// The Group column, when there is exactly one (null otherwise).
+ /// The column, when there is exactly one (null otherwise).
///
- public readonly ColumnInfo Group;
+ public ColumnInfo Group { get; }
///
- /// The Weight column, when there is exactly one (null otherwise).
+ /// The column, when there is exactly one (null otherwise).
///
- public readonly ColumnInfo Weight;
+ public ColumnInfo Weight { get; }
///
- /// The Name column, when there is exactly one (null otherwise).
+ /// The column, when there is exactly one (null otherwise).
///
- public readonly ColumnInfo Name;
+ public ColumnInfo Name { get; }
// Maps from role to the associated column infos.
private readonly Dictionary> _map;
@@ -183,21 +234,21 @@ private RoleMappedSchema(ISchema schema, Dictionary> map, ColumnRole rol
Contracts.AssertNonEmpty(role.Value);
Contracts.AssertValue(info);
- List list;
- if (!map.TryGetValue(role.Value, out list))
+ if (!map.TryGetValue(role.Value, out var list))
{
list = new List();
map.Add(role.Value, list);
@@ -223,7 +273,7 @@ private static void Add(Dictionary> map, ColumnRole rol
list.Add(info);
}
- private static Dictionary> MapFromNames(ISchema schema, IEnumerable> roles)
+ private static Dictionary> MapFromNames(ISchema schema, IEnumerable> roles, bool opt = false)
{
Contracts.AssertValue(schema);
Contracts.AssertValue(roles);
@@ -231,28 +281,13 @@ private static Dictionary> MapFromNames(ISchema schema,
var map = new Dictionary>();
foreach (var kvp in roles)
{
- Contracts.CheckNonEmpty(kvp.Key.Value, nameof(roles), "Bad column role");
- if (string.IsNullOrEmpty(kvp.Value))
- continue;
- var info = ColumnInfo.CreateFromName(schema, kvp.Value, kvp.Key.Value);
- Add(map, kvp.Key.Value, info);
- }
- return map;
- }
-
- private static Dictionary> MapFromNamesOpt(ISchema schema, IEnumerable> roles)
- {
- Contracts.AssertValue(schema);
- Contracts.AssertValue(roles);
-
- var map = new Dictionary>();
- foreach (var kvp in roles)
- {
- Contracts.CheckNonEmpty(kvp.Key.Value, nameof(roles), "Bad column role");
+ Contracts.AssertNonEmpty(kvp.Key.Value);
if (string.IsNullOrEmpty(kvp.Value))
continue;
ColumnInfo info;
- if (!ColumnInfo.TryCreateFromName(schema, kvp.Value, out info))
+ if (!opt)
+ info = ColumnInfo.CreateFromName(schema, kvp.Value, kvp.Key.Value);
+ else if (!ColumnInfo.TryCreateFromName(schema, kvp.Value, out info))
continue;
Add(map, kvp.Key.Value, info);
}
@@ -263,39 +298,26 @@ private static Dictionary> MapFromNamesOpt(ISchema sche
/// Returns whether there are any columns with the given column role.
///
public bool Has(ColumnRole role)
- {
- return role.Value != null && _map.ContainsKey(role.Value);
- }
+ => _map.ContainsKey(role.Value);
///
/// Returns whether there is exactly one column of the given role.
///
public bool HasUnique(ColumnRole role)
- {
- IReadOnlyList cols;
- return role.Value != null && _map.TryGetValue(role.Value, out cols) && cols.Count == 1;
- }
+ => _map.TryGetValue(role.Value, out var cols) && cols.Count == 1;
///
/// Returns whether there are two or more columns of the given role.
///
public bool HasMultiple(ColumnRole role)
- {
- IReadOnlyList cols;
- return role.Value != null && _map.TryGetValue(role.Value, out cols) && cols.Count > 1;
- }
+ => _map.TryGetValue(role.Value, out var cols) && cols.Count > 1;
///
/// If there are columns of the given role, this returns the infos as a readonly list. Otherwise,
/// it returns null.
///
public IReadOnlyList GetColumns(ColumnRole role)
- {
- IReadOnlyList list;
- if (role.Value != null && _map.TryGetValue(role.Value, out list))
- return list;
- return null;
- }
+ => _map.TryGetValue(role.Value, out var list) ? list : null;
///
/// An enumerable over all role-column associations within this object.
@@ -327,8 +349,7 @@ public IEnumerable> GetColumnRoleNames()
///
public IEnumerable> GetColumnRoleNames(ColumnRole role)
{
- IReadOnlyList list;
- if (role.Value != null && _map.TryGetValue(role.Value, out list))
+ if (_map.TryGetValue(role.Value, out var list))
{
foreach (var info in list)
yield return new KeyValuePair(role, info.Name);
@@ -363,45 +384,82 @@ private static Dictionary> Copy(Dictionary
- /// Creates a RoleMappedSchema from the given schema with no column role assignments.
+ /// Constructor given a schema, and mapping pairs of roles to columns in the schema.
+ /// This skips null or empty column-names. It will also skip column-names that are not
+ /// found in the schema if is true.
///
- public static RoleMappedSchema Create(ISchema schema)
+ /// The schema over which roles are defined
+ /// Whether to consider the column names specified "optional" or not. If false then any non-empty
+ /// values for the column names that does not appear in will result in an exception being thrown,
+ /// but if true such values will be ignored
+ /// The column role to column name mappings
+ public RoleMappedSchema(ISchema schema, bool opt = false, params KeyValuePair[] roles)
+ : this(Contracts.CheckRef(schema, nameof(schema)), Contracts.CheckRef(roles, nameof(roles)), opt)
{
- Contracts.CheckValue(schema, nameof(schema));
- return new RoleMappedSchema(schema, new Dictionary>());
}
///
- /// Creates a RoleMappedSchema from the given schema and role/column-name pairs.
- /// This skips null or empty column-names.
+ /// Constructor given a schema, and mapping pairs of roles to columns in the schema.
+ /// This skips null or empty column names. It will also skip column-names that are not
+ /// found in the schema if is true.
///
- public static RoleMappedSchema Create(ISchema schema, params KeyValuePair[] roles)
+ /// The schema over which roles are defined
+ /// The column role to column name mappings
+ /// Whether to consider the column names specified "optional" or not. If false then any non-empty
+ /// values for the column names that does not appear in will result in an exception being thrown,
+ /// but if true such values will be ignored
+ public RoleMappedSchema(ISchema schema, IEnumerable> roles, bool opt = false)
+ : this(Contracts.CheckRef(schema, nameof(schema)),
+ MapFromNames(schema, Contracts.CheckRef(roles, nameof(roles)), opt))
{
- Contracts.CheckValue(schema, nameof(schema));
- Contracts.CheckValue(roles, nameof(roles));
- return new RoleMappedSchema(schema, MapFromNames(schema, roles));
}
- ///
- /// Creates a RoleMappedSchema from the given schema and role/column-name pairs.
- /// This skips null or empty column-names.
- ///
- public static RoleMappedSchema Create(ISchema schema, IEnumerable> roles)
+ private static IEnumerable> PredefinedRolesHelper(
+ string label, string feature, string group, string weight, string name,
+ IEnumerable> custom = null)
{
- Contracts.CheckValue(schema, nameof(schema));
- Contracts.CheckValue(roles, nameof(roles));
- return new RoleMappedSchema(schema, MapFromNames(schema, roles));
+ if (!string.IsNullOrWhiteSpace(label))
+ yield return ColumnRole.Label.Bind(label);
+ if (!string.IsNullOrWhiteSpace(feature))
+ yield return ColumnRole.Feature.Bind(feature);
+ if (!string.IsNullOrWhiteSpace(group))
+ yield return ColumnRole.Group.Bind(group);
+ if (!string.IsNullOrWhiteSpace(weight))
+ yield return ColumnRole.Weight.Bind(weight);
+ if (!string.IsNullOrWhiteSpace(name))
+ yield return ColumnRole.Name.Bind(name);
+ if (custom != null)
+ {
+ foreach (var role in custom)
+ yield return role;
+ }
}
///
- /// Creates a RoleMappedSchema from the given schema and role/column-name pairs.
- /// This skips null or empty column-names, or column-names that are not found in the schema.
+ /// Convenience constructor for role-mappings over the commonly used roles. Note that if any column name specified
+ /// is null or whitespace, it is ignored.
///
- public static RoleMappedSchema CreateOpt(ISchema schema, IEnumerable> roles)
+ /// The schema over which roles are defined
+ /// The column name that will be mapped to the role
+ /// The column name that will be mapped to the role
+ /// The column name that will be mapped to the role
+ /// The column name that will be mapped to the role
+ /// The column name that will be mapped to the role
+ /// Any additional desired custom column role mappings
+ /// Whether to consider the column names specified "optional" or not. If false then any non-empty
+ /// values for the column names that does not appear in will result in an exception being thrown,
+ /// but if true such values will be ignored
+ public RoleMappedSchema(ISchema schema, string label, string feature,
+ string group = null, string weight = null, string name = null,
+ IEnumerable> custom = null, bool opt = false)
+ : this(Contracts.CheckRef(schema, nameof(schema)), PredefinedRolesHelper(label, feature, group, weight, name, custom), opt)
{
- Contracts.CheckValue(schema, nameof(schema));
- Contracts.CheckValue(roles, nameof(roles));
- return new RoleMappedSchema(schema, MapFromNamesOpt(schema, roles));
+ Contracts.CheckValueOrNull(label);
+ Contracts.CheckValueOrNull(feature);
+ Contracts.CheckValueOrNull(group);
+ Contracts.CheckValueOrNull(weight);
+ Contracts.CheckValueOrNull(name);
+ Contracts.CheckValueOrNull(custom);
}
}
@@ -415,12 +473,13 @@ public sealed class RoleMappedData
///
/// The data.
///
- public readonly IDataView Data;
+ public IDataView Data { get; }
///
- /// The role mapped schema. Note that Schema.Schema is guaranteed to be the same as Data.Schema.
+ /// The role mapped schema. Note that 's is
+ /// guaranteed to be the same as 's .
///
- public readonly RoleMappedSchema Schema;
+ public RoleMappedSchema Schema { get; }
private RoleMappedData(IDataView data, RoleMappedSchema schema)
{
@@ -432,45 +491,61 @@ private RoleMappedData(IDataView data, RoleMappedSchema schema)
}
///
- /// Creates a RoleMappedData from the given data with no column role assignments.
- ///
- public static RoleMappedData Create(IDataView data)
- {
- Contracts.CheckValue(data, nameof(data));
- return new RoleMappedData(data, RoleMappedSchema.Create(data.Schema));
- }
-
- ///
- /// Creates a RoleMappedData from the given schema and role/column-name pairs.
- /// This skips null or empty column-names.
+ /// Constructor given a data view, and mapping pairs of roles to columns in the data view's schema.
+ /// This skips null or empty column-names. It will also skip column-names that are not
+ /// found in the schema if is true.
///
- public static RoleMappedData Create(IDataView data, params KeyValuePair[] roles)
+ /// The data over which roles are defined
+ /// Whether to consider the column names specified "optional" or not. If false then any non-empty
+ /// values for the column names that does not appear in 's schema will result in an exception being thrown,
+ /// but if true such values will be ignored
+ /// The column role to column name mappings
+ public RoleMappedData(IDataView data, bool opt = false, params KeyValuePair[] roles)
+ : this(Contracts.CheckRef(data, nameof(data)), new RoleMappedSchema(data.Schema, Contracts.CheckRef(roles, nameof(roles)), opt))
{
- Contracts.CheckValue(data, nameof(data));
- Contracts.CheckValue(roles, nameof(roles));
- return new RoleMappedData(data, RoleMappedSchema.Create(data.Schema, roles));
}
///
- /// Creates a RoleMappedData from the given schema and role/column-name pairs.
- /// This skips null or empty column-names.
+ /// Constructor given a data view, and mapping pairs of roles to columns in the data view's schema.
+ /// This skips null or empty column-names. It will also skip column-names that are not
+ /// found in the schema if is true.
///
- public static RoleMappedData Create(IDataView data, IEnumerable> roles)
+ /// The schema over which roles are defined
+ /// The column role to column name mappings
+ /// Whether to consider the column names specified "optional" or not. If false then any non-empty
+ /// values for the column names that does not appear in 's schema will result in an exception being thrown,
+ /// but if true such values will be ignored
+ public RoleMappedData(IDataView data, IEnumerable> roles, bool opt = false)
+ : this(Contracts.CheckRef(data, nameof(data)), new RoleMappedSchema(data.Schema, Contracts.CheckRef(roles, nameof(roles)), opt))
{
- Contracts.CheckValue(data, nameof(data));
- Contracts.CheckValue(roles, nameof(roles));
- return new RoleMappedData(data, RoleMappedSchema.Create(data.Schema, roles));
}
///
- /// Creates a RoleMappedData from the given schema and role/column-name pairs.
- /// This skips null or empty column-names, or column-names that are not found in the schema.
+ /// Convenience constructor for role-mappings over the commonly used roles. Note that if any column name specified
+ /// is null or whitespace, it is ignored.
///
- public static RoleMappedData CreateOpt(IDataView data, IEnumerable> roles)
+ /// The data over which roles are defined
+ /// The column name that will be mapped to the role
+ /// The column name that will be mapped to the role
+ /// The column name that will be mapped to the role
+ /// The column name that will be mapped to the role
+ /// The column name that will be mapped to the role
+ /// Any additional desired custom column role mappings
+ /// Whether to consider the column names specified "optional" or not. If false then any non-empty
+ /// values for the column names that does not appear in 's schema will result in an exception being thrown,
+ /// but if true such values will be ignored
+ public RoleMappedData(IDataView data, string label, string feature,
+ string group = null, string weight = null, string name = null,
+ IEnumerable> custom = null, bool opt = false)
+ : this(Contracts.CheckRef(data, nameof(data)),
+ new RoleMappedSchema(data.Schema, label, feature, group, weight, name, custom, opt))
{
- Contracts.CheckValue(data, nameof(data));
- Contracts.CheckValue(roles, nameof(roles));
- return new RoleMappedData(data, RoleMappedSchema.CreateOpt(data.Schema, roles));
+ Contracts.CheckValueOrNull(label);
+ Contracts.CheckValueOrNull(feature);
+ Contracts.CheckValueOrNull(group);
+ Contracts.CheckValueOrNull(weight);
+ Contracts.CheckValueOrNull(name);
+ Contracts.CheckValueOrNull(custom);
}
}
}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
index fc78e72c53..23b1c601c9 100644
--- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
@@ -254,7 +254,7 @@ private RoleMappedData ApplyAllTransformsToData(IHostEnvironment env, IChannel c
RoleMappedData srcData, IDataView marker)
{
var pipe = ApplyTransformUtils.ApplyAllTransformsToData(env, srcData.Data, dstData, marker);
- return RoleMappedData.Create(pipe, srcData.Schema.GetColumnRoleNames());
+ return new RoleMappedData(pipe, srcData.Schema.GetColumnRoleNames());
}
///
@@ -277,7 +277,7 @@ private RoleMappedData CreateRoleMappedData(IHostEnvironment env, IChannel ch, I
// Training pipe and examples.
var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn);
- return TrainUtils.CreateExamples(data, label, features, group, weight, name, customCols);
+ return new RoleMappedData(data, label, features, group, weight, name, customCols);
}
private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output)
@@ -568,7 +568,7 @@ private FoldResult RunFold(int fold)
{
using (var file = host.CreateOutputFile(modelFileName))
{
- var rmd = RoleMappedData.Create(
+ var rmd = new RoleMappedData(
CompositeDataLoader.ApplyTransform(host, _loader, null, null,
(e, newSource) => ApplyTransformUtils.ApplyAllTransformsToData(e, trainData.Data, newSource)),
trainData.Schema.GetColumnRoleNames());
@@ -581,17 +581,17 @@ private FoldResult RunFold(int fold)
if (!evalComp.IsGood())
evalComp = EvaluateUtils.GetEvaluatorType(ch, scorePipe.Schema);
var eval = evalComp.CreateInstance(host);
- // Note that this doesn't require the provided columns to exist (because of "Opt").
+ // Note that this doesn't require the provided columns to exist (because of the "opt" parameter).
// We don't normally expect the scorer to drop columns, but if it does, we should not require
// all the columns in the test pipeline to still be present.
- var dataEval = RoleMappedData.CreateOpt(scorePipe, testData.Schema.GetColumnRoleNames());
+ var dataEval = new RoleMappedData(scorePipe, testData.Schema.GetColumnRoleNames(), opt: true);
var dict = eval.Evaluate(dataEval);
RoleMappedData perInstance = null;
if (_savePerInstance)
{
var perInst = eval.GetPerInstanceMetrics(dataEval);
- perInstance = RoleMappedData.CreateOpt(perInst, dataEval.Schema.GetColumnRoleNames());
+ perInstance = new RoleMappedData(perInst, dataEval.Schema.GetColumnRoleNames(), opt: true);
}
ch.Done();
return new FoldResult(dict, dataEval.Schema.Schema, perInstance, trainData.Schema);
diff --git a/src/Microsoft.ML.Data/Commands/DataCommand.cs b/src/Microsoft.ML.Data/Commands/DataCommand.cs
index 435c25bf5b..2a62d78901 100644
--- a/src/Microsoft.ML.Data/Commands/DataCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/DataCommand.cs
@@ -305,7 +305,7 @@ protected void LoadModelObjects(
// can be loaded with no data at all, to get their schemas.
if (trainPipe == null)
trainPipe = ModelFileUtils.LoadLoader(Host, rep, new MultiFileSource(null), loadTransforms: true);
- trainSchema = RoleMappedSchema.Create(trainPipe.Schema, trainRoleMappings);
+ trainSchema = new RoleMappedSchema(trainPipe.Schema, trainRoleMappings);
}
// If the role mappings are null, an alternative would be to fail. However the idea
// is that the scorer should always still succeed, although perhaps with reduced
diff --git a/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs b/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs
index d0e066d789..cd2eb464af 100644
--- a/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs
@@ -158,7 +158,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
evalComp = EvaluateUtils.GetEvaluatorType(ch, input.Schema);
var eval = evalComp.CreateInstance(env);
- var data = TrainUtils.CreateExamples(input, label, null, group, weight, null, customCols);
+ var data = new RoleMappedData(input, label, null, group, weight, null, customCols);
return eval.GetPerInstanceMetrics(data);
}
}
@@ -236,7 +236,7 @@ private void RunCore(IChannel ch)
if (!evalComp.IsGood())
evalComp = EvaluateUtils.GetEvaluatorType(ch, view.Schema);
var evaluator = evalComp.CreateInstance(Host);
- var data = TrainUtils.CreateExamples(view, label, null, group, weight, name, customCols);
+ var data = new RoleMappedData(view, label, null, group, weight, name, customCols);
var metrics = evaluator.Evaluate(data);
MetricWriter.PrintWarnings(ch, metrics);
evaluator.PrintFoldResults(ch, metrics);
@@ -248,7 +248,7 @@ private void RunCore(IChannel ch)
if (!string.IsNullOrWhiteSpace(Args.OutputDataFile))
{
var perInst = evaluator.GetPerInstanceMetrics(data);
- var perInstData = TrainUtils.CreateExamples(perInst, label, null, group, weight, name, customCols);
+ var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols);
var idv = evaluator.GetPerInstanceDataViewToSave(perInstData);
MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, idv);
}
diff --git a/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs b/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs
index 505d3a28e6..e1057d18b4 100644
--- a/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs
@@ -219,7 +219,7 @@ public static void LoadModel(IHostEnvironment env, Stream modelStream, bool load
if (roles != null)
{
var emptyView = ModelFileUtils.LoadPipeline(env, rep, new MultiFileSource(null));
- schema = RoleMappedSchema.CreateOpt(emptyView.Schema, roles);
+ schema = new RoleMappedSchema(emptyView.Schema, roles, opt: true);
}
else
{
diff --git a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs
index 02d655b48b..c353e4a4ec 100644
--- a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs
@@ -97,10 +97,7 @@ private void RunCore(IChannel ch)
ch.Trace("Creating loader");
- IPredictor predictor;
- IDataLoader loader;
- RoleMappedSchema trainSchema;
- LoadModelObjects(ch, true, out predictor, true, out trainSchema, out loader);
+ LoadModelObjects(ch, true, out var predictor, true, out var trainSchema, out var loader);
ch.AssertValue(predictor);
ch.AssertValueOrNull(trainSchema);
ch.AssertValue(loader);
@@ -116,7 +113,7 @@ private void RunCore(IChannel ch)
string group = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema,
nameof(Args.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId);
var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn);
- var schema = TrainUtils.CreateRoleMappedSchemaOpt(loader.Schema, feat, group, customCols);
+ var schema = new RoleMappedSchema(loader.Schema, label: null, feature: feat, group: group, custom: customCols, opt: true);
var mapper = bindable.Bind(Host, schema);
if (!scorer.IsGood())
@@ -153,22 +150,20 @@ private void RunCore(IChannel ch)
Args.OutputAllColumns == true || Utils.Size(Args.OutputColumn) == 0;
if (Args.OutputAllColumns == true && Utils.Size(Args.OutputColumn) != 0)
- ch.Warning("outputAllColumns=+ always writes all columns irrespective of outputColumn specified.");
+ ch.Warning(nameof(Args.OutputAllColumns) + "=+ always writes all columns irrespective of " + nameof(Args.OutputColumn) + " specified.");
if (!outputAllColumns && Utils.Size(Args.OutputColumn) != 0)
{
foreach (var outCol in Args.OutputColumn)
{
- int dummyColIndex;
- if (!loader.Schema.TryGetColumnIndex(outCol, out dummyColIndex))
+ if (!loader.Schema.TryGetColumnIndex(outCol, out int dummyColIndex))
throw ch.ExceptUserArg(nameof(Arguments.OutputColumn), "Column '{0}' not found.", outCol);
}
}
- int colMax;
uint maxScoreId = 0;
if (!outputAllColumns)
- maxScoreId = loader.Schema.GetMaxMetadataKind(out colMax, MetadataUtils.Kinds.ScoreColumnSetId);
+ maxScoreId = loader.Schema.GetMaxMetadataKind(out int colMax, MetadataUtils.Kinds.ScoreColumnSetId);
ch.Assert(outputAllColumns || maxScoreId > 0); // score set IDs are one-based
var cols = new List();
for (int i = 0; i < loader.Schema.ColumnCount; i++)
@@ -211,12 +206,12 @@ private bool ShouldAddColumn(ISchema schema, int i, uint scoreSet, bool outputNa
{
switch (schema.GetColumnName(i))
{
- case "Label":
- case "Name":
- case "Names":
- return true;
- default:
- break;
+ case "Label":
+ case "Name":
+ case "Names":
+ return true;
+ default:
+ break;
}
}
if (Args.OutputColumn != null && Array.FindIndex(Args.OutputColumn, schema.GetColumnName(i).Equals) >= 0)
@@ -229,8 +224,7 @@ public static class ScoreUtils
{
public static IDataScorerTransform GetScorer(IPredictor predictor, RoleMappedData data, IHostEnvironment env, RoleMappedSchema trainSchema)
{
- ISchemaBoundMapper mapper;
- var sc = GetScorerComponentAndMapper(predictor, null, data.Schema, env, out mapper);
+ var sc = GetScorerComponentAndMapper(predictor, null, data.Schema, env, out var mapper);
return sc.CreateInstance(env, data.Data, mapper, trainSchema);
}
@@ -247,9 +241,8 @@ public static IDataScorerTransform GetScorer(SubComponent GetScorerC
Contracts.AssertValue(mapper);
string loadName = null;
- DvText scoreKind = default(DvText);
+ DvText scoreKind = default;
if (mapper.OutputSchema.ColumnCount > 0 &&
mapper.OutputSchema.TryGetMetadata(TextType.Instance, MetadataUtils.Kinds.ScoreColumnKind, 0, ref scoreKind) &&
scoreKind.HasChars)
@@ -311,10 +304,8 @@ public static ISchemaBindableMapper GetSchemaBindableMapper(IHostEnvironment env
env.CheckValue(predictor, nameof(predictor));
env.CheckValueOrNull(scorerSettings);
- ISchemaBindableMapper bindable;
-
// See if we can instantiate a mapper using scorer arguments.
- if (scorerSettings.IsGood() && TryCreateBindableFromScorer(env, predictor, scorerSettings, out bindable))
+ if (scorerSettings.IsGood() && TryCreateBindableFromScorer(env, predictor, scorerSettings, out var bindable))
return bindable;
// The easy case is that the predictor implements the interface.
diff --git a/src/Microsoft.ML.Data/Commands/TestCommand.cs b/src/Microsoft.ML.Data/Commands/TestCommand.cs
index 79e7bd5458..d0ebbd5a05 100644
--- a/src/Microsoft.ML.Data/Commands/TestCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/TestCommand.cs
@@ -114,7 +114,7 @@ private void RunCore(IChannel ch)
if (!evalComp.IsGood())
evalComp = EvaluateUtils.GetEvaluatorType(ch, scorePipe.Schema);
var evaluator = evalComp.CreateInstance(Host);
- var data = TrainUtils.CreateExamples(scorePipe, label, null, group, weight, name, customCols);
+ var data = new RoleMappedData(scorePipe, label, null, group, weight, name, customCols);
var metrics = evaluator.Evaluate(data);
MetricWriter.PrintWarnings(ch, metrics);
evaluator.PrintFoldResults(ch, metrics);
@@ -128,7 +128,7 @@ private void RunCore(IChannel ch)
if (!string.IsNullOrWhiteSpace(Args.OutputDataFile))
{
var perInst = evaluator.GetPerInstanceMetrics(data);
- var perInstData = TrainUtils.CreateExamples(perInst, label, null, group, weight, name, customCols);
+ var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols);
var idv = evaluator.GetPerInstanceDataViewToSave(perInstData);
MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, idv);
}
diff --git a/src/Microsoft.ML.Data/Commands/TrainCommand.cs b/src/Microsoft.ML.Data/Commands/TrainCommand.cs
index b5b23f9020..b5e3964157 100644
--- a/src/Microsoft.ML.Data/Commands/TrainCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/TrainCommand.cs
@@ -157,7 +157,7 @@ private void RunCore(IChannel ch, string cmd)
ch.Trace("Binding columns");
var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn);
- var data = TrainUtils.CreateExamples(view, label, feature, group, weight, name, customCols);
+ var data = new RoleMappedData(view, label, feature, group, weight, name, customCols);
// REVIEW: Unify the code that creates validation examples in Train, TrainTest and CV commands.
RoleMappedData validData = null;
@@ -172,7 +172,7 @@ private void RunCore(IChannel ch, string cmd)
ch.Trace("Constructing the validation pipeline");
IDataView validPipe = CreateRawLoader(dataFile: Args.ValidationFile);
validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, view, validPipe);
- validData = RoleMappedData.Create(validPipe, data.Schema.GetColumnRoleNames());
+ validData = new RoleMappedData(validPipe, data.Schema.GetColumnRoleNames());
}
}
@@ -550,7 +550,7 @@ private static bool AddCacheIfWanted(IHostEnvironment env, IChannel ch, ITrainer
var prefetch = data.Schema.GetColumnRoles().Select(kc => kc.Value.Index).ToArray();
var cacheView = new CacheDataView(env, data.Data, prefetch);
// Because the prefetching worked, we know that these are valid columns.
- data = RoleMappedData.Create(cacheView, data.Schema.GetColumnRoleNames());
+ data = new RoleMappedData(cacheView, data.Schema.GetColumnRoleNames());
}
else
ch.Trace("Not caching");
@@ -571,97 +571,5 @@ public static IEnumerable> CheckAndGenerateCust
}
return customColumnArg.Select(kindName => new ColumnRole(kindName.Key).Bind(kindName.Value));
}
-
- ///
- /// Given a schema and a bunch of column names, create the BoundSchema object. Any or all of the column
- /// names may be null or whitespace, in which case they are ignored. Any columns that are specified but not
- /// valid columns of the schema are also ignored.
- ///
- public static RoleMappedSchema CreateRoleMappedSchemaOpt(ISchema schema, string feature, string group, IEnumerable> custom = null)
- {
- Contracts.CheckValueOrNull(feature);
- Contracts.CheckValueOrNull(custom);
-
- var list = new List>();
- if (!string.IsNullOrWhiteSpace(feature))
- list.Add(ColumnRole.Feature.Bind(feature));
- if (!string.IsNullOrWhiteSpace(group))
- list.Add(ColumnRole.Group.Bind(group));
- if (custom != null)
- list.AddRange(custom);
-
- return RoleMappedSchema.CreateOpt(schema, list);
- }
-
- ///
- /// Given a view and a bunch of column names, create the RoleMappedData object. Any or all of the column
- /// names may be null or whitespace, in which case they are ignored. Any columns that are specified must
- /// be valid columns of the schema.
- ///
- public static RoleMappedData CreateExamples(IDataView view, string label, string feature,
- string group = null, string weight = null, string name = null,
- IEnumerable> custom = null)
- {
- Contracts.CheckValueOrNull(label);
- Contracts.CheckValueOrNull(feature);
- Contracts.CheckValueOrNull(group);
- Contracts.CheckValueOrNull(weight);
- Contracts.CheckValueOrNull(name);
- Contracts.CheckValueOrNull(custom);
-
- var list = new List>();
- if (!string.IsNullOrWhiteSpace(label))
- list.Add(ColumnRole.Label.Bind(label));
- if (!string.IsNullOrWhiteSpace(feature))
- list.Add(ColumnRole.Feature.Bind(feature));
- if (!string.IsNullOrWhiteSpace(group))
- list.Add(ColumnRole.Group.Bind(group));
- if (!string.IsNullOrWhiteSpace(weight))
- list.Add(ColumnRole.Weight.Bind(weight));
- if (!string.IsNullOrWhiteSpace(name))
- list.Add(ColumnRole.Name.Bind(name));
- if (custom != null)
- list.AddRange(custom);
-
- return RoleMappedData.Create(view, list);
- }
-
- ///
- /// Given a view and a bunch of column names, create the RoleMappedData object. Any or all of the column
- /// names may be null or whitespace, in which case they are ignored. Any columns that are specified but not
- /// valid columns of the schema are also ignored.
- ///
- public static RoleMappedData CreateExamplesOpt(IDataView view, string label, string feature,
- string group = null, string weight = null, string name = null,
- IEnumerable> custom = null)
- {
- Contracts.CheckValueOrNull(label);
- Contracts.CheckValueOrNull(feature);
- Contracts.CheckValueOrNull(group);
- Contracts.CheckValueOrNull(weight);
- Contracts.CheckValueOrNull(name);
- Contracts.CheckValueOrNull(custom);
-
- var list = new List>();
- if (!string.IsNullOrWhiteSpace(label))
- list.Add(ColumnRole.Label.Bind(label));
- if (!string.IsNullOrWhiteSpace(feature))
- list.Add(ColumnRole.Feature.Bind(feature));
- if (!string.IsNullOrWhiteSpace(group))
- list.Add(ColumnRole.Group.Bind(group));
- if (!string.IsNullOrWhiteSpace(weight))
- list.Add(ColumnRole.Weight.Bind(weight));
- if (!string.IsNullOrWhiteSpace(name))
- list.Add(ColumnRole.Name.Bind(name));
- if (custom != null)
- list.AddRange(custom);
-
- return RoleMappedData.CreateOpt(view, list);
- }
-
- private static KeyValuePair Pair(ColumnRole kind, T value)
- {
- return new KeyValuePair(kind, value);
- }
}
}
diff --git a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs
index f6ffa772f9..7c4249c6ee 100644
--- a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs
@@ -147,7 +147,7 @@ private void RunCore(IChannel ch, string cmd)
ch.Trace("Binding columns");
var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn);
- var data = TrainUtils.CreateExamples(trainPipe, label, features, group, weight, name, customCols);
+ var data = new RoleMappedData(trainPipe, label, features, group, weight, name, customCols);
RoleMappedData validData = null;
if (!string.IsNullOrWhiteSpace(Args.ValidationFile))
@@ -161,7 +161,7 @@ private void RunCore(IChannel ch, string cmd)
ch.Trace("Constructing the validation pipeline");
IDataView validPipe = CreateRawLoader(dataFile: Args.ValidationFile);
validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, validPipe);
- validData = RoleMappedData.Create(validPipe, data.Schema.GetColumnRoleNames());
+ validData = new RoleMappedData(validPipe, data.Schema.GetColumnRoleNames());
}
}
@@ -189,8 +189,8 @@ private void RunCore(IChannel ch, string cmd)
if (!evalComp.IsGood())
evalComp = EvaluateUtils.GetEvaluatorType(ch, scorePipe.Schema);
var evaluator = evalComp.CreateInstance(Host);
- var dataEval = TrainUtils.CreateExamplesOpt(scorePipe, label, features,
- group, weight, name, customCols);
+ var dataEval = new RoleMappedData(scorePipe, label, features,
+ group, weight, name, customCols, opt: true);
var metrics = evaluator.Evaluate(dataEval);
MetricWriter.PrintWarnings(ch, metrics);
evaluator.PrintFoldResults(ch, metrics);
@@ -204,7 +204,7 @@ private void RunCore(IChannel ch, string cmd)
if (!string.IsNullOrWhiteSpace(Args.OutputDataFile))
{
var perInst = evaluator.GetPerInstanceMetrics(dataEval);
- var perInstData = TrainUtils.CreateExamples(perInst, label, null, group, weight, name, customCols);
+ var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols);
var idv = evaluator.GetPerInstanceDataViewToSave(perInstData);
MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, idv);
}
diff --git a/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs b/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs
index 3dd16f141b..f08d52fe85 100644
--- a/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs
+++ b/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs
@@ -364,9 +364,7 @@ private sealed class Dense : FeatureNameCollection
private readonly int _count;
private readonly string[] _names;
- private readonly RoleMappedSchema _schema;
-
- public override RoleMappedSchema Schema => _schema;
+ public override RoleMappedSchema Schema { get; }
public Dense(int count, string[] names)
{
@@ -379,8 +377,9 @@ public Dense(int count, string[] names)
if (size > 0)
Array.Copy(names, _names, size);
- _schema = RoleMappedSchema.Create(new FeatureNameCollectionSchema(this),
- RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Feature, RoleMappedSchema.ColumnRole.Feature.Value));
+ // REVIEW: This seems wrong. The default feature column name is "Features" yet the role is named "Feature".
+ Schema = new RoleMappedSchema(new FeatureNameCollectionSchema(this),
+ roles: RoleMappedSchema.ColumnRole.Feature.Bind(RoleMappedSchema.ColumnRole.Feature.Value));
}
public override int Count => _count;
@@ -470,8 +469,9 @@ public Sparse(int count, string[] names, int cnn)
}
Contracts.Assert(cv == cnn);
- _schema = RoleMappedSchema.Create(new FeatureNameCollectionSchema(this),
- RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Feature, RoleMappedSchema.ColumnRole.Feature.Value));
+ // REVIEW: This seems wrong. The default feature column name is "Features" yet the role is named "Feature".
+ _schema = new RoleMappedSchema(new FeatureNameCollectionSchema(this),
+ roles: RoleMappedSchema.ColumnRole.Feature.Bind(RoleMappedSchema.ColumnRole.Feature.Value));
}
///
diff --git a/src/Microsoft.ML.Data/EntryPoints/InputBase.cs b/src/Microsoft.ML.Data/EntryPoints/InputBase.cs
index 5583c66df0..bc45a929d4 100644
--- a/src/Microsoft.ML.Data/EntryPoints/InputBase.cs
+++ b/src/Microsoft.ML.Data/EntryPoints/InputBase.cs
@@ -146,7 +146,7 @@ public static TOut Train(IHost host, TArg input,
TrainUtils.AddNormalizerIfNeeded(host, ch, trainer, ref view, feature, input.NormalizeFeatures);
ch.Trace("Binding columns");
- var roleMappedData = TrainUtils.CreateExamples(view, label, feature, group, weight, name, custom);
+ var roleMappedData = new RoleMappedData(view, label, feature, group, weight, name, custom);
RoleMappedData cachedRoleMappedData = roleMappedData;
Cache.CachingType? cachingType = null;
@@ -184,7 +184,7 @@ public static TOut Train(IHost host, TArg input,
Data = roleMappedData.Data,
Caching = cachingType.Value
}).OutputData;
- cachedRoleMappedData = RoleMappedData.Create(cacheView, roleMappedData.Schema.GetColumnRoleNames());
+ cachedRoleMappedData = new RoleMappedData(cacheView, roleMappedData.Schema.GetColumnRoleNames());
}
var predictor = TrainUtils.Train(host, ch, cachedRoleMappedData, trainer, "Train", calibrator, maxCalibrationExamples);
diff --git a/src/Microsoft.ML.Data/EntryPoints/PredictorModel.cs b/src/Microsoft.ML.Data/EntryPoints/PredictorModel.cs
index af726fa758..4b474f847c 100644
--- a/src/Microsoft.ML.Data/EntryPoints/PredictorModel.cs
+++ b/src/Microsoft.ML.Data/EntryPoints/PredictorModel.cs
@@ -80,7 +80,7 @@ public void Save(IHostEnvironment env, Stream stream)
// Create the chain of transforms for saving.
IDataView data = new EmptyDataView(env, _transformModel.InputSchema);
data = _transformModel.Apply(env, data);
- var roleMappedData = RoleMappedData.CreateOpt(data, _roleMappings);
+ var roleMappedData = new RoleMappedData(data, _roleMappings, opt: true);
TrainUtils.SaveModel(env, ch, stream, _predictor, roleMappedData);
ch.Done();
@@ -102,7 +102,7 @@ public void PrepareData(IHostEnvironment env, IDataView input, out RoleMappedDat
env.CheckValue(input, nameof(input));
input = _transformModel.Apply(env, input);
- roleMappedData = RoleMappedData.CreateOpt(input, _roleMappings);
+ roleMappedData = new RoleMappedData(input, _roleMappings, opt: true);
predictor = _predictor;
}
@@ -141,7 +141,7 @@ public RoleMappedSchema GetTrainingSchema(IHostEnvironment env)
{
Contracts.CheckValue(env, nameof(env));
var predInput = _transformModel.Apply(env, new EmptyDataView(env, _transformModel.InputSchema));
- var trainRms = RoleMappedSchema.CreateOpt(predInput.Schema, _roleMappings);
+ var trainRms = new RoleMappedSchema(predInput.Schema, _roleMappings, opt: true);
return trainRms;
}
}
diff --git a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs
index 39a5f31c38..74f7ca0068 100644
--- a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs
@@ -796,7 +796,7 @@ public static CommonOutputs.CommonEvaluateOutput AnomalyDetection(IHostEnvironme
string name;
MatchColumns(host, input, out label, out weight, out name);
var evaluator = new AnomalyDetectionMamlEvaluator(host, input);
- var data = TrainUtils.CreateExamples(input.Data, label, null, null, weight, name);
+ var data = new RoleMappedData(input.Data, label, null, null, weight, name);
var metrics = evaluator.Evaluate(data);
var warnings = ExtractWarnings(host, metrics);
diff --git a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs
index 90078da9ee..7161f66439 100644
--- a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs
@@ -1455,7 +1455,7 @@ public static CommonOutputs.ClassificationEvaluateOutput Binary(IHostEnvironment
string name;
MatchColumns(host, input, out label, out weight, out name);
var evaluator = new BinaryClassifierMamlEvaluator(host, input);
- var data = TrainUtils.CreateExamples(input.Data, label, null, null, weight, name);
+ var data = new RoleMappedData(input.Data, label, null, null, weight, name);
var metrics = evaluator.Evaluate(data);
var warnings = ExtractWarnings(host, metrics);
diff --git a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs
index 907760649f..bec1ac144a 100644
--- a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs
@@ -776,7 +776,7 @@ public ClusteringMamlEvaluator(IHostEnvironment env, Arguments args)
string feat = EvaluateUtils.GetColName(_featureCol, schema.Feature, DefaultColumnNames.Features);
if (!schema.Schema.TryGetColumnIndex(feat, out int featCol))
throw Host.ExceptUserArg(nameof(Arguments.FeatureColumn), "Features column '{0}' not found", feat);
- yield return RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Feature, feat);
+ yield return RoleMappedSchema.ColumnRole.Feature.Bind(feat);
}
}
@@ -867,7 +867,7 @@ public static CommonOutputs.CommonEvaluateOutput Clustering(IHostEnvironment env
nameof(ClusteringMamlEvaluator.Arguments.FeatureColumn),
input.FeatureColumn, DefaultColumnNames.Features);
var evaluator = new ClusteringMamlEvaluator(host, input);
- var data = TrainUtils.CreateExamples(input.Data, label, features, null, weight, name);
+ var data = new RoleMappedData(input.Data, label, features, null, weight, name);
var metrics = evaluator.Evaluate(data);
var warnings = ExtractWarnings(host, metrics);
diff --git a/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs
index bbb53ba631..1a5a2177f3 100644
--- a/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs
@@ -95,7 +95,7 @@ protected MamlEvaluatorBase(ArgumentsBase args, IHostEnvironment env, string sco
public Dictionary Evaluate(RoleMappedData data)
{
- data = RoleMappedData.Create(data.Data, GetInputColumnRoles(data.Schema, needStrat: true));
+ data = new RoleMappedData(data.Data, GetInputColumnRoles(data.Schema, needStrat: true));
return Evaluator.Evaluate(data);
}
@@ -108,7 +108,7 @@ public Dictionary Evaluate(RoleMappedData data)
: StratCols.Select(col => RoleMappedSchema.CreatePair(Strat, col));
if (needName && schema.Name != null)
- roles = roles.Prepend(RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Name, schema.Name.Name));
+ roles = roles.Prepend(RoleMappedSchema.ColumnRole.Name.Bind(schema.Name.Name));
return roles.Concat(GetInputColumnRolesCore(schema));
}
@@ -126,12 +126,12 @@ public Dictionary Evaluate(RoleMappedData data)
yield return RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreInfo.Name);
// Get the label column information.
- string lab = EvaluateUtils.GetColName(LabelCol, schema.Label, DefaultColumnNames.Label);
- yield return RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, lab);
+ string label = EvaluateUtils.GetColName(LabelCol, schema.Label, DefaultColumnNames.Label);
+ yield return RoleMappedSchema.ColumnRole.Label.Bind(label);
- var weight = EvaluateUtils.GetColName(WeightCol, schema.Weight, null);
+ string weight = EvaluateUtils.GetColName(WeightCol, schema.Weight, null);
if (!string.IsNullOrEmpty(weight))
- yield return RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Weight, weight);
+ yield return RoleMappedSchema.ColumnRole.Weight.Bind(weight);
}
public virtual IEnumerable GetOverallMetricColumns()
@@ -203,7 +203,7 @@ public IDataTransform GetPerInstanceMetrics(RoleMappedData scoredData)
Host.AssertValue(scoredData);
var schema = scoredData.Schema;
- var dataEval = RoleMappedData.Create(scoredData.Data, GetInputColumnRoles(schema));
+ var dataEval = new RoleMappedData(scoredData.Data, GetInputColumnRoles(schema));
return Evaluator.GetPerInstanceMetrics(dataEval);
}
@@ -260,7 +260,7 @@ protected virtual IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMap
public IDataView GetPerInstanceDataViewToSave(RoleMappedData perInstance)
{
Host.CheckValue(perInstance, nameof(perInstance));
- var data = RoleMappedData.Create(perInstance.Data, GetInputColumnRoles(perInstance.Schema, needName: true));
+ var data = new RoleMappedData(perInstance.Data, GetInputColumnRoles(perInstance.Schema, needName: true));
return WrapPerInstance(data);
}
diff --git a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs
index 3b5e5fc910..d57835b168 100644
--- a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs
@@ -784,7 +784,7 @@ public static CommonOutputs.CommonEvaluateOutput MultiOutputRegression(IHostEnvi
string name;
MatchColumns(host, input, out label, out weight, out name);
var evaluator = new MultiOutputRegressionMamlEvaluator(host, input);
- var data = TrainUtils.CreateExamples(input.Data, label, null, null, weight, name);
+ var data = new RoleMappedData(input.Data, label, null, null, weight, name);
var metrics = evaluator.Evaluate(data);
var warnings = ExtractWarnings(host, metrics);
diff --git a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs
index 94c920c3c6..fd23e7c3b0 100644
--- a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs
@@ -1070,7 +1070,7 @@ public static CommonOutputs.ClassificationEvaluateOutput MultiClass(IHostEnviron
MatchColumns(host, input, out string label, out string weight, out string name);
var evaluator = new MultiClassMamlEvaluator(host, input);
- var data = TrainUtils.CreateExamples(input.Data, label, null, null, weight, name);
+ var data = new RoleMappedData(input.Data, label, null, null, weight, name);
var metrics = evaluator.Evaluate(data);
var warnings = ExtractWarnings(host, metrics);
diff --git a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs
index 6d61f6b965..fb8d9c1249 100644
--- a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs
@@ -556,7 +556,7 @@ public static CommonOutputs.CommonEvaluateOutput QuantileRegression(IHostEnviron
string name;
MatchColumns(host, input, out label, out weight, out name);
var evaluator = new QuantileRegressionMamlEvaluator(host, input);
- var data = TrainUtils.CreateExamples(input.Data, label, null, null, weight, name);
+ var data = new RoleMappedData(input.Data, label, null, null, weight, name);
var metrics = evaluator.Evaluate(data);
var warnings = ExtractWarnings(host, metrics);
diff --git a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs
index ae9c2a8594..cdf3f9c57e 100644
--- a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs
@@ -851,7 +851,7 @@ public RankerMamlEvaluator(IHostEnvironment env, Arguments args)
{
var cols = base.GetInputColumnRolesCore(schema);
var groupIdCol = EvaluateUtils.GetColName(_groupIdCol, schema.Group, DefaultColumnNames.GroupId);
- return cols.Prepend(RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Group, groupIdCol));
+ return cols.Prepend(RoleMappedSchema.ColumnRole.Group.Bind(groupIdCol));
}
protected override void PrintAdditionalMetricsCore(IChannel ch, Dictionary[] metrics)
@@ -1039,7 +1039,7 @@ public static CommonOutputs.CommonEvaluateOutput Ranking(IHostEnvironment env, R
nameof(RankerMamlEvaluator.Arguments.GroupIdColumn),
input.GroupIdColumn, DefaultColumnNames.GroupId);
var evaluator = new RankerMamlEvaluator(host, input);
- var data = TrainUtils.CreateExamples(input.Data, label, null, groupId, weight, name);
+ var data = new RoleMappedData(input.Data, label, null, groupId, weight, name);
var metrics = evaluator.Evaluate(data);
var warnings = ExtractWarnings(host, metrics);
diff --git a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs
index 4292e13b8d..1804ce429f 100644
--- a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs
@@ -354,7 +354,7 @@ public static CommonOutputs.CommonEvaluateOutput Regression(IHostEnvironment env
string name;
MatchColumns(host, input, out label, out weight, out name);
var evaluator = new RegressionMamlEvaluator(host, input);
- var data = TrainUtils.CreateExamples(input.Data, label, null, null, weight, name);
+ var data = new RoleMappedData(input.Data, label, null, null, weight, name);
var metrics = evaluator.Evaluate(data);
var warnings = ExtractWarnings(host, metrics);
diff --git a/src/Microsoft.ML.Data/Model/Pfa/SavePfaCommand.cs b/src/Microsoft.ML.Data/Model/Pfa/SavePfaCommand.cs
index 6c28dac997..dfec0913ca 100644
--- a/src/Microsoft.ML.Data/Model/Pfa/SavePfaCommand.cs
+++ b/src/Microsoft.ML.Data/Model/Pfa/SavePfaCommand.cs
@@ -147,13 +147,13 @@ private void Run(IChannel ch)
{
RoleMappedData data;
if (trainSchema != null)
- data = RoleMappedData.Create(end, trainSchema.GetColumnRoleNames());
+ data = new RoleMappedData(end, trainSchema.GetColumnRoleNames());
else
{
// We had a predictor, but no roles stored in the model. Just suppose
// default column names are OK, if present.
- data = TrainUtils.CreateExamplesOpt(end, DefaultColumnNames.Label,
- DefaultColumnNames.Features, DefaultColumnNames.GroupId, DefaultColumnNames.Weight, DefaultColumnNames.Name);
+ data = new RoleMappedData(end, DefaultColumnNames.Label,
+ DefaultColumnNames.Features, DefaultColumnNames.GroupId, DefaultColumnNames.Weight, DefaultColumnNames.Name, opt: true);
}
var scorePipe = ScoreUtils.GetScorer(rawPred, data, Host, trainSchema);
diff --git a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs
index 91c84a0734..a407873bac 100644
--- a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs
+++ b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs
@@ -70,7 +70,7 @@ private static Bindings Create(IHostEnvironment env, ISchemaBindableMapper binda
Contracts.AssertValue(roles);
Contracts.AssertValueOrNull(suffix);
- var mapper = bindable.Bind(env, RoleMappedSchema.Create(input, roles));
+ var mapper = bindable.Bind(env, new RoleMappedSchema(input, roles));
// We don't actually depend on this invariant, but if this assert fires it means the bindable
// did the wrong thing.
Contracts.Assert(mapper.InputSchema.Schema == input);
diff --git a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs
index fe69585b78..2fd039897a 100644
--- a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs
+++ b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs
@@ -117,7 +117,7 @@ public BindingsImpl ApplyToSchema(ISchema input, ISchemaBindableMapper bindable,
env.AssertValue(bindable);
string scoreCol = RowMapper.OutputSchema.GetColumnName(ScoreColumnIndex);
- var schema = RoleMappedSchema.Create(input, RowMapper.GetInputColumnRoles());
+ var schema = new RoleMappedSchema(input, RowMapper.GetInputColumnRoles());
// Checks compatibility of the predictor input types.
var mapper = bindable.Bind(env, schema);
@@ -148,7 +148,7 @@ public static BindingsImpl Create(ModelLoadContext ctx, ISchema input,
string scoreKind = ctx.LoadNonEmptyString();
string scoreCol = ctx.LoadNonEmptyString();
- var mapper = bindable.Bind(env, RoleMappedSchema.Create(input, roles));
+ var mapper = bindable.Bind(env, new RoleMappedSchema(input, roles));
var rowMapper = mapper as ISchemaBoundRowMapper;
env.CheckParam(rowMapper != null, nameof(bindable), "Bindable expected to be an " + nameof(ISchemaBindableMapper) + "!");
diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs b/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs
index 22fa9686ed..f3c560c8dc 100644
--- a/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs
@@ -225,7 +225,7 @@ public static bool CreateIfNeeded(IHostEnvironment env, ref RoleMappedData data,
env.AssertValue(featInfo); // Should be defined, if FEaturesAreNormalized returned a definite value.
var view = CreateMinMaxNormalizer(env, data.Data, name: featInfo.Name);
- data = RoleMappedData.Create(view, data.Schema.GetColumnRoleNames());
+ data = new RoleMappedData(view, data.Schema.GetColumnRoleNames());
return true;
}
diff --git a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs
index bdf1a36d41..bba2f58256 100644
--- a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs
@@ -181,7 +181,7 @@ private static RoleMappedData CreateDataFromArgs(IExceptionContext
var name = TrainUtils.MatchNameOrDefaultOrNull(ectx, schema, nameof(args.NameColumn), args.NameColumn,
DefaultColumnNames.Name);
var customCols = TrainUtils.CheckAndGenerateCustomColumns(ectx, args.CustomColumn);
- return TrainUtils.CreateExamples(input, label, feat, group, weight, name, customCols);
+ return new RoleMappedData(input, label, feat, group, weight, name, customCols);
}
}
}
diff --git a/src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs b/src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs
index 3e39a0008a..900778cd31 100644
--- a/src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs
+++ b/src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs
@@ -338,7 +338,7 @@ public static RoleMappedSchema LoadRoleMappedSchemaOrNull(IHostEnvironment env,
if (roleMappings == null)
return null;
var pipe = ModelFileUtils.LoadLoader(h, rep, new MultiFileSource(null), loadTransforms: true);
- return RoleMappedSchema.Create(pipe.Schema, roleMappings);
+ return new RoleMappedSchema(pipe.Schema, roleMappings);
}
///
diff --git a/src/Microsoft.ML.Ensemble/EnsembleUtils.cs b/src/Microsoft.ML.Ensemble/EnsembleUtils.cs
index d5134e48c9..6366321c48 100644
--- a/src/Microsoft.ML.Ensemble/EnsembleUtils.cs
+++ b/src/Microsoft.ML.Ensemble/EnsembleUtils.cs
@@ -33,7 +33,7 @@ public static RoleMappedData SelectFeatures(IHost host, RoleMappedData data, Bit
host, "FeatureSelector", data.Data, name, name, type, type,
(ref VBuffer src, ref VBuffer dst) => SelectFeatures(ref src, features, card, ref dst));
- var res = RoleMappedData.Create(view, data.Schema.GetColumnRoleNames());
+ var res = new RoleMappedData(view, data.Schema.GetColumnRoleNames());
return res;
}
diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs b/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs
index e8ea19684e..f512114bbb 100644
--- a/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs
+++ b/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs
@@ -307,7 +307,7 @@ private static TOut CreatePipelineEnsemble(IHostEnvironment env, IPredicto
var dv = new EmptyDataView(env, inputSchema);
// The role mappings are specific to the individual predictors.
- var rmd = RoleMappedData.Create(dv);
+ var rmd = new RoleMappedData(dv);
var predictorModel = new PredictorModel(env, rmd, dv, ensemble);
var output = new TOut { PredictorModel = predictorModel };
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs
index 7435958e62..f49e3af81c 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs
@@ -181,11 +181,11 @@ public void Train(List>> models,
var bldr = new ArrayDataViewBuilder(host);
Array.Resize(ref labels, count);
Array.Resize(ref features, count);
- bldr.AddColumn("Label", NumberType.Float, labels);
- bldr.AddColumn("Features", NumberType.Float, features);
+ bldr.AddColumn(DefaultColumnNames.Label, NumberType.Float, labels);
+ bldr.AddColumn(DefaultColumnNames.Features, NumberType.Float, features);
var view = bldr.GetDataView();
- var rmd = RoleMappedData.Create(view, ColumnRole.Label.Bind("Label"), ColumnRole.Feature.Bind("Features"));
+ var rmd = new RoleMappedData(view, DefaultColumnNames.Label, DefaultColumnNames.Features);
var trainer = BasePredictorType.CreateInstance(host);
if (trainer is ITrainerEx ex && ex.NeedNormalization)
diff --git a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs
index 43956c608f..3cf30a3211 100644
--- a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs
+++ b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs
@@ -318,7 +318,7 @@ public IPredictor Calibrate(IChannel ch, IDataView data, ICalibratorTrainer cali
if (caliTrainer.NeedsTraining)
{
- var bound = new Bound(this, RoleMappedSchema.Create(data.Schema));
+ var bound = new Bound(this, new RoleMappedSchema(data.Schema));
using (var curs = data.GetRowCursor(col => true))
{
var scoreGetter = (ValueGetter)bound.CreateScoreGetter(curs, col => true, out Action disposer);
diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs
index 2976214cfe..8465e1a5e8 100644
--- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs
@@ -84,7 +84,7 @@ public virtual void CalculateMetrics(FeatureSubsetModel GetBatches(IRandom rand)
var view = new GenerateNumberTransform(Host, args, Data.Data);
var viewTest = new RangeFilter(Host, new RangeFilter.Arguments() { Column = name, Max = ValidationDatasetProportion }, view);
var viewTrain = new RangeFilter(Host, new RangeFilter.Arguments() { Column = name, Max = ValidationDatasetProportion, Complement = true }, view);
- dataTest = RoleMappedData.Create(viewTest, Data.Schema.GetColumnRoleNames());
- dataTrain = RoleMappedData.Create(viewTrain, Data.Schema.GetColumnRoleNames());
+ dataTest = new RoleMappedData(viewTest, Data.Schema.GetColumnRoleNames());
+ dataTrain = new RoleMappedData(viewTrain, Data.Schema.GetColumnRoleNames());
}
if (BatchSize > 0)
diff --git a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BootstrapSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BootstrapSelector.cs
index 25e7f8d64d..97dda8aeba 100644
--- a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BootstrapSelector.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BootstrapSelector.cs
@@ -46,7 +46,7 @@ public override IEnumerable GetSubsets(Batch batch, IRandom rand)
{
// REVIEW: Consider ways to reintroduce "balanced" samples.
var viewTrain = new BootstrapSampleTransform(Host, new BootstrapSampleTransform.Arguments(), Data.Data);
- var dataTrain = RoleMappedData.Create(viewTrain, Data.Schema.GetColumnRoleNames());
+ var dataTrain = new RoleMappedData(viewTrain, Data.Schema.GetColumnRoleNames());
yield return FeatureSelector.SelectFeatures(dataTrain, rand);
}
}
diff --git a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/RandomPartitionSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/RandomPartitionSelector.cs
index 322a2133dc..0fca7ac55b 100644
--- a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/RandomPartitionSelector.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/RandomPartitionSelector.cs
@@ -45,7 +45,7 @@ public override IEnumerable GetSubsets(Batch batch, IRandom rand)
for (int i = 0; i < Size; i++)
{
var viewTrain = new RangeFilter(Host, new RangeFilter.Arguments() { Column = name, Min = (Double)i / Size, Max = (Double)(i + 1) / Size }, view);
- var dataTrain = RoleMappedData.Create(viewTrain, Data.Schema.GetColumnRoleNames());
+ var dataTrain = new RoleMappedData(viewTrain, Data.Schema.GetColumnRoleNames());
yield return FeatureSelector.SelectFeatures(dataTrain, rand);
}
}
diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs
index 01668eb3f2..22a4e145b8 100644
--- a/src/Microsoft.ML.FastTree/FastTree.cs
+++ b/src/Microsoft.ML.FastTree/FastTree.cs
@@ -1386,7 +1386,7 @@ private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxB
// Since we've passed it through a few transforms, reconstitute the mapping on the
// newly transformed data.
- examples = RoleMappedData.Create(data, examples.Schema.GetColumnRoleNames());
+ examples = new RoleMappedData(data, examples.Schema.GetColumnRoleNames());
// Get the index of the columns in the transposed view, while we're at it composing
// the list of the columns we want to transpose.
diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs
index bc16a07338..931dd2335b 100644
--- a/src/Microsoft.ML.FastTree/GamTrainer.cs
+++ b/src/Microsoft.ML.FastTree/GamTrainer.cs
@@ -990,10 +990,11 @@ public Context(IChannel ch, GamPredictorBase pred, RoleMappedData data, IEvaluat
{
_eval = eval;
var builder = new ArrayDataViewBuilder(pred.Host);
- builder.AddColumn("Label", NumberType.Float, _labels);
- builder.AddColumn("Score", NumberType.Float, _scores);
- _dataForEvaluator = RoleMappedData.Create(builder.GetDataView(), RoleMappedSchema.ColumnRole.Label.Bind("Label"),
- RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, "Score"));
+ builder.AddColumn(DefaultColumnNames.Label, NumberType.Float, _labels);
+ builder.AddColumn(DefaultColumnNames.Score, NumberType.Float, _scores);
+ _dataForEvaluator = new RoleMappedData(builder.GetDataView(), opt: false,
+ RoleMappedSchema.ColumnRole.Label.Bind(DefaultColumnNames.Label),
+ new RoleMappedSchema.ColumnRole(MetadataUtils.Const.ScoreValueKind.Score).Bind(DefaultColumnNames.Score));
}
_data.Schema.Schema.TryGetColumnIndex(DefaultColumnNames.Features, out int featureIndex);
@@ -1196,7 +1197,7 @@ private Context Init(IChannel ch)
}
var pred = rawPred as GamPredictorBase;
ch.CheckUserArg(pred != null, nameof(Args.InputModelFile), "Predictor was not a " + nameof(GamPredictorBase));
- var data = RoleMappedData.CreateOpt(loader, schema.GetColumnRoleNames());
+ var data = new RoleMappedData(loader, schema.GetColumnRoleNames(), opt: true);
if (hadCalibrator && !string.IsNullOrWhiteSpace(Args.OutputModelFile))
ch.Warning("If you save the GAM model, only the GAM model, not the wrapping calibrator, will be saved.");
diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs
index 253d76d654..a7044d2502 100644
--- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs
+++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs
@@ -393,8 +393,7 @@ private void EnsureCachedPosition()
public IEnumerable> GetInputColumnRoles()
{
- yield return new KeyValuePair(
- RoleMappedSchema.ColumnRole.Feature, _inputSchema.Feature.Name);
+ yield return RoleMappedSchema.ColumnRole.Feature.Bind(_inputSchema.Feature.Name);
}
public Func GetDependencies(Func predicate)
diff --git a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs
index df386d0dc1..ff663a8d6e 100644
--- a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs
+++ b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs
@@ -97,54 +97,37 @@ public class Arguments : UnsupervisedLearnerInputBaseWithWeight
public KMeansPlusPlusTrainer(IHostEnvironment env, Arguments args)
: base(env, LoadNameValue)
{
- Contracts.CheckValue(args, nameof(args));
- Contracts.CheckUserArg(args.K > 0, nameof(args.K), "Number of means must be positive");
+ Host.CheckValue(args, nameof(args));
+ Host.CheckUserArg(args.K > 0, nameof(args.K), "Must be positive");
_k = args.K;
- Contracts.CheckUserArg(args.MaxIterations > 0, nameof(args.MaxIterations), "Number of iterations must be positive");
+ Host.CheckUserArg(args.MaxIterations > 0, nameof(args.MaxIterations), "Must be positive");
_maxIterations = args.MaxIterations;
- Contracts.CheckUserArg(args.OptTol > 0, nameof(args.OptTol), "Tolerance must be positive");
+ Host.CheckUserArg(args.OptTol > 0, nameof(args.OptTol), "Tolerance must be positive");
_convergenceThreshold = args.OptTol;
_centroids = new VBuffer[_k];
- Contracts.CheckUserArg(args.AccelMemBudgetMb > 0, nameof(args.AccelMemBudgetMb), "Memory budget must be positive");
+ Host.CheckUserArg(args.AccelMemBudgetMb > 0, nameof(args.AccelMemBudgetMb), "Must be positive");
_accelMemBudgetMb = args.AccelMemBudgetMb;
_initAlgorithm = args.InitAlgorithm;
- if (args.NumThreads.HasValue)
- {
- Contracts.CheckUserArg(args.NumThreads.Value > 0, nameof(args.NumThreads), "The number of threads must be either null or a positive integer.");
- }
+ Host.CheckUserArg(!args.NumThreads.HasValue || args.NumThreads > 0, nameof(args.NumThreads),
+ "Must be either null or a positive integer.");
_numThreads = ComputeNumThreads(Host, args.NumThreads);
}
- public override bool NeedNormalization
- {
- get { return true; }
- }
-
- public override bool NeedCalibration
- {
- get { return false; }
- }
-
- public override bool WantCaching
- {
- get { return true; }
- }
-
- public override PredictionKind PredictionKind
- {
- get { return PredictionKind.Clustering; }
- }
+ public override bool NeedNormalization => true;
+ public override bool NeedCalibration => false;
+ public override bool WantCaching => true;
+ public override PredictionKind PredictionKind => PredictionKind.Clustering;
public override void Train(RoleMappedData data)
{
- Contracts.CheckValue(data, nameof(data));
+ Host.CheckValue(data, nameof(data));
data.CheckFeatureFloatVector(out _dimensionality);
Contracts.Assert(_dimensionality > 0);
@@ -311,7 +294,7 @@ public static void Initialize(
// This check is only performed once, at the first pass of initialization
if (dimensionality != cursor.Features.Length)
{
- throw Contracts.Except(
+ throw ch.Except(
"Dimensionality doesn't match, expected {0}, got {1}",
dimensionality,
cursor.Features.Length);
@@ -330,7 +313,7 @@ public static void Initialize(
probabilityWeight = Math.Min(probabilityWeight, distance);
}
- Contracts.Assert(FloatUtils.IsFinite(probabilityWeight));
+ ch.Assert(FloatUtils.IsFinite(probabilityWeight));
}
if (probabilityWeight > 0)
@@ -362,7 +345,7 @@ public static void Initialize(
// persist the candidate as a new centroid
if (!haveCandidate)
{
- throw Contracts.Except(
+ throw ch.Except(
"Not enough distinct instances to populate {0} clusters (only found {1} distinct instances)", k, i);
}
@@ -716,13 +699,13 @@ public static void Initialize(IHost host, int numThreads, IChannel ch, FeatureFl
out long missingFeatureCount, out long totalTrainingInstances)
{
Contracts.CheckValue(host, nameof(host));
- host.CheckValue(cursorFactory, nameof(cursorFactory));
host.CheckValue(ch, nameof(ch));
- host.CheckValue(centroids, nameof(centroids));
- host.CheckUserArg(numThreads > 0, nameof(KMeansPlusPlusTrainer.Arguments.NumThreads), "Must be positive");
- host.CheckUserArg(k > 0, nameof(KMeansPlusPlusTrainer.Arguments.K), "Must be positive");
- host.CheckParam(dimensionality > 0, nameof(dimensionality), "Must be positive");
- host.CheckUserArg(accelMemBudgetMb >= 0, nameof(KMeansPlusPlusTrainer.Arguments.AccelMemBudgetMb), "Must be non-negative");
+ ch.CheckValue(cursorFactory, nameof(cursorFactory));
+ ch.CheckValue(centroids, nameof(centroids));
+ ch.CheckUserArg(numThreads > 0, nameof(KMeansPlusPlusTrainer.Arguments.NumThreads), "Must be positive");
+ ch.CheckUserArg(k > 0, nameof(KMeansPlusPlusTrainer.Arguments.K), "Must be positive");
+ ch.CheckParam(dimensionality > 0, nameof(dimensionality), "Must be positive");
+ ch.CheckUserArg(accelMemBudgetMb >= 0, nameof(KMeansPlusPlusTrainer.Arguments.AccelMemBudgetMb), "Must be non-negative");
int numRounds;
int numSamplesPerRound;
@@ -787,7 +770,7 @@ public static void Initialize(IHost host, int numThreads, IChannel ch, FeatureFl
VBufferUtils.Densify(ref clusters[clusterCount]);
clustersL2s[clusterCount] = VectorUtils.NormSquared(clusters[clusterCount]);
clusterPrevCount = clusterCount;
- Contracts.Assert(clusterCount - clusterPrevCount <= numSamplesPerRound);
+ ch.Assert(clusterCount - clusterPrevCount <= numSamplesPerRound);
clusterCount++;
logicalExternalRounds++;
pCh.Checkpoint(logicalExternalRounds, numRounds + 2);
@@ -828,11 +811,11 @@ public static void Initialize(IHost host, int numThreads, IChannel ch, FeatureFl
clusterCount++;
}
- Contracts.Assert(clusterCount - clusterPrevCount <= numSamplesPerRound);
+ ch.Assert(clusterCount - clusterPrevCount <= numSamplesPerRound);
logicalExternalRounds++;
pCh.Checkpoint(logicalExternalRounds, numRounds + 2);
}
- Contracts.Assert(clusterCount == clusters.Length);
+ ch.Assert(clusterCount == clusters.Length);
}
// Finally, we do one last pass through the dataset, finding for
@@ -851,7 +834,7 @@ public static void Initialize(IHost host, int numThreads, IChannel ch, FeatureFl
clustersL2s, false, false, out discardBestWeight, out bestCluster);
#if DEBUG
int debugBestCluster = KMeansUtils.FindBestCluster(ref point, clusters, clustersL2s);
- Contracts.Assert(bestCluster == debugBestCluster);
+ ch.Assert(bestCluster == debugBestCluster);
#endif
weights[bestCluster]++;
},
@@ -881,9 +864,9 @@ public static void Initialize(IHost host, int numThreads, IChannel ch, FeatureFl
ref debugWeightBuffer, ref debugTotalWeights);
for (int i = 0; i < totalWeights.Length; i++)
- Contracts.Assert(totalWeights[i] == debugTotalWeights[i]);
+ ch.Assert(totalWeights[i] == debugTotalWeights[i]);
#endif
- Contracts.Assert(totalWeights.Length == clusters.Length);
+ ch.Assert(totalWeights.Length == clusters.Length);
logicalExternalRounds++;
// If we sampled exactly the right number of points then we can
@@ -899,10 +882,10 @@ public static void Initialize(IHost host, int numThreads, IChannel ch, FeatureFl
else
{
ArrayDataViewBuilder arrDv = new ArrayDataViewBuilder(host);
- arrDv.AddColumn("Features", PrimitiveType.FromKind(DataKind.R4), clusters);
- arrDv.AddColumn("Weights", PrimitiveType.FromKind(DataKind.R4), totalWeights);
+ arrDv.AddColumn(DefaultColumnNames.Features, PrimitiveType.FromKind(DataKind.R4), clusters);
+ arrDv.AddColumn(DefaultColumnNames.Weight, PrimitiveType.FromKind(DataKind.R4), totalWeights);
var subDataViewCursorFactory = new FeatureFloatVectorCursor.Factory(
- TrainUtils.CreateExamples(arrDv.GetDataView(), null, "Features", weight: "Weights"), CursOpt.Weight | CursOpt.Features);
+ new RoleMappedData(arrDv.GetDataView(), null, DefaultColumnNames.Features, weight: DefaultColumnNames.Weight), CursOpt.Weight | CursOpt.Features);
long discard1;
long discard2;
KMeansPlusPlusInit.Initialize(host, numThreads, ch, subDataViewCursorFactory, k, dimensionality, centroids, out discard1, out discard2, false);
@@ -1198,13 +1181,13 @@ public int GetBestCluster(int idx)
public SharedState(FeatureFloatVectorCursor.Factory factory, IChannel ch, long baseMaxInstancesToAccelerate, int k,
bool isParallel, long totalTrainingInstances)
{
- Contracts.AssertValue(factory);
Contracts.AssertValue(ch);
- Contracts.Assert(k > 0);
- Contracts.Assert(totalTrainingInstances > 0);
+ ch.AssertValue(factory);
+ ch.Assert(k > 0);
+ ch.Assert(totalTrainingInstances > 0);
_acceleratedRowMap = new KMeansAcceleratedRowMap(factory, ch, baseMaxInstancesToAccelerate, totalTrainingInstances, isParallel);
- Contracts.Assert(MaxInstancesToAccelerate >= 0,
+ ch.Assert(MaxInstancesToAccelerate >= 0,
"MaxInstancesToAccelerate cannot be negative as KMeansAcceleratedRowMap sets it to 0 when baseMaxInstancesToAccelerate is negative");
if (MaxInstancesToAccelerate > 0)
@@ -1576,12 +1559,12 @@ public static RowStats ParallelWeightedReservoirSample(
},
(Heap[] heaps, IRandom rand, ref Heap finalHeap) =>
{
- Contracts.Assert(finalHeap == null);
+ host.Assert(finalHeap == null);
finalHeap = new Heap((x, y) => x.Weight > y.Weight, numSamples);
for (int i = 0; i < heaps.Length; i++)
{
- Contracts.AssertValue(heaps[i]);
- Contracts.Assert(heaps[i].Count <= numSamples, "heaps[i].Count must not be greater than numSamples");
+ host.AssertValue(heaps[i]);
+ host.Assert(heaps[i].Count <= numSamples, "heaps[i].Count must not be greater than numSamples");
while (heaps[i].Count > 0)
{
var row = heaps[i].Pop();
@@ -1597,7 +1580,7 @@ public static RowStats ParallelWeightedReservoirSample(
}, ref buffer, ref outHeap);
if (outHeap.Count != numSamples)
- throw Contracts.Except("Failed to initialize clusters: too few examples");
+ throw host.Except("Failed to initialize clusters: too few examples");
// Keep in mind that the distribution of samples in dst will not be random. It will
// have the residual minHeap ordering.
diff --git a/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs b/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs
index 6b72d64af5..9c40e61ca2 100644
--- a/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs
+++ b/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs
@@ -172,13 +172,13 @@ private void Run(IChannel ch)
{
RoleMappedData data;
if (trainSchema != null)
- data = RoleMappedData.Create(end, trainSchema.GetColumnRoleNames());
+ data = new RoleMappedData(end, trainSchema.GetColumnRoleNames());
else
{
// We had a predictor, but no roles stored in the model. Just suppose
// default column names are OK, if present.
- data = TrainUtils.CreateExamplesOpt(end, DefaultColumnNames.Label,
- DefaultColumnNames.Features, DefaultColumnNames.GroupId, DefaultColumnNames.Weight, DefaultColumnNames.Name);
+ data = new RoleMappedData(end, DefaultColumnNames.Label,
+ DefaultColumnNames.Features, DefaultColumnNames.GroupId, DefaultColumnNames.Weight, DefaultColumnNames.Name, opt: true);
}
var scorePipe = ScoreUtils.GetScorer(rawPred, data, Host, trainSchema);
diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs
index 73b3f1051f..67c53223d7 100644
--- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs
+++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs
@@ -82,7 +82,7 @@ public FieldAwareFactorizationMachineScalarRowMapper(IHostEnvironment env, RoleM
_pred = pred;
var inputFeatureColumns = _columns.Select(c => new KeyValuePair(RoleMappedSchema.ColumnRole.Feature, c.Name)).ToList();
- InputSchema = RoleMappedSchema.Create(schema.Schema, inputFeatureColumns);
+ InputSchema = new RoleMappedSchema(schema.Schema, inputFeatureColumns);
OutputSchema = outputSchema;
_inputColumnIndexes = new List();
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs
index 738c0197cb..a48a53ae65 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs
@@ -118,7 +118,7 @@ protected RoleMappedData PrepareDataFromTrainingExamples(IChannel ch, RoleMapped
ch.Assert(idvToFeedTrain.CanShuffle);
var roles = examples.Schema.GetColumnRoleNames();
- var examplesToFeedTrain = RoleMappedData.Create(idvToFeedTrain, roles);
+ var examplesToFeedTrain = new RoleMappedData(idvToFeedTrain, roles);
ch.Assert(examplesToFeedTrain.Schema.Label != null);
ch.Assert(examplesToFeedTrain.Schema.Feature != null);
diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs
index 8aa6e4b6e0..7b5fcc8a93 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs
@@ -79,7 +79,7 @@ private TScalarPredictor TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappe
var roles = data.Schema.GetColumnRoleNames()
.Where(kvp => kvp.Key.Value != CR.Label.Value)
.Prepend(CR.Label.Bind(dstName));
- var td = RoleMappedData.Create(view, roles);
+ var td = new RoleMappedData(view, roles);
trainer.Train(td);
@@ -214,7 +214,7 @@ public static ModelOperations.PredictorModelOutput CombineOvaModels(IHostEnviron
input.FeatureColumn, DefaultColumnNames.Features);
var weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(input.WeightColumn),
input.WeightColumn, DefaultColumnNames.Weight);
- var data = TrainUtils.CreateExamples(normalizedView, label, feature, null, weight);
+ var data = new RoleMappedData(normalizedView, label, feature, null, weight);
return new ModelOperations.PredictorModelOutput
{
diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs
index c710b3149d..cf1e7c062b 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs
@@ -76,7 +76,7 @@ private TDistPredictor TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedD
var roles = data.Schema.GetColumnRoleNames()
.Where(kvp => kvp.Key.Value != CR.Label.Value)
.Prepend(CR.Label.Bind(dstName));
- var td = RoleMappedData.Create(view, roles);
+ var td = new RoleMappedData(view, roles);
trainer.Train(td);
diff --git a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs
index 87420a265e..1d2f466ad5 100644
--- a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs
+++ b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs
@@ -126,12 +126,12 @@ private FastForestRegressionPredictor FitModel(IEnumerable previousR
}
ArrayDataViewBuilder dvBuilder = new ArrayDataViewBuilder(_host);
- dvBuilder.AddColumn("Label", NumberType.Float, targets);
- dvBuilder.AddColumn("Features", NumberType.Float, features);
+ dvBuilder.AddColumn(DefaultColumnNames.Label, NumberType.Float, targets);
+ dvBuilder.AddColumn(DefaultColumnNames.Features, NumberType.Float, features);
IDataView view = dvBuilder.GetDataView();
_host.Assert(view.GetRowCount() == targets.Length, "This data view will have as many rows as there have been evaluations");
- RoleMappedData data = TrainUtils.CreateExamples(view, "Label", "Features");
+ RoleMappedData data = new RoleMappedData(view, DefaultColumnNames.Label, DefaultColumnNames.Features);
using (IChannel ch = _host.Start("Single training"))
{
diff --git a/src/Microsoft.ML.Transforms/LearnerFeatureSelection.cs b/src/Microsoft.ML.Transforms/LearnerFeatureSelection.cs
index 8448a406db..637d75250b 100644
--- a/src/Microsoft.ML.Transforms/LearnerFeatureSelection.cs
+++ b/src/Microsoft.ML.Transforms/LearnerFeatureSelection.cs
@@ -299,7 +299,7 @@ private static void TrainCore(IHost host, IDataView input, Arguments args, ref V
ch.Trace("Binding columns");
var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, args.CustomColumn);
- var data = TrainUtils.CreateExamples(view, label, feature, group, weight, name, customCols);
+ var data = new RoleMappedData(view, label, feature, group, weight, name, customCols);
var predictor = TrainUtils.Train(host, ch, data, trainer, args.Filter.Kind, null,
null, 0, args.CacheData);
diff --git a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs
index f6c99d061c..631b7b6946 100644
--- a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs
+++ b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs
@@ -433,13 +433,11 @@ public static CombinedOutput CombineMetrics(IHostEnvironment env, CombineMetrics
var eval = GetEvaluator(env, input.Kind);
var perInst = EvaluateUtils.ConcatenatePerInstanceDataViews(env, eval, true, true, input.PerInstanceMetrics.Select(
- idv => RoleMappedData.CreateOpt(idv, new[]
- {
- RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, input.LabelColumn),
- RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Weight, input.WeightColumn.Value),
- RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Group, input.GroupColumn.Value),
- RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Name, input.NameColumn.Value)
- })).ToArray(),
+ idv => new RoleMappedData(idv, opt: true,
+ RoleMappedSchema.ColumnRole.Label.Bind(input.LabelColumn),
+ RoleMappedSchema.ColumnRole.Weight.Bind(input.WeightColumn.Value),
+ RoleMappedSchema.ColumnRole.Group.Bind(input.GroupColumn),
+ RoleMappedSchema.ColumnRole.Name.Bind(input.NameColumn.Value))).ToArray(),
out var variableSizeVectorColumnNames);
var warnings = input.Warnings != null ? new List(input.Warnings) : new List();
diff --git a/src/Microsoft.ML/Runtime/EntryPoints/FeatureCombiner.cs b/src/Microsoft.ML/Runtime/EntryPoints/FeatureCombiner.cs
index b26960c561..6502fc2afa 100644
--- a/src/Microsoft.ML/Runtime/EntryPoints/FeatureCombiner.cs
+++ b/src/Microsoft.ML/Runtime/EntryPoints/FeatureCombiner.cs
@@ -22,7 +22,7 @@ public sealed class FeatureCombinerInput : TransformInputBase
[Argument(ArgumentType.Multiple, HelpText = "Features", SortOrder = 2)]
public string[] Features;
- public IEnumerable> GetRoles()
+ internal IEnumerable> GetRoles()
{
if (Utils.Size(Features) > 0)
{
@@ -49,7 +49,7 @@ public static CommonOutputs.TransformOutput PrepareFeatures(IHostEnvironment env
using (var ch = host.Start(featureCombiner))
{
var viewTrain = input.Data;
- var rms = RoleMappedSchema.Create(viewTrain.Schema, input.GetRoles());
+ var rms = new RoleMappedSchema(viewTrain.Schema, input.GetRoles());
var feats = rms.GetColumns(RoleMappedSchema.ColumnRole.Feature);
if (Utils.Size(feats) == 0)
throw ch.Except("No feature columns specified");
diff --git a/src/Microsoft.ML/Runtime/EntryPoints/OneVersusAllMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/OneVersusAllMacro.cs
index c95f384ad6..e4a54de040 100644
--- a/src/Microsoft.ML/Runtime/EntryPoints/OneVersusAllMacro.cs
+++ b/src/Microsoft.ML/Runtime/EntryPoints/OneVersusAllMacro.cs
@@ -128,7 +128,7 @@ private static int GetNumberOfClasses(IHostEnvironment env, Arguments input, out
input.WeightColumn, DefaultColumnNames.Weight);
// Get number of classes
- var data = TrainUtils.CreateExamples(input.TrainingData, label, feature, null, weight);
+ var data = new RoleMappedData(input.TrainingData, label, feature, null, weight);
data.CheckMultiClassLabel(out var numClasses);
return numClasses;
}
diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
index 43e208ad3c..ec8c0e12d7 100644
--- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
+++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
@@ -684,9 +684,7 @@ public void EntryPointCalibrate()
// This tests that the SchemaBindableCalibratedPredictor doesn't get confused if its sub-predictor is already calibrated.
var fastForest = new FastForestClassification(Env, new FastForestClassification.Arguments());
- var rmd = RoleMappedData.Create(splitOutput.TrainData[0],
- RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Feature, "Features"),
- RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, "Label"));
+ var rmd = new RoleMappedData(splitOutput.TrainData[0], "Label", "Features");
fastForest.Train(rmd);
var ffModel = new PredictorModel(Env, rmd, splitOutput.TrainData[0], fastForest.CreatePredictor());
var calibratedFfModel = Calibrate.Platt(Env,
@@ -1220,9 +1218,7 @@ public void EntryPointMulticlassPipelineEnsemble()
}, data);
var mlr = new MulticlassLogisticRegression(Env, new MulticlassLogisticRegression.Arguments());
- RoleMappedData rmd = RoleMappedData.Create(data,
- RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Feature, "Features"),
- RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, "Label"));
+ var rmd = new RoleMappedData(data, "Label", "Features");
mlr.Train(rmd);
predictorModels[i] = new PredictorModel(Env, rmd, data, mlr.CreatePredictor());
diff --git a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs
index 928f740b57..f9b08ef72a 100644
--- a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs
+++ b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs
@@ -599,7 +599,7 @@ private void CombineAndTestTreeEnsembles(IDataView idv, IPredictorModel[] fastTr
var fastTree = combiner.CombineModels(fastTrees.Select(pm => pm.Predictor as IPredictorProducing));
- var data = RoleMappedData.Create(idv, RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Feature, "Features"));
+ var data = new RoleMappedData(idv, label: null, feature: "Features");
var scored = ScoreModel.Score(Env, new ScoreModel.Input() { Data = idv, PredictorModel = new PredictorModel(Env, data, idv, fastTree) }).ScoredData;
Assert.True(scored.Schema.TryGetColumnIndex("Score", out int scoreCol));
Assert.True(scored.Schema.TryGetColumnIndex("Probability", out int probCol));
@@ -1537,10 +1537,10 @@ public void CompareSvmPredictorResultsToLibSvm()
Column = new[] { new NormalizeTransform.AffineColumn() { Name = "Features", Source = "Features" } }
},
trainView);
- var trainData = TrainUtils.CreateExamples(trainView, "Label", "Features");
+ var trainData = new RoleMappedData(trainView, "Label", "Features");
IDataView testView = new TextLoader(env, new TextLoader.Arguments(), new MultiFileSource(GetDataPath(TestDatasets.mnistOneClass.testFilename)));
ApplyTransformUtils.ApplyAllTransformsToData(env, trainView, testView);
- var testData = TrainUtils.CreateExamples(testView, "Label", "Features");
+ var testData = new RoleMappedData(testView, "Label", "Features");
CompareSvmToLibSvmCore("linear kernel", "LinearKernel", env, trainData, testData);
CompareSvmToLibSvmCore("polynomial kernel", "PolynomialKernel{d=2}", env, trainData, testData);
diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs
index c1291cd52e..7c371082de 100644
--- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs
+++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs
@@ -74,7 +74,7 @@ public void TrainAndPredictIrisModelUsingDirectInstantiationTest()
// Explicity adding CacheDataView since caching is not working though trainer has 'Caching' On/Auto
var cached = new CacheDataView(env, trans, prefetch: null);
- var trainRoles = TrainUtils.CreateExamples(cached, label: "Label", feature: "Features");
+ var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features");
trainer.Train(trainRoles);
// Get scorer and evaluate the predictions from test data
@@ -176,7 +176,7 @@ private void CompareMatrics(ClassificationMetrics metrics)
private ClassificationMetrics Evaluate(IHostEnvironment env, IDataView scoredData)
{
- var dataEval = TrainUtils.CreateExamplesOpt(scoredData, label: "Label", feature: "Features");
+ var dataEval = new RoleMappedData(scoredData, label: "Label", feature: "Features", opt: true);
// Evaluate.
// It does not work. It throws error "Failed to find 'Score' column" when Evaluate is called
@@ -193,7 +193,7 @@ private IDataScorerTransform GetScorer(IHostEnvironment env, IDataView transform
using (var ch = env.Start("Saving model"))
using (var memoryStream = new MemoryStream())
{
- var trainRoles = TrainUtils.CreateExamples(transforms, label: "Label", feature: "Features");
+ var trainRoles = new RoleMappedData(transforms, label: "Label", feature: "Features");
// Model cannot be saved with CacheDataView
TrainUtils.SaveModel(env, ch, memoryStream, pred, trainRoles);
@@ -201,7 +201,7 @@ private IDataScorerTransform GetScorer(IHostEnvironment env, IDataView transform
using (var rep = RepositoryReader.Open(memoryStream, ch))
{
IDataLoader testPipe = ModelFileUtils.LoadLoader(env, rep, new MultiFileSource(testDataPath), true);
- RoleMappedData testRoles = TrainUtils.CreateExamples(testPipe, label: "Label", feature: "Features");
+ RoleMappedData testRoles = new RoleMappedData(testPipe, label: "Label", feature: "Features");
return ScoreUtils.GetScorer(pred, testRoles, env, testRoles.Schema);
}
}
diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs
index 2266f4d1f0..da208cb3f0 100644
--- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs
+++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs
@@ -78,7 +78,7 @@ public void TrainAndPredictSentimentModelWithDirectionInstantiationTest()
MinDocumentsInLeafs = 2
});
- var trainRoles = TrainUtils.CreateExamples(trans, label: "Label", feature: "Features");
+ var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
trainer.Train(trainRoles);
// Get scorer and evaluate the predictions from test data
@@ -103,7 +103,7 @@ public void TrainAndPredictSentimentModelWithDirectionInstantiationTest()
private BinaryClassificationMetrics EvaluateBinary(IHostEnvironment env, IDataView scoredData)
{
- var dataEval = TrainUtils.CreateExamplesOpt(scoredData, label: "Label", feature: "Features");
+ var dataEval = new RoleMappedData(scoredData, label: "Label", feature: "Features", opt: true);
// Evaluate.
// It does not work. It throws error "Failed to find 'Score' column" when Evaluate is called