8
8
using System . Collections . Immutable ;
9
9
using System . IO ;
10
10
using System . Linq ;
11
+ using System . Reflection ;
11
12
using Microsoft . ML ;
12
13
using Microsoft . ML . Calibrators ;
13
14
using Microsoft . ML . CommandLine ;
@@ -396,6 +397,7 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string
396
397
}
397
398
398
399
[ BestFriend ]
400
+ [ PredictionTransformerLoadType ( typeof ( CalibratedModelParametersBase < , > ) ) ]
399
401
internal sealed class ValueMapperCalibratedModelParameters < TSubModel , TCalibrator > :
400
402
ValueMapperCalibratedModelParametersBase < TSubModel , TCalibrator > , ICanSaveModel
401
403
where TSubModel : class
@@ -430,8 +432,8 @@ private static VersionInfo GetVersionInfoBulk()
430
432
loaderAssemblyName : typeof ( ValueMapperCalibratedModelParameters < TSubModel , TCalibrator > ) . Assembly . FullName ) ;
431
433
}
432
434
433
- private ValueMapperCalibratedModelParameters ( IHostEnvironment env , ModelLoadContext ctx )
434
- : base ( env , RegistrationName , GetPredictor ( env , ctx ) , GetCalibrator ( env , ctx ) )
435
+ private ValueMapperCalibratedModelParameters ( IHostEnvironment env , ModelLoadContext ctx , TSubModel predictor , TCalibrator calibrator )
436
+ : base ( env , RegistrationName , predictor , calibrator )
435
437
{
436
438
}
437
439
@@ -443,7 +445,16 @@ private static CalibratedModelParametersBase Create(IHostEnvironment env, ModelL
443
445
var ver2 = GetVersionInfoBulk ( ) ;
444
446
var ver = ctx . Header . ModelSignature == ver2 . ModelSignature ? ver2 : ver1 ;
445
447
ctx . CheckAtModel ( ver ) ;
446
- return new ValueMapperCalibratedModelParameters < TSubModel , TCalibrator > ( env , ctx ) ;
448
+
449
+ // Load first the predictor and calibrator
450
+ var predictor = GetPredictor ( env , ctx ) ;
451
+ var calibrator = GetCalibrator ( env , ctx ) ;
452
+
453
+ // Create a generic type using the correct parameter types of predictor and calibrator
454
+ Type genericType = typeof ( ValueMapperCalibratedModelParameters < , > ) ;
455
+ var genericInstance = CreateCalibratedModelParameters . Create ( env , ctx , predictor , calibrator , genericType ) ;
456
+
457
+ return ( CalibratedModelParametersBase ) genericInstance ;
447
458
}
448
459
449
460
void ICanSaveModel . Save ( ModelSaveContext ctx )
@@ -456,6 +467,7 @@ void ICanSaveModel.Save(ModelSaveContext ctx)
456
467
}
457
468
458
469
[ BestFriend ]
470
+ [ PredictionTransformerLoadType ( typeof ( CalibratedModelParametersBase < , > ) ) ]
459
471
internal sealed class FeatureWeightsCalibratedModelParameters < TSubModel , TCalibrator > :
460
472
ValueMapperCalibratedModelParametersBase < TSubModel , TCalibrator > ,
461
473
IPredictorWithFeatureWeights < float > ,
@@ -487,8 +499,9 @@ private static VersionInfo GetVersionInfo()
487
499
loaderAssemblyName : typeof ( FeatureWeightsCalibratedModelParameters < TSubModel , TCalibrator > ) . Assembly . FullName ) ;
488
500
}
489
501
490
- private FeatureWeightsCalibratedModelParameters ( IHostEnvironment env , ModelLoadContext ctx )
491
- : base ( env , RegistrationName , GetPredictor ( env , ctx ) , GetCalibrator ( env , ctx ) )
502
+ private FeatureWeightsCalibratedModelParameters ( IHostEnvironment env , ModelLoadContext ctx ,
503
+ TSubModel predictor , TCalibrator calibrator )
504
+ : base ( env , RegistrationName , predictor , calibrator )
492
505
{
493
506
Host . Check ( SubModel is IPredictorWithFeatureWeights < float > , "Predictor does not implement " + nameof ( IPredictorWithFeatureWeights < float > ) ) ;
494
507
_featureWeights = ( IPredictorWithFeatureWeights < float > ) SubModel ;
@@ -499,7 +512,16 @@ private static CalibratedModelParametersBase Create(IHostEnvironment env, ModelL
499
512
Contracts . CheckValue ( env , nameof ( env ) ) ;
500
513
env . CheckValue ( ctx , nameof ( ctx ) ) ;
501
514
ctx . CheckAtModel ( GetVersionInfo ( ) ) ;
502
- return new FeatureWeightsCalibratedModelParameters < TSubModel , TCalibrator > ( env , ctx ) ;
515
+
516
+ // Load first the predictor and calibrator
517
+ var predictor = GetPredictor ( env , ctx ) ;
518
+ var calibrator = GetCalibrator ( env , ctx ) ;
519
+
520
+ // Create a generic type using the correct parameter types of predictor and calibrator
521
+ Type genericType = typeof ( FeatureWeightsCalibratedModelParameters < , > ) ;
522
+ var genericInstance = CreateCalibratedModelParameters . Create ( env , ctx , predictor , calibrator , genericType ) ;
523
+
524
+ return ( CalibratedModelParametersBase ) genericInstance ;
503
525
}
504
526
505
527
void ICanSaveModel . Save ( ModelSaveContext ctx )
@@ -520,6 +542,7 @@ public void GetFeatureWeights(ref VBuffer<float> weights)
520
542
/// Encapsulates a predictor and a calibrator that implement <see cref="IParameterMixer"/>.
521
543
/// Its implementation of <see cref="IParameterMixer.CombineParameters"/> combines both the predictors and the calibrators.
522
544
/// </summary>
545
+ [ PredictionTransformerLoadType ( typeof ( CalibratedModelParametersBase < , > ) ) ]
523
546
internal sealed class ParameterMixingCalibratedModelParameters < TSubModel , TCalibrator > :
524
547
ValueMapperCalibratedModelParametersBase < TSubModel , TCalibrator > ,
525
548
IParameterMixer < float > ,
@@ -553,8 +576,8 @@ private static VersionInfo GetVersionInfo()
553
576
loaderAssemblyName : typeof ( ParameterMixingCalibratedModelParameters < TSubModel , TCalibrator > ) . Assembly . FullName ) ;
554
577
}
555
578
556
- private ParameterMixingCalibratedModelParameters ( IHostEnvironment env , ModelLoadContext ctx )
557
- : base ( env , RegistrationName , GetPredictor ( env , ctx ) , GetCalibrator ( env , ctx ) )
579
+ private ParameterMixingCalibratedModelParameters ( IHostEnvironment env , ModelLoadContext ctx , TSubModel predictor , TCalibrator calibrator )
580
+ : base ( env , RegistrationName , predictor , calibrator )
558
581
{
559
582
Host . Check ( SubModel is IParameterMixer < float > , "Predictor does not implement " + nameof ( IParameterMixer ) ) ;
560
583
Host . Check ( SubModel is IPredictorWithFeatureWeights < float > , "Predictor does not implement " + nameof ( IPredictorWithFeatureWeights < float > ) ) ;
@@ -566,7 +589,16 @@ private static CalibratedModelParametersBase Create(IHostEnvironment env, ModelL
566
589
Contracts . CheckValue ( env , nameof ( env ) ) ;
567
590
env . CheckValue ( ctx , nameof ( ctx ) ) ;
568
591
ctx . CheckAtModel ( GetVersionInfo ( ) ) ;
569
- return new ParameterMixingCalibratedModelParameters < TSubModel , TCalibrator > ( env , ctx ) ;
592
+
593
+ // Load first the predictor and calibrator
594
+ var predictor = GetPredictor ( env , ctx ) ;
595
+ var calibrator = GetCalibrator ( env , ctx ) ;
596
+
597
+ // Create a generic type using the correct parameter types of predictor and calibrator
598
+ Type genericType = typeof ( ParameterMixingCalibratedModelParameters < , > ) ;
599
+ object genericInstance = CreateCalibratedModelParameters . Create ( env , ctx , predictor , calibrator , genericType ) ;
600
+
601
+ return ( CalibratedModelParametersBase ) genericInstance ;
570
602
}
571
603
572
604
void ICanSaveModel . Save ( ModelSaveContext ctx )
@@ -777,6 +809,28 @@ ValueMapper<TSrc, VBuffer<float>> IFeatureContributionMapper.GetFeatureContribut
777
809
}
778
810
}
779
811
812
+ internal static class CreateCalibratedModelParameters
813
+ {
814
+ internal static object Create ( IHostEnvironment env , ModelLoadContext ctx , object predictor , ICalibrator calibrator , Type calibratedModelParametersType )
815
+ {
816
+ Type [ ] genericTypeArgs = { predictor . GetType ( ) , calibrator . GetType ( ) } ;
817
+ Type constructed = calibratedModelParametersType . MakeGenericType ( genericTypeArgs ) ;
818
+
819
+ Type [ ] constructorArgs = {
820
+ typeof ( IHostEnvironment ) ,
821
+ typeof ( ModelLoadContext ) ,
822
+ predictor . GetType ( ) ,
823
+ calibrator . GetType ( )
824
+ } ;
825
+
826
+ // Call the appropiate constructor of the created generic type passing on the previously loaded predictor and calibrator
827
+ var genericCtor = constructed . GetConstructor ( BindingFlags . NonPublic | BindingFlags . Instance , null , constructorArgs , null ) ;
828
+ object genericInstance = genericCtor . Invoke ( new object [ ] { env , ctx , predictor , calibrator } ) ;
829
+
830
+ return genericInstance ;
831
+ }
832
+ }
833
+
780
834
[ BestFriend ]
781
835
internal static class CalibratorUtils
782
836
{
0 commit comments