Skip to content

Commit

Permalink
Improve error message when defining custom type for variables (#5114)
Browse files Browse the repository at this point in the history
* helpful error

* helpful error

* remove unused library

* more useful error msg

* add comments

* fix typo

* let error msg contains right onnx format

* work around with ApiCompat

* fix DataView variable description

* a work version, need refactor

* fix test failure

* remove signatures that contains IDataView

* refactor checkin

* checkin

* update

* reorder signatures

* more refactor

* update

* refactor

* fix assertion failures

* review comments

* review more comments

* update
  • Loading branch information
wangyems authored May 28, 2020
1 parent 3cd23c5 commit 65f6c5f
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 63 deletions.
121 changes: 73 additions & 48 deletions src/Microsoft.ML.Data/Data/SchemaDefinition.cs
Original file line number Diff line number Diff line change
Expand Up @@ -326,21 +326,12 @@ public enum Direction
Both = Read | Write
}

/// <summary>
/// Create a schema definition by enumerating all public fields of the given type.
/// </summary>
/// <param name="userType">The type to base the schema on.</param>
/// <param name="direction">Accept fields and properties based on their direction.</param>
/// <returns>The generated schema definition.</returns>
public static SchemaDefinition Create(Type userType, Direction direction = Direction.Both)
internal static MemberInfo[] GetMemberInfos(Type userType, Direction direction)
{
// REVIEW: This will have to be updated whenever we start
// supporting properties and not just fields.
Contracts.CheckValue(userType, nameof(userType));

SchemaDefinition cols = new SchemaDefinition();
HashSet<string> colNames = new HashSet<string>();

var fieldInfos = userType.GetFields(BindingFlags.Public | BindingFlags.Instance);
var propertyInfos =
userType
Expand All @@ -349,57 +340,90 @@ public static SchemaDefinition Create(Type userType, Direction direction = Direc
((direction & Direction.Write) == Direction.Write && (x.CanWrite && x.GetSetMethod() != null))) &&
x.GetIndexParameters().Length == 0);

var memberInfos = (fieldInfos as IEnumerable<MemberInfo>).Concat(propertyInfos).ToArray();
return (fieldInfos as IEnumerable<MemberInfo>).Concat(propertyInfos).ToArray();
}

foreach (var memberInfo in memberInfos)
internal static bool NeedToCheckMemberInfo(MemberInfo memberInfo)
{
switch (memberInfo)
{
// Clause to handle the field that may be used to expose the cursor channel.
// This field does not need a column.
// REVIEW: maybe validate the channel attribute now, instead
// of later at cursor creation.
switch (memberInfo)
{
case FieldInfo fieldInfo:
if (fieldInfo.FieldType == typeof(IChannel))
continue;
case FieldInfo fieldInfo:
if (fieldInfo.FieldType == typeof(IChannel))
return false;

// Const fields do not need to be mapped.
if (fieldInfo.IsLiteral)
continue;
// Const fields do not need to be mapped.
if (fieldInfo.IsLiteral)
return false;

break;
break;

case PropertyInfo propertyInfo:
if (propertyInfo.PropertyType == typeof(IChannel))
continue;
break;
case PropertyInfo propertyInfo:
if (propertyInfo.PropertyType == typeof(IChannel))
return false;
break;

default:
Contracts.Assert(false);
throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
}
default:
Contracts.Assert(false);
throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
}

if (memberInfo.GetCustomAttribute<NoColumnAttribute>() != null)
continue;
if (memberInfo.GetCustomAttribute<NoColumnAttribute>() != null)
return false;

var customAttributes = memberInfo.GetCustomAttributes();
var customTypeAttributes = customAttributes.Where(x => x is DataViewTypeAttribute);
if (customTypeAttributes.Count() > 1)
throw Contracts.ExceptParam(nameof(userType), "Member {0} cannot be marked with multiple attributes, {1}, derived from {2}.",
memberInfo.Name, customTypeAttributes, typeof(DataViewTypeAttribute));
else if (customTypeAttributes.Count() == 1)
{
var customTypeAttribute = (DataViewTypeAttribute)customTypeAttributes.First();
customTypeAttribute.Register();
}
return true;
}

