@@ -475,7 +475,7 @@ public float Cost
475
475
476
476
private EntryPointNode ( IHostEnvironment env , IChannel ch , ModuleCatalog moduleCatalog , RunContext context ,
477
477
string id , string entryPointName , JObject inputs , JObject outputs , bool checkpoint = false ,
478
- string stageId = "" , float cost = float . NaN , string label = null , string group = null , string weight = null )
478
+ string stageId = "" , float cost = float . NaN , string label = null , string group = null , string weight = null , string name = null )
479
479
{
480
480
Contracts . AssertValue ( env ) ;
481
481
env . AssertNonEmpty ( id ) ;
@@ -510,49 +510,10 @@ private EntryPointNode(IHostEnvironment env, IChannel ch, ModuleCatalog moduleCa
510
510
throw _host . Except ( $ "The following required inputs were not provided: { String . Join ( ", " , missing ) } ") ;
511
511
512
512
var inputInstance = _inputBuilder . GetInstance ( ) ;
513
- var warning = "Different {0} column specified in trainer and in macro: '{1}', '{2}'." +
514
- " Using column '{2}'. To column use '{1}' instead, please specify this name in" +
515
- "the trainer node arguments." ;
516
- if ( ! string . IsNullOrEmpty ( label ) && Utils . Size ( _entryPoint . InputKinds ) > 0 &&
517
- _entryPoint . InputKinds . Contains ( typeof ( CommonInputs . ITrainerInputWithLabel ) ) )
518
- {
519
- var labelColField = _inputBuilder . GetFieldNameOrNull ( "LabelColumn" ) ;
520
- ch . AssertNonEmpty ( labelColField ) ;
521
- var labelColFieldType = _inputBuilder . GetFieldTypeOrNull ( labelColField ) ;
522
- ch . Assert ( labelColFieldType == typeof ( string ) ) ;
523
- var inputLabel = inputInstance . GetType ( ) . GetField ( labelColField ) . GetValue ( inputInstance ) ;
524
- if ( label != ( string ) inputLabel )
525
- ch . Warning ( warning , "label" , label , inputLabel ) ;
526
- else
527
- _inputBuilder . TrySetValue ( labelColField , label ) ;
528
- }
529
- if ( ! string . IsNullOrEmpty ( group ) && Utils . Size ( _entryPoint . InputKinds ) > 0 &&
530
- _entryPoint . InputKinds . Contains ( typeof ( CommonInputs . ITrainerInputWithGroupId ) ) )
531
- {
532
- var groupColField = _inputBuilder . GetFieldNameOrNull ( "GroupIdColumn" ) ;
533
- ch . AssertNonEmpty ( groupColField ) ;
534
- var groupColFieldType = _inputBuilder . GetFieldTypeOrNull ( groupColField ) ;
535
- ch . Assert ( groupColFieldType == typeof ( string ) ) ;
536
- var inputGroup = inputInstance . GetType ( ) . GetField ( groupColField ) . GetValue ( inputInstance ) ;
537
- if ( group != ( Optional < string > ) inputGroup )
538
- ch . Warning ( warning , "group Id" , label , inputGroup ) ;
539
- else
540
- _inputBuilder . TrySetValue ( groupColField , label ) ;
541
- }
542
- if ( ! string . IsNullOrEmpty ( weight ) && Utils . Size ( _entryPoint . InputKinds ) > 0 &&
543
- ( _entryPoint . InputKinds . Contains ( typeof ( CommonInputs . ITrainerInputWithWeight ) ) ||
544
- _entryPoint . InputKinds . Contains ( typeof ( CommonInputs . IUnsupervisedTrainerWithWeight ) ) ) )
545
- {
546
- var weightColField = _inputBuilder . GetFieldNameOrNull ( "WeightColumn" ) ;
547
- ch . AssertNonEmpty ( weightColField ) ;
548
- var weightColFieldType = _inputBuilder . GetFieldTypeOrNull ( weightColField ) ;
549
- ch . Assert ( weightColFieldType == typeof ( string ) ) ;
550
- var inputWeight = inputInstance . GetType ( ) . GetField ( weightColField ) . GetValue ( inputInstance ) ;
551
- if ( weight != ( Optional < string > ) inputWeight )
552
- ch . Warning ( warning , "weight" , label , inputWeight ) ;
553
- else
554
- _inputBuilder . TrySetValue ( weightColField , label ) ;
555
- }
513
+ SetColumnArgument ( ch , inputInstance , "LabelColumn" , label , "label" , typeof ( CommonInputs . ITrainerInputWithLabel ) ) ;
514
+ SetColumnArgument ( ch , inputInstance , "GroupIdColumn" , group , "group Id" , typeof ( CommonInputs . ITrainerInputWithGroupId ) ) ;
515
+ SetColumnArgument ( ch , inputInstance , "WeightColumn" , weight , "weight" , typeof ( CommonInputs . ITrainerInputWithWeight ) , typeof ( CommonInputs . IUnsupervisedTrainerWithWeight ) ) ;
516
+ SetColumnArgument ( ch , inputInstance , "NameColumn" , name , "name" ) ;
556
517
557
518
// Validate outputs.
558
519
_outputHelper = new OutputHelper ( _host , _entryPoint . OutputType ) ;
@@ -568,6 +529,38 @@ private EntryPointNode(IHostEnvironment env, IChannel ch, ModuleCatalog moduleCa
568
529
Cost = cost ;
569
530
}
570
531
532
+ private void SetColumnArgument ( IChannel ch , object inputInstance , string argName , string colName , string columnRole , params Type [ ] inputKinds )
533
+ {
534
+ Contracts . AssertValue ( ch ) ;
535
+ ch . AssertValue ( inputInstance ) ;
536
+ ch . AssertNonEmpty ( argName ) ;
537
+ ch . AssertValueOrNull ( colName ) ;
538
+ ch . AssertNonEmpty ( columnRole ) ;
539
+ ch . AssertValueOrNull ( inputKinds ) ;
540
+
541
+ var colField = _inputBuilder . GetFieldNameOrNull ( argName ) ;
542
+ if ( string . IsNullOrEmpty ( colField ) )
543
+ return ;
544
+
545
+ const string warning = "Different {0} column specified in trainer and in macro: '{1}', '{2}'." +
546
+ " Using column '{2}'. To column use '{1}' instead, please specify this name in" +
547
+ "the trainer node arguments." ;
548
+ if ( ! string . IsNullOrEmpty ( colName ) && Utils . Size ( _entryPoint . InputKinds ) > 0 &&
549
+ ( Utils . Size ( inputKinds ) == 0 || _entryPoint . InputKinds . Intersect ( inputKinds ) . Any ( ) ) )
550
+ {
551
+ ch . AssertNonEmpty ( colField ) ;
552
+ var colFieldType = _inputBuilder . GetFieldTypeOrNull ( colField ) ;
553
+ ch . Assert ( colFieldType == typeof ( string ) ) ;
554
+ var inputColName = inputInstance . GetType ( ) . GetField ( colField ) . GetValue ( inputInstance ) ;
555
+ ch . Assert ( inputColName is string || inputColName is Optional < string > ) ;
556
+ var str = inputColName is string ? ( string ) inputColName : ( ( Optional < string > ) inputColName ) . Value ;
557
+ if ( colName != str )
558
+ ch . Warning ( warning , columnRole , colName , inputColName ) ;
559
+ else
560
+ _inputBuilder . TrySetValue ( colField , colName ) ;
561
+ }
562
+ }
563
+
571
564
public static EntryPointNode Create (
572
565
IHostEnvironment env ,
573
566
string entryPointName ,
@@ -902,7 +895,7 @@ private object BuildParameterValue(List<ParameterBinding> bindings)
902
895
}
903
896
904
897
public static List < EntryPointNode > ValidateNodes ( IHostEnvironment env , RunContext context , JArray nodes ,
905
- ModuleCatalog moduleCatalog , string label = null , string group = null , string weight = null )
898
+ ModuleCatalog moduleCatalog , string label = null , string group = null , string weight = null , string name = null )
906
899
{
907
900
Contracts . AssertValue ( env ) ;
908
901
env . AssertValue ( context ) ;
@@ -918,7 +911,7 @@ public static List<EntryPointNode> ValidateNodes(IHostEnvironment env, RunContex
918
911
if ( node == null )
919
912
throw env . Except ( "Unexpected node token: '{0}'" , nodes [ i ] ) ;
920
913
921
- string name = node [ FieldNames . Name ] . Value < string > ( ) ;
914
+ string nodeName = node [ FieldNames . Name ] . Value < string > ( ) ;
922
915
var inputs = node [ FieldNames . Inputs ] as JObject ;
923
916
if ( inputs == null && node [ FieldNames . Inputs ] != null )
924
917
throw env . Except ( "Unexpected {0} token: '{1}'" , FieldNames . Inputs , node [ FieldNames . Inputs ] ) ;
@@ -927,7 +920,7 @@ public static List<EntryPointNode> ValidateNodes(IHostEnvironment env, RunContex
927
920
if ( outputs == null && node [ FieldNames . Outputs ] != null )
928
921
throw env . Except ( "Unexpected {0} token: '{1}'" , FieldNames . Outputs , node [ FieldNames . Outputs ] ) ;
929
922
930
- var id = context . GenerateId ( name ) ;
923
+ var id = context . GenerateId ( nodeName ) ;
931
924
var unexpectedFields = node . Properties ( ) . Where (
932
925
x => x . Name != FieldNames . Name && x . Name != FieldNames . Inputs && x . Name != FieldNames . Outputs
933
926
&& x . Name != FieldNames . StageId && x . Name != FieldNames . Checkpoint && x . Name != FieldNames . Cost ) ;
@@ -942,7 +935,7 @@ public static List<EntryPointNode> ValidateNodes(IHostEnvironment env, RunContex
942
935
ch . Warning ( "Node '{0}' has unexpected fields that are ignored: {1}" , id , string . Join ( ", " , unexpectedFields . Select ( x => x . Name ) ) ) ;
943
936
}
944
937
945
- result . Add ( new EntryPointNode ( env , ch , moduleCatalog , context , id , name , inputs , outputs , checkpoint , stageId , cost , label , group , weight ) ) ;
938
+ result . Add ( new EntryPointNode ( env , ch , moduleCatalog , context , id , nodeName , inputs , outputs , checkpoint , stageId , cost , label , group , weight , name ) ) ;
946
939
}
947
940
948
941
ch . Done ( ) ;
0 commit comments