Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 12 additions & 15 deletions src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,6 @@ private static readonly FuncInstanceMethodInfo1<Mapper, Delegate> _makeVecTrivia
private readonly SlotsDroppingTransformer _parent;
private readonly int[] _cols;
private readonly DataViewType[] _srcTypes;
private readonly DataViewType[] _rawTypes;
private readonly DataViewType[] _dstTypes;
private readonly SlotDropper[] _slotDropper;
// Track if all the slots of the column are to be dropped.
Expand All @@ -467,7 +466,6 @@ public Mapper(SlotsDroppingTransformer parent, DataViewSchema inputSchema)
_parent = parent;
_cols = new int[_parent.ColumnPairs.Length];
_srcTypes = new DataViewType[_parent.ColumnPairs.Length];
_rawTypes = new DataViewType[_parent.ColumnPairs.Length];
_dstTypes = new DataViewType[_parent.ColumnPairs.Length];
_slotDropper = new SlotDropper[_parent.ColumnPairs.Length];
_suppressed = new bool[_parent.ColumnPairs.Length];
Expand All @@ -480,8 +478,8 @@ public Mapper(SlotsDroppingTransformer parent, DataViewSchema inputSchema)
_srcTypes[i] = inputSchema[_cols[i]].Type;
VectorDataViewType srcVectorType = _srcTypes[i] as VectorDataViewType;

_rawTypes[i] = srcVectorType?.ItemType ?? _srcTypes[i];
if (!IsValidColumnType(_rawTypes[i]))
var rawType = srcVectorType?.ItemType ?? _srcTypes[i];
if (!IsValidColumnType(rawType))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].inputColumnName);