var mappingNameAttr = memberInfo.GetCustomAttribute<ColumnNameAttribute>();
string name = mappingNameAttr?.Name ?? memberInfo.Name;
// Disallow duplicate names, because the field enumeration order is not actually
// well defined, so we are not guaranteed to have consistent "hiding" from run to
// run, across different .NET versions.
if (!colNames.Add(name))
throw Contracts.ExceptParam(nameof(userType), "Duplicate column name '{0}' detected, this is disallowed", name);
internal static bool GetNameAndCustomAttributes(MemberInfo memberInfo, Type userType, HashSet<string> colNames, out string name, out IEnumerable<Attribute> customAttributes)
{
name = null;
customAttributes = null;

if (!NeedToCheckMemberInfo(memberInfo))
return false;

customAttributes = memberInfo.GetCustomAttributes();
var customTypeAttributes = customAttributes.Where(x => x is DataViewTypeAttribute);
if (customTypeAttributes.Count() > 1)
throw Contracts.ExceptParam(nameof(userType), "Member {0} cannot be marked with multiple attributes, {1}, derived from {2}.",
memberInfo.Name, customTypeAttributes, typeof(DataViewTypeAttribute));
else if (customTypeAttributes.Count() == 1)
{
var customTypeAttribute = (DataViewTypeAttribute)customTypeAttributes.First();
customTypeAttribute.Register();
}

var mappingNameAttr = memberInfo.GetCustomAttribute<ColumnNameAttribute>();
name = mappingNameAttr?.Name ?? memberInfo.Name;
// Disallow duplicate names, because the field enumeration order is not actually
// well defined, so we are not guaranteed to have consistent "hiding" from run to
// run, across different .NET versions.
if (!colNames.Add(name))
throw Contracts.ExceptParam(nameof(userType), "Duplicate column name '{0}' detected, this is disallowed", name);

return true;
}

/// <summary>
/// Create a schema definition by enumerating all public fields of the given type.
/// </summary>
/// <param name="userType">The type to base the schema on.</param>
/// <param name="direction">Accept fields and properties based on their direction.</param>
/// <returns>The generated schema definition.</returns>
public static SchemaDefinition Create(Type userType, Direction direction = Direction.Both)
{
var memberInfos = GetMemberInfos(userType, direction);

SchemaDefinition cols = new SchemaDefinition();
HashSet<string> colNames = new HashSet<string>();

foreach (var memberInfo in memberInfos)
{
if (!GetNameAndCustomAttributes(memberInfo, userType, colNames, out string name, out IEnumerable<Attribute> customAttributes))
continue;

InternalSchemaDefinition.GetVectorAndItemType(memberInfo, out bool isVector, out Type dataType);

Expand Down Expand Up @@ -442,6 +466,7 @@ public static SchemaDefinition Create(Type userType, Direction direction = Direc

cols.Add(new Column(memberInfo.Name, columnType, name));
}

return cols;
}
}
Expand Down
37 changes: 22 additions & 15 deletions src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs
Original file line number Diff line number Diff line change
Expand Up @@ -160,19 +160,7 @@ public static void GetVectorAndItemType(MemberInfo memberInfo, out bool isVector
}
}

