Skip to content

Commit b8a70ac

Browse files
authored
Fix a value-mapping bug (#3180)
1 parent 9d79ab3 commit b8a70ac

File tree

3 files changed

+315
-158
lines changed

3 files changed

+315
-158
lines changed

src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,14 @@ public static ValueMappingEstimator<TInputType, TOutputType> MapValue<TInputType
264264
{
265265
var keys = keyValuePairs.Select(pair => pair.Key);
266266
var values = keyValuePairs.Select(pair => pair.Value);
267-
return new ValueMappingEstimator<TInputType, TOutputType>(CatalogUtils.GetEnvironment(catalog), keys, values, treatValuesAsKeyType,
267+
268+
var lookupMap = DataViewHelper.CreateDataView(catalog.GetEnvironment(), keys, values,
269+
ValueMappingTransformer.DefaultKeyColumnName,
270+
ValueMappingTransformer.DefaultValueColumnName, treatValuesAsKeyType);
271+
272+
return new ValueMappingEstimator<TInputType, TOutputType>(catalog.GetEnvironment(), lookupMap,
273+
lookupMap.Schema[ValueMappingTransformer.DefaultKeyColumnName],
274+
lookupMap.Schema[ValueMappingTransformer.DefaultValueColumnName],
268275
new[] { (outputColumnName, inputColumnName ?? outputColumnName) });
269276
}
270277

@@ -287,7 +294,15 @@ internal static ValueMappingEstimator<TInputType, TOutputType> MapValue<TInputTy
287294
env.CheckValue(columns, nameof(columns));
288295
var keys = keyValuePairs.Select(pair => pair.Key);
289296
var values = keyValuePairs.Select(pair => pair.Value);
290-
return new ValueMappingEstimator<TInputType, TOutputType>(env, keys, values, InputOutputColumnPair.ConvertToValueTuples(columns));
297+
298+
var lookupMap = DataViewHelper.CreateDataView(catalog.GetEnvironment(), keys, values,
299+
ValueMappingTransformer.DefaultKeyColumnName,
300+
ValueMappingTransformer.DefaultValueColumnName, false);
301+
302+
return new ValueMappingEstimator<TInputType, TOutputType>(catalog.GetEnvironment(), lookupMap,
303+
lookupMap.Schema[ValueMappingTransformer.DefaultKeyColumnName],
304+
lookupMap.Schema[ValueMappingTransformer.DefaultValueColumnName],
305+
InputOutputColumnPair.ConvertToValueTuples(columns));
291306
}
292307

293308
/// <summary>
@@ -311,8 +326,15 @@ internal static ValueMappingEstimator<TInputType, TOutputType> MapValue<TInputTy
311326
env.CheckValue(columns, nameof(columns));
312327
var keys = keyValuePairs.Select(pair => pair.Key);
313328
var values = keyValuePairs.Select(pair => pair.Value);
314-
return new ValueMappingEstimator<TInputType, TOutputType>(env, keys, values, treatValuesAsKeyType,
315-
InputOutputColumnPair.ConvertToValueTuples(columns));
329+
330+
var lookupMap = DataViewHelper.CreateDataView(catalog.GetEnvironment(), keys, values,
331+
ValueMappingTransformer.DefaultKeyColumnName,
332+
ValueMappingTransformer.DefaultValueColumnName, treatValuesAsKeyType);
333+
334+
return new ValueMappingEstimator<TInputType, TOutputType>(catalog.GetEnvironment(), lookupMap,
335+
lookupMap.Schema[ValueMappingTransformer.DefaultKeyColumnName],
336+
lookupMap.Schema[ValueMappingTransformer.DefaultValueColumnName],
337+
InputOutputColumnPair.ConvertToValueTuples(columns));
316338
}
317339

318340
/// <summary>
@@ -339,7 +361,15 @@ public static ValueMappingEstimator<TInputType, TOutputType> MapValue<TInputType
339361
{
340362
var keys = keyValuePairs.Select(pair => pair.Key);
341363
var values = keyValuePairs.Select(pair => pair.Value);
342-
return new ValueMappingEstimator<TInputType, TOutputType>(CatalogUtils.GetEnvironment(catalog), keys, values,
364+
365+
// Convert parallel key and value lists to IDataView with two columns, so that the underlying infra can use it.
366+
var lookupMap = DataViewHelper.CreateDataView(catalog.GetEnvironment(), keys, values,
367+
ValueMappingTransformer.DefaultKeyColumnName,
368+
ValueMappingTransformer.DefaultValueColumnName);
369+
370+
return new ValueMappingEstimator<TInputType, TOutputType>(catalog.GetEnvironment(), lookupMap,
371+
lookupMap.Schema[ValueMappingTransformer.DefaultKeyColumnName],
372+
lookupMap.Schema[ValueMappingTransformer.DefaultValueColumnName],
343373
new[] { (outputColumnName, inputColumnName ?? outputColumnName) });
344374
}
345375

@@ -362,8 +392,15 @@ internal static ValueMappingEstimator<TInputType, TOutputType> MapValue<TInputTy
362392
env.CheckValue(columns, nameof(columns));
363393
var keys = keyValuePairs.Select(pair => pair.Key);
364394
var values = keyValuePairs.Select(pair => pair.Value);
365-
return new ValueMappingEstimator<TInputType, TOutputType>(env, keys, values,
366-
InputOutputColumnPair.ConvertToValueTuples(columns));
395+
396+
var lookupMap = DataViewHelper.CreateDataView(catalog.GetEnvironment(), keys, values,
397+
ValueMappingTransformer.DefaultKeyColumnName,
398+
ValueMappingTransformer.DefaultValueColumnName);
399+
400+
return new ValueMappingEstimator<TInputType, TOutputType>(catalog.GetEnvironment(), lookupMap,
401+
lookupMap.Schema[ValueMappingTransformer.DefaultKeyColumnName],
402+
lookupMap.Schema[ValueMappingTransformer.DefaultValueColumnName],
403+
InputOutputColumnPair.ConvertToValueTuples(columns));
367404
}
368405

369406
/// <summary>
@@ -386,7 +423,7 @@ public static ValueMappingEstimator MapValue(
386423
this TransformsCatalog.ConversionTransforms catalog,
387424
string outputColumnName, IDataView lookupMap, DataViewSchema.Column keyColumn, DataViewSchema.Column valueColumn, string inputColumnName = null)
388425
{
389-
return new ValueMappingEstimator(CatalogUtils.GetEnvironment(catalog), lookupMap, keyColumn.Name, valueColumn.Name,
426+
return new ValueMappingEstimator(CatalogUtils.GetEnvironment(catalog), lookupMap, keyColumn, valueColumn,
390427
new[] { (outputColumnName, inputColumnName ?? outputColumnName) });
391428
}
392429

@@ -406,8 +443,7 @@ internal static ValueMappingEstimator MapValue(
406443
{
407444
var env = CatalogUtils.GetEnvironment(catalog);
408445
env.CheckValue(columns, nameof(columns));
409-
return new ValueMappingEstimator(env, lookupMap, keyColumn.Name, valueColumn.Name,
410-
InputOutputColumnPair.ConvertToValueTuples(columns));
446+
return new ValueMappingEstimator(env, lookupMap, keyColumn, valueColumn, InputOutputColumnPair.ConvertToValueTuples(columns));
411447
}
412448
}
413449
}

0 commit comments

Comments
 (0)