Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error message when defining custom type for variables #5114

Merged
merged 23 commits into from
May 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 73 additions & 48 deletions src/Microsoft.ML.Data/Data/SchemaDefinition.cs
Original file line number Diff line number Diff line change
Expand Up @@ -326,21 +326,12 @@ public enum Direction
Both = Read | Write
}

/// <summary>
/// Create a schema definition by enumerating all public fields of the given type.
/// </summary>
/// <param name="userType">The type to base the schema on.</param>
/// <param name="direction">Accept fields and properties based on their direction.</param>
/// <returns>The generated schema definition.</returns>
public static SchemaDefinition Create(Type userType, Direction direction = Direction.Both)
internal static MemberInfo[] GetMemberInfos(Type userType, Direction direction)
{
// REVIEW: This will have to be updated whenever we start
// supporting properties and not just fields.
Contracts.CheckValue(userType, nameof(userType));

SchemaDefinition cols = new SchemaDefinition();
HashSet<string> colNames = new HashSet<string>();

var fieldInfos = userType.GetFields(BindingFlags.Public | BindingFlags.Instance);
var propertyInfos =
userType
Expand All @@ -349,57 +340,90 @@ public static SchemaDefinition Create(Type userType, Direction direction = Direc
((direction & Direction.Write) == Direction.Write && (x.CanWrite && x.GetSetMethod() != null))) &&
x.GetIndexParameters().Length == 0);

var memberInfos = (fieldInfos as IEnumerable<MemberInfo>).Concat(propertyInfos).ToArray();
return (fieldInfos as IEnumerable<MemberInfo>).Concat(propertyInfos).ToArray();
}

foreach (var memberInfo in memberInfos)
internal static bool NeedToCheckMemberInfo(MemberInfo memberInfo)
{
wangyems marked this conversation as resolved.
Show resolved Hide resolved
switch (memberInfo)
wangyems marked this conversation as resolved.
Show resolved Hide resolved
{
// Clause to handle the field that may be used to expose the cursor channel.
// This field does not need a column.
// REVIEW: maybe validate the channel attribute now, instead
// of later at cursor creation.
switch (memberInfo)
{
case FieldInfo fieldInfo:
if (fieldInfo.FieldType == typeof(IChannel))
continue;
case FieldInfo fieldInfo:
if (fieldInfo.FieldType == typeof(IChannel))
return false;

// Const fields do not need to be mapped.
if (fieldInfo.IsLiteral)
continue;
// Const fields do not need to be mapped.
if (fieldInfo.IsLiteral)
return false;

break;
break;

case PropertyInfo propertyInfo:
if (propertyInfo.PropertyType == typeof(IChannel))
continue;
break;
case PropertyInfo propertyInfo:
if (propertyInfo.PropertyType == typeof(IChannel))
return false;
break;

default:
Contracts.Assert(false);
throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
}
default:
Contracts.Assert(false);
throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
}

if (memberInfo.GetCustomAttribute<NoColumnAttribute>() != null)
continue;
if (memberInfo.GetCustomAttribute<NoColumnAttribute>() != null)
return false;

var customAttributes = memberInfo.GetCustomAttributes();
var customTypeAttributes = customAttributes.Where(x => x is DataViewTypeAttribute);
if (customTypeAttributes.Count() > 1)
throw Contracts.ExceptParam(nameof(userType), "Member {0} cannot be marked with multiple attributes, {1}, derived from {2}.",
memberInfo.Name, customTypeAttributes, typeof(DataViewTypeAttribute));
else if (customTypeAttributes.Count() == 1)
{
var customTypeAttribute = (DataViewTypeAttribute)customTypeAttributes.First();
customTypeAttribute.Register();
}
return true;
}

