Skip to content

Commit 23433c0

Browse files
yaeldMSeerhardt
authored andcommitted
Fix a bug in Tree leaf featurizer entry point, and add a test for it. (dotnet#131)
* Fix a bug in Tree leaf featurizer entry point, and add a test for it. * Improve unit test * Update unit test * Decrease number of trees and leaves in unit test
1 parent 256b8c8 commit 23433c0

File tree

3 files changed

+75
-7
lines changed

3 files changed

+75
-7
lines changed

src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -703,10 +703,12 @@ public static IDataTransform CreateForEntryPoint(IHostEnvironment env, Arguments
703703
using (var ch = host.Start("Create Tree Ensemble Scorer"))
704704
{
705705
var scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments() { Suffix = args.Suffix };
706-
var predictor = args.PredictorModel?.Predictor;
706+
var predictor = args.PredictorModel.Predictor;
707707
ch.Trace("Prepare data");
708708
RoleMappedData data = null;
709-
args.PredictorModel?.PrepareData(env, input, out data, out var predictor2);
709+
args.PredictorModel.PrepareData(env, input, out data, out var predictor2);
710+
ch.AssertValue(data);
711+
ch.Assert(predictor == predictor2);
710712

711713
// Make sure that the given predictor has the correct number of input features.
712714
if (predictor is CalibratedPredictorBase)
@@ -715,16 +717,16 @@ public static IDataTransform CreateForEntryPoint(IHostEnvironment env, Arguments
715717
// be non-null.
716718
var vm = predictor as IValueMapper;
717719
ch.CheckUserArg(vm != null, nameof(args.PredictorModel), "Predictor does not have compatible type");
718-
if (data != null && vm?.InputType.VectorSize != data.Schema.Feature.Type.VectorSize)
720+
if (data != null && vm.InputType.VectorSize != data.Schema.Feature.Type.VectorSize)
719721
{
720722
throw ch.ExceptUserArg(nameof(args.PredictorModel),
721723
"Predictor expects {0} features, but data has {1} features",
722-
vm?.InputType.VectorSize, data.Schema.Feature.Type.VectorSize);
724+
vm.InputType.VectorSize, data.Schema.Feature.Type.VectorSize);
723725
}
724726

725727
var bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor);
726-
var bound = bindable.Bind(env, data?.Schema);
727-
xf = new GenericScorer(env, scorerArgs, input, bound, data?.Schema);
728+
var bound = bindable.Bind(env, data.Schema);
729+
xf = new GenericScorer(env, scorerArgs, data.Data, bound, data.Schema);
728730
ch.Done();
729731
}
730732
return xf;

test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
<ItemGroup>
2020
<NativeAssemblyReference Include="CpuMathNative" />
21+
<NativeAssemblyReference Include="FastTreeNative" />
2122
</ItemGroup>
2223

2324
</Project>

test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
using System.Collections.Generic;
77
using System.IO;
88
using System.Linq;
9-
using Microsoft.ML.Runtime;
109
using Microsoft.ML.Runtime.Api;
1110
using Microsoft.ML.Runtime.Core.Tests.UnitTests;
1211
using Microsoft.ML.Runtime.Data;
1312
using Microsoft.ML.Runtime.Data.IO;
1413
using Microsoft.ML.Runtime.EntryPoints;
1514
using Microsoft.ML.Runtime.EntryPoints.JsonUtils;
15+
using Microsoft.ML.Runtime.FastTree;
1616
using Microsoft.ML.Runtime.Internal.Utilities;
1717
using Microsoft.ML.Runtime.Learners;
1818
using Newtonsoft.Json;
@@ -2521,5 +2521,70 @@ public void EntryPointPrepareLabelConvertPredictedLabel()
25212521
}
25222522
}
25232523
}
2524+
2525+
[Fact]
2526+
public void EntryPointTreeLeafFeaturizer()
2527+
{
2528+
var dataPath = GetDataPath(@"adult.tiny.with-schema.txt");
2529+
var inputFile = new SimpleFileHandle(Env, dataPath, false, false);
2530+
var dataView = ImportTextData.ImportText(Env, new ImportTextData.Input { InputFile = inputFile }).Data;
2531+
var cat = Categorical.CatTransformDict(Env, new CategoricalTransform.Arguments()
2532+
{
2533+
Data = dataView,
2534+
Column = new[] { new CategoricalTransform.Column { Name = "Categories", Source = "Categories" } }
2535+
});
2536+
var concat = SchemaManipulation.ConcatColumns(Env, new ConcatTransform.Arguments()
2537+
{
2538+
Data = cat.OutputData,
2539+
Column = new[] { new ConcatTransform.Column { Name = "Features", Source = new[] { "Categories", "NumericFeatures" } } }
2540+
});
2541+
2542+
var fastTree = FastTree.FastTree.TrainBinary(Env, new FastTreeBinaryClassificationTrainer.Arguments
2543+
{
2544+
FeatureColumn = "Features",
2545+
NumTrees = 5,
2546+
NumLeaves = 4,
2547+
LabelColumn = DefaultColumnNames.Label,
2548+
TrainingData = concat.OutputData
2549+
});
2550+
2551+
var combine = ModelOperations.CombineModels(Env, new ModelOperations.PredictorModelInput()
2552+
{
2553+
PredictorModel = fastTree.PredictorModel,
2554+
TransformModels = new[] { cat.Model, concat.Model }
2555+
});
2556+
2557+
var treeLeaf = TreeFeaturize.Featurizer(Env, new TreeEnsembleFeaturizerTransform.ArgumentsForEntryPoint
2558+
{
2559+
Data = dataView,
2560+
PredictorModel = combine.PredictorModel
2561+
});
2562+
2563+
var view = treeLeaf.OutputData;
2564+
Assert.True(view.Schema.TryGetColumnIndex("Trees", out int treesCol));
2565+
Assert.True(view.Schema.TryGetColumnIndex("Leaves", out int leavesCol));
2566+
Assert.True(view.Schema.TryGetColumnIndex("Paths", out int pathsCol));
2567+
VBuffer<float> treeValues = default(VBuffer<float>);
2568+
VBuffer<float> leafIndicators = default(VBuffer<float>);
2569+
VBuffer<float> pathIndicators = default(VBuffer<float>);
2570+
using (var curs = view.GetRowCursor(c => c == treesCol || c == leavesCol || c == pathsCol))
2571+
{
2572+
var treesGetter = curs.GetGetter<VBuffer<float>>(treesCol);
2573+
var leavesGetter = curs.GetGetter<VBuffer<float>>(leavesCol);
2574+
var pathsGetter = curs.GetGetter<VBuffer<float>>(pathsCol);
2575+
while (curs.MoveNext())
2576+
{
2577+
treesGetter(ref treeValues);
2578+
leavesGetter(ref leafIndicators);
2579+
pathsGetter(ref pathIndicators);
2580+
2581+
Assert.Equal(5, treeValues.Length);
2582+
Assert.Equal(5, treeValues.Count);
2583+
Assert.Equal(20, leafIndicators.Length);
2584+
Assert.Equal(5, leafIndicators.Count);
2585+
Assert.Equal(15, pathIndicators.Length);
2586+
}
2587+
}
2588+
}
25242589
}
25252590
}

0 commit comments

Comments
 (0)