@@ -793,6 +793,17 @@ private void CastInputToString<T>(OnnxContext ctx, out OnnxNode node, out long[]
793
793
node . AddAttribute ( "keys_strings" , terms . Select ( item => item . ToString ( ) ) ) ;
794
794
}
795
795
796
+ private void CastInputToFloat < T > ( OnnxContext ctx , out OnnxNode node , out long [ ] termIds , string srcVariableName , int iinfo ,
797
+ string opType , string labelEncoderOutput )
798
+ {
799
+ var castOutput = ctx . AddIntermediateVariable ( NumberDataViewType . Single , "castOutput" ) ;
800
+ var castNode = ctx . CreateNode ( "Cast" , srcVariableName , castOutput , ctx . GetNodeName ( "Cast" ) , "" ) ;
801
+ var t = InternalDataKindExtensions . ToInternalDataKind ( DataKind . Single ) . ToType ( ) ;
802
+ castNode . AddAttribute ( "to" , t ) ;
803
+ node = ctx . CreateNode ( opType , castOutput , labelEncoderOutput , ctx . GetNodeName ( opType ) ) ;
804
+ var terms = GetTermsAndIds < T > ( iinfo , out termIds ) ;
805
+ node . AddAttribute ( "keys_floats" , terms . Select ( item => Convert . ToSingle ( item ) ) ) ;
806
+ }
796
807
private bool SaveAsOnnxCore ( OnnxContext ctx , int iinfo , ColInfo info , string srcVariableName , string dstVariableName )
797
808
{
798
809
OnnxNode node ;
@@ -808,6 +819,11 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src
808
819
var terms = GetTermsAndIds < ReadOnlyMemory < char > > ( iinfo , out termIds ) ;
809
820
node . AddAttribute ( "keys_strings" , terms ) ;
810
821
}
822
+ else if ( type . Equals ( BooleanDataViewType . Instance ) )
823
+ {
824
+ // LabelEncoder doesn't support boolean tensors, so values are cast to floats
825
+ CastInputToFloat < Boolean > ( ctx , out node , out termIds , srcVariableName , iinfo , opType , labelEncoderOutput ) ;
826
+ }
811
827
else if ( type . Equals ( NumberDataViewType . Single ) )
812
828
{
813
829
node = ctx . CreateNode ( opType , srcVariableName , labelEncoderOutput , ctx . GetNodeName ( opType ) ) ;
@@ -817,13 +833,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src
817
833
else if ( type . Equals ( NumberDataViewType . Double ) )
818
834
{
819
835
// LabelEncoder doesn't support double tensors, so values are cast to floats
820
- var castOutput = ctx . AddIntermediateVariable ( NumberDataViewType . Single , "castOutput" ) ;
821
- castNode = ctx . CreateNode ( "Cast" , srcVariableName , castOutput , ctx . GetNodeName ( "Cast" ) , "" ) ;
822
- var t = InternalDataKindExtensions . ToInternalDataKind ( DataKind . Single ) . ToType ( ) ;
823
- castNode . AddAttribute ( "to" , t ) ;
824
- node = ctx . CreateNode ( opType , castOutput , labelEncoderOutput , ctx . GetNodeName ( opType ) ) ;
825
- var terms = GetTermsAndIds < double > ( iinfo , out termIds ) ;
826
- node . AddAttribute ( "keys_floats" , terms ) ;
836
+ CastInputToFloat < Double > ( ctx , out node , out termIds , srcVariableName , iinfo , opType , labelEncoderOutput ) ;
827
837
}
828
838
else if ( type . Equals ( NumberDataViewType . Int64 ) )
829
839
{
0 commit comments