Skip to content

Commit acc4ac0

Browse files
authored
One type label policy in trainers (#2804)
1 parent e2464f6 commit acc4ac0

File tree

64 files changed

+670
-596
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+670
-596
lines changed

src/Microsoft.ML.Core/Data/AnnotationUtils.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ public static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.Column
308308
schema.Schema[list[0].Index].Annotations.GetValue(Kinds.SlotNames, ref slotNames);
309309
}
310310

311-
public static bool HasKeyValues(this SchemaShape.Column col)
311+
public static bool NeedsSlotNames(this SchemaShape.Column col)
312312
{
313313
return col.Annotations.TryFindColumn(Kinds.KeyValues, out var metaCol)
314314
&& metaCol.Kind == SchemaShape.Column.VectorKind.Vector
@@ -442,7 +442,7 @@ public static bool TryGetCategoricalFeatureIndices(DataViewSchema schema, int co
442442
public static IEnumerable<SchemaShape.Column> AnnotationsForMulticlassScoreColumn(SchemaShape.Column? labelColumn = null)
443443
{
444444
var cols = new List<SchemaShape.Column>();
445-
if (labelColumn != null && labelColumn.Value.IsKey && HasKeyValues(labelColumn.Value))
445+
if (labelColumn != null && labelColumn.Value.IsKey && NeedsSlotNames(labelColumn.Value))
446446
cols.Add(new SchemaShape.Column(Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false));
447447
cols.AddRange(GetTrainerOutputAnnotation());
448448
return cols;

src/Microsoft.ML.Data/Data/Conversion.cs

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ private Conversions()
111111
AddStd<I1, R4>(Convert);
112112
AddStd<I1, R8>(Convert);
113113
AddAux<I1, SB>(Convert);
114+
AddStd<I1, BL>(Convert);
114115

115116
AddStd<I2, I1>(Convert);
116117
AddStd<I2, I2>(Convert);
@@ -119,6 +120,7 @@ private Conversions()
119120
AddStd<I2, R4>(Convert);
120121
AddStd<I2, R8>(Convert);
121122
AddAux<I2, SB>(Convert);
123+
AddStd<I2, BL>(Convert);
122124

123125
AddStd<I4, I1>(Convert);
124126
AddStd<I4, I2>(Convert);
@@ -127,6 +129,7 @@ private Conversions()
127129
AddStd<I4, R4>(Convert);
128130
AddStd<I4, R8>(Convert);
129131
AddAux<I4, SB>(Convert);
132+
AddStd<I4, BL>(Convert);
130133

131134
AddStd<I8, I1>(Convert);
132135
AddStd<I8, I2>(Convert);
@@ -135,6 +138,7 @@ private Conversions()
135138
AddStd<I8, R4>(Convert);
136139
AddStd<I8, R8>(Convert);
137140
AddAux<I8, SB>(Convert);
141+
AddStd<I8, BL>(Convert);
138142

139143
AddStd<U1, U1>(Convert);
140144
AddStd<U1, U2>(Convert);
@@ -144,6 +148,7 @@ private Conversions()
144148
AddStd<U1, R4>(Convert);
145149
AddStd<U1, R8>(Convert);
146150
AddAux<U1, SB>(Convert);
151+
AddStd<U1, BL>(Convert);
147152

148153
AddStd<U2, U1>(Convert);
149154
AddStd<U2, U2>(Convert);
@@ -153,6 +158,7 @@ private Conversions()
153158
AddStd<U2, R4>(Convert);
154159
AddStd<U2, R8>(Convert);
155160
AddAux<U2, SB>(Convert);
161+
AddStd<U2, BL>(Convert);
156162

157163
AddStd<U4, U1>(Convert);
158164
AddStd<U4, U2>(Convert);
@@ -162,6 +168,7 @@ private Conversions()
162168
AddStd<U4, R4>(Convert);
163169
AddStd<U4, R8>(Convert);
164170
AddAux<U4, SB>(Convert);
171+
AddStd<U4, BL>(Convert);
165172

166173
AddStd<U8, U1>(Convert);
167174
AddStd<U8, U2>(Convert);
@@ -171,6 +178,7 @@ private Conversions()
171178
AddStd<U8, R4>(Convert);
172179
AddStd<U8, R8>(Convert);
173180
AddAux<U8, SB>(Convert);
181+
AddStd<U8, BL>(Convert);
174182

175183
AddStd<UG, U1>(Convert);
176184
AddStd<UG, U2>(Convert);
@@ -180,11 +188,13 @@ private Conversions()
180188
AddAux<UG, SB>(Convert);
181189

182190
AddStd<R4, R4>(Convert);
191+
AddStd<R4, BL>(Convert);
183192
AddStd<R4, R8>(Convert);
184193
AddAux<R4, SB>(Convert);
185194

186195
AddStd<R8, R4>(Convert);
187196
AddStd<R8, R8>(Convert);
197+
AddStd<R8, BL>(Convert);
188198
AddAux<R8, SB>(Convert);
189199

190200
AddStd<TX, I1>(Convert);
@@ -901,6 +911,19 @@ public void Convert(in BL src, ref SB dst)
901911
public void Convert(in DZ src, ref SB dst) { ClearDst(ref dst); dst.AppendFormat("{0:o}", src); }
902912
#endregion ToStringBuilder
903913

914+
#region ToBL
915+
public void Convert(in R8 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
916+
public void Convert(in R4 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
917+
public void Convert(in I1 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
918+
public void Convert(in I2 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
919+
public void Convert(in I4 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
920+
public void Convert(in I8 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
921+
public void Convert(in U1 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
922+
public void Convert(in U2 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
923+
public void Convert(in U4 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
924+
public void Convert(in U8 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
925+
#endregion
926+
904927
#region FromR4
905928
public void Convert(in R4 src, ref R4 dst) => dst = src;
906929
public void Convert(in R4 src, ref R8 dst) => dst = src;
@@ -1139,7 +1162,7 @@ private bool TryParseCore(ReadOnlySpan<char> span, out ulong dst)
11391162
dst = res;
11401163
return true;
11411164

1142-
LFail:
1165+
LFail:
11431166
dst = 0;
11441167
return false;
11451168
}
@@ -1246,7 +1269,7 @@ private bool TryParseNonNegative(ReadOnlySpan<char> span, out long result)
12461269
result = res;
12471270
return true;
12481271

1249-
LFail:
1272+
LFail:
12501273
result = 0;
12511274
return false;
12521275
}

src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ internal static bool CanWrap(ISchemaBoundMapper mapper, DataViewType labelNameTy
429429
var scoreType = outSchema[scoreIdx].Type;
430430

431431
// Check that the type is vector, and is of compatible size with the score output.
432-
return labelNameType is VectorType vectorType && vectorType.Size == scoreType.GetVectorSize();
432+
return labelNameType is VectorType vectorType && vectorType.Size == scoreType.GetVectorSize() && vectorType.ItemType == TextDataViewType.Instance;
433433
}
434434

435435
internal static ISchemaBoundMapper WrapCore<T>(IHostEnvironment env, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)

src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,15 @@ private protected virtual void CheckLabelCompatible(SchemaShape.Column labelCol)
145145
private protected TTransformer TrainTransformer(IDataView trainSet,
146146
IDataView validationSet = null, IPredictor initPredictor = null)
147147
{
148+
CheckInputSchema(SchemaShape.Create(trainSet.Schema));
148149
var trainRoleMapped = MakeRoles(trainSet);
149-
var validRoleMapped = validationSet == null ? null : MakeRoles(validationSet);
150+
RoleMappedData validRoleMapped = null;
151+
152+
if (validationSet != null)
153+
{
154+
CheckInputSchema(SchemaShape.Create(validationSet.Schema));
155+
validRoleMapped = MakeRoles(validationSet);
156+
}
150157

151158
var pred = TrainModelCore(new TrainContext(trainRoleMapped, validRoleMapped, null, initPredictor));
152159
return MakeTransformer(pred, trainSet.Schema);

src/Microsoft.ML.FastTree/FastTreeRanking.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ private protected override void CheckLabelCompatible(SchemaShape.Column labelCol
105105
if (!labelCol.IsKey && labelCol.ItemType != NumberDataViewType.Single)
106106
error();
107107
}
108+
108109
private protected override float GetMaxLabel()
109110
{
110111
return GetLabelGains().Length - 1;

src/Microsoft.ML.StandardTrainers/Standard/MultiClass/MetaMulticlassTrainer.cs

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88
using Microsoft.ML.Calibrators;
99
using Microsoft.ML.CommandLine;
1010
using Microsoft.ML.Data;
11-
using Microsoft.ML.Data.Conversion;
1211
using Microsoft.ML.Internal.Internallearn;
1312
using Microsoft.ML.Runtime;
14-
13+
using Microsoft.ML.Transforms;
1514
namespace Microsoft.ML.Trainers
1615
{
1716
using TScalarTrainer = ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictorProducing<float>>, IPredictorProducing<float>>;
@@ -32,7 +31,7 @@ internal abstract class OptionsBase
3231
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of instances to train the calibrator", SortOrder = 150, ShortName = "numcali")]
3332
internal int MaxCalibrationExamples = 1000000000;
3433

35-
[Argument(ArgumentType.Multiple, HelpText = "Whether to treat missing labels as having negative labels, instead of keeping them missing", SortOrder = 150, ShortName = "missNeg")]
34+
[Argument(ArgumentType.Multiple, HelpText = "Whether to treat missing labels as having negative labels, or exclude their rows from dataview.", SortOrder = 150, ShortName = "missNeg")]
3635
public bool ImputeMissingLabelsAsNegative;
3736
}
3837

@@ -98,20 +97,15 @@ private protected IDataView MapLabelsCore<T>(DataViewType type, InPredicate<T> e
9897
Host.AssertValue(data);
9998
Host.Assert(data.Schema.Label.HasValue);
10099

101-
var lab = data.Schema.Label.Value;
100+
var label = data.Schema.Label.Value;
101+
IDataView dataView = data.Data;
102+
if (!Args.ImputeMissingLabelsAsNegative)
103+
dataView = new NAFilter(Host, data.Data, false, label.Name);
102104

103-
InPredicate<T> isMissing;
104-
if (!Args.ImputeMissingLabelsAsNegative && Conversions.Instance.TryGetIsNAPredicate(type, out isMissing))
105-
{
106-
return LambdaColumnMapper.Create(Host, "Label mapper", data.Data,
107-
lab.Name, lab.Name, type, NumberDataViewType.Single,
108-
(in T src, ref float dst) =>
109-
dst = equalsTarget(in src) ? 1 : (isMissing(in src) ? float.NaN : default(float)));
110-
}
111105
return LambdaColumnMapper.Create(Host, "Label mapper", data.Data,
112-
lab.Name, lab.Name, type, NumberDataViewType.Single,
113-
(in T src, ref float dst) =>
114-
dst = equalsTarget(in src) ? 1 : default(float));
106+
label.Name, label.Name, type, BooleanDataViewType.Instance,
107+
(in T src, ref bool dst) =>
108+
dst = equalsTarget(in src) ? true : false);
115109
}
116110

117111
private protected abstract TModel TrainCore(IChannel ch, RoleMappedData data, int count);

src/Microsoft.ML.StandardTrainers/Standard/MultiClass/OneVersusAllTrainer.cs

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -145,28 +145,18 @@ private ISingleFeaturePredictionTransformer<TScalarPredictor> TrainOne(IChannel
145145

146146
private IDataView MapLabels(RoleMappedData data, int cls)
147147
{
148-
var lab = data.Schema.Label.Value;
149-
Host.Assert(!lab.IsHidden);
150-
Host.Assert(lab.Type.GetKeyCount() > 0 || lab.Type == NumberDataViewType.Single || lab.Type == NumberDataViewType.Double);
148+
var label = data.Schema.Label.Value;
149+
Host.Assert(!label.IsHidden);
150+
Host.Assert(label.Type.GetKeyCount() > 0 || label.Type == NumberDataViewType.Single || label.Type == NumberDataViewType.Double);
151151

152-
if (lab.Type.GetKeyCount() > 0)
152+
if (label.Type.GetKeyCount() > 0)
153153
{
154154
// Key values are 1-based.
155155
uint key = (uint)(cls + 1);
156156
return MapLabelsCore(NumberDataViewType.UInt32, (in uint val) => key == val, data);
157157
}
158-
if (lab.Type == NumberDataViewType.Single)
159-
{
160-
float key = cls;
161-
return MapLabelsCore(NumberDataViewType.Single, (in float val) => key == val, data);
162-
}
163-
if (lab.Type == NumberDataViewType.Double)
164-
{
165-
double key = cls;
166-
return MapLabelsCore(NumberDataViewType.Double, (in double val) => key == val, data);
167-
}
168158

169-
throw Host.ExceptNotSupp($"Label column type is not supported by OneVersusAllTrainer: {lab.Type.RawType}");
159+
throw Host.ExceptNotSupp($"Label column type is not supported by OneVersusAllTrainer: {label.Type.RawType}");
170160
}
171161

172162
/// <summary> Trains a <see cref="MulticlassPredictionTransformer{OneVersusAllModelParameters}"/> model.</summary>

src/Microsoft.ML.StandardTrainers/Standard/MultiClass/PairwiseCouplingTrainer.cs

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -142,31 +142,19 @@ private ISingleFeaturePredictionTransformer<TDistPredictor> TrainOne(IChannel ch
142142

143143
private IDataView MapLabels(RoleMappedData data, int cls1, int cls2)
144144
{
145-
var lab = data.Schema.Label.Value;
146-
Host.Assert(!lab.IsHidden);
147-
Host.Assert(lab.Type.GetKeyCount() > 0 || lab.Type == NumberDataViewType.Single || lab.Type == NumberDataViewType.Double);
145+
var label = data.Schema.Label.Value;
146+
Host.Assert(!label.IsHidden);
147+
Host.Assert(label.Type.GetKeyCount() > 0 || label.Type == NumberDataViewType.Single || label.Type == NumberDataViewType.Double);
148148

149-
if (lab.Type.GetKeyCount() > 0)
149+
if (label.Type.GetKeyCount() > 0)
150150
{
151151
// Key values are 1-based.
152152
uint key1 = (uint)(cls1 + 1);
153153
uint key2 = (uint)(cls2 + 1);
154154
return MapLabelsCore(NumberDataViewType.UInt32, (in uint val) => val == key1 || val == key2, data);
155155
}
156-
if (lab.Type == NumberDataViewType.Single)
157-
{
158-
float key1 = cls1;
159-
float key2 = cls2;
160-
return MapLabelsCore(NumberDataViewType.Single, (in float val) => val == key1 || val == key2, data);
161-
}
162-
if (lab.Type == NumberDataViewType.Double)
163-
{
164-
double key1 = cls1;
165-
double key2 = cls2;
166-
return MapLabelsCore(NumberDataViewType.Double, (in double val) => val == key1 || val == key2, data);
167-
}
168156

169-
throw Host.ExceptNotSupp($"Label column type is not supported by nameof(PairwiseCouplingTrainer): {lab.Type.RawType}");
157+
throw Host.ExceptNotSupp($"Label column type is not supported by nameof(PairwiseCouplingTrainer): {label.Type.RawType}");
170158
}
171159

172160
/// <summary>

src/Microsoft.ML.StandardTrainers/Standard/Online/AveragedPerceptron.cs

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -174,20 +174,6 @@ private protected override void CheckLabels(RoleMappedData data)
174174
data.CheckBinaryLabel();
175175
}
176176

177-
private protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
178-
{
179-
Contracts.Assert(labelCol.IsValid);
180-
181-
Action error =
182-
() => throw Host.ExceptSchemaMismatch(nameof(labelCol), "label", labelCol.Name, "float, double, bool or KeyType", labelCol.GetTypeString());
183-
184-
if (labelCol.Kind != SchemaShape.Column.VectorKind.Scalar)
185-
error();
186-
187-
if (!labelCol.IsKey && labelCol.ItemType != NumberDataViewType.Single && labelCol.ItemType != NumberDataViewType.Double && !(labelCol.ItemType is BooleanDataViewType))
188-
error();
189-
}
190-
191177
private protected override TrainStateBase MakeState(IChannel ch, int numFeatures, LinearModelParameters predictor)
192178
{
193179
return new TrainState(ch, numFeatures, predictor, this);

src/Microsoft.ML.StandardTrainers/Standard/SdcaBinary.cs

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1512,20 +1512,6 @@ private protected SdcaBinaryTrainerBase(IHostEnvironment env, BinaryOptionsBase
15121512

15131513
private protected abstract SchemaShape.Column[] ComputeSdcaBinaryClassifierSchemaShape();
15141514

1515-
private protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
1516-
{
1517-
Contracts.Assert(labelCol.IsValid);
1518-
1519-
Action error =
1520-
() => throw Host.ExceptSchemaMismatch(nameof(labelCol), "label", labelCol.Name, "float, double, bool or KeyType", labelCol.GetTypeString());
1521-
1522-
if (labelCol.Kind != SchemaShape.Column.VectorKind.Scalar)
1523-
error();
1524-
1525-
if (!labelCol.IsKey && labelCol.ItemType != NumberDataViewType.Single && labelCol.ItemType != NumberDataViewType.Double && !(labelCol.ItemType is BooleanDataViewType))
1526-
error();
1527-
}
1528-
15291515
private protected LinearBinaryModelParameters CreateLinearBinaryModelParameters(VBuffer<float>[] weights, float[] bias)
15301516
{
15311517
Host.CheckParam(Utils.Size(weights) == 1, nameof(weights));

0 commit comments

Comments
 (0)