Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error message when defining custom type for variables #5114

Merged
merged 23 commits into from
May 28, 2020
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 109 additions & 84 deletions src/Microsoft.ML.Data/Data/SchemaDefinition.cs
Original file line number Diff line number Diff line change
Expand Up @@ -326,21 +326,12 @@ public enum Direction
Both = Read | Write
}

/// <summary>
/// Create a schema definition by enumerating all public fields of the given type.
/// </summary>
/// <param name="userType">The type to base the schema on.</param>
/// <param name="direction">Accept fields and properties based on their direction.</param>
/// <returns>The generated schema definition.</returns>
public static SchemaDefinition Create(Type userType, Direction direction = Direction.Both)
public static MemberInfo[] GetMemberInfos(Type userType, Direction direction)
wangyems marked this conversation as resolved.
Show resolved Hide resolved
{
// REVIEW: This will have to be updated whenever we start
// supporting properties and not just fields.
Contracts.CheckValue(userType, nameof(userType));

SchemaDefinition cols = new SchemaDefinition();
HashSet<string> colNames = new HashSet<string>();

var fieldInfos = userType.GetFields(BindingFlags.Public | BindingFlags.Instance);
var propertyInfos =
userType
Expand All @@ -349,98 +340,132 @@ public static SchemaDefinition Create(Type userType, Direction direction = Direc
((direction & Direction.Write) == Direction.Write && (x.CanWrite && x.GetSetMethod() != null))) &&
x.GetIndexParameters().Length == 0);

var memberInfos = (fieldInfos as IEnumerable<MemberInfo>).Concat(propertyInfos).ToArray();
return (fieldInfos as IEnumerable<MemberInfo>).Concat(propertyInfos).ToArray();
}

foreach (var memberInfo in memberInfos)
public static bool CheckMemberInfo(MemberInfo memberInfo)
wangyems marked this conversation as resolved.
Show resolved Hide resolved
{
wangyems marked this conversation as resolved.
Show resolved Hide resolved
switch (memberInfo)
wangyems marked this conversation as resolved.
Show resolved Hide resolved
{
// Clause to handle the field that may be used to expose the cursor channel.
// This field does not need a column.
// REVIEW: maybe validate the channel attribute now, instead
// of later at cursor creation.
switch (memberInfo)
{
case FieldInfo fieldInfo:
if (fieldInfo.FieldType == typeof(IChannel))
continue;
case FieldInfo fieldInfo:
if (fieldInfo.FieldType == typeof(IChannel))
return false;

// Const fields do not need to be mapped.
if (fieldInfo.IsLiteral)
continue;
// Const fields do not need to be mapped.
if (fieldInfo.IsLiteral)
return false;

break;
break;

case PropertyInfo propertyInfo:
if (propertyInfo.PropertyType == typeof(IChannel))
continue;
break;
case PropertyInfo propertyInfo:
if (propertyInfo.PropertyType == typeof(IChannel))
return false;
break;

default:
Contracts.Assert(false);
throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
}
default:
Contracts.Assert(false);
throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
}

if (memberInfo.GetCustomAttribute<NoColumnAttribute>() != null)
return false;

if (memberInfo.GetCustomAttribute<NoColumnAttribute>() != null)
continue;
return true;
}

