6
6
using System . Collections . Generic ;
7
7
using System . IO ;
8
8
using System . Linq ;
9
- using Microsoft . ML . Runtime ;
10
9
using Microsoft . ML . Runtime . Api ;
11
10
using Microsoft . ML . Runtime . Core . Tests . UnitTests ;
12
11
using Microsoft . ML . Runtime . Data ;
13
12
using Microsoft . ML . Runtime . Data . IO ;
14
13
using Microsoft . ML . Runtime . EntryPoints ;
15
14
using Microsoft . ML . Runtime . EntryPoints . JsonUtils ;
15
+ using Microsoft . ML . Runtime . FastTree ;
16
16
using Microsoft . ML . Runtime . Internal . Utilities ;
17
17
using Microsoft . ML . Runtime . Learners ;
18
18
using Newtonsoft . Json ;
@@ -2521,5 +2521,70 @@ public void EntryPointPrepareLabelConvertPredictedLabel()
2521
2521
}
2522
2522
}
2523
2523
}
2524
+
2525
+ [ Fact ]
2526
+ public void EntryPointTreeLeafFeaturizer ( )
2527
+ {
2528
+ var dataPath = GetDataPath ( @"adult.tiny.with-schema.txt" ) ;
2529
+ var inputFile = new SimpleFileHandle ( Env , dataPath , false , false ) ;
2530
+ var dataView = ImportTextData . ImportText ( Env , new ImportTextData . Input { InputFile = inputFile } ) . Data ;
2531
+ var cat = Categorical . CatTransformDict ( Env , new CategoricalTransform . Arguments ( )
2532
+ {
2533
+ Data = dataView ,
2534
+ Column = new [ ] { new CategoricalTransform . Column { Name = "Categories" , Source = "Categories" } }
2535
+ } ) ;
2536
+ var concat = SchemaManipulation . ConcatColumns ( Env , new ConcatTransform . Arguments ( )
2537
+ {
2538
+ Data = cat . OutputData ,
2539
+ Column = new [ ] { new ConcatTransform . Column { Name = "Features" , Source = new [ ] { "Categories" , "NumericFeatures" } } }
2540
+ } ) ;
2541
+
2542
+ var fastTree = FastTree . FastTree . TrainBinary ( Env , new FastTreeBinaryClassificationTrainer . Arguments
2543
+ {
2544
+ FeatureColumn = "Features" ,
2545
+ NumTrees = 5 ,
2546
+ NumLeaves = 4 ,
2547
+ LabelColumn = DefaultColumnNames . Label ,
2548
+ TrainingData = concat . OutputData
2549
+ } ) ;
2550
+
2551
+ var combine = ModelOperations . CombineModels ( Env , new ModelOperations . PredictorModelInput ( )
2552
+ {
2553
+ PredictorModel = fastTree . PredictorModel ,
2554
+ TransformModels = new [ ] { cat . Model , concat . Model }
2555
+ } ) ;
2556
+
2557
+ var treeLeaf = TreeFeaturize . Featurizer ( Env , new TreeEnsembleFeaturizerTransform . ArgumentsForEntryPoint
2558
+ {
2559
+ Data = dataView ,
2560
+ PredictorModel = combine . PredictorModel
2561
+ } ) ;
2562
+
2563
+ var view = treeLeaf . OutputData ;
2564
+ Assert . True ( view . Schema . TryGetColumnIndex ( "Trees" , out int treesCol ) ) ;
2565
+ Assert . True ( view . Schema . TryGetColumnIndex ( "Leaves" , out int leavesCol ) ) ;
2566
+ Assert . True ( view . Schema . TryGetColumnIndex ( "Paths" , out int pathsCol ) ) ;
2567
+ VBuffer < float > treeValues = default ( VBuffer < float > ) ;
2568
+ VBuffer < float > leafIndicators = default ( VBuffer < float > ) ;
2569
+ VBuffer < float > pathIndicators = default ( VBuffer < float > ) ;
2570
+ using ( var curs = view . GetRowCursor ( c => c == treesCol || c == leavesCol || c == pathsCol ) )
2571
+ {
2572
+ var treesGetter = curs . GetGetter < VBuffer < float > > ( treesCol ) ;
2573
+ var leavesGetter = curs . GetGetter < VBuffer < float > > ( leavesCol ) ;
2574
+ var pathsGetter = curs . GetGetter < VBuffer < float > > ( pathsCol ) ;
2575
+ while ( curs . MoveNext ( ) )
2576
+ {
2577
+ treesGetter ( ref treeValues ) ;
2578
+ leavesGetter ( ref leafIndicators ) ;
2579
+ pathsGetter ( ref pathIndicators ) ;
2580
+
2581
+ Assert . Equal ( 5 , treeValues . Length ) ;
2582
+ Assert . Equal ( 5 , treeValues . Count ) ;
2583
+ Assert . Equal ( 20 , leafIndicators . Length ) ;
2584
+ Assert . Equal ( 5 , leafIndicators . Count ) ;
2585
+ Assert . Equal ( 15 , pathIndicators . Length ) ;
2586
+ }
2587
+ }
2588
+ }
2524
2589
}
2525
2590
}
0 commit comments