var mappingNameAttr = memberInfo.GetCustomAttribute<ColumnNameAttribute>();
string name = mappingNameAttr?.Name ?? memberInfo.Name;
// Disallow duplicate names, because the field enumeration order is not actually
// well defined, so we are not guaranteed to have consistent "hiding" from run to
// run, across different .NET versions.
if (!colNames.Add(name))
throw Contracts.ExceptParam(nameof(userType), "Duplicate column name '{0}' detected, this is disallowed", name);
internal static bool GetNameAndCustomAttributes(MemberInfo memberInfo, Type userType, HashSet<string> colNames, out string name, out IEnumerable<Attribute> customAttributes)
{
name = null;
customAttributes = null;

if (!NeedToCheckMemberInfo(memberInfo))
return false;

customAttributes = memberInfo.GetCustomAttributes();
var customTypeAttributes = customAttributes.Where(x => x is DataViewTypeAttribute);
if (customTypeAttributes.Count() > 1)
throw Contracts.ExceptParam(nameof(userType), "Member {0} cannot be marked with multiple attributes, {1}, derived from {2}.",
memberInfo.Name, customTypeAttributes, typeof(DataViewTypeAttribute));
else if (customTypeAttributes.Count() == 1)
{
var customTypeAttribute = (DataViewTypeAttribute)customTypeAttributes.First();
customTypeAttribute.Register();
}

var mappingNameAttr = memberInfo.GetCustomAttribute<ColumnNameAttribute>();
name = mappingNameAttr?.Name ?? memberInfo.Name;
// Disallow duplicate names, because the field enumeration order is not actually
// well defined, so we are not guaranteed to have consistent "hiding" from run to
// run, across different .NET versions.
if (!colNames.Add(name))
throw Contracts.ExceptParam(nameof(userType), "Duplicate column name '{0}' detected, this is disallowed", name);

return true;
}

/// <summary>
/// Create a schema definition by enumerating all public fields of the given type.
/// </summary>
/// <param name="userType">The type to base the schema on.</param>
/// <param name="direction">Accept fields and properties based on their direction.</param>
/// <returns>The generated schema definition.</returns>
public static SchemaDefinition Create(Type userType, Direction direction = Direction.Both)
{
var memberInfos = GetMemberInfos(userType, direction);

SchemaDefinition cols = new SchemaDefinition();
HashSet<string> colNames = new HashSet<string>();

foreach (var memberInfo in memberInfos)
{
if (!GetNameAndCustomAttributes(memberInfo, userType, colNames, out string name, out IEnumerable<Attribute> customAttributes))
continue;

InternalSchemaDefinition.GetVectorAndItemType(memberInfo, out bool isVector, out Type dataType);

Expand Down Expand Up @@ -442,6 +466,7 @@ public static SchemaDefinition Create(Type userType, Direction direction = Direc

cols.Add(new Column(memberInfo.Name, columnType, name));
}

return cols;
}
}
Expand Down
37 changes: 22 additions & 15 deletions src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs
Original file line number Diff line number Diff line change
Expand Up @@ -160,19 +160,7 @@ public static void GetVectorAndItemType(MemberInfo memberInfo, out bool isVector
}
}