var customAttributes = memberInfo.GetCustomAttributes();
var customTypeAttributes = customAttributes.Where(x => x is DataViewTypeAttribute);
if (customTypeAttributes.Count() > 1)
throw Contracts.ExceptParam(nameof(userType), "Member {0} cannot be marked with multiple attributes, {1}, derived from {2}.",
memberInfo.Name, customTypeAttributes, typeof(DataViewTypeAttribute));
else if (customTypeAttributes.Count() == 1)
{
var customTypeAttribute = (DataViewTypeAttribute)customTypeAttributes.First();
customTypeAttribute.Register();
}
public static bool ValidateMemberInfo(MemberInfo memberInfo, Type userType, HashSet<string> colNames, out string name, out IEnumerable<Attribute> customAttributes)
wangyems marked this conversation as resolved.
Show resolved Hide resolved
{
// Clause to handle the field that may be used to expose the cursor channel.
// This field does not need a column.
// REVIEW: maybe validate the channel attribute now, instead
// of later at cursor creation.

name = null;
customAttributes = null;

var mappingNameAttr = memberInfo.GetCustomAttribute<ColumnNameAttribute>();
string name = mappingNameAttr?.Name ?? memberInfo.Name;
// Disallow duplicate names, because the field enumeration order is not actually
// well defined, so we are not guaranteed to have consistent "hiding" from run to
// run, across different .NET versions.
if (!colNames.Add(name))
throw Contracts.ExceptParam(nameof(userType), "Duplicate column name '{0}' detected, this is disallowed", name);
if (!CheckMemberInfo(memberInfo))
return false;

customAttributes = memberInfo.GetCustomAttributes();
var customTypeAttributes = customAttributes.Where(x => x is DataViewTypeAttribute);
if (customTypeAttributes.Count() > 1)
throw Contracts.ExceptParam(nameof(userType), "Member {0} cannot be marked with multiple attributes, {1}, derived from {2}.",
memberInfo.Name, customTypeAttributes, typeof(DataViewTypeAttribute));
else if (customTypeAttributes.Count() == 1)
{
var customTypeAttribute = (DataViewTypeAttribute)customTypeAttributes.First();
customTypeAttribute.Register();
}

InternalSchemaDefinition.GetVectorAndItemType(memberInfo, out bool isVector, out Type dataType);
var mappingNameAttr = memberInfo.GetCustomAttribute<ColumnNameAttribute>();
name = mappingNameAttr?.Name ?? memberInfo.Name;
// Disallow duplicate names, because the field enumeration order is not actually
// well defined, so we are not guaranteed to have consistent "hiding" from run to
// run, across different .NET versions.
if (!colNames.Add(name))
throw Contracts.ExceptParam(nameof(userType), "Duplicate column name '{0}' detected, this is disallowed", name);

// Get the column type.
DataViewType columnType;
if (!DataViewTypeManager.Knows(dataType, customAttributes))
return true;
}

/// <summary>
/// Create a schema definition by enumerating all public fields of the given type.
/// </summary>
/// <param name="userType">The type to base the schema on.</param>
/// <param name="direction">Accept fields and properties based on their direction.</param>
/// <returns>The generated schema definition.</returns>
public static SchemaDefinition Create(Type userType, Direction direction = Direction.Both)
{
var memberInfos = GetMemberInfos(userType, direction);

SchemaDefinition cols = new SchemaDefinition();
HashSet<string> colNames = new HashSet<string>();

foreach (var memberInfo in memberInfos)
{
if (ValidateMemberInfo(memberInfo, userType, colNames, out string name, out IEnumerable<Attribute> customAttributes))
wangyems marked this conversation as resolved.
Show resolved Hide resolved
{
PrimitiveDataViewType itemType;
var keyAttr = memberInfo.GetCustomAttribute<KeyTypeAttribute>();
if (keyAttr != null)
{
if (!KeyDataViewType.IsValidDataType(dataType))
throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with KeyType attribute, but does not appear to be a valid kind of data for a key type", memberInfo.Name);
if (keyAttr.KeyCount == null)
itemType = new KeyDataViewType(dataType, dataType.ToMaxInt());
else
itemType = new KeyDataViewType(dataType, keyAttr.KeyCount.Count.GetValueOrDefault());
}
else
itemType = ColumnTypeExtensions.PrimitiveTypeFromType(dataType);
InternalSchemaDefinition.GetVectorAndItemType(memberInfo, out bool isVector, out Type dataType);

var vectorAttr = memberInfo.GetCustomAttribute<VectorTypeAttribute>();
if (vectorAttr != null && !isVector)
throw Contracts.ExceptParam(nameof(userType), $"Member {memberInfo.Name} marked with {nameof(VectorTypeAttribute)}, but does not appear to be a vector type", memberInfo.Name);
if (isVector)
// Get the column type.
DataViewType columnType;
if (!DataViewTypeManager.Knows(dataType, customAttributes))
{
int[] dims = vectorAttr?.Dims;
if (dims != null && dims.Any(d => d < 0))
throw Contracts.ExceptParam(nameof(userType), "Some of member {0}'s dimension lengths are negative");
if (Utils.Size(dims) == 0)
columnType = new VectorDataViewType(itemType, 0);
PrimitiveDataViewType itemType;
var keyAttr = memberInfo.GetCustomAttribute<KeyTypeAttribute>();
if (keyAttr != null)
{
if (!KeyDataViewType.IsValidDataType(dataType))
throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with KeyType attribute, but does not appear to be a valid kind of data for a key type", memberInfo.Name);
if (keyAttr.KeyCount == null)
itemType = new KeyDataViewType(dataType, dataType.ToMaxInt());
else
itemType = new KeyDataViewType(dataType, keyAttr.KeyCount.Count.GetValueOrDefault());
}
else
itemType = ColumnTypeExtensions.PrimitiveTypeFromType(dataType);

var vectorAttr = memberInfo.GetCustomAttribute<VectorTypeAttribute>();
if (vectorAttr != null && !isVector)
throw Contracts.ExceptParam(nameof(userType), $"Member {memberInfo.Name} marked with {nameof(VectorTypeAttribute)}, but does not appear to be a vector type", memberInfo.Name);
if (isVector)
{
int[] dims = vectorAttr?.Dims;
if (dims != null && dims.Any(d => d < 0))
throw Contracts.ExceptParam(nameof(userType), "Some of member {0}'s dimension lengths are negative");
if (Utils.Size(dims) == 0)
columnType = new VectorDataViewType(itemType, 0);
else
columnType = new VectorDataViewType(itemType, dims);
}
else
columnType = new VectorDataViewType(itemType, dims);
columnType = itemType;
}
else
columnType = itemType;
}
else
columnType = DataViewTypeManager.GetDataViewType(dataType, customAttributes);
columnType = DataViewTypeManager.GetDataViewType(dataType, customAttributes);

cols.Add(new Column(memberInfo.Name, columnType, name));
cols.Add(new Column(memberInfo.Name, columnType, name));
}
}
return cols;
}
Expand Down
37 changes: 22 additions & 15 deletions src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs
Original file line number Diff line number Diff line change
Expand Up @@ -160,19 +160,7 @@ public static void GetVectorAndItemType(MemberInfo memberInfo, out bool isVector
}
}

