Skip to content

Commit 3e72d19

Browse files
Tensorflow fix (#5547)
* fix tensorflow issue on sample repo * add comments
1 parent 5318cc2 commit 3e72d19

File tree

3 files changed

+36
-13
lines changed

3 files changed

+36
-13
lines changed

src/Microsoft.ML.TensorFlow/TensorflowTransform.cs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,18 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) :
509509
var shape = originalShape.dims;
510510

511511
if (shape == null || (shape.Length == 0))
512-
_fullySpecifiedShapes[i] = new TensorShape();
512+
{
513+
// for vector type input TensorShape should same to dim
514+
if (_isInputVector[i])
515+
{
516+
vecType = (VectorDataViewType)type;
517+
var colTypeDims = vecType.Dimensions.Select(dim => (int)dim).ToArray();
518+
_fullySpecifiedShapes[i] = new TensorShape(colTypeDims);
519+
}
520+
else
521+
// for primitive type use default TensorShape
522+
_fullySpecifiedShapes[i] = new TensorShape();
523+
}
513524
else
514525
{
515526
vecType = (VectorDataViewType)type;

src/Microsoft.ML.TensorFlow/TensorflowUtils.cs

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,6 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph grap
5252
if (mlType == null || op.NumOutputs <= 0)
5353
continue;
5454

55-
// Construct the final ML.NET type of a Tensorflow variable.
56-
var tensorShape = op.output.TensorShape.dims;
57-
var columnType = new VectorDataViewType(mlType);
58-
if (!(Utils.Size(tensorShape) == 1 && tensorShape[0] <= 0) &&
59-
(Utils.Size(tensorShape) > 0 && tensorShape.Skip(1).All(x => x > 0)))
60-
columnType = new VectorDataViewType(mlType, tensorShape[0] > 0 ? tensorShape : tensorShape.Skip(1).ToArray());
61-
6255
// There can be at most two metadata fields.
6356
// 1. The first field always presents. Its value is this operator's type. For example,
6457
// if an output is produced by an "Softmax" operator, the value of this field should be "Softmax".
@@ -83,7 +76,24 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph grap
8376
(ref VBuffer<ReadOnlyMemory<char>> value) => { upstreamOperatorNames.CopyTo(ref value); });
8477
}
8578

86-
schemaBuilder.AddColumn(op.name, columnType, metadataBuilder.ToAnnotations());
79+
// Construct the final ML.NET type of a Tensorflow variable.
80+
var tensorShape = op.output.TensorShape.dims;
81+
82+
if(tensorShape == null)
83+
{
84+
// primitive column type
85+
schemaBuilder.AddColumn(op.name, mlType, metadataBuilder.ToAnnotations());
86+
}
87+
else
88+
{
89+
// vector column type
90+
DataViewType columnType = new VectorDataViewType(mlType);
91+
if (!(Utils.Size(tensorShape) == 1 && tensorShape[0] <= 0) &&
92+
(Utils.Size(tensorShape) > 0 && tensorShape.Skip(1).All(x => x > 0)))
93+
columnType = new VectorDataViewType(mlType, tensorShape[0] > 0 ? tensorShape : tensorShape.Skip(1).ToArray());
94+
95+
schemaBuilder.AddColumn(op.name, columnType, metadataBuilder.ToAnnotations());
96+
}
8797
}
8898
return schemaBuilder.ToSchema();
8999
}

test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,10 +1262,10 @@ class TextOutput
12621262

12631263
class PrimitiveInput
12641264
{
1265-
[LoadColumn(0, 1)]
1265+
[LoadColumn(0)]
12661266
public string input1;
12671267

1268-
[LoadColumn(1, 2)]
1268+
[LoadColumn(1)]
12691269
public string input2;
12701270
}
12711271

@@ -1305,8 +1305,10 @@ public void TensorFlowPrimitiveInputTest()
13051305
{
13061306
using var tensorFlowModel = _mlContext.Model.LoadTensorFlowModel(@"model_primitive_input_test");
13071307
var schema = tensorFlowModel.GetModelSchema();
1308-
Assert.True(schema.TryGetColumnIndex("input1", out var colIndex));
1309-
Assert.True(schema.TryGetColumnIndex("input2", out colIndex));
1308+
Assert.True(schema.GetColumnOrNull("input1").HasValue);
1309+
Assert.True(schema.GetColumnOrNull("input1").Value.Type is TextDataViewType);
1310+
Assert.True(schema.GetColumnOrNull("input2").HasValue);
1311+
Assert.True(schema.GetColumnOrNull("input2").Value.Type is TextDataViewType);
13101312

13111313
var dataview = _mlContext.Data.CreateTextLoader<PrimitiveInput>().Load(new MultiFileSource(null));
13121314

0 commit comments

Comments
 (0)