/// <summary>
/// Given a type and name for a variable, returns whether this appears to be a vector type,
/// and also the associated data type for this type. If a valid data type could not
/// be determined, this will throw.
/// </summary>
/// <param name="name">The name of the variable to inspect.</param>
/// <param name="rawType">The type of the variable to inspect.</param>
/// <param name="attributes">Attribute of <paramref name="rawType"/>. It can be <see langword="null"/> if attributes don't exist.</param>
/// <param name="isVector">Whether this appears to be a vector type.</param>
/// <param name="itemType">
/// The corresponding <see cref="PrimitiveDataViewType"/> RawType of the type, or items of this type if vector.
/// </param>
public static void GetVectorAndItemType(string name, Type rawType, IEnumerable<Attribute> attributes, out bool isVector, out Type itemType)
public static void GetMappedType(Type rawType, out Type itemType, out bool isVector)
{
// Determine whether this is a vector, and also determine the raw item type.
isVector = true;
Expand All @@ -189,10 +177,29 @@ public static void GetVectorAndItemType(string name, Type rawType, IEnumerable<A
// The internal type of string is ReadOnlyMemory<char>. That is, string will be stored as ReadOnlyMemory<char> in IDataView.
if (itemType == typeof(string))
itemType = typeof(ReadOnlyMemory<char>);
}

/// <summary>
/// Given a type and name for a variable, returns whether this appears to be a vector type,
/// and also the associated data type for this type. If a valid data type could not
/// be determined, this will throw.
/// </summary>
/// <param name="name">The name of the variable to inspect.</param>
/// <param name="rawType">The type of the variable to inspect.</param>
/// <param name="attributes">Attribute of <paramref name="rawType"/>. It can be <see langword="null"/> if attributes don't exist.</param>
/// <param name="isVector">Whether this appears to be a vector type.</param>
/// <param name="itemType">
/// The corresponding <see cref="PrimitiveDataViewType"/> RawType of the type, or items of this type if vector.
/// </param>
public static void GetVectorAndItemType(string name, Type rawType, IEnumerable<Attribute> attributes, out bool isVector, out Type itemType)
{
GetMappedType(rawType, out itemType, out isVector);
// Check if the itemType extracted from rawType is supported by ML.NET's type system.
// It must be one of either ML.NET's pre-defined types or custom types registered by the user.
else if (!itemType.TryGetDataKind(out _) && !DataViewTypeManager.Knows(itemType, attributes))
throw Contracts.ExceptParam(nameof(rawType), "Could not determine an IDataView type for member {0}", name);
if (!itemType.TryGetDataKind(out _) && !DataViewTypeManager.Knows(itemType, attributes))
{
throw Contracts.ExceptParam(nameof(rawType), "Could not determine an IDataView type and registered custom types for member {0}", name);
}
}

public static InternalSchemaDefinition Create(Type userType, SchemaDefinition.Direction direction)
Expand Down
86 changes: 86 additions & 0 deletions src/Microsoft.ML.Data/DataView/TypedCursor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,90 @@ public RowCursor<TRow>[] GetCursorSet(Func<int, bool> additionalColumnsPredicate
.ToArray();
}

private static void ValidateMemberInfo(MemberInfo memberInfo, IDataView data)
{
wangyems marked this conversation as resolved.
Show resolved Hide resolved
wangyems marked this conversation as resolved.
Show resolved Hide resolved
if (!SchemaDefinition.NeedToCheckMemberInfo(memberInfo))
return;
wangyems marked this conversation as resolved.
Show resolved Hide resolved

wangyems marked this conversation as resolved.
Show resolved Hide resolved
var mappingNameAttr = memberInfo.GetCustomAttribute<ColumnNameAttribute>();
var singleName = mappingNameAttr?.Name ?? memberInfo.Name;

Type actualType = null;
bool isVector = false;
IEnumerable<Attribute> customAttributes = null;
switch (memberInfo)
{
case FieldInfo fieldInfo:
InternalSchemaDefinition.GetMappedType(fieldInfo.FieldType, out actualType, out isVector);
customAttributes = fieldInfo.GetCustomAttributes();
break;

case PropertyInfo propertyInfo:
InternalSchemaDefinition.GetMappedType(propertyInfo.PropertyType, out actualType, out isVector);
customAttributes = propertyInfo.GetCustomAttributes();
break;

default:
Contracts.Assert(false);
throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
}

if (!actualType.TryGetDataKind(out _) && !DataViewTypeManager.Knows(actualType, customAttributes))
{
int colIndex;
data.Schema.TryGetColumnIndex(singleName, out colIndex);
DataViewType expectedType = data.Schema[colIndex].Type;
if (!actualType.Equals(expectedType.RawType))
throw Contracts.ExceptParam(nameof(actualType), $"The expected type '{expectedType.RawType}' does not match the type of the '{singleName}' member: '{actualType}'. Please change the {singleName} member to '{expectedType.RawType}'");
}
}

private static void ValidateUserType(SchemaDefinition schemaDefinition, Type userType, IDataView data)
{
//Get memberInfos
MemberInfo[] memberInfos = null;
if (schemaDefinition == null)
{
memberInfos = SchemaDefinition.GetMemberInfos(userType, SchemaDefinition.Direction.Write);

if (memberInfos == null)
return;

foreach (var memberInfo in memberInfos)
ValidateMemberInfo(memberInfo, data);
}
else
{
for (int i = 0; i < schemaDefinition.Count; ++i)
{
var col = schemaDefinition[i];
if (col.MemberName == null)
throw Contracts.ExceptParam(nameof(schemaDefinition), "Null field name detected in schema definition");

MemberInfo memberInfo = null;
// Infer the column name.
var colName = string.IsNullOrEmpty(col.ColumnName) ? col.MemberName : col.ColumnName;

if (col.Generator == null)
{
memberInfo = userType.GetField(col.MemberName);

if (memberInfo == null)
memberInfo = userType.GetProperty(col.MemberName);

if ((memberInfo is FieldInfo && (memberInfo as FieldInfo).FieldType == typeof(IChannel)) ||
(memberInfo is PropertyInfo && (memberInfo as PropertyInfo).PropertyType == typeof(IChannel)))
continue;
}
else
{
memberInfo = col.ReturnType;
}
ValidateMemberInfo(memberInfo, data);
}
wangyems marked this conversation as resolved.
Show resolved Hide resolved
}
}

/// <summary>
/// Create a Cursorable object on a given data view.
/// </summary>
Expand All @@ -231,6 +315,8 @@ public static TypedCursorable<TRow> Create(IHostEnvironment env, IDataView data,
env.AssertValue(data);
env.AssertValueOrNull(schemaDefinition);

ValidateUserType(schemaDefinition, typeof(TRow), data);

var outSchema = schemaDefinition == null
? InternalSchemaDefinition.Create(typeof(TRow), SchemaDefinition.Direction.Write)
: InternalSchemaDefinition.Create(typeof(TRow), schemaDefinition);
Expand Down
39 changes: 39 additions & 0 deletions test/Microsoft.ML.Tests/OnnxSequenceTypeWithAttributesTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,46 @@ public void OnnxSequenceTypeWithColumnNameAttributeTest()
{
Assert.Equal(onnxOut[keys[i]], input.Input[i]);
}
}

public class WrongOutputObj
{
[ColumnName("output")]
[OnnxSequenceType(typeof(IEnumerable<float>))]
public IEnumerable<float> Output;
}

public static PredictionEngine<FloatInput, WrongOutputObj> LoadModelWithWrongCustomType(string onnxModelFilePath)
{
var ctx = new MLContext(1);
var dataView = ctx.Data.LoadFromEnumerable(new List<FloatInput>());

var pipeline = ctx.Transforms.ApplyOnnxModel(
modelFile: onnxModelFilePath,
outputColumnNames: new[] { "output" }, inputColumnNames: new[] { "input" });

var model = pipeline.Fit(dataView);
return ctx.Model.CreatePredictionEngine<FloatInput, WrongOutputObj>(model);
}

[OnnxFact]
public void OnnxSequenceTypeWithColumnNameAttributeTestWithWrongCustomType()
{
var modelFile = Path.Combine(Directory.GetCurrentDirectory(), "zipmap", "TestZipMapString.onnx");
var expectedExceptionMessage = "The expected type 'System.Collections.Generic.IEnumerable`1[System.Collections.Generic.IDictionary`2[System.String,System.Single]]'" +
" does not match the type of the 'output' member: 'System.Collections.Generic.IEnumerable`1[System.Single]'." +
" Please change the output member to 'System.Collections.Generic.IEnumerable`1[System.Collections.Generic.IDictionary`2[System.String,System.Single]]'";
try
{
var predictor = LoadModelWithWrongCustomType(modelFile);
Assert.True(false);
}
catch (System.Exception ex)
{
//truncate the string to only necessary information as Linux and Windows have different way of encoding the string
Assert.Equal(expectedExceptionMessage, ex.Message.Substring(0, 387));
return;
}
}
}
}