diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/CustomMappingWithInMemoryCustomType.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/CustomMappingWithInMemoryCustomType.cs new file mode 100644 index 0000000000..688c5d1fe5 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/CustomMappingWithInMemoryCustomType.cs @@ -0,0 +1,179 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML; +using Microsoft.ML.Data; + +namespace Samples.Dynamic +{ + class CustomMappingWithInMemoryCustomType + { + static public void Example() + { + var mlContext = new MLContext(); + // Build in-memory data. + var tribe = new List() { new AlienHero("ML.NET", 2, 1000, 2000, 3000, 4000, 5000, 6000, 7000) }; + + // Build a ML.NET pipeline and make prediction. + var tribeDataView = mlContext.Data.LoadFromEnumerable(tribe); + var pipeline = mlContext.Transforms.CustomMapping(AlienFusionProcess.GetMapping(), contractName: null); + var model = pipeline.Fit(tribeDataView); + var tribeTransformed = model.Transform(tribeDataView); + + // Print out prediction produced by the model. + var firstAlien = mlContext.Data.CreateEnumerable(tribeTransformed, false).First(); + Console.WriteLine($"We got a super alien with name {firstAlien.Name}, age {firstAlien.Merged.Age}, " + + $"height {firstAlien.Merged.Height}, weight {firstAlien.Merged.Weight}, and {firstAlien.Merged.HandCount} hands."); + + // Expected output: + // We got a super alien with name Super Unknown, age 4002, height 6000, weight 8000, and 10000 hands. + + // Create a prediction engine and print out its prediction. + var engine = mlContext.Model.CreatePredictionEngine(model); + var alien = new AlienHero("TEN.LM", 1, 2, 3, 4, 5, 6, 7, 8); + var superAlien = engine.Predict(alien); + Console.Write($"We got a super alien with name {superAlien.Name}, age {superAlien.Merged.Age}, " + + $"height {superAlien.Merged.Height}, weight {superAlien.Merged.Weight}, and {superAlien.Merged.HandCount} hands."); + + // Expected output: + // We got a super alien with name Super Unknown, age 6, height 8, weight 10, and 12 hands. + } + + // A custom type which ML.NET doesn't know yet. Its value will be loaded as a DataView column in this test. + private class AlienBody + { + public int Age { get; set; } + public float Height { get; set; } + public float Weight { get; set; } + public int HandCount { get; set; } + + public AlienBody(int age, float height, float weight, int handCount) + { + Age = age; + Height = height; + Weight = weight; + HandCount = handCount; + } + } + + // DataViewTypeAttribute applied to class AlienBody members. + private sealed class AlienTypeAttributeAttribute : DataViewTypeAttribute + { + public int RaceId { get; } + + // Create an DataViewTypeAttribute> from raceId to a AlienBody. + public AlienTypeAttributeAttribute(int raceId) + { + RaceId = raceId; + } + + // A function implicitly invoked by ML.NET when processing a custom type. + // It binds a DataViewType to a custom type plus its attributes. + public override void Register() + { + DataViewTypeManager.Register(new DataViewAlienBodyType(RaceId), typeof(AlienBody), new[] { this }); + } + + public override bool Equals(DataViewTypeAttribute other) + { + if (other is AlienTypeAttributeAttribute) + return RaceId == ((AlienTypeAttributeAttribute)other).RaceId; + return false; + } + + public override int GetHashCode() => RaceId.GetHashCode(); + } + + // A custom class with a type which ML.NET doesn't know yet. Its value will be loaded as a DataView row in this test. + // It will be the input of AlienFusionProcess.MergeBody(AlienHero, SuperAlienHero). + // + // The members One> and Two" would be mapped to different types inside ML.NET type system because they + // have different AlienTypeAttributeAttribute's. For example, the column type of One would be DataViewAlienBodyType + // with RaceId=100. + // + private class AlienHero + { + public string Name { get; set; } + + [AlienTypeAttribute(100)] + public AlienBody One { get; set; } + + [AlienTypeAttribute(200)] + public AlienBody Two { get; set; } + + public AlienHero() + { + Name = "Unknown"; + One = new AlienBody(0, 0, 0, 0); + Two = new AlienBody(0, 0, 0, 0); + } + + public AlienHero(string name, + int age, float height, float weight, int handCount, + int anotherAge, float anotherHeight, float anotherWeight, int anotherHandCount) + { + Name = "Unknown"; + One = new AlienBody(age, height, weight, handCount); + Two = new AlienBody(anotherAge, anotherHeight, anotherWeight, anotherHandCount); + } + } + + // Type of AlienBody in ML.NET's type system. + // It usually shows up as DataViewSchema.Column.Type among IDataView.Schema. + private class DataViewAlienBodyType : StructuredDataViewType + { + public int RaceId { get; } + + public DataViewAlienBodyType(int id) : base(typeof(AlienBody)) + { + RaceId = id; + } + + public override bool Equals(DataViewType other) + { + if (other is DataViewAlienBodyType otherAlien) + return otherAlien.RaceId == RaceId; + return false; + } + + public override int GetHashCode() + { + return RaceId.GetHashCode(); + } + } + + // The output type of processing AlienHero using AlienFusionProcess.MergeBody(AlienHero, SuperAlienHero). + private class SuperAlienHero + { + public string Name { get; set; } + + [AlienTypeAttribute(007)] + public AlienBody Merged { get; set; } + + public SuperAlienHero() + { + Name = "Unknown"; + Merged = new AlienBody(0, 0, 0, 0); + } + } + + // The implementation of custom mapping is MergeBody. It accepts AlienHero and produces SuperAlienHero. + private class AlienFusionProcess + { + public static void MergeBody(AlienHero input, SuperAlienHero output) + { + output.Name = "Super " + input.Name; + output.Merged.Age = input.One.Age + input.Two.Age; + output.Merged.Height = input.One.Height + input.Two.Height; + output.Merged.Weight = input.One.Weight + input.Two.Weight; + output.Merged.HandCount = input.One.HandCount + input.Two.HandCount; + } + + public static Action GetMapping() + { + return MergeBody; + } + } + + } +} diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/ConvertToGrayScaleInMemory.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/ConvertToGrayScaleInMemory.cs new file mode 100644 index 0000000000..883dfa5dc1 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/ConvertToGrayScaleInMemory.cs @@ -0,0 +1,85 @@ +using System; +using System.Drawing; +using Microsoft.ML; +using Microsoft.ML.Transforms.Image; + +namespace Samples.Dynamic +{ + class ConvertToGrayScaleInMemory + { + static public void Example() + { + var mlContext = new MLContext(); + // Create an image list. + var images = new[] { new ImageDataPoint(2, 3, Color.Blue), new ImageDataPoint(2, 3, Color.Red) }; + + // Convert the list of data points to an IDataView object, which is consumable by ML.NET API. + var data = mlContext.Data.LoadFromEnumerable(images); + + // Convert image to gray scale. + var pipeline = mlContext.Transforms.ConvertToGrayscale("GrayImage", "Image"); + + // Fit the model. + var model = pipeline.Fit(data); + + // Test path: image files -> IDataView -> Enumerable of Bitmaps. + var transformedData = model.Transform(data); + + // Load images in DataView back to Enumerable. + var transformedDataPoints = mlContext.Data.CreateEnumerable(transformedData, false); + + // Print out input and output pixels. + foreach (var dataPoint in transformedDataPoints) + { + var image = dataPoint.Image; + var grayImage = dataPoint.GrayImage; + for (int x = 0; x < grayImage.Width; ++x) + { + for (int y = 0; y < grayImage.Height; ++y) + { + var pixel = image.GetPixel(x, y); + var grayPixel = grayImage.GetPixel(x, y); + Console.WriteLine($"The original pixel is {pixel} and its pixel in gray is {grayPixel}"); + } + } + } + + // Expected output: + // The original pixel is Color[A = 255, R = 0, G = 0, B = 255] and its pixel in gray is Color[A = 255, R = 28, G = 28, B = 28] + // The original pixel is Color[A = 255, R = 0, G = 0, B = 255] and its pixel in gray is Color[A = 255, R = 28, G = 28, B = 28] + // The original pixel is Color[A = 255, R = 0, G = 0, B = 255] and its pixel in gray is Color[A = 255, R = 28, G = 28, B = 28] + // The original pixel is Color[A = 255, R = 0, G = 0, B = 255] and its pixel in gray is Color[A = 255, R = 28, G = 28, B = 28] + // The original pixel is Color[A = 255, R = 0, G = 0, B = 255] and its pixel in gray is Color[A = 255, R = 28, G = 28, B = 28] + // The original pixel is Color[A = 255, R = 0, G = 0, B = 255] and its pixel in gray is Color[A = 255, R = 28, G = 28, B = 28] + // The original pixel is Color[A = 255, R = 255, G = 0, B = 0] and its pixel in gray is Color[A = 255, R = 77, G = 77, B = 77] + // The original pixel is Color[A = 255, R = 255, G = 0, B = 0] and its pixel in gray is Color[A = 255, R = 77, G = 77, B = 77] + // The original pixel is Color[A = 255, R = 255, G = 0, B = 0] and its pixel in gray is Color[A = 255, R = 77, G = 77, B = 77] + // The original pixel is Color[A = 255, R = 255, G = 0, B = 0] and its pixel in gray is Color[A = 255, R = 77, G = 77, B = 77] + // The original pixel is Color[A = 255, R = 255, G = 0, B = 0] and its pixel in gray is Color[A = 255, R = 77, G = 77, B = 77] + // The original pixel is Color[A = 255, R = 255, G = 0, B = 0] and its pixel in gray is Color[A = 255, R = 77, G = 77, B = 77] + } + + private class ImageDataPoint + { + [ImageType(3, 4)] + public Bitmap Image { get; set; } + + [ImageType(3, 4)] + public Bitmap GrayImage { get; set; } + + public ImageDataPoint() + { + Image = null; + GrayImage = null; + } + + public ImageDataPoint(int width, int height, Color color) + { + Image = new Bitmap(width, height); + for (int i = 0; i < width; ++i) + for (int j = 0; j < height; ++j) + Image.SetPixel(i, j, color); + } + } + } +} diff --git a/src/Microsoft.ML.Data/Data/DataViewTypeManager.cs b/src/Microsoft.ML.Data/Data/DataViewTypeManager.cs new file mode 100644 index 0000000000..fda18633ff --- /dev/null +++ b/src/Microsoft.ML.Data/Data/DataViewTypeManager.cs @@ -0,0 +1,227 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using Microsoft.ML.Internal.CpuMath.Core; +using Microsoft.ML.Internal.Utilities; + +namespace Microsoft.ML.Data +{ + /// + /// A singleton class for managing the map between ML.NET and C# . + /// To support custom column type in , the column's underlying type (e.g., a C# class's type) + /// should be registered with a class derived from . + /// + public static class DataViewTypeManager + { + /// + /// Types have been used in ML.NET type systems. They can have multiple-to-one type mapping. + /// For example, UInt32 and Key can be mapped to . This class enforces one-to-one mapping for all + /// user-registered types. + /// + private static HashSet _bannedRawTypes = new HashSet() + { + typeof(Boolean), typeof(SByte), typeof(Byte), + typeof(Int16), typeof(UInt16), typeof(Int32), typeof(UInt32), + typeof(Int64), typeof(UInt64), typeof(Single), typeof(Double), + typeof(string), typeof(ReadOnlySpan), typeof(ReadOnlyMemory), + typeof(VBuffer<>), typeof(Nullable<>), typeof(DateTime), typeof(DateTimeOffset), + typeof(TimeSpan), typeof(DataViewRowId) + }; + + /// + /// Mapping from a plus its s to a . + /// + private static Dictionary _rawTypeToDataViewTypeMap = new Dictionary(); + + /// + /// Mapping from a to a plus its s. + /// + private static Dictionary _dataViewTypeToRawTypeMap = new Dictionary(); + + /// + /// The lock that one should acquire if the state of will be accessed or modified. + /// + private static object _lock = new object(); + + /// + /// Returns the registered for and its . + /// + internal static DataViewType GetDataViewType(Type type, IEnumerable typeAttributes = null) + { + lock (_lock) + { + // Compute the ID of type with extra attributes. + var rawType = new TypeWithAttributes(type, typeAttributes); + + // Get the DataViewType's ID which typeID is mapped into. + if (!_rawTypeToDataViewTypeMap.TryGetValue(rawType, out DataViewType dataViewType)) + throw Contracts.ExceptParam(nameof(type), $"The raw type {type} with attributes {typeAttributes} is not registered with a DataView type."); + + // Retrieve the actual DataViewType identified by dataViewType. + return dataViewType; + } + } + + /// + /// If has been registered with a , this function returns . + /// Otherwise, this function returns . + /// + internal static bool Knows(Type type, IEnumerable typeAttributes = null) + { + lock (_lock) + { + // Compute the ID of type with extra attributes. + var rawType = new TypeWithAttributes(type, typeAttributes); + + // Check if this ID has been associated with a DataViewType. + // Note that the dictionary below contains (rawType, dataViewType) pairs (key type is TypeWithAttributes, and value type is DataViewType). + if (_rawTypeToDataViewTypeMap.ContainsKey(rawType)) + return true; + else + return false; + } + } + + /// + /// If has been registered with a , this function returns . + /// Otherwise, this function returns . + /// + internal static bool Knows(DataViewType dataViewType) + { + lock (_lock) + { + // Check if this the ID has been associated with a DataViewType. + // Note that the dictionary below contains (dataViewType, rawType) pairs (key type is DataViewType, and value type is TypeWithAttributes). + if (_dataViewTypeToRawTypeMap.ContainsKey(dataViewType)) + return true; + else + return false; + } + } + + /// + /// This function tells that should be representation of data in in + /// ML.NET's type system. The registered must be a standard C# object's type. + /// + /// Native type in C#. + /// The corresponding type of in ML.NET's type system. + /// The s attached to . + public static void Register(DataViewType dataViewType, Type type, IEnumerable typeAttributes = null) + { + lock (_lock) + { + if (_bannedRawTypes.Contains(type)) + throw Contracts.ExceptParam(nameof(type), $"Type {type} has been registered as ML.NET's default supported type, " + + $"so it can't not be registered again."); + + var rawType = new TypeWithAttributes(type, typeAttributes); + + if (_rawTypeToDataViewTypeMap.ContainsKey(rawType) && _rawTypeToDataViewTypeMap[rawType].Equals(dataViewType) && + _dataViewTypeToRawTypeMap.ContainsKey(dataViewType) && _dataViewTypeToRawTypeMap[dataViewType].Equals(rawType)) + // This type pair has been registered. Note that registering one data type pair multiple times is allowed. + return; + + if (_rawTypeToDataViewTypeMap.ContainsKey(rawType) && !_rawTypeToDataViewTypeMap[rawType].Equals(dataViewType)) + { + // There is a pair of (rawType, anotherDataViewType) in _typeToDataViewType so we cannot register + // (rawType, dataViewType) again. The assumption here is that one rawType can only be associated + // with one dataViewType. + var associatedDataViewType = _rawTypeToDataViewTypeMap[rawType]; + throw Contracts.ExceptParam(nameof(type), $"Repeated type register. The raw type {type} " + + $"has been associated with {associatedDataViewType} so it cannot be associated with {dataViewType}."); + } + + if (_dataViewTypeToRawTypeMap.ContainsKey(dataViewType) && !_dataViewTypeToRawTypeMap[dataViewType].Equals(rawType)) + { + // There is a pair of (dataViewType, anotherRawType) in _dataViewTypeToType so we cannot register + // (dataViewType, rawType) again. The assumption here is that one dataViewType can only be associated + // with one rawType. + var associatedRawType = _dataViewTypeToRawTypeMap[dataViewType].TargetType; + throw Contracts.ExceptParam(nameof(dataViewType), $"Repeated type register. The DataView type {dataViewType} " + + $"has been associated with {associatedRawType} so it cannot be associated with {type}."); + } + + _rawTypeToDataViewTypeMap.Add(rawType, dataViewType); + _dataViewTypeToRawTypeMap.Add(dataViewType, rawType); + } + } + + /// + /// An instance of represents an unique key of its and . + /// + private class TypeWithAttributes + { + /// + /// The underlying type. + /// + public Type TargetType { get; } + + /// + /// The underlying type's attributes. Together with , uniquely defines + /// a key when using as the key type in . Note that the + /// uniqueness is determined by and below. + /// + private IEnumerable _associatedAttributes; + + public TypeWithAttributes(Type type, IEnumerable attributes) + { + TargetType = type; + _associatedAttributes = attributes; + } + + public override bool Equals(object obj) + { + if (obj is TypeWithAttributes other) + { + // Flag of having the same type. + var sameType = TargetType.Equals(other.TargetType); + // Flag of having the attribute configurations. + var sameAttributeConfig = true; + + if (_associatedAttributes == null && other._associatedAttributes == null) + sameAttributeConfig = true; + else if (_associatedAttributes == null && other._associatedAttributes != null) + sameAttributeConfig = false; + else if (_associatedAttributes != null && other._associatedAttributes == null) + sameAttributeConfig = false; + else if (_associatedAttributes.Count() != other._associatedAttributes.Count()) + sameAttributeConfig = false; + else + { + var zipped = _associatedAttributes.Zip(other._associatedAttributes, (attr, otherAttr) => (attr, otherAttr)); + foreach (var (attr, otherAttr) in zipped) + { + if (!attr.Equals(otherAttr)) + sameAttributeConfig = false; + } + } + + return sameType && sameAttributeConfig; + } + return false; + } + + /// + /// This function computes a hashing ID from and attributes attached to it. + /// If a type is defined as a member in a , can be obtained by calling + /// . + /// + public override int GetHashCode() + { + if (_associatedAttributes == null) + return TargetType.GetHashCode(); + + var code = TargetType.GetHashCode(); + foreach (var attr in _associatedAttributes) + code = Hashing.CombineHash(code, attr.GetHashCode()); + return code; + } + + } + } +} diff --git a/src/Microsoft.ML.Data/Data/SchemaDefinition.cs b/src/Microsoft.ML.Data/Data/SchemaDefinition.cs index e08960ffaf..157d35d940 100644 --- a/src/Microsoft.ML.Data/Data/SchemaDefinition.cs +++ b/src/Microsoft.ML.Data/Data/SchemaDefinition.cs @@ -382,6 +382,17 @@ public static SchemaDefinition Create(Type userType, Direction direction = Direc if (memberInfo.GetCustomAttribute() != null) continue; + var customAttributes = memberInfo.GetCustomAttributes(); + var customTypeAttributes = customAttributes.Where(x => x is DataViewTypeAttribute); + if (customTypeAttributes.Count() > 1) + throw Contracts.ExceptParam(nameof(userType), "Member {0} cannot be marked with multiple attributes, {1}, derived from {2}.", + memberInfo.Name, customTypeAttributes, typeof(DataViewTypeAttribute)); + else if (customTypeAttributes.Count() == 1) + { + var customTypeAttribute = (DataViewTypeAttribute)customTypeAttributes.First(); + customTypeAttribute.Register(); + } + var mappingNameAttr = memberInfo.GetCustomAttribute(); string name = mappingNameAttr?.Name ?? memberInfo.Name; // Disallow duplicate names, because the field enumeration order is not actually @@ -392,37 +403,42 @@ public static SchemaDefinition Create(Type userType, Direction direction = Direc InternalSchemaDefinition.GetVectorAndItemType(memberInfo, out bool isVector, out Type dataType); - PrimitiveDataViewType itemType; - var keyAttr = memberInfo.GetCustomAttribute(); - if (keyAttr != null) - { - if (!KeyDataViewType.IsValidDataType(dataType)) - throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with KeyType attribute, but does not appear to be a valid kind of data for a key type", memberInfo.Name); - if (keyAttr.KeyCount == null) - itemType = new KeyDataViewType(dataType, dataType.ToMaxInt()); - else - itemType = new KeyDataViewType(dataType, keyAttr.KeyCount.Count.GetValueOrDefault()); - } - else - itemType = ColumnTypeExtensions.PrimitiveTypeFromType(dataType); - // Get the column type. DataViewType columnType; - var vectorAttr = memberInfo.GetCustomAttribute(); - if (vectorAttr != null && !isVector) - throw Contracts.ExceptParam(nameof(userType), $"Member {memberInfo.Name} marked with {nameof(VectorTypeAttribute)}, but does not appear to be a vector type", memberInfo.Name); - if (isVector) + if (!DataViewTypeManager.Knows(dataType, customAttributes)) { - int[] dims = vectorAttr?.Dims; - if (dims != null && dims.Any(d => d < 0)) - throw Contracts.ExceptParam(nameof(userType), "Some of member {0}'s dimension lengths are negative"); - if (Utils.Size(dims) == 0) - columnType = new VectorDataViewType(itemType, 0); + PrimitiveDataViewType itemType; + var keyAttr = memberInfo.GetCustomAttribute(); + if (keyAttr != null) + { + if (!KeyDataViewType.IsValidDataType(dataType)) + throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with KeyType attribute, but does not appear to be a valid kind of data for a key type", memberInfo.Name); + if (keyAttr.KeyCount == null) + itemType = new KeyDataViewType(dataType, dataType.ToMaxInt()); + else + itemType = new KeyDataViewType(dataType, keyAttr.KeyCount.Count.GetValueOrDefault()); + } + else + itemType = ColumnTypeExtensions.PrimitiveTypeFromType(dataType); + + var vectorAttr = memberInfo.GetCustomAttribute(); + if (vectorAttr != null && !isVector) + throw Contracts.ExceptParam(nameof(userType), $"Member {memberInfo.Name} marked with {nameof(VectorTypeAttribute)}, but does not appear to be a vector type", memberInfo.Name); + if (isVector) + { + int[] dims = vectorAttr?.Dims; + if (dims != null && dims.Any(d => d < 0)) + throw Contracts.ExceptParam(nameof(userType), "Some of member {0}'s dimension lengths are negative"); + if (Utils.Size(dims) == 0) + columnType = new VectorDataViewType(itemType, 0); + else + columnType = new VectorDataViewType(itemType, dims); + } else - columnType = new VectorDataViewType(itemType, dims); + columnType = itemType; } else - columnType = itemType; + columnType = DataViewTypeManager.GetDataViewType(dataType, customAttributes); cols.Add(new Column(memberInfo.Name, columnType, name)); } diff --git a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs index 7fde918a09..6df350c20c 100644 --- a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs @@ -269,6 +269,10 @@ private Delegate CreateGetter(DataViewType colType, InternalSchemaDefinition.Col return Utils.MarshalInvoke(delForKey, keyRawType, peek, colType); } } + else if (DataViewTypeManager.Knows(colType)) + { + del = CreateDirectGetterDelegate; + } else { // REVIEW: Is this even possible? @@ -843,7 +847,7 @@ public AnnotationInfo(string kind, T value, DataViewType annotationType = null) Contracts.Assert(value != null); bool isVector; Type itemType; - InternalSchemaDefinition.GetVectorAndItemType(typeof(T), "annotation value", out isVector, out itemType); + InternalSchemaDefinition.GetVectorAndItemType("annotation value", typeof(T), null, out isVector, out itemType); if (annotationType == null) { diff --git a/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs b/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs index 14e97cf100..7e6981140f 100644 --- a/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs +++ b/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs @@ -119,7 +119,7 @@ public void AssertRep() Contracts.Assert(Generator.GetMethodInfo().ReturnType == typeof(void)); // Checks that the return type of the generator is compatible with ColumnType. - GetVectorAndItemType(ComputedReturnType, "return type", out bool isVector, out Type itemType); + GetVectorAndItemType("return type", ComputedReturnType, null, out bool isVector, out Type itemType); Contracts.Assert(isVector == ColumnType is VectorDataViewType); Contracts.Assert(itemType == ColumnType.GetItemType().RawType); } @@ -147,11 +147,11 @@ public static void GetVectorAndItemType(MemberInfo memberInfo, out bool isVector switch (memberInfo) { case FieldInfo fieldInfo: - GetVectorAndItemType(fieldInfo.FieldType, fieldInfo.Name, out isVector, out itemType); + GetVectorAndItemType(fieldInfo.Name, fieldInfo.FieldType, fieldInfo.GetCustomAttributes(), out isVector, out itemType); break; case PropertyInfo propertyInfo: - GetVectorAndItemType(propertyInfo.PropertyType, propertyInfo.Name, out isVector, out itemType); + GetVectorAndItemType(propertyInfo.Name, propertyInfo.PropertyType, propertyInfo.GetCustomAttributes(), out isVector, out itemType); break; default: @@ -165,13 +165,14 @@ public static void GetVectorAndItemType(MemberInfo memberInfo, out bool isVector /// and also the associated data type for this type. If a valid data type could not /// be determined, this will throw. /// - /// The type of the variable to inspect. /// The name of the variable to inspect. + /// The type of the variable to inspect. + /// Attribute of . It can be if attributes don't exist. /// Whether this appears to be a vector type. /// /// The corresponding RawType of the type, or items of this type if vector. /// - public static void GetVectorAndItemType(Type rawType, string name, out bool isVector, out Type itemType) + public static void GetVectorAndItemType(string name, Type rawType, IEnumerable attributes, out bool isVector, out Type itemType) { // Determine whether this is a vector, and also determine the raw item type. isVector = true; @@ -185,9 +186,12 @@ public static void GetVectorAndItemType(Type rawType, string name, out bool isVe isVector = false; } + // The internal type of string is ReadOnlyMemory. That is, string will be stored as ReadOnlyMemory in IDataView. if (itemType == typeof(string)) itemType = typeof(ReadOnlyMemory); - else if (!itemType.TryGetDataKind(out _)) + // Check if the itemType extracted from rawType is supported by ML.NET's type system. + // It must be one of either ML.NET's pre-defined types or custom types registered by the user. + else if (!itemType.TryGetDataKind(out _) && !DataViewTypeManager.Knows(itemType, attributes)) throw Contracts.ExceptParam(nameof(rawType), "Could not determine an IDataView type for member {0}", name); } @@ -242,7 +246,7 @@ public static InternalSchemaDefinition Create(Type userType, SchemaDefinition us var parameterType = col.ReturnType; if (parameterType == null) throw Contracts.ExceptParam(nameof(userSchemaDefinition), "No return parameter found in computed column."); - GetVectorAndItemType(parameterType, "returnType", out isVector, out dataItemType); + GetVectorAndItemType("returnType", parameterType, null, out isVector, out dataItemType); } // Infer the column name. var colName = string.IsNullOrEmpty(col.ColumnName) ? col.MemberName : col.ColumnName; diff --git a/src/Microsoft.ML.Data/DataView/TypedCursor.cs b/src/Microsoft.ML.Data/DataView/TypedCursor.cs index b5ec10fd2a..f7b79feb21 100644 --- a/src/Microsoft.ML.Data/DataView/TypedCursor.cs +++ b/src/Microsoft.ML.Data/DataView/TypedCursor.cs @@ -319,6 +319,10 @@ private Action GenerateSetter(DataViewRow input, int index, InternalSchema del = CreateDirectSetter; } + else if (DataViewTypeManager.Knows(colType)) + { + del = CreateDirectSetter; + } else { // REVIEW: Is this even possible? diff --git a/src/Microsoft.ML.Data/Utils/ApiUtils.cs b/src/Microsoft.ML.Data/Utils/ApiUtils.cs index 6704cbda1e..117e1981cb 100644 --- a/src/Microsoft.ML.Data/Utils/ApiUtils.cs +++ b/src/Microsoft.ML.Data/Utils/ApiUtils.cs @@ -3,6 +3,8 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; +using System.Linq; using System.Reflection; using System.Reflection.Emit; using Microsoft.ML.Data; @@ -16,14 +18,15 @@ namespace Microsoft.ML internal static class ApiUtils { - private static OpCode GetAssignmentOpCode(Type t) + private static OpCode GetAssignmentOpCode(Type t, IEnumerable attributes) { // REVIEW: This should be a Dictionary based solution. // DvTypes, strings, arrays, all nullable types, VBuffers and RowId. if (t == typeof(ReadOnlyMemory) || t == typeof(string) || t.IsArray || (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(VBuffer<>)) || (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(Nullable<>)) || - t == typeof(DateTime) || t == typeof(DateTimeOffset) || t == typeof(TimeSpan) || t == typeof(DataViewRowId)) + t == typeof(DateTime) || t == typeof(DateTimeOffset) || t == typeof(TimeSpan) || + t == typeof(DataViewRowId) || DataViewTypeManager.Knows(t, attributes)) { return OpCodes.Stobj; } @@ -56,7 +59,7 @@ internal static Delegate GeneratePeek(InternalSchemaDefinition.Colum case FieldInfo fieldInfo: Type fieldType = fieldInfo.FieldType; - var assignmentOpCode = GetAssignmentOpCode(fieldType); + var assignmentOpCode = GetAssignmentOpCode(fieldType, fieldInfo.GetCustomAttributes()); Func func = GeneratePeek; var methInfo = func.GetMethodInfo().GetGenericMethodDefinition() .MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType); @@ -65,7 +68,7 @@ internal static Delegate GeneratePeek(InternalSchemaDefinition.Colum case PropertyInfo propertyInfo: Type propertyType = propertyInfo.PropertyType; - var assignmentOpCodeProp = GetAssignmentOpCode(propertyType); + var assignmentOpCodeProp = GetAssignmentOpCode(propertyType, propertyInfo.GetCustomAttributes()); Func funcProp = GeneratePeek; var methInfoProp = funcProp.GetMethodInfo().GetGenericMethodDefinition() .MakeGenericMethod(typeof(TOwn), typeof(TRow), propertyType); @@ -132,7 +135,7 @@ internal static Delegate GeneratePoke(InternalSchemaDefinition.Colum case FieldInfo fieldInfo: Type fieldType = fieldInfo.FieldType; - var assignmentOpCode = GetAssignmentOpCode(fieldType); + var assignmentOpCode = GetAssignmentOpCode(fieldType, fieldInfo.GetCustomAttributes()); Func func = GeneratePoke; var methInfo = func.GetMethodInfo().GetGenericMethodDefinition() .MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType); @@ -141,7 +144,7 @@ internal static Delegate GeneratePoke(InternalSchemaDefinition.Colum case PropertyInfo propertyInfo: Type propertyType = propertyInfo.PropertyType; - var assignmentOpCodeProp = GetAssignmentOpCode(propertyType); + var assignmentOpCodeProp = GetAssignmentOpCode(propertyType, propertyInfo.GetCustomAttributes()); Func funcProp = GeneratePoke; var methInfoProp = funcProp.GetMethodInfo().GetGenericMethodDefinition() .MakeGenericMethod(typeof(TOwn), typeof(TRow), propertyType); diff --git a/src/Microsoft.ML.DataView/DataViewType.cs b/src/Microsoft.ML.DataView/DataViewType.cs index 153ba02261..9a12cfe981 100644 --- a/src/Microsoft.ML.DataView/DataViewType.cs +++ b/src/Microsoft.ML.DataView/DataViewType.cs @@ -41,6 +41,10 @@ private protected DataViewType(Type rawType) // Object.Equals(Object) and GetHashCode. In classes below where Equals(ColumnType other) // is effectively a referencial comparison, there is no need to override base class implementations // of Object.Equals(Object) (and GetHashCode) since its also a referencial comparison. + /// + /// Return if is equivalent to and otherwise. + /// + /// Another to be compared with . public abstract bool Equals(DataViewType other); } @@ -461,4 +465,25 @@ public override bool Equals(DataViewType other) public override string ToString() => "TimeSpan"; } + + /// + /// should be used to decorated class properties and fields, if that class' instances will be loaded as ML.NET . + /// The function will be called to register a for a with its s. + /// Whenever a value typed to the registered and its s, that value's type (i.e., a ) + /// in would be the associated . + /// + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)] + public abstract class DataViewTypeAttribute : Attribute, IEquatable + { + /// + /// A function implicitly invoked by ML.NET when processing a custom type. It binds a DataViewType to a custom type plus its attributes. + /// + public abstract void Register(); + + /// + /// Return if is equivalent to and otherwise. + /// + /// Another to be compared with . + public abstract bool Equals(DataViewTypeAttribute other); + } } \ No newline at end of file diff --git a/src/Microsoft.ML.ImageAnalytics/ExtensionsCatalog.cs b/src/Microsoft.ML.ImageAnalytics/ExtensionsCatalog.cs index 2ffd5ba97c..78cbc74874 100644 --- a/src/Microsoft.ML.ImageAnalytics/ExtensionsCatalog.cs +++ b/src/Microsoft.ML.ImageAnalytics/ExtensionsCatalog.cs @@ -26,6 +26,7 @@ public static class ImageEstimatorsCatalog /// /// /// public static ImageGrayscalingEstimator ConvertToGrayscale(this TransformsCatalog catalog, string outputColumnName, string inputColumnName = null) diff --git a/src/Microsoft.ML.ImageAnalytics/ImageType.cs b/src/Microsoft.ML.ImageAnalytics/ImageType.cs index 3b50354177..6a30084f81 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageType.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageType.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Drawing; using Microsoft.ML.Data; using Microsoft.ML.Internal.Utilities; @@ -9,16 +10,76 @@ namespace Microsoft.ML.Transforms.Image { + /// + /// Allows a member to be marked as a , primarily allowing one to set + /// the shape of an image field. + /// + public sealed class ImageTypeAttribute : DataViewTypeAttribute + { + /// + /// The height of the image type. + /// + internal int Height { get; } + + /// + /// The width of the image type. + /// + internal int Width { get; } + + /// + /// Create an image type without knowing its height and width. + /// + public ImageTypeAttribute() + { + } + + /// + /// Create an image type with known height and width. + /// + public ImageTypeAttribute(int height, int width) + { + Contracts.CheckParam(width > 0, nameof(width), "Should be positive number"); + Contracts.CheckParam(height > 0, nameof(height), "Should be positive number"); + Height = height; + Width = width; + } + + /// + /// Images with the same width and height should equal. + /// + public override bool Equals(DataViewTypeAttribute other) + { + if (other is ImageTypeAttribute otherImage) + return Height == otherImage.Height && Width == otherImage.Width; + return false; + } + + /// + /// Produce the same hash code for all images with the same height and the same width. + /// + public override int GetHashCode() + { + return Hashing.CombineHash(Height.GetHashCode(), Width.GetHashCode()); + } + + public override void Register() + { + DataViewTypeManager.Register(new ImageDataViewType(Height, Width), typeof(Bitmap), new[] { this }); + } + } + public sealed class ImageDataViewType : StructuredDataViewType { public readonly int Height; public readonly int Width; + public ImageDataViewType(int height, int width) : base(typeof(Bitmap)) { Contracts.CheckParam(height > 0, nameof(height), "Must be positive."); Contracts.CheckParam(width > 0, nameof(width), " Must be positive."); Contracts.CheckParam((long)height * width <= int.MaxValue / 4, nameof(height), nameof(height) + " * " + nameof(width) + " is too large."); + Height = height; Width = width; } @@ -38,11 +99,6 @@ public override bool Equals(DataViewType other) return Width == tmp.Width; } - public override bool Equals(object other) - { - return other is DataViewType tmp && Equals(tmp); - } - public override int GetHashCode() { return Hashing.CombineHash(Height.GetHashCode(), Width.GetHashCode()); diff --git a/src/Microsoft.ML.Transforms/CustomMappingCatalog.cs b/src/Microsoft.ML.Transforms/CustomMappingCatalog.cs index 9dd54c37f8..819a518188 100644 --- a/src/Microsoft.ML.Transforms/CustomMappingCatalog.cs +++ b/src/Microsoft.ML.Transforms/CustomMappingCatalog.cs @@ -38,6 +38,7 @@ public static class CustomMappingCatalog /// /// public static CustomMappingEstimator CustomMapping(this TransformsCatalog catalog, Action mapAction, string contractName, diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs new file mode 100644 index 0000000000..b51d8952d5 --- /dev/null +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs @@ -0,0 +1,281 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Data; +using Microsoft.ML.Transforms; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.RunTests +{ + public class TestCustomTypeRegister : TestDataViewBase + { + public TestCustomTypeRegister(ITestOutputHelper helper) + : base(helper) + { + } + + /// + /// A custom type which ML.NET doesn't know yet. Its value will be loaded as a DataView column in this test. + /// + private class AlienBody + { + public int Age { get; set; } + public float Height { get; set; } + public float Weight { get; set; } + public int HandCount { get; set; } + + public AlienBody(int age, float height, float weight, int handCount) + { + Age = age; + Height = height; + Weight = weight; + HandCount = handCount; + } + } + + /// + /// applied to class members. + /// + private sealed class AlienTypeAttributeAttribute : DataViewTypeAttribute + { + public int RaceId { get; } + + /// + /// Create an from to a . + /// + public AlienTypeAttributeAttribute(int raceId) + { + RaceId = raceId; + } + + /// + /// A function implicitly invoked by ML.NET when processing a custom type. It binds a DataViewType to a custom type plus its attributes. + /// + public override void Register() + { + DataViewTypeManager.Register(new DataViewAlienBodyType(RaceId), typeof(AlienBody), new[] { this }); + } + + public override bool Equals(DataViewTypeAttribute other) + { + if (other is AlienTypeAttributeAttribute) + return RaceId == ((AlienTypeAttributeAttribute)other).RaceId; + return false; + } + + public override int GetHashCode() => RaceId.GetHashCode(); + } + + /// + /// A custom class with a type which ML.NET doesn't know yet. Its value will be loaded as a DataView row in this test. + /// It will be the input of . + /// + /// and would be mapped to different types inside ML.NET type system because they + /// have different s. For example, the column type of would + /// be . + /// + private class AlienHero + { + public string Name { get; set; } + + [AlienTypeAttribute(100)] + public AlienBody One { get; set; } + + [AlienTypeAttribute(200)] + public AlienBody Two { get; set; } + + public AlienHero() + { + Name = "Unknown"; + One = new AlienBody(0, 0, 0, 0); + Two = new AlienBody(0, 0, 0, 0); + } + + public AlienHero(string name, + int age, float height, float weight, int handCount, + int anotherAge, float anotherHeight, float anotherWeight, int anotherHandCount) + { + Name = "Unknown"; + One = new AlienBody(age, height, weight, handCount); + Two = new AlienBody(anotherAge, anotherHeight, anotherWeight, anotherHandCount); + } + } + + /// + /// Type of in ML.NET's type system. + /// It usually shows up as among . + /// + private class DataViewAlienBodyType : StructuredDataViewType + { + public int RaceId { get; } + + public DataViewAlienBodyType(int id) : base(typeof(AlienBody)) + { + RaceId = id; + } + + public override bool Equals(DataViewType other) + { + if (other is DataViewAlienBodyType otherAlien) + return otherAlien.RaceId == RaceId; + return false; + } + + public override int GetHashCode() + { + return RaceId.GetHashCode(); + } + } + + /// + /// The output type of processing using . + /// + private class SuperAlienHero + { + public string Name { get; set; } + + [AlienTypeAttribute(007)] + public AlienBody Merged { get; set; } + + public SuperAlienHero() + { + Name = "Unknown"; + Merged = new AlienBody(0, 0, 0, 0); + } + } + + /// + /// A mapping from to . It is used to create a + /// in . + /// + [CustomMappingFactoryAttribute("LambdaAlienHero")] + private class AlienFusionProcess : CustomMappingFactory + { + public static void MergeBody(AlienHero input, SuperAlienHero output) + { + output.Name = "Super " + input.Name; + output.Merged.Age = input.One.Age + input.Two.Age; + output.Merged.Height = input.One.Height + input.Two.Height; + output.Merged.Weight = input.One.Weight + input.Two.Weight; + output.Merged.HandCount = input.One.HandCount + input.Two.HandCount; + } + + public override Action GetMapping() + { + return MergeBody; + } + } + + [Fact] + public void RegisterTypeWithAttribute() + { + // Build in-memory data. + var tribe = new List() { new AlienHero("ML.NET", 2, 1000, 2000, 3000, 4000, 5000, 6000, 7000) }; + + // Build a ML.NET pipeline and make prediction. + var tribeDataView = ML.Data.LoadFromEnumerable(tribe); + var heroEstimator = new CustomMappingEstimator(ML, AlienFusionProcess.MergeBody, "LambdaAlienHero"); + var model = heroEstimator.Fit(tribeDataView); + var tribeTransformed = model.Transform(tribeDataView); + var tribeEnumerable = ML.Data.CreateEnumerable(tribeTransformed, false).ToList(); + + // Make sure the pipeline output is correct. + Assert.Equal(tribeEnumerable[0].Name, "Super " + tribe[0].Name); + Assert.Equal(tribeEnumerable[0].Merged.Age, tribe[0].One.Age + tribe[0].Two.Age); + Assert.Equal(tribeEnumerable[0].Merged.Height, tribe[0].One.Height + tribe[0].Two.Height); + Assert.Equal(tribeEnumerable[0].Merged.Weight, tribe[0].One.Weight + tribe[0].Two.Weight); + Assert.Equal(tribeEnumerable[0].Merged.HandCount, tribe[0].One.HandCount + tribe[0].Two.HandCount); + + // Build prediction engine from the trained pipeline. + var engine = ML.Model.CreatePredictionEngine(model); + var alien = new AlienHero("TEN.LM", 1, 2, 3, 4, 5, 6, 7, 8); + var superAlien = engine.Predict(alien); + + // Make sure the prediction engine produces expected result. + Assert.Equal(superAlien.Name, "Super " + alien.Name); + Assert.Equal(superAlien.Merged.Age, alien.One.Age + alien.Two.Age); + Assert.Equal(superAlien.Merged.Height, alien.One.Height + alien.Two.Height); + Assert.Equal(superAlien.Merged.Weight, alien.One.Weight + alien.Two.Weight); + Assert.Equal(superAlien.Merged.HandCount, alien.One.HandCount + alien.Two.HandCount); + } + + [Fact] + public void TestTypeManager() + { + // Semantically identical DataViewTypes should produce the same hash code. + var a = new DataViewAlienBodyType(9527); + var aCode = a.GetHashCode(); + var b = new DataViewAlienBodyType(9527); + var bCode = b.GetHashCode(); + + Assert.Equal(aCode, bCode); + + // Semantically identical attributes should produce the same hash code. + var c = new AlienTypeAttributeAttribute(1228); + var cCode = c.GetHashCode(); + var d = new AlienTypeAttributeAttribute(1228); + var dCode = d.GetHashCode(); + + Assert.Equal(cCode, dCode); + + // Check registering the same type pair is OK. + // Note that "a" and "b" should be identical. + DataViewTypeManager.Register(a, typeof(AlienBody)); + DataViewTypeManager.Register(a, typeof(AlienBody)); + DataViewTypeManager.Register(b, typeof(AlienBody)); + DataViewTypeManager.Register(b, typeof(AlienBody)); + + // Check if register of (a, typeof(AlienBody)) successes. + Assert.True(DataViewTypeManager.Knows(a)); + Assert.True(DataViewTypeManager.Knows(b)); + Assert.True(DataViewTypeManager.Knows(typeof(AlienBody))); + Assert.Equal(a, DataViewTypeManager.GetDataViewType(typeof(AlienBody))); + Assert.Equal(b, DataViewTypeManager.GetDataViewType(typeof(AlienBody))); + + // Make sure registering the same type twice throws. + bool isWrong = false; + try + { + // "a" has been registered with AlienBody without any attribute, so the user can't + // register "a" again with AlienBody plus the attribute "c." + DataViewTypeManager.Register(a, typeof(AlienBody), new[] { c }); + } + catch + { + isWrong = true; + } + Assert.True(isWrong); + + // Make sure registering the same type twice throws. + isWrong = false; + try + { + // AlienBody has been registered with "a," so user can't register it with + // "new DataViewAlienBodyType(5566)" again. + DataViewTypeManager.Register(new DataViewAlienBodyType(5566), typeof(AlienBody)); + } + catch + { + isWrong = true; + } + Assert.True(isWrong); + + // Register a type with attribute. + var e = new DataViewAlienBodyType(7788); + var f = new AlienTypeAttributeAttribute(8877); + DataViewTypeManager.Register(e, typeof(AlienBody), new[] { f }); + Assert.True(DataViewTypeManager.Knows(e)); + Assert.True(DataViewTypeManager.Knows(typeof(AlienBody), new[] { f })); + Assert.True(DataViewTypeManager.Knows(typeof(AlienBody), new[] { f })); + // "e" is associated with typeof(AlienBody) with "f," so the call below should return true. + Assert.Equal(e, DataViewTypeManager.GetDataViewType(typeof(AlienBody), new[] { f })); + // "a" is associated with typeof(AlienBody) without any attribute, so the call below should return false. + Assert.NotEqual(a, DataViewTypeManager.GetDataViewType(typeof(AlienBody), new[] { f })); + } + } +} diff --git a/test/Microsoft.ML.Tests/ImagesTests.cs b/test/Microsoft.ML.Tests/ImagesTests.cs index ef1abf93a9..77a913fd5c 100644 --- a/test/Microsoft.ML.Tests/ImagesTests.cs +++ b/test/Microsoft.ML.Tests/ImagesTests.cs @@ -184,6 +184,89 @@ public void TestGreyscaleTransformImages() Done(); } + [Fact] + public void TestGrayScaleInMemory() + { + // Create an image list. + var images = new List(){ new ImageDataPoint(10, 10, Color.Blue), new ImageDataPoint(10, 10, Color.Red) }; + + // Convert the list of data points to an IDataView object, which is consumable by ML.NET API. + var data = ML.Data.LoadFromEnumerable(images); + + // Convert image to gray scale. + var pipeline = ML.Transforms.ConvertToGrayscale("GrayImage", "Image"); + + // Fit the model. + var model = pipeline.Fit(data); + + // Test path: image files -> IDataView -> Enumerable of Bitmaps. + var transformedData = model.Transform(data); + + // Load images in DataView back to Enumerable. + var transformedDataPoints = ML.Data.CreateEnumerable(transformedData, false); + + foreach (var dataPoint in transformedDataPoints) + { + var image = dataPoint.Image; + var grayImage = dataPoint.GrayImage; + + Assert.NotNull(grayImage); + + Assert.Equal(image.Width, grayImage.Width); + Assert.Equal(image.Height, grayImage.Height); + + for (int x = 0; x < grayImage.Width; ++x) + { + for (int y = 0; y < grayImage.Height; ++y) + { + var pixel = grayImage.GetPixel(x, y); + // greyscale image has same values for R, G and B. + Assert.True(pixel.R == pixel.G && pixel.G == pixel.B); + } + } + } + + var engine = ML.Model.CreatePredictionEngine(model); + var singleImage = new ImageDataPoint(17, 36, Color.Pink); + var transformedSingleImage = engine.Predict(singleImage); + + Assert.Equal(singleImage.Image.Height, transformedSingleImage.GrayImage.Height); + Assert.Equal(singleImage.Image.Width, transformedSingleImage.GrayImage.Width); + + for (int x = 0; x < transformedSingleImage.GrayImage.Width; ++x) + { + for (int y = 0; y < transformedSingleImage.GrayImage.Height; ++y) + { + var pixel = transformedSingleImage.GrayImage.GetPixel(x, y); + // greyscale image has same values for R, G and B. + Assert.True(pixel.R == pixel.G && pixel.G == pixel.B); + } + } + } + + private class ImageDataPoint + { + [ImageType(10, 10)] + public Bitmap Image { get; set; } + + [ImageType(10, 10)] + public Bitmap GrayImage { get; set; } + + public ImageDataPoint() + { + Image = null; + GrayImage = null; + } + + public ImageDataPoint(int width, int height, Color color) + { + Image = new Bitmap(width, height); + for (int i = 0; i < width; ++i) + for (int j = 0; j < height; ++j) + Image.SetPixel(i, j, color); + } + } + [Fact] public void TestBackAndForthConversionWithAlphaInterleave() {