Skip to content

Commit

Permalink
Fixes KMeans scoring differences between ORT and OnnxRunner (dotnet#4942
Browse files Browse the repository at this point in the history
)

* support for batch inferencing on ORT models

* resolving comments
  • Loading branch information
Lynx1820 authored and mstfbl committed Mar 17, 2020
1 parent 2f67666 commit d7a9228
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 4 deletions.
8 changes: 5 additions & 3 deletions src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -331,19 +331,21 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string
var nameX = featureColumn;

// Compute X^2 from X
var nameX2 = ctx.AddIntermediateVariable(null, "X2", true);
var nameX2 = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, 1), "X2");
var reduceNodeX2 = ctx.CreateNode("ReduceSumSquare", nameX, nameX2, ctx.GetNodeName("ReduceSumSquare"), "");
reduceNodeX2.AddAttribute("axes", new long[] { 1 });

// Compute -2XC^T. Note that Gemm always takes three inputs. Since we only have two here,
// a dummy one, named zero, is created.
var dataViewType = new VectorDataViewType(NumberDataViewType.Single, _centroids.Length);
var zeroName = ctx.AddInitializer(new float[] { 0f }, null, "zero");
var nameXC2 = ctx.AddIntermediateVariable(null, "XC2", true);
var nameXC2 = ctx.AddIntermediateVariable(dataViewType, "XC2");
var gemmNodeXC2 = ctx.CreateNode("Gemm", new[] { nameX, nameC, zeroName }, new[] { nameXC2 }, ctx.GetNodeName("Gemm"), "");
gemmNodeXC2.AddAttribute("alpha", -2f);
gemmNodeXC2.AddAttribute("transB", 1);

// Compute Z = X^2 - 2XC^T
var nameZ = ctx.AddIntermediateVariable(null, "Z", true);
var nameZ = ctx.AddIntermediateVariable(dataViewType, "Z");
var addNodeZ = ctx.CreateNode("Add", new[] { nameX2, nameXC2 }, new[] { nameZ }, ctx.GetNodeName("Add"), "");

// Compute Y = Z + C^2
Expand Down
65 changes: 64 additions & 1 deletion test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,16 @@
"X2"
],
"name": "ReduceSumSquare",
"opType": "ReduceSumSquare"
"opType": "ReduceSumSquare",
"attribute": [
{
"name": "axes",
"ints": [
"1"
],
"type": "INTS"
}
]
},
{
"input": [
Expand Down Expand Up @@ -377,6 +386,60 @@
}
}
},
{
"name": "X2",
"type": {
"tensorType": {
"elemType": 1,
"shape": {
"dim": [
{
"dimValue": "-1"
},
{
"dimValue": "1"
}
]
}
}
}
},
{
"name": "XC2",
"type": {
"tensorType": {
"elemType": 1,
"shape": {
"dim": [
{
"dimValue": "-1"
},
{
"dimValue": "4"
}
]
}
}
}
},
{
"name": "Z",
"type": {
"tensorType": {
"elemType": 1,
"shape": {
"dim": [
{
"dimValue": "-1"
},
{
"dimValue": "4"
}
]
}
}
}
},
{
"name": "Features.output",
"type": {
Expand Down

0 comments on commit d7a9228

Please sign in to comment.