Skip to content

Commit c5d3638

Browse files
committed
add comments
1 parent 5773543 commit c5d3638

File tree

3 files changed

+8
-2
lines changed

3 files changed

+8
-2
lines changed

src/Microsoft.ML.TensorFlow/TensorflowTransform.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,13 +510,15 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) :
510510

511511
if (shape == null || (shape.Length == 0))
512512
{
513+
// for vector type input TensorShape should same to dim
513514
if (_isInputVector[i])
514515
{
515516
vecType = (VectorDataViewType)type;
516517
var colTypeDims = vecType.Dimensions.Select(dim => (int)dim).ToArray();
517518
_fullySpecifiedShapes[i] = new TensorShape(colTypeDims);
518519
}
519520
else
521+
// for primitive type use default TensorShape
520522
_fullySpecifiedShapes[i] = new TensorShape();
521523
}
522524
else

src/Microsoft.ML.TensorFlow/TensorflowUtils.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,12 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph grap
8181

8282
if(tensorShape == null)
8383
{
84+
// primitive column type
8485
schemaBuilder.AddColumn(op.name, mlType, metadataBuilder.ToAnnotations());
8586
}
8687
else
8788
{
89+
// vector column type
8890
DataViewType columnType = new VectorDataViewType(mlType);
8991
if (!(Utils.Size(tensorShape) == 1 && tensorShape[0] <= 0) &&
9092
(Utils.Size(tensorShape) > 0 && tensorShape.Skip(1).All(x => x > 0)))

test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1305,8 +1305,10 @@ public void TensorFlowPrimitiveInputTest()
13051305
{
13061306
using var tensorFlowModel = _mlContext.Model.LoadTensorFlowModel(@"model_primitive_input_test");
13071307
var schema = tensorFlowModel.GetModelSchema();
1308-
Assert.True(schema.TryGetColumnIndex("input1", out var colIndex));
1309-
Assert.True(schema.TryGetColumnIndex("input2", out colIndex));
1308+
Assert.True(schema.GetColumnOrNull("input1").HasValue);
1309+
Assert.True(schema.GetColumnOrNull("input1").Value.Type is TextDataViewType);
1310+
Assert.True(schema.GetColumnOrNull("input2").HasValue);
1311+
Assert.True(schema.GetColumnOrNull("input2").Value.Type is TextDataViewType);
13101312

13111313
var dataview = _mlContext.Data.CreateTextLoader<PrimitiveInput>().Load(new MultiFileSource(null));
13121314

0 commit comments

Comments
 (0)