int valueCount = srcVectorType?.Size ?? 1;
Expand Down Expand Up @@ -896,27 +894,26 @@ public void SaveAsOnnx(OnnxContext ctx)
public bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
{
string opType;
if (_srcTypes[iinfo] is VectorDataViewType)
var slots = _slotDropper[iinfo].GetPreservedSlots();
// vector column is not suppressed
if (slots.Count() > 0)
{
opType = "GatherElements";
IEnumerable<long> slots = _slotDropper[iinfo].GetPreservedSlots();
var slotsVar = ctx.AddInitializer(slots, new long[] { 1, slots.Count() }, "PreservedSlots");
var node = ctx.CreateNode(opType, new[] { srcVariableName, slotsVar }, new[] { dstVariableName }, ctx.GetNodeName(opType), "");
node.AddAttribute("axis", 1);
}
// When the vector/scalar columnn is suppressed, we simply create an empty output vector
else
{
string constVal;
long[] dims = { 1, 1 };
float[] floatVals = { 0.0f };
long[] keyVals = { 0 };
string[] stringVals = { "" };
if (_rawTypes[iinfo] is TextDataViewType)
constVal = ctx.AddInitializer(stringVals, dims);
else if (_rawTypes[iinfo] is KeyDataViewType)
constVal = ctx.AddInitializer(keyVals, dims);
var type = _srcTypes[iinfo].GetItemType();
if (type == TextDataViewType.Instance)
constVal = ctx.AddInitializer(new string[] { "" }, new long[] { 1, 1 });
else if (type == NumberDataViewType.Single)
constVal = ctx.AddInitializer(new float[] { 0 }, new long[] { 1, 1 });
else
constVal = ctx.AddInitializer(floatVals, dims);
constVal = ctx.AddInitializer(new double[] { 0 }, new long[] { 1, 1 });

opType = "Identity";
ctx.CreateNode(opType, constVal, dstVariableName, ctx.GetNodeName(opType), "");
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Transforms/CountFeatureSelection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace Microsoft.ML.Transforms
/// | | |
/// | -- | -- |
/// | Does this estimator need to look at the data to train its parameters? | Yes |
/// | Input column data type | Vector or scalar of numeric, [text](xref:Microsoft.ML.Data.TextDataViewType) or [key](xref:Microsoft.ML.Data.KeyDataViewType) data types|
/// | Input column data type | Vector or scalar of <xref:System.Single>, <xref:System.Double> or [text](xref:Microsoft.ML.Data.TextDataViewType) data types|
/// | Output column data type | Same as the input column|
/// | Exportable to ONNX | Yes |
///
Expand Down
75 changes: 40 additions & 35 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1648,57 +1648,62 @@ public void UseKeyDataViewTypeAsUInt32InOnnxInput()
Done();
}

[Fact]
public void FeatureSelectionOnnxTest()
[Theory]
[InlineData(DataKind.String)]
[InlineData(DataKind.Single)]
[InlineData(DataKind.Double)]
public void FeatureSelectionOnnxTest(DataKind dataKind)
{
var mlContext = new MLContext(seed: 1);

string dataPath = GetDataPath("breast-cancer.txt");

var dataView = ML.Data.LoadFromTextFile(dataPath, new[] {
new TextLoader.Column("ScalarFloat", DataKind.Single, 6),
new TextLoader.Column("VectorFloat", DataKind.Single, 1, 4),
new TextLoader.Column("VectorDouble", DataKind.Double, 4, 8),
var dataView = mlContext.Data.LoadFromTextFile(dataPath, new[] {
new TextLoader.Column("Scalar", dataKind, 6),
new TextLoader.Column("Vector", dataKind, 1, 6),
new TextLoader.Column("Label", DataKind.Boolean, 0)
});

var columns = new[] {
new CountFeatureSelectingEstimator.ColumnOptions("FeatureSelectDouble", "VectorDouble", count: 1),
new CountFeatureSelectingEstimator.ColumnOptions("ScalFeatureSelectMissing690", "ScalarFloat", count: 690),
new CountFeatureSelectingEstimator.ColumnOptions("ScalFeatureSelectMissing100", "ScalarFloat", count: 100),
new CountFeatureSelectingEstimator.ColumnOptions("VecFeatureSelectMissing690", "VectorDouble", count: 690),
new CountFeatureSelectingEstimator.ColumnOptions("VecFeatureSelectMissing100", "VectorDouble", count: 100)
};
var pipeline = ML.Transforms.FeatureSelection.SelectFeaturesBasedOnCount("FeatureSelect", "VectorFloat", count: 1)
.Append(ML.Transforms.FeatureSelection.SelectFeaturesBasedOnCount(columns))
.Append(ML.Transforms.FeatureSelection.SelectFeaturesBasedOnMutualInformation("FeatureSelectMIScalarFloat", "ScalarFloat"))
.Append(ML.Transforms.FeatureSelection.SelectFeaturesBasedOnMutualInformation("FeatureSelectMIVectorFloat", "VectorFloat"));
IEstimator<ITransformer>[] pipelines =
{
// one or more features selected
mlContext.Transforms.FeatureSelection.SelectFeaturesBasedOnCount("VectorOutput", "Vector", count: 690).
Append(mlContext.Transforms.FeatureSelection.SelectFeaturesBasedOnCount("ScalarOutput", "Scalar", count: 100)),

var model = pipeline.Fit(dataView);
var transformedData = model.Transform(dataView);
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);
// no feature selected => column suppressed
mlContext.Transforms.FeatureSelection.SelectFeaturesBasedOnCount("VectorOutput", "Vector", count: 800).
Append(mlContext.Transforms.FeatureSelection.SelectFeaturesBasedOnCount("ScalarOutput", "Scalar", count: 800)),

var onnxFileName = "countfeatures.onnx";
var onnxModelPath = GetOutputPath(onnxFileName);
mlContext.Transforms.FeatureSelection.SelectFeaturesBasedOnMutualInformation("VectorOutput", "Vector").
Append(mlContext.Transforms.FeatureSelection.SelectFeaturesBasedOnMutualInformation("ScalarOutput", "Scalar"))
};
for (int i = 0; i < pipelines.Length; i++)
{
//There's currently no support for suppressed string columns, since onnx string variable initiation is not supported
if (dataKind == DataKind.String && i > 0)
break;
var model = pipelines[i].Fit(dataView);
var transformedData = model.Transform(dataView);
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);

SaveOnnxModel(onnxModel, onnxModelPath, null);
var onnxFileName = "countfeatures.onnx";
var onnxModelPath = GetOutputPath(onnxFileName);

if (IsOnnxRuntimeSupported())
{
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(dataView);
var onnxResult = onnxTransformer.Transform(dataView);
CompareSelectedColumns<float>("FeatureSelectMIScalarFloat", "FeatureSelectMIScalarFloat", transformedData, onnxResult);
CompareSelectedColumns<float>("FeatureSelectMIVectorFloat", "FeatureSelectMIVectorFloat", transformedData, onnxResult);
CompareSelectedColumns<float>("ScalFeatureSelectMissing690", "ScalFeatureSelectMissing690", transformedData, onnxResult);
CompareSelectedColumns<double>("VecFeatureSelectMissing690", "VecFeatureSelectMissing690", transformedData, onnxResult);
SaveOnnxModel(onnxModel, onnxModelPath, null);

if (IsOnnxRuntimeSupported())
{
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(dataView);
var onnxResult = onnxTransformer.Transform(dataView);
CompareResults("VectorOutput", "VectorOutput", transformedData, onnxResult);
CompareResults("ScalarOutput", "ScalarOutput", transformedData, onnxResult);
}
}
Done();
}



[Fact]
public void SelectColumnsOnnxTest()
{
Expand Down