/// <summary>
/// Given a type and name for a variable, returns whether this appears to be a vector type,
/// and also the associated data type for this type. If a valid data type could not
/// be determined, this will throw.
/// </summary>
/// <param name="name">The name of the variable to inspect.</param>
/// <param name="rawType">The type of the variable to inspect.</param>
/// <param name="attributes">Attribute of <paramref name="rawType"/>. It can be <see langword="null"/> if attributes don't exist.</param>
/// <param name="isVector">Whether this appears to be a vector type.</param>
/// <param name="itemType">
/// The corresponding <see cref="PrimitiveDataViewType"/> RawType of the type, or items of this type if vector.
/// </param>
public static void GetVectorAndItemType(string name, Type rawType, IEnumerable<Attribute> attributes, out bool isVector, out Type itemType)
public static void GetMappedType(Type rawType, out Type itemType, out bool isVector)
{
// Determine whether this is a vector, and also determine the raw item type.
isVector = true;
Expand All @@ -189,10 +177,29 @@ public static void GetVectorAndItemType(string name, Type rawType, IEnumerable<A
// The internal type of string is ReadOnlyMemory<char>. That is, string will be stored as ReadOnlyMemory<char> in IDataView.
if (itemType == typeof(string))
itemType = typeof(ReadOnlyMemory<char>);
}

/// <summary>
/// Given a type and name for a variable, returns whether this appears to be a vector type,
/// and also the associated data type for this type. If a valid data type could not
/// be determined, this will throw.
/// </summary>
/// <param name="name">The name of the variable to inspect.</param>
/// <param name="rawType">The type of the variable to inspect.</param>
/// <param name="attributes">Attribute of <paramref name="rawType"/>. It can be <see langword="null"/> if attributes don't exist.</param>
/// <param name="isVector">Whether this appears to be a vector type.</param>
/// <param name="itemType">
/// The corresponding <see cref="PrimitiveDataViewType"/> RawType of the type, or items of this type if vector.
/// </param>
public static void GetVectorAndItemType(string name, Type rawType, IEnumerable<Attribute> attributes, out bool isVector, out Type itemType)
{
GetMappedType(rawType, out itemType, out isVector);
// Check if the itemType extracted from rawType is supported by ML.NET's type system.
// It must be one of either ML.NET's pre-defined types or custom types registered by the user.
else if (!itemType.TryGetDataKind(out _) && !DataViewTypeManager.Knows(itemType, attributes))
throw Contracts.ExceptParam(nameof(rawType), "Could not determine an IDataView type for member {0}", name);
if (!itemType.TryGetDataKind(out _) && !DataViewTypeManager.Knows(itemType, attributes))
{
throw Contracts.ExceptParam(nameof(rawType), "Could not determine an IDataView type and registered custom types for member {0}", name);
}
}

public static InternalSchemaDefinition Create(Type userType, SchemaDefinition.Direction direction)
Expand Down
86 changes: 86 additions & 0 deletions src/Microsoft.ML.Data/DataView/TypedCursor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,90 @@ public RowCursor<TRow>[] GetCursorSet(Func<int, bool> additionalColumnsPredicate
.ToArray();
}

private static void ValidateMemberInfo(MemberInfo memberInfo, IDataView data)
{
if (!SchemaDefinition.NeedToCheckMemberInfo(memberInfo))
return;

var mappingNameAttr = memberInfo.GetCustomAttribute<ColumnNameAttribute>();
var singleName = mappingNameAttr?.Name ?? memberInfo.Name;

Type actualType = null;
bool isVector = false;
IEnumerable<Attribute> customAttributes = null;
switch (memberInfo)
{
case FieldInfo fieldInfo:
InternalSchemaDefinition.GetMappedType(fieldInfo.FieldType, out actualType, out isVector);
customAttributes = fieldInfo.GetCustomAttributes();
break;

case PropertyInfo propertyInfo:
InternalSchemaDefinition.GetMappedType(propertyInfo.PropertyType, out actualType, out isVector);
customAttributes = propertyInfo.GetCustomAttributes();
break;

default:
Contracts.Assert(false);
throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
}

if (!actualType.TryGetDataKind(out _) && !DataViewTypeManager.Knows(actualType, customAttributes))
{
int colIndex;
data.Schema.TryGetColumnIndex(singleName, out colIndex);
DataViewType expectedType = data.Schema[colIndex].Type;
if (!actualType.Equals(expectedType.RawType))
throw Contracts.ExceptParam(nameof(actualType), $"The expected type '{expectedType.RawType}' does not match the type of the '{singleName}' member: '{actualType}'. Please change the {singleName} member to '{expectedType.RawType}'");
}
}

private static void ValidateUserType(SchemaDefinition schemaDefinition, Type userType, IDataView data)
{
//Get memberInfos
MemberInfo[] memberInfos = null;
if (schemaDefinition == null)
{
memberInfos = SchemaDefinition.GetMemberInfos(userType, SchemaDefinition.Direction.Write);

if (memberInfos == null)
return;

foreach (var memberInfo in memberInfos)
ValidateMemberInfo(memberInfo, data);
}
else
{
for (int i = 0; i < schemaDefinition.Count; ++i)
{
var col = schemaDefinition[i];
if (col.MemberName == null)
throw Contracts.ExceptParam(nameof(schemaDefinition), "Null field name detected in schema definition");

MemberInfo memberInfo = null;
// Infer the column name.
var colName = string.IsNullOrEmpty(col.ColumnName) ? col.MemberName : col.ColumnName;

if (col.Generator == null)
{
memberInfo = userType.GetField(col.MemberName);

if (memberInfo == null)
memberInfo = userType.GetProperty(col.MemberName);

if ((memberInfo is FieldInfo && (memberInfo as FieldInfo).FieldType == typeof(IChannel)) ||
(memberInfo is PropertyInfo && (memberInfo as PropertyInfo).PropertyType == typeof(IChannel)))
continue;
}
else
{
memberInfo = col.ReturnType;
}
ValidateMemberInfo(memberInfo, data);
}
}
}

/// <summary>
/// Create a Cursorable object on a given data view.
/// </summary>
Expand All @@ -231,6 +315,8 @@ public static TypedCursorable<TRow> Create(IHostEnvironment env, IDataView data,
env.AssertValue(data);
env.AssertValueOrNull(schemaDefinition);

ValidateUserType(schemaDefinition, typeof(TRow), data);

var outSchema = schemaDefinition == null
? InternalSchemaDefinition.Create(typeof(TRow), SchemaDefinition.Direction.Write)
: InternalSchemaDefinition.Create(typeof(TRow), schemaDefinition);
Expand Down
39 changes: 39 additions & 0 deletions test/Microsoft.ML.Tests/OnnxSequenceTypeWithAttributesTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,46 @@ public void OnnxSequenceTypeWithColumnNameAttributeTest()
{
Assert.Equal(onnxOut[keys[i]], input.Input[i]);
}
}

public class WrongOutputObj
{
[ColumnName("output")]
[OnnxSequenceType(typeof(IEnumerable<float>))]
public IEnumerable<float> Output;
}

public static PredictionEngine<FloatInput, WrongOutputObj> LoadModelWithWrongCustomType(string onnxModelFilePath)
{
var ctx = new MLContext(1);
var dataView = ctx.Data.LoadFromEnumerable(new List<FloatInput>());

var pipeline = ctx.Transforms.ApplyOnnxModel(
modelFile: onnxModelFilePath,
outputColumnNames: new[] { "output" }, inputColumnNames: new[] { "input" });

var model = pipeline.Fit(dataView);
return ctx.Model.CreatePredictionEngine<FloatInput, WrongOutputObj>(model);
}

[OnnxFact]
public void OnnxSequenceTypeWithColumnNameAttributeTestWithWrongCustomType()
{
var modelFile = Path.Combine(Directory.GetCurrentDirectory(), "zipmap", "TestZipMapString.onnx");
var expectedExceptionMessage = "The expected type 'System.Collections.Generic.IEnumerable`1[System.Collections.Generic.IDictionary`2[System.String,System.Single]]'" +
" does not match the type of the 'output' member: 'System.Collections.Generic.IEnumerable`1[System.Single]'." +
" Please change the output member to 'System.Collections.Generic.IEnumerable`1[System.Collections.Generic.IDictionary`2[System.String,System.Single]]'";
try
{
var predictor = LoadModelWithWrongCustomType(modelFile);
Assert.True(false);
}
catch (System.Exception ex)
{
//truncate the string to only necessary information as Linux and Windows have different way of encoding the string
Assert.Equal(expectedExceptionMessage, ex.Message.Substring(0, 387));
return;
}
}
}
}

0 comments on commit 65f6c5f

Please sign in to comment.