/// <summary>
/// Given a type and name for a variable, returns whether this appears to be a vector type,
/// and also the associated data type for this type. If a valid data type could not
/// be determined, this will throw.
/// </summary>
/// <param name="name">The name of the variable to inspect.</param>
/// <param name="rawType">The type of the variable to inspect.</param>
/// <param name="attributes">Attribute of <paramref name="rawType"/>. It can be <see langword="null"/> if attributes don't exist.</param>
/// <param name="isVector">Whether this appears to be a vector type.</param>
/// <param name="itemType">
/// The corresponding <see cref="PrimitiveDataViewType"/> RawType of the type, or items of this type if vector.
/// </param>
public static void GetVectorAndItemType(string name, Type rawType, IEnumerable<Attribute> attributes, out bool isVector, out Type itemType)
public static void GetMappedType(Type rawType, out Type itemType, out bool isVector)
{
// Determine whether this is a vector, and also determine the raw item type.
isVector = true;
Expand All @@ -189,10 +177,29 @@ public static void GetVectorAndItemType(string name, Type rawType, IEnumerable<A
// The internal type of string is ReadOnlyMemory<char>. That is, string will be stored as ReadOnlyMemory<char> in IDataView.
if (itemType == typeof(string))
itemType = typeof(ReadOnlyMemory<char>);
}

/// <summary>
/// Given a type and name for a variable, returns whether this appears to be a vector type,
/// and also the associated data type for this type. If a valid data type could not
/// be determined, this will throw.
/// </summary>
/// <param name="name">The name of the variable to inspect.</param>
/// <param name="rawType">The type of the variable to inspect.</param>
/// <param name="attributes">Attribute of <paramref name="rawType"/>. It can be <see langword="null"/> if attributes don't exist.</param>
/// <param name="isVector">Whether this appears to be a vector type.</param>
/// <param name="itemType">
/// The corresponding <see cref="PrimitiveDataViewType"/> RawType of the type, or items of this type if vector.
/// </param>
public static void GetVectorAndItemType(string name, Type rawType, IEnumerable<Attribute> attributes, out bool isVector, out Type itemType)
{
GetMappedType(rawType, out itemType, out isVector);
// Check if the itemType extracted from rawType is supported by ML.NET's type system.
// It must be one of either ML.NET's pre-defined types or custom types registered by the user.
else if (!itemType.TryGetDataKind(out _) && !DataViewTypeManager.Knows(itemType, attributes))
throw Contracts.ExceptParam(nameof(rawType), "Could not determine an IDataView type for member {0}", name);
if (!itemType.TryGetDataKind(out _) && !DataViewTypeManager.Knows(itemType, attributes))
{
throw Contracts.ExceptParam(nameof(rawType), "Could not determine an IDataView type and registered custom types for member {0}", name);
}
}

public static InternalSchemaDefinition Create(Type userType, SchemaDefinition.Direction direction)
Expand Down
Loading