Skip to content

Commit acce352

Browse files
authored
Added boolean support for KeyToValue and ValueToKey (#4900)
* support for booleans * standarize test
1 parent ce0041e commit acce352

File tree

4 files changed

+32
-10
lines changed

4 files changed

+32
-10
lines changed

src/Microsoft.ML.Data/Transforms/KeyToValue.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ public override bool SaveOnnx(OnnxContext ctx, string srcVariableName, string ds
512512

513513
var labelEncoderOutput = dstVariableName;
514514
var labelEncoderInput = srcVariableName;
515-
if (TypeOutput == NumberDataViewType.Double)
515+
if (TypeOutput == NumberDataViewType.Double || TypeOutput == BooleanDataViewType.Instance)
516516
labelEncoderOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "CastNodeOutput");
517517
else if (TypeOutput == NumberDataViewType.Int64 || TypeOutput == NumberDataViewType.UInt16 ||
518518
TypeOutput == NumberDataViewType.Int32 || TypeOutput == NumberDataViewType.Int16 ||
@@ -555,6 +555,15 @@ public override bool SaveOnnx(OnnxContext ctx, string srcVariableName, string ds
555555
string[] values = Array.ConvertAll<TValue, string>(_values.GetValues().ToArray(), item => Convert.ToString(item));
556556
node.AddAttribute("values_strings", values);
557557
}
558+
else if (TypeOutput == BooleanDataViewType.Instance)
559+
{
560+
float[] values = Array.ConvertAll<TValue, float>(_values.GetValues().ToArray(), item => Convert.ToSingle(item));
561+
node.AddAttribute("values_floats", values);
562+
opType = "Cast";
563+
castNode = ctx.CreateNode(opType, labelEncoderOutput, dstVariableName, ctx.GetNodeName(opType), "");
564+
t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Boolean).ToType();
565+
castNode.AddAttribute("to", t);
566+
}
558567
else
559568
return false;
560569

src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,17 @@ private void CastInputToString<T>(OnnxContext ctx, out OnnxNode node, out long[]
793793
node.AddAttribute("keys_strings", terms.Select(item => item.ToString()));
794794
}
795795

796+
private void CastInputToFloat<T>(OnnxContext ctx, out OnnxNode node, out long[] termIds, string srcVariableName, int iinfo,
797+
string opType, string labelEncoderOutput)
798+
{
799+
var castOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "castOutput");
800+
var castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName("Cast"), "");
801+
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
802+
castNode.AddAttribute("to", t);
803+
node = ctx.CreateNode(opType, castOutput, labelEncoderOutput, ctx.GetNodeName(opType));
804+
var terms = GetTermsAndIds<T>(iinfo, out termIds);
805+
node.AddAttribute("keys_floats", terms.Select(item => Convert.ToSingle(item)));
806+
}
796807
private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
797808
{
798809
OnnxNode node;
@@ -808,6 +819,11 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src
808819
var terms = GetTermsAndIds<ReadOnlyMemory<char>>(iinfo, out termIds);
809820
node.AddAttribute("keys_strings", terms);
810821
}
822+
else if (type.Equals(BooleanDataViewType.Instance))
823+
{
824+
// LabelEncoder doesn't support boolean tensors, so values are cast to floats
825+
CastInputToFloat<Boolean>(ctx, out node, out termIds, srcVariableName, iinfo, opType, labelEncoderOutput);
826+
}
811827
else if (type.Equals(NumberDataViewType.Single))
812828
{
813829
node = ctx.CreateNode(opType, srcVariableName, labelEncoderOutput, ctx.GetNodeName(opType));
@@ -817,13 +833,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src
817833
else if (type.Equals(NumberDataViewType.Double))
818834
{
819835
// LabelEncoder doesn't support double tensors, so values are cast to floats
820-
var castOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "castOutput");
821-
castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName("Cast"), "");
822-
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
823-
castNode.AddAttribute("to", t);
824-
node = ctx.CreateNode(opType, castOutput, labelEncoderOutput, ctx.GetNodeName(opType));
825-
var terms = GetTermsAndIds<double>(iinfo, out termIds);
826-
node.AddAttribute("keys_floats", terms);
836+
CastInputToFloat<Double>(ctx, out node, out termIds, srcVariableName, iinfo, opType, labelEncoderOutput);
827837
}
828838
else if (type.Equals(NumberDataViewType.Int64))
829839
{

test/Microsoft.ML.Tests/OnnxConversionTest.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,10 +1209,11 @@ public void IndicateMissingValuesOnnxConversionTest()
12091209
[InlineData(DataKind.UInt16)]
12101210
[InlineData(DataKind.Double)]
12111211
[InlineData(DataKind.String)]
1212+
[InlineData(DataKind.Boolean)]
12121213
public void ValueToKeyMappingOnnxConversionTest(DataKind valueType)
12131214
{
12141215
var mlContext = new MLContext(seed: 1);
1215-
string filePath = GetDataPath("type-conversion.txt");
1216+
string filePath = (valueType == DataKind.Boolean) ? GetDataPath("type-conversion-boolean.txt") : GetDataPath("type-conversion.txt");
12161217

12171218
TextLoader.Column[] columns = new[]
12181219
{
@@ -1249,10 +1250,11 @@ public void ValueToKeyMappingOnnxConversionTest(DataKind valueType)
12491250
[InlineData(DataKind.UInt16)]
12501251
[InlineData(DataKind.Double)]
12511252
[InlineData(DataKind.String)]
1253+
[InlineData(DataKind.Boolean)]
12521254
public void KeyToValueMappingOnnxConversionTest(DataKind valueType)
12531255
{
12541256
var mlContext = new MLContext(seed: 1);
1255-
string filePath = GetDataPath("type-conversion.txt");
1257+
string filePath = (valueType == DataKind.Boolean) ? GetDataPath("type-conversion-boolean.txt") : GetDataPath("type-conversion.txt");
12561258

12571259
TextLoader.Column[] columns = new[]
12581260
{

test/data/type-conversion-boolean.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
False

0 commit comments

Comments
 (0)