From 43a6a81185a04f05b47e168685dd22d624ac2b35 Mon Sep 17 00:00:00 2001 From: Lehonti Ramos <17771375+Lehonti@users.noreply.github.com> Date: Fri, 25 Aug 2023 23:58:35 +0200 Subject: [PATCH] File-scoped namespaces in files under `EntryPoints` (`Microsoft.ML.Core`) (#6790) Co-authored-by: Lehonti Ramos --- .../EntryPoints/EntryPointModuleAttribute.cs | 35 +- .../EntryPoints/EntryPointUtils.cs | 213 ++- .../EntryPoints/ModuleArgs.cs | 1193 ++++++++--------- .../EntryPoints/PredictorModel.cs | 99 +- .../EntryPoints/TransformModel.cs | 103 +- 5 files changed, 819 insertions(+), 824 deletions(-) diff --git a/src/Microsoft.ML.Core/EntryPoints/EntryPointModuleAttribute.cs b/src/Microsoft.ML.Core/EntryPoints/EntryPointModuleAttribute.cs index 1dcfd675c1..14bf50f20c 100644 --- a/src/Microsoft.ML.Core/EntryPoints/EntryPointModuleAttribute.cs +++ b/src/Microsoft.ML.Core/EntryPoints/EntryPointModuleAttribute.cs @@ -3,23 +3,22 @@ // See the LICENSE file in the project root for more information. using System; -namespace Microsoft.ML.EntryPoints -{ - /// - /// This is a signature for classes that are 'holders' of entry points and components. - /// - [BestFriend] - internal delegate void SignatureEntryPointModule(); +namespace Microsoft.ML.EntryPoints; + +/// +/// This is a signature for classes that are 'holders' of entry points and components. +/// +[BestFriend] +internal delegate void SignatureEntryPointModule(); - /// - /// A simplified assembly attribute for marking EntryPoint modules. - /// - [AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)] - [BestFriend] - internal sealed class EntryPointModuleAttribute : LoadableClassAttributeBase - { - public EntryPointModuleAttribute(Type loaderType) - : base(null, typeof(void), loaderType, null, new[] { typeof(SignatureEntryPointModule) }, loaderType.FullName) - { } - } +/// +/// A simplified assembly attribute for marking EntryPoint modules. +/// +[AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)] +[BestFriend] +internal sealed class EntryPointModuleAttribute : LoadableClassAttributeBase +{ + public EntryPointModuleAttribute(Type loaderType) + : base(null, typeof(void), loaderType, null, new[] { typeof(SignatureEntryPointModule) }, loaderType.FullName) + { } } diff --git a/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs b/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs index 3cf2e81016..d4ee9c2a42 100644 --- a/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs +++ b/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs @@ -9,127 +9,126 @@ using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Runtime; -namespace Microsoft.ML.EntryPoints +namespace Microsoft.ML.EntryPoints; + +[BestFriend] +internal static class EntryPointUtils { - [BestFriend] - internal static class EntryPointUtils + private static readonly FuncStaticMethodInfo1 _isValueWithinRangeMethodInfo + = new FuncStaticMethodInfo1(IsValueWithinRange); + + private static bool IsValueWithinRange(TlcModule.RangeAttribute range, object obj) { - private static readonly FuncStaticMethodInfo1 _isValueWithinRangeMethodInfo - = new FuncStaticMethodInfo1(IsValueWithinRange); + T val; + if (obj is Optional asOptional) + val = asOptional.Value; + else + val = (T)obj; - private static bool IsValueWithinRange(TlcModule.RangeAttribute range, object obj) - { - T val; - if (obj is Optional asOptional) - val = asOptional.Value; - else - val = (T)obj; - - return - (range.Min == null || ((IComparable)range.Min).CompareTo(val) <= 0) && - (range.Inf == null || ((IComparable)range.Inf).CompareTo(val) < 0) && - (range.Max == null || ((IComparable)range.Max).CompareTo(val) >= 0) && - (range.Sup == null || ((IComparable)range.Sup).CompareTo(val) > 0); - } + return + (range.Min == null || ((IComparable)range.Min).CompareTo(val) <= 0) && + (range.Inf == null || ((IComparable)range.Inf).CompareTo(val) < 0) && + (range.Max == null || ((IComparable)range.Max).CompareTo(val) >= 0) && + (range.Sup == null || ((IComparable)range.Sup).CompareTo(val) > 0); + } - public static bool IsValueWithinRange(this TlcModule.RangeAttribute range, object val) - { - Contracts.AssertValue(range); - Contracts.AssertValue(val); - // Avoid trying to cast double as float. If range - // was specified using floats, but value being checked - // is double, change range to be of type double - if (range.Type == typeof(float) && val is double) - range.CastToDouble(); - return Utils.MarshalInvoke(_isValueWithinRangeMethodInfo, range.Type, range, val); - } + public static bool IsValueWithinRange(this TlcModule.RangeAttribute range, object val) + { + Contracts.AssertValue(range); + Contracts.AssertValue(val); + // Avoid trying to cast double as float. If range + // was specified using floats, but value being checked + // is double, change range to be of type double + if (range.Type == typeof(float) && val is double) + range.CastToDouble(); + return Utils.MarshalInvoke(_isValueWithinRangeMethodInfo, range.Type, range, val); + } - /// - /// Performs checks on an EntryPoint input class equivalent to the checks that are done - /// when parsing a JSON EntryPoint graph. - /// - /// Call this method from EntryPoint methods to ensure that range and required checks are performed - /// in a consistent manner when EntryPoints are created directly from code. - /// - public static void CheckInputArgs(IExceptionContext ectx, object args) + /// + /// Performs checks on an EntryPoint input class equivalent to the checks that are done + /// when parsing a JSON EntryPoint graph. + /// + /// Call this method from EntryPoint methods to ensure that range and required checks are performed + /// in a consistent manner when EntryPoints are created directly from code. + /// + public static void CheckInputArgs(IExceptionContext ectx, object args) + { + foreach (var fieldInfo in args.GetType().GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)) { - foreach (var fieldInfo in args.GetType().GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)) - { - var attr = fieldInfo.GetCustomAttributes(typeof(ArgumentAttribute), false).FirstOrDefault() - as ArgumentAttribute; - if (attr == null || attr.Visibility == ArgumentAttribute.VisibilityType.CmdLineOnly) - continue; - - var fieldVal = fieldInfo.GetValue(args); - var fieldType = fieldInfo.FieldType; - - // Optionals are either left in their Implicit constructed state or - // a new Explicit optional is constructed. They should never be set - // to null. - if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Optional<>) && fieldVal == null) - throw ectx.Except("Field '{0}' is Optional<> and set to null instead of an explicit value.", fieldInfo.Name); - - if (attr.IsRequired) - { - bool equalToDefault; - if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Optional<>)) - equalToDefault = !((Optional)fieldVal).IsExplicit; - else - equalToDefault = fieldType.IsValueType ? Activator.CreateInstance(fieldType).Equals(fieldVal) : fieldVal == null; - - if (equalToDefault) - throw ectx.Except("Field '{0}' is required but is not set.", fieldInfo.Name); - } - - var rangeAttr = fieldInfo.GetCustomAttributes(typeof(TlcModule.RangeAttribute), false).FirstOrDefault() - as TlcModule.RangeAttribute; - if (rangeAttr != null && fieldVal != null && !rangeAttr.IsValueWithinRange(fieldVal)) - throw ectx.Except("Field '{0}' is set to a value that falls outside the range bounds.", fieldInfo.Name); - } - } + var attr = fieldInfo.GetCustomAttributes(typeof(ArgumentAttribute), false).FirstOrDefault() + as ArgumentAttribute; + if (attr == null || attr.Visibility == ArgumentAttribute.VisibilityType.CmdLineOnly) + continue; - public static IHost CheckArgsAndCreateHost(IHostEnvironment env, string hostName, object input) - { - Contracts.CheckValue(env, nameof(env)); - var host = env.Register(hostName); - host.CheckValue(input, nameof(input)); - CheckInputArgs(host, input); - return host; - } + var fieldVal = fieldInfo.GetValue(args); + var fieldType = fieldInfo.FieldType; - /// - /// Searches for the given column name in the schema. This method applies a - /// common policy that throws an exception if the column is not found - /// and the column name was explicitly specified. If the column is not found - /// and the column name was not explicitly specified, it returns null. - /// - public static string FindColumnOrNull(IExceptionContext ectx, DataViewSchema schema, Optional value) - { - Contracts.CheckValueOrNull(ectx); - ectx.CheckValue(schema, nameof(schema)); - ectx.CheckValue(value, nameof(value)); + // Optionals are either left in their Implicit constructed state or + // a new Explicit optional is constructed. They should never be set + // to null. + if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Optional<>) && fieldVal == null) + throw ectx.Except("Field '{0}' is Optional<> and set to null instead of an explicit value.", fieldInfo.Name); - if (value == "") - return null; - if (schema.GetColumnOrNull(value) == null) + if (attr.IsRequired) { - if (value.IsExplicit) - throw ectx.Except("Column '{0}' not found", value); - return null; + bool equalToDefault; + if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Optional<>)) + equalToDefault = !((Optional)fieldVal).IsExplicit; + else + equalToDefault = fieldType.IsValueType ? Activator.CreateInstance(fieldType).Equals(fieldVal) : fieldVal == null; + + if (equalToDefault) + throw ectx.Except("Field '{0}' is required but is not set.", fieldInfo.Name); } - return value; + + var rangeAttr = fieldInfo.GetCustomAttributes(typeof(TlcModule.RangeAttribute), false).FirstOrDefault() + as TlcModule.RangeAttribute; + if (rangeAttr != null && fieldVal != null && !rangeAttr.IsValueWithinRange(fieldVal)) + throw ectx.Except("Field '{0}' is set to a value that falls outside the range bounds.", fieldInfo.Name); } + } - /// - /// Converts EntryPoint Optional{T} types into nullable types, with the - /// implicit value being converted to the null value. - /// - public static T? AsNullable(this Optional opt) where T : struct + public static IHost CheckArgsAndCreateHost(IHostEnvironment env, string hostName, object input) + { + Contracts.CheckValue(env, nameof(env)); + var host = env.Register(hostName); + host.CheckValue(input, nameof(input)); + CheckInputArgs(host, input); + return host; + } + + /// + /// Searches for the given column name in the schema. This method applies a + /// common policy that throws an exception if the column is not found + /// and the column name was explicitly specified. If the column is not found + /// and the column name was not explicitly specified, it returns null. + /// + public static string FindColumnOrNull(IExceptionContext ectx, DataViewSchema schema, Optional value) + { + Contracts.CheckValueOrNull(ectx); + ectx.CheckValue(schema, nameof(schema)); + ectx.CheckValue(value, nameof(value)); + + if (value == "") + return null; + if (schema.GetColumnOrNull(value) == null) { - if (opt.IsExplicit) - return opt.Value; - else - return null; + if (value.IsExplicit) + throw ectx.Except("Column '{0}' not found", value); + return null; } + return value; + } + + /// + /// Converts EntryPoint Optional{T} types into nullable types, with the + /// implicit value being converted to the null value. + /// + public static T? AsNullable(this Optional opt) where T : struct + { + if (opt.IsExplicit) + return opt.Value; + else + return null; } } diff --git a/src/Microsoft.ML.Core/EntryPoints/ModuleArgs.cs b/src/Microsoft.ML.Core/EntryPoints/ModuleArgs.cs index 867df0a648..15d0fc56e6 100644 --- a/src/Microsoft.ML.Core/EntryPoints/ModuleArgs.cs +++ b/src/Microsoft.ML.Core/EntryPoints/ModuleArgs.cs @@ -9,723 +9,722 @@ using Microsoft.ML.Data; using Microsoft.ML.Runtime; -namespace Microsoft.ML.EntryPoints +namespace Microsoft.ML.EntryPoints; + +/// +/// This class defines attributes to annotate module inputs, outputs, entry points etc. when defining +/// the module interface. +/// +[BestFriend] +internal static class TlcModule { /// - /// This class defines attributes to annotate module inputs, outputs, entry points etc. when defining - /// the module interface. + /// An attribute used to annotate the component. /// - [BestFriend] - internal static class TlcModule + [AttributeUsage(AttributeTargets.Class)] + public sealed class ComponentAttribute : Attribute { /// - /// An attribute used to annotate the component. + /// The load name of the component. Must be unique within its kind. /// - [AttributeUsage(AttributeTargets.Class)] - public sealed class ComponentAttribute : Attribute + public string Name { get; set; } + + /// + /// UI friendly name. Can contain spaces and other forbidden for Name symbols. + /// + public string FriendlyName { get; set; } + + /// + /// Alternative names of the component. Each alias must also be unique in the component's kind. + /// + public string[] Aliases { get; set; } + + /// + /// Comma-separated . + /// + public string Alias { - /// - /// The load name of the component. Must be unique within its kind. - /// - public string Name { get; set; } - - /// - /// UI friendly name. Can contain spaces and other forbidden for Name symbols. - /// - public string FriendlyName { get; set; } - - /// - /// Alternative names of the component. Each alias must also be unique in the component's kind. - /// - public string[] Aliases { get; set; } - - /// - /// Comma-separated . - /// - public string Alias + get { - get - { - if (Aliases == null) - return null; - return string.Join(",", Aliases); - } - set + if (Aliases == null) + return null; + return string.Join(",", Aliases); + } + set + { + if (string.IsNullOrWhiteSpace(value)) + Aliases = null; + else { - if (string.IsNullOrWhiteSpace(value)) - Aliases = null; - else - { - var parts = value.Split(','); - Aliases = parts.Select(x => x.Trim()).ToArray(); - } + var parts = value.Split(','); + Aliases = parts.Select(x => x.Trim()).ToArray(); } } - - /// - /// Description of the component. - /// - public string Desc { get; set; } - - /// - /// This should indicate a name of an embedded resource that contains detailed documents - /// for the component, for example, markdown document with the .md extension. The embedded resource - /// is assumed to be in the same assembly as the class on which this attribute is ascribed. - /// - public string DocName { get; set; } } /// - /// An attribute used to annotate the signature interface. - /// Effectively, this is a way to associate the signature interface with a user-friendly name. + /// Description of the component. /// - [AttributeUsage(AttributeTargets.Interface)] - public sealed class ComponentKindAttribute : Attribute - { - public readonly string Kind; - - public ComponentKindAttribute(string kind) - { - Kind = kind; - } - } + public string Desc { get; set; } /// - /// An attribute used to annotate the kind of entry points. - /// Typically it is used on the input classes. + /// This should indicate a name of an embedded resource that contains detailed documents + /// for the component, for example, markdown document with the .md extension. The embedded resource + /// is assumed to be in the same assembly as the class on which this attribute is ascribed. /// - [AttributeUsage(AttributeTargets.Class)] - public sealed class EntryPointKindAttribute : Attribute - { - public readonly Type[] Kinds; + public string DocName { get; set; } + } - public EntryPointKindAttribute(params Type[] kinds) - { - Kinds = kinds; - } + /// + /// An attribute used to annotate the signature interface. + /// Effectively, this is a way to associate the signature interface with a user-friendly name. + /// + [AttributeUsage(AttributeTargets.Interface)] + public sealed class ComponentKindAttribute : Attribute + { + public readonly string Kind; + + public ComponentKindAttribute(string kind) + { + Kind = kind; } + } - /// - /// An attribute used to annotate the outputs of the module. - /// - [AttributeUsage(AttributeTargets.Field)] - public sealed class OutputAttribute : Attribute + /// + /// An attribute used to annotate the kind of entry points. + /// Typically it is used on the input classes. + /// + [AttributeUsage(AttributeTargets.Class)] + public sealed class EntryPointKindAttribute : Attribute + { + public readonly Type[] Kinds; + + public EntryPointKindAttribute(params Type[] kinds) { - /// - /// Official name of the output. If it is not specified, the field name is used. - /// - public string Name { get; set; } - - /// - /// The description of the output. - /// - public string Desc { get; set; } - - /// - /// The rank order of the output. Because .NET reflection returns members in an unspecified order, this - /// is the only way to ensure consistency. - /// - public Double SortOrder { get; set; } + Kinds = kinds; } + } + /// + /// An attribute used to annotate the outputs of the module. + /// + [AttributeUsage(AttributeTargets.Field)] + public sealed class OutputAttribute : Attribute + { /// - /// An attribute to indicate that a field is optional in an EntryPoint module. - /// A node can be run without optional input fields. + /// Official name of the output. If it is not specified, the field name is used. /// - [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)] - public sealed class OptionalInputAttribute : Attribute { } + public string Name { get; set; } /// - /// An attribute used to annotate the valid range of a numeric input. + /// The description of the output. /// - [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)] - public sealed class RangeAttribute : Attribute - { - private object _min; - private object _max; - private object _inf; - private object _sup; - private Type _type; - - /// - /// The target type of this range attribute, as determined by the type of - /// the set range bound values. - /// - public Type Type => _type; - - /// - /// An inclusive lower bound of the value. - /// - public object Min - { - get { return _min; } - set - { - CheckType(value); - Contracts.Check(_inf == null, - "The minimum and infimum cannot be both set in a range attribute"); - Contracts.Check(_max == null || ((IComparable)_max).CompareTo(value) != -1, - "The minimum must be less than or equal to the maximum"); - Contracts.Check(_sup == null || ((IComparable)_sup).CompareTo(value) == 1, - "The minimum must be less than the supremum"); - _min = value; - } - } - - /// - /// An inclusive upper bound of the value. - /// - public object Max - { - get { return _max; } - set - { - CheckType(value); - Contracts.Check(_sup == null, - "The maximum and supremum cannot be both set in a range attribute"); - Contracts.Check(_min == null || ((IComparable)_min).CompareTo(value) != 1, - "The maximum must be greater than or equal to the minimum"); - Contracts.Check(_inf == null || ((IComparable)_inf).CompareTo(value) == -1, - "The maximum must be greater than the infimum"); - _max = value; - } - } + public string Desc { get; set; } - /// - /// An exclusive lower bound of the value. - /// - public object Inf - { - get { return _inf; } - set - { - CheckType(value); - Contracts.Check(_min == null, - "The infimum and minimum cannot be both set in a range attribute"); - Contracts.Check(_max == null || ((IComparable)_max).CompareTo(value) == 1, - "The infimum must be less than the maximum"); - Contracts.Check(_sup == null || ((IComparable)_sup).CompareTo(value) == 1, - "The infimum must be less than the supremum"); - _inf = value; - } - } + /// + /// The rank order of the output. Because .NET reflection returns members in an unspecified order, this + /// is the only way to ensure consistency. + /// + public Double SortOrder { get; set; } + } - /// - /// An exclusive upper bound of the value. - /// - public object Sup - { - get { return _sup; } - set - { - CheckType(value); - Contracts.Check(_max == null, - "The supremum and maximum cannot be both set in a range attribute"); - Contracts.Check(_min == null || ((IComparable)_min).CompareTo(value) == -1, - "The supremum must be greater than the minimum"); - Contracts.Check(_inf == null || ((IComparable)_inf).CompareTo(value) == -1, - "The supremum must be greater than the infimum"); - _sup = value; - } - } + /// + /// An attribute to indicate that a field is optional in an EntryPoint module. + /// A node can be run without optional input fields. + /// + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)] + public sealed class OptionalInputAttribute : Attribute { } - private void CheckType(object val) - { - Contracts.CheckValue(val, nameof(val)); - if (_type == null) - { - Contracts.Check(val is IComparable, "Type for range attribute must support IComparable"); - _type = val.GetType(); - } - else - Contracts.Check(_type == val.GetType(), "All Range attribute values must be of the same type"); - } + /// + /// An attribute used to annotate the valid range of a numeric input. + /// + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)] + public sealed class RangeAttribute : Attribute + { + private object _min; + private object _max; + private object _inf; + private object _sup; + private Type _type; - public void CastToDouble() - { - _type = typeof(double); - if (_inf != null) - _inf = Convert.ToDouble(_inf); - if (_min != null) - _min = Convert.ToDouble(_min); - if (_max != null) - _max = Convert.ToDouble(_max); - if (_sup != null) - _sup = Convert.ToDouble(_sup); - } + /// + /// The target type of this range attribute, as determined by the type of + /// the set range bound values. + /// + public Type Type => _type; - public override string ToString() + /// + /// An inclusive lower bound of the value. + /// + public object Min + { + get { return _min; } + set { - string optionalTypeSpecifier = ""; - if (_type == typeof(double)) - optionalTypeSpecifier = "d"; - else if (_type == typeof(float)) - optionalTypeSpecifier = "f"; - - var pieces = new List(); - if (_inf != null) - pieces.Add($"Inf = {_inf}{optionalTypeSpecifier}"); - if (_min != null) - pieces.Add($"Min = {_min}{optionalTypeSpecifier}"); - if (_max != null) - pieces.Add($"Max = {_max}{optionalTypeSpecifier}"); - if (_sup != null) - pieces.Add($"Sup = {_sup}{optionalTypeSpecifier}"); - return $"[TlcModule.Range({string.Join(", ", pieces)})]"; + CheckType(value); + Contracts.Check(_inf == null, + "The minimum and infimum cannot be both set in a range attribute"); + Contracts.Check(_max == null || ((IComparable)_max).CompareTo(value) != -1, + "The minimum must be less than or equal to the maximum"); + Contracts.Check(_sup == null || ((IComparable)_sup).CompareTo(value) == 1, + "The minimum must be less than the supremum"); + _min = value; } } /// - /// An attribute used to indicate suggested sweep ranges for parameter sweeping. + /// An inclusive upper bound of the value. /// - public abstract class SweepableParamAttribute : Attribute + public object Max { - public string Name { get; set; } - private IComparable _rawValue; - public virtual IComparable RawValue + get { return _max; } + set { - get => _rawValue; - set - { - if (!Frozen) - _rawValue = value; - } + CheckType(value); + Contracts.Check(_sup == null, + "The maximum and supremum cannot be both set in a range attribute"); + Contracts.Check(_min == null || ((IComparable)_min).CompareTo(value) != 1, + "The maximum must be greater than or equal to the minimum"); + Contracts.Check(_inf == null || ((IComparable)_inf).CompareTo(value) == -1, + "The maximum must be greater than the infimum"); + _max = value; } - - // The raw value will store an index for discrete parameters, - // but sometimes we want the text or numeric value itself, - // not the hot index. The processed value does that for discrete - // params. For other params, it just returns the raw value itself. - public virtual IComparable ProcessedValue() => _rawValue; - - // Allows for hyperparameter value freezing, so that sweeps - // will not alter the current value when true. - public bool Frozen { get; set; } - - // Allows the sweepable param to be set directly using the - // available ValueText attribute on IParameterValues (from - // the ParameterSets used in the old hyperparameter sweepers). - public abstract void SetUsingValueText(string valueText); - - public abstract SweepableParamAttribute Clone(); } /// - /// An attribute used to indicate suggested sweep ranges for discrete parameter sweeping. - /// The value is the index of the option chosen. Use Options[Value] to get the corresponding - /// option using the index. + /// An exclusive lower bound of the value. /// - [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)] - public sealed class SweepableDiscreteParamAttribute : SweepableParamAttribute + public object Inf { - public object[] Options { get; } - - public SweepableDiscreteParamAttribute(string name, object[] values, bool isBool = false) : this(values, isBool) + get { return _inf; } + set { - Name = name; - } - - public SweepableDiscreteParamAttribute(object[] values, bool isBool = false) - { - Options = isBool ? new object[] { false, true } : values; - } - - public override IComparable RawValue - { - get => base.RawValue; - set - { - var val = Convert.ToInt32(value); - if (!Frozen && 0 <= val && val < Options.Length) - base.RawValue = val; - } + CheckType(value); + Contracts.Check(_min == null, + "The infimum and minimum cannot be both set in a range attribute"); + Contracts.Check(_max == null || ((IComparable)_max).CompareTo(value) == 1, + "The infimum must be less than the maximum"); + Contracts.Check(_sup == null || ((IComparable)_sup).CompareTo(value) == 1, + "The infimum must be less than the supremum"); + _inf = value; } - - public override void SetUsingValueText(string valueText) - { - for (int i = 0; i < Options.Length; i++) - if (valueText == Options[i].ToString()) - RawValue = i; - } - - public int IndexOf(object option) - { - for (int i = 0; i < Options.Length; i++) - if (option == Options[i]) - return i; - return -1; - } - - private static string TranslateOption(object o) - { - switch (o) - { - case float _: - case double _: - return $"{o}f"; - case long _: - case int _: - case byte _: - case short _: - return o.ToString(); - case bool _: - return o.ToString().ToLower(); - case Enum _: - var type = o.GetType(); - var defaultName = $"Enums.{type.Name}.{o.ToString()}"; - var name = type.FullName?.Replace("+", "."); - if (name == null) - return defaultName; - var index1 = name.LastIndexOf(".", StringComparison.Ordinal); - var index2 = name.Substring(0, index1).LastIndexOf(".", StringComparison.Ordinal) + 1; - if (index2 >= 0) - return $"{name.Substring(index2)}.{o.ToString()}"; - return defaultName; - default: - return $"\"{o}\""; - } - } - - public override SweepableParamAttribute Clone() => - new SweepableDiscreteParamAttribute(Name, Options) { RawValue = RawValue, Frozen = Frozen }; - - public override string ToString() - { - var name = string.IsNullOrEmpty(Name) ? "" : $"\"{Name}\", "; - return $"[TlcModule.{GetType().Name}({name}new object[]{{{string.Join(", ", Options.Select(TranslateOption))}}})]"; - } - - public override IComparable ProcessedValue() => (IComparable)Options[(int)RawValue]; } /// - /// An attribute used to indicate suggested sweep ranges for float parameter sweeping. + /// An exclusive upper bound of the value. /// - [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)] - public sealed class SweepableFloatParamAttribute : SweepableParamAttribute + public object Sup { - public float Min { get; } - public float Max { get; } - public float? StepSize { get; } - public int? NumSteps { get; } - public bool IsLogScale { get; } - - public SweepableFloatParamAttribute(string name, float min, float max, float stepSize = -1, int numSteps = -1, - bool isLogScale = false) : this(min, max, stepSize, numSteps, isLogScale) + get { return _sup; } + set { - Name = name; - } - - public SweepableFloatParamAttribute(float min, float max, float stepSize = -1, int numSteps = -1, bool isLogScale = false) - { - Min = min; - Max = max; - if (!stepSize.Equals(-1)) - StepSize = stepSize; - if (numSteps != -1) - NumSteps = numSteps; - IsLogScale = isLogScale; + CheckType(value); + Contracts.Check(_max == null, + "The supremum and maximum cannot be both set in a range attribute"); + Contracts.Check(_min == null || ((IComparable)_min).CompareTo(value) == -1, + "The supremum must be greater than the minimum"); + Contracts.Check(_inf == null || ((IComparable)_inf).CompareTo(value) == -1, + "The supremum must be greater than the infimum"); + _sup = value; } + } - public override void SetUsingValueText(string valueText) + private void CheckType(object val) + { + Contracts.CheckValue(val, nameof(val)); + if (_type == null) { - RawValue = float.Parse(valueText); + Contracts.Check(val is IComparable, "Type for range attribute must support IComparable"); + _type = val.GetType(); } + else + Contracts.Check(_type == val.GetType(), "All Range attribute values must be of the same type"); + } - public override SweepableParamAttribute Clone() => - new SweepableFloatParamAttribute(Name, Min, Max, StepSize ?? -1, NumSteps ?? -1, IsLogScale) { RawValue = RawValue, Frozen = Frozen }; + public void CastToDouble() + { + _type = typeof(double); + if (_inf != null) + _inf = Convert.ToDouble(_inf); + if (_min != null) + _min = Convert.ToDouble(_min); + if (_max != null) + _max = Convert.ToDouble(_max); + if (_sup != null) + _sup = Convert.ToDouble(_sup); + } - public override string ToString() - { - var optional = new StringBuilder(); - if (StepSize != null) - optional.Append($", stepSize:{StepSize}"); - if (NumSteps != null) - optional.Append($", numSteps:{NumSteps}"); - if (IsLogScale) - optional.Append($", isLogScale:true"); - var name = string.IsNullOrEmpty(Name) ? "" : $"\"{Name}\", "; - return $"[TlcModule.{GetType().Name}({name}{Min}f, {Max}f{optional})]"; - } + public override string ToString() + { + string optionalTypeSpecifier = ""; + if (_type == typeof(double)) + optionalTypeSpecifier = "d"; + else if (_type == typeof(float)) + optionalTypeSpecifier = "f"; + + var pieces = new List(); + if (_inf != null) + pieces.Add($"Inf = {_inf}{optionalTypeSpecifier}"); + if (_min != null) + pieces.Add($"Min = {_min}{optionalTypeSpecifier}"); + if (_max != null) + pieces.Add($"Max = {_max}{optionalTypeSpecifier}"); + if (_sup != null) + pieces.Add($"Sup = {_sup}{optionalTypeSpecifier}"); + return $"[TlcModule.Range({string.Join(", ", pieces)})]"; } + } - /// - /// An attribute used to indicate suggested sweep ranges for long parameter sweeping. - /// - [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)] - public sealed class SweepableLongParamAttribute : SweepableParamAttribute + /// + /// An attribute used to indicate suggested sweep ranges for parameter sweeping. + /// + public abstract class SweepableParamAttribute : Attribute + { + public string Name { get; set; } + private IComparable _rawValue; + public virtual IComparable RawValue { - public long Min { get; } - public long Max { get; } - public float? StepSize { get; } - public int? NumSteps { get; } - public bool IsLogScale { get; } - - public SweepableLongParamAttribute(string name, long min, long max, float stepSize = -1, int numSteps = -1, - bool isLogScale = false) : this(min, max, stepSize, numSteps, isLogScale) + get => _rawValue; + set { - Name = name; + if (!Frozen) + _rawValue = value; } + } - public SweepableLongParamAttribute(long min, long max, float stepSize = -1, int numSteps = -1, bool isLogScale = false) - { - Min = min; - Max = max; - if (!stepSize.Equals(-1)) - StepSize = stepSize; - if (numSteps != -1) - NumSteps = numSteps; - IsLogScale = isLogScale; - } + // The raw value will store an index for discrete parameters, + // but sometimes we want the text or numeric value itself, + // not the hot index. The processed value does that for discrete + // params. For other params, it just returns the raw value itself. + public virtual IComparable ProcessedValue() => _rawValue; - public override void SetUsingValueText(string valueText) - { - RawValue = long.Parse(valueText); - } + // Allows for hyperparameter value freezing, so that sweeps + // will not alter the current value when true. + public bool Frozen { get; set; } + + // Allows the sweepable param to be set directly using the + // available ValueText attribute on IParameterValues (from + // the ParameterSets used in the old hyperparameter sweepers). + public abstract void SetUsingValueText(string valueText); - public override SweepableParamAttribute Clone() => - new SweepableLongParamAttribute(Name, Min, Max, StepSize ?? -1, NumSteps ?? -1, IsLogScale) { RawValue = RawValue, Frozen = Frozen }; + public abstract SweepableParamAttribute Clone(); + } + + /// + /// An attribute used to indicate suggested sweep ranges for discrete parameter sweeping. + /// The value is the index of the option chosen. Use Options[Value] to get the corresponding + /// option using the index. + /// + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)] + public sealed class SweepableDiscreteParamAttribute : SweepableParamAttribute + { + public object[] Options { get; } - public override string ToString() + public SweepableDiscreteParamAttribute(string name, object[] values, bool isBool = false) : this(values, isBool) + { + Name = name; + } + + public SweepableDiscreteParamAttribute(object[] values, bool isBool = false) + { + Options = isBool ? new object[] { false, true } : values; + } + + public override IComparable RawValue + { + get => base.RawValue; + set { - var optional = new StringBuilder(); - if (StepSize != null) - optional.Append($", stepSize:{StepSize}"); - if (NumSteps != null) - optional.Append($", numSteps:{NumSteps}"); - if (IsLogScale) - optional.Append($", isLogScale:true"); - var name = string.IsNullOrEmpty(Name) ? "" : $"\"{Name}\", "; - return $"[TlcModule.{GetType().Name}({name}{Min}, {Max}{optional})]"; + var val = Convert.ToInt32(value); + if (!Frozen && 0 <= val && val < Options.Length) + base.RawValue = val; } } - /// - /// An attribute to mark an entry point of a module. - /// - [AttributeUsage(AttributeTargets.Method)] - public sealed class EntryPointAttribute : Attribute + public override void SetUsingValueText(string valueText) { - /// - /// The entry point name. - /// - public string Name { get; set; } - - /// - /// The entry point description. - /// - public string Desc { get; set; } - - /// - /// UI friendly name. Can contain spaces and other forbidden for Name symbols. - /// - public string UserName { get; set; } - - /// - /// Short name of the Entry Point - /// - public string ShortName { get; set; } + for (int i = 0; i < Options.Length; i++) + if (valueText == Options[i].ToString()) + RawValue = i; } - /// - /// The list of data types that are supported as inputs or outputs. - /// - public enum DataKind + public int IndexOf(object option) { - /// - /// Not used. - /// - Unknown = 0, - /// - /// Integer, including long. - /// - Int, - /// - /// Unsigned integer, including ulong. - /// - UInt, - /// - /// Floating point, including double. - /// - Float, - /// - /// A char. - /// - Char, - /// - /// A string. - /// - String, - /// - /// A boolean value. - /// - Bool, - /// - /// A dataset, represented by an . - /// - DataView, - /// - /// A file handle, represented by an . - /// - FileHandle, - /// - /// A transform model, represented by an . - /// - TransformModel, - /// - /// A predictor model, represented by an . - /// - PredictorModel, - /// - /// An enum: one value of a specified list. - /// - Enum, - /// - /// An array (0 or more values of the same type, accessible by index). - /// - Array, - /// - /// A dictionary (0 or more values of the same type, identified by a unique string key). - /// The underlying C# representation is - /// - Dictionary, - /// - /// A component of a specified kind. The component is identified by the "load name" (unique per kind) and, - /// optionally, a set of parameters, unique to each component. Example: "BinaryClassifierEvaluator{threshold=0.5}". - /// The C# representation is . - /// - Component + for (int i = 0; i < Options.Length; i++) + if (option == Options[i]) + return i; + return -1; } - public static DataKind GetDataType(Type type) + private static string TranslateOption(object o) { - Contracts.AssertValue(type); - - // If this is a Optional-wrapped type, unwrap it and examine - // the inner type. - if (type.IsGenericType && (type.GetGenericTypeDefinition() == typeof(Optional<>) || type.GetGenericTypeDefinition() == typeof(Nullable<>))) - type = type.GetGenericArguments()[0]; - - if (type == typeof(char)) - return DataKind.Char; - if (type == typeof(string)) - return DataKind.String; - if (type == typeof(bool)) - return DataKind.Bool; - if (type == typeof(int) || type == typeof(long)) - return DataKind.Int; - if (type == typeof(uint) || type == typeof(ulong)) - return DataKind.UInt; - if (type == typeof(Single) || type == typeof(Double)) - return DataKind.Float; - if (typeof(IDataView).IsAssignableFrom(type)) - return DataKind.DataView; - if (typeof(TransformModel).IsAssignableFrom(type)) - return DataKind.TransformModel; - if (typeof(PredictorModel).IsAssignableFrom(type)) - return DataKind.PredictorModel; - if (typeof(IFileHandle).IsAssignableFrom(type)) - return DataKind.FileHandle; - if (type.IsEnum) - return DataKind.Enum; - if (type.IsArray) - return DataKind.Array; - if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Dictionary<,>) - && type.GetGenericArguments()[0] == typeof(string)) + switch (o) { - return DataKind.Dictionary; + case float _: + case double _: + return $"{o}f"; + case long _: + case int _: + case byte _: + case short _: + return o.ToString(); + case bool _: + return o.ToString().ToLower(); + case Enum _: + var type = o.GetType(); + var defaultName = $"Enums.{type.Name}.{o.ToString()}"; + var name = type.FullName?.Replace("+", "."); + if (name == null) + return defaultName; + var index1 = name.LastIndexOf(".", StringComparison.Ordinal); + var index2 = name.Substring(0, index1).LastIndexOf(".", StringComparison.Ordinal) + 1; + if (index2 >= 0) + return $"{name.Substring(index2)}.{o.ToString()}"; + return defaultName; + default: + return $"\"{o}\""; } - if (typeof(IComponentFactory).IsAssignableFrom(type)) - return DataKind.Component; - - return DataKind.Unknown; } - public static bool IsNumericKind(DataKind kind) + public override SweepableParamAttribute Clone() => + new SweepableDiscreteParamAttribute(Name, Options) { RawValue = RawValue, Frozen = Frozen }; + + public override string ToString() { - return kind == DataKind.Int || kind == DataKind.UInt || kind == DataKind.Float; + var name = string.IsNullOrEmpty(Name) ? "" : $"\"{Name}\", "; + return $"[TlcModule.{GetType().Name}({name}new object[]{{{string.Join(", ", Options.Select(TranslateOption))}}})]"; } + + public override IComparable ProcessedValue() => (IComparable)Options[(int)RawValue]; } /// - /// The untyped base class for 'maybe'. + /// An attribute used to indicate suggested sweep ranges for float parameter sweeping. /// - [BestFriend] - internal abstract class Optional + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)] + public sealed class SweepableFloatParamAttribute : SweepableParamAttribute { - /// - /// Whether the value was set 'explicitly', or 'implicitly'. - /// - public readonly bool IsExplicit; + public float Min { get; } + public float Max { get; } + public float? StepSize { get; } + public int? NumSteps { get; } + public bool IsLogScale { get; } + + public SweepableFloatParamAttribute(string name, float min, float max, float stepSize = -1, int numSteps = -1, + bool isLogScale = false) : this(min, max, stepSize, numSteps, isLogScale) + { + Name = name; + } - public abstract object GetValue(); + public SweepableFloatParamAttribute(float min, float max, float stepSize = -1, int numSteps = -1, bool isLogScale = false) + { + Min = min; + Max = max; + if (!stepSize.Equals(-1)) + StepSize = stepSize; + if (numSteps != -1) + NumSteps = numSteps; + IsLogScale = isLogScale; + } + + public override void SetUsingValueText(string valueText) + { + RawValue = float.Parse(valueText); + } - private protected Optional(bool isExplicit) + public override SweepableParamAttribute Clone() => + new SweepableFloatParamAttribute(Name, Min, Max, StepSize ?? -1, NumSteps ?? -1, IsLogScale) { RawValue = RawValue, Frozen = Frozen }; + + public override string ToString() { - IsExplicit = isExplicit; + var optional = new StringBuilder(); + if (StepSize != null) + optional.Append($", stepSize:{StepSize}"); + if (NumSteps != null) + optional.Append($", numSteps:{NumSteps}"); + if (IsLogScale) + optional.Append($", isLogScale:true"); + var name = string.IsNullOrEmpty(Name) ? "" : $"\"{Name}\", "; + return $"[TlcModule.{GetType().Name}({name}{Min}f, {Max}f{optional})]"; } } /// - /// This is a 'maybe' class that is able to differentiate the cases when the value is set 'explicitly', or 'implicitly'. - /// The idea is that if the default value is specified by the user, in some cases it needs to be treated differently - /// than if it's auto-filled. - /// - /// An example is the weight column: the default behavior is to use 'Weight' column if it's present. But if the user explicitly sets - /// the weight column to be 'Weight', we need to actually enforce the presence of the column. + /// An attribute used to indicate suggested sweep ranges for long parameter sweeping. /// - /// The type of the value - [BestFriend] - internal sealed class Optional : Optional + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)] + public sealed class SweepableLongParamAttribute : SweepableParamAttribute { - public readonly T Value; + public long Min { get; } + public long Max { get; } + public float? StepSize { get; } + public int? NumSteps { get; } + public bool IsLogScale { get; } + + public SweepableLongParamAttribute(string name, long min, long max, float stepSize = -1, int numSteps = -1, + bool isLogScale = false) : this(min, max, stepSize, numSteps, isLogScale) + { + Name = name; + } - private Optional(bool isExplicit, T value) - : base(isExplicit) + public SweepableLongParamAttribute(long min, long max, float stepSize = -1, int numSteps = -1, bool isLogScale = false) { - Value = value; + Min = min; + Max = max; + if (!stepSize.Equals(-1)) + StepSize = stepSize; + if (numSteps != -1) + NumSteps = numSteps; + IsLogScale = isLogScale; } - /// - /// Create the 'implicit' value. - /// - public static Optional Implicit(T value) + public override void SetUsingValueText(string valueText) { - return new Optional(false, value); + RawValue = long.Parse(valueText); } - public static Optional Explicit(T value) + public override SweepableParamAttribute Clone() => + new SweepableLongParamAttribute(Name, Min, Max, StepSize ?? -1, NumSteps ?? -1, IsLogScale) { RawValue = RawValue, Frozen = Frozen }; + + public override string ToString() { - return new Optional(true, value); + var optional = new StringBuilder(); + if (StepSize != null) + optional.Append($", stepSize:{StepSize}"); + if (NumSteps != null) + optional.Append($", numSteps:{NumSteps}"); + if (IsLogScale) + optional.Append($", isLogScale:true"); + var name = string.IsNullOrEmpty(Name) ? "" : $"\"{Name}\", "; + return $"[TlcModule.{GetType().Name}({name}{Min}, {Max}{optional})]"; } + } + /// + /// An attribute to mark an entry point of a module. + /// + [AttributeUsage(AttributeTargets.Method)] + public sealed class EntryPointAttribute : Attribute + { /// - /// The implicit conversion into . + /// The entry point name. /// - public static implicit operator T(Optional optional) - { - return optional.Value; - } + public string Name { get; set; } /// - /// The implicit conversion from . - /// This will assume that the parameter is set 'explicitly'. + /// The entry point description. /// - public static implicit operator Optional(T value) - { - return new Optional(true, value); - } + public string Desc { get; set; } - public override object GetValue() - { - return Value; - } + /// + /// UI friendly name. Can contain spaces and other forbidden for Name symbols. + /// + public string UserName { get; set; } - public override string ToString() + /// + /// Short name of the Entry Point + /// + public string ShortName { get; set; } + } + + /// + /// The list of data types that are supported as inputs or outputs. + /// + public enum DataKind + { + /// + /// Not used. + /// + Unknown = 0, + /// + /// Integer, including long. + /// + Int, + /// + /// Unsigned integer, including ulong. + /// + UInt, + /// + /// Floating point, including double. + /// + Float, + /// + /// A char. + /// + Char, + /// + /// A string. + /// + String, + /// + /// A boolean value. + /// + Bool, + /// + /// A dataset, represented by an . + /// + DataView, + /// + /// A file handle, represented by an . + /// + FileHandle, + /// + /// A transform model, represented by an . + /// + TransformModel, + /// + /// A predictor model, represented by an . + /// + PredictorModel, + /// + /// An enum: one value of a specified list. + /// + Enum, + /// + /// An array (0 or more values of the same type, accessible by index). + /// + Array, + /// + /// A dictionary (0 or more values of the same type, identified by a unique string key). + /// The underlying C# representation is + /// + Dictionary, + /// + /// A component of a specified kind. The component is identified by the "load name" (unique per kind) and, + /// optionally, a set of parameters, unique to each component. Example: "BinaryClassifierEvaluator{threshold=0.5}". + /// The C# representation is . + /// + Component + } + + public static DataKind GetDataType(Type type) + { + Contracts.AssertValue(type); + + // If this is a Optional-wrapped type, unwrap it and examine + // the inner type. + if (type.IsGenericType && (type.GetGenericTypeDefinition() == typeof(Optional<>) || type.GetGenericTypeDefinition() == typeof(Nullable<>))) + type = type.GetGenericArguments()[0]; + + if (type == typeof(char)) + return DataKind.Char; + if (type == typeof(string)) + return DataKind.String; + if (type == typeof(bool)) + return DataKind.Bool; + if (type == typeof(int) || type == typeof(long)) + return DataKind.Int; + if (type == typeof(uint) || type == typeof(ulong)) + return DataKind.UInt; + if (type == typeof(Single) || type == typeof(Double)) + return DataKind.Float; + if (typeof(IDataView).IsAssignableFrom(type)) + return DataKind.DataView; + if (typeof(TransformModel).IsAssignableFrom(type)) + return DataKind.TransformModel; + if (typeof(PredictorModel).IsAssignableFrom(type)) + return DataKind.PredictorModel; + if (typeof(IFileHandle).IsAssignableFrom(type)) + return DataKind.FileHandle; + if (type.IsEnum) + return DataKind.Enum; + if (type.IsArray) + return DataKind.Array; + if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Dictionary<,>) + && type.GetGenericArguments()[0] == typeof(string)) { - if (Value == null) - return ""; - return Value.ToString(); + return DataKind.Dictionary; } + if (typeof(IComponentFactory).IsAssignableFrom(type)) + return DataKind.Component; + + return DataKind.Unknown; + } + + public static bool IsNumericKind(DataKind kind) + { + return kind == DataKind.Int || kind == DataKind.UInt || kind == DataKind.Float; + } +} + +/// +/// The untyped base class for 'maybe'. +/// +[BestFriend] +internal abstract class Optional +{ + /// + /// Whether the value was set 'explicitly', or 'implicitly'. + /// + public readonly bool IsExplicit; + + public abstract object GetValue(); + + private protected Optional(bool isExplicit) + { + IsExplicit = isExplicit; + } +} + +/// +/// This is a 'maybe' class that is able to differentiate the cases when the value is set 'explicitly', or 'implicitly'. +/// The idea is that if the default value is specified by the user, in some cases it needs to be treated differently +/// than if it's auto-filled. +/// +/// An example is the weight column: the default behavior is to use 'Weight' column if it's present. But if the user explicitly sets +/// the weight column to be 'Weight', we need to actually enforce the presence of the column. +/// +/// The type of the value +[BestFriend] +internal sealed class Optional : Optional +{ + public readonly T Value; + + private Optional(bool isExplicit, T value) + : base(isExplicit) + { + Value = value; + } + + /// + /// Create the 'implicit' value. + /// + public static Optional Implicit(T value) + { + return new Optional(false, value); + } + + public static Optional Explicit(T value) + { + return new Optional(true, value); + } + + /// + /// The implicit conversion into . + /// + public static implicit operator T(Optional optional) + { + return optional.Value; + } + + /// + /// The implicit conversion from . + /// This will assume that the parameter is set 'explicitly'. + /// + public static implicit operator Optional(T value) + { + return new Optional(true, value); + } + + public override object GetValue() + { + return Value; + } + + public override string ToString() + { + if (Value == null) + return ""; + return Value.ToString(); } } diff --git a/src/Microsoft.ML.Core/EntryPoints/PredictorModel.cs b/src/Microsoft.ML.Core/EntryPoints/PredictorModel.cs index aef1b8a298..9eeaddadfe 100644 --- a/src/Microsoft.ML.Core/EntryPoints/PredictorModel.cs +++ b/src/Microsoft.ML.Core/EntryPoints/PredictorModel.cs @@ -6,64 +6,63 @@ using Microsoft.ML.Data; using Microsoft.ML.Runtime; -namespace Microsoft.ML.EntryPoints +namespace Microsoft.ML.EntryPoints; + +/// +/// Base type for standard predictor model port type. +/// +[BestFriend] +internal abstract class PredictorModel { - /// - /// Base type for standard predictor model port type. - /// [BestFriend] - internal abstract class PredictorModel + private protected PredictorModel() { - [BestFriend] - private protected PredictorModel() - { - } + } - /// - /// Save the model to the given stream. - /// - [BestFriend] - internal abstract void Save(IHostEnvironment env, Stream stream); + /// + /// Save the model to the given stream. + /// + [BestFriend] + internal abstract void Save(IHostEnvironment env, Stream stream); - /// - /// Extract only the transform portion of the predictor model. - /// - [BestFriend] - internal abstract TransformModel TransformModel { get; } + /// + /// Extract only the transform portion of the predictor model. + /// + [BestFriend] + internal abstract TransformModel TransformModel { get; } - /// - /// Extract the predictor object out of the predictor model. - /// - [BestFriend] - internal abstract IPredictor Predictor { get; } + /// + /// Extract the predictor object out of the predictor model. + /// + [BestFriend] + internal abstract IPredictor Predictor { get; } - /// - /// Apply the predictor model to the transform model and return the resulting predictor model. - /// - [BestFriend] - internal abstract PredictorModel Apply(IHostEnvironment env, TransformModel transformModel); + /// + /// Apply the predictor model to the transform model and return the resulting predictor model. + /// + [BestFriend] + internal abstract PredictorModel Apply(IHostEnvironment env, TransformModel transformModel); - /// - /// For a given input data, return role mapped data and the predictor object. - /// The scoring entry point will hopefully know how to construct a scorer out of them. - /// - [BestFriend] - internal abstract void PrepareData(IHostEnvironment env, IDataView input, out RoleMappedData roleMappedData, out IPredictor predictor); + /// + /// For a given input data, return role mapped data and the predictor object. + /// The scoring entry point will hopefully know how to construct a scorer out of them. + /// + [BestFriend] + internal abstract void PrepareData(IHostEnvironment env, IDataView input, out RoleMappedData roleMappedData, out IPredictor predictor); - /// - /// Returns a string array containing the label names of the label column type predictor was trained on. - /// If the training label is a key with text key value annotation, it should return this annotation. The order of the labels should be consistent - /// with the key values. Otherwise, it returns null. - /// - /// - /// The column type of the label the predictor was trained on. - [BestFriend] - internal abstract string[] GetLabelInfo(IHostEnvironment env, out DataViewType labelType); + /// + /// Returns a string array containing the label names of the label column type predictor was trained on. + /// If the training label is a key with text key value annotation, it should return this annotation. The order of the labels should be consistent + /// with the key values. Otherwise, it returns null. + /// + /// + /// The column type of the label the predictor was trained on. + [BestFriend] + internal abstract string[] GetLabelInfo(IHostEnvironment env, out DataViewType labelType); - /// - /// Returns the that was used in training. - /// - [BestFriend] - internal abstract RoleMappedSchema GetTrainingSchema(IHostEnvironment env); - } + /// + /// Returns the that was used in training. + /// + [BestFriend] + internal abstract RoleMappedSchema GetTrainingSchema(IHostEnvironment env); } diff --git a/src/Microsoft.ML.Core/EntryPoints/TransformModel.cs b/src/Microsoft.ML.Core/EntryPoints/TransformModel.cs index 8e2ddb7e66..a3d2eac26c 100644 --- a/src/Microsoft.ML.Core/EntryPoints/TransformModel.cs +++ b/src/Microsoft.ML.Core/EntryPoints/TransformModel.cs @@ -6,65 +6,64 @@ using Microsoft.ML.Data; using Microsoft.ML.Runtime; -namespace Microsoft.ML.EntryPoints +namespace Microsoft.ML.EntryPoints; + +/// +/// Interface for standard transform model port type. +/// +[BestFriend] +internal abstract class TransformModel { - /// - /// Interface for standard transform model port type. - /// [BestFriend] - internal abstract class TransformModel + private protected TransformModel() { - [BestFriend] - private protected TransformModel() - { - } + } - /// - /// The input schema that this transform model was originally instantiated on. - /// Note that the schema may have columns that aren't needed by this transform model. - /// If an exists with this schema, then applying this transform model to it - /// shouldn't fail because of column type issues. - /// - // REVIEW: Would be nice to be able to trim this to the minimum needed somehow. Note - // however that doing so may cause issues for composing transform models. For example, - // if transform model A needs column X and model B needs Y, that is NOT produced by A, - // then trimming A's input schema would cause composition to fail. - [BestFriend] - internal abstract DataViewSchema InputSchema { get; } + /// + /// The input schema that this transform model was originally instantiated on. + /// Note that the schema may have columns that aren't needed by this transform model. + /// If an exists with this schema, then applying this transform model to it + /// shouldn't fail because of column type issues. + /// + // REVIEW: Would be nice to be able to trim this to the minimum needed somehow. Note + // however that doing so may cause issues for composing transform models. For example, + // if transform model A needs column X and model B needs Y, that is NOT produced by A, + // then trimming A's input schema would cause composition to fail. + [BestFriend] + internal abstract DataViewSchema InputSchema { get; } - /// - /// The output schema that this transform model was originally instantiated on. The schema resulting - /// from may differ from this, similarly to how - /// may differ from the schema of dataviews we apply this transform model to. - /// - [BestFriend] - internal abstract DataViewSchema OutputSchema { get; } + /// + /// The output schema that this transform model was originally instantiated on. The schema resulting + /// from may differ from this, similarly to how + /// may differ from the schema of dataviews we apply this transform model to. + /// + [BestFriend] + internal abstract DataViewSchema OutputSchema { get; } - /// - /// Apply the transform(s) in the model to the given input data. - /// - [BestFriend] - internal abstract IDataView Apply(IHostEnvironment env, IDataView input); + /// + /// Apply the transform(s) in the model to the given input data. + /// + [BestFriend] + internal abstract IDataView Apply(IHostEnvironment env, IDataView input); - /// - /// Apply the transform(s) in the model to the given transform model. - /// - [BestFriend] - internal abstract TransformModel Apply(IHostEnvironment env, TransformModel input); + /// + /// Apply the transform(s) in the model to the given transform model. + /// + [BestFriend] + internal abstract TransformModel Apply(IHostEnvironment env, TransformModel input); - /// - /// Save the model to the given stream. - /// - [BestFriend] - internal abstract void Save(IHostEnvironment env, Stream stream); + /// + /// Save the model to the given stream. + /// + [BestFriend] + internal abstract void Save(IHostEnvironment env, Stream stream); - /// - /// Returns the transform model as an that can output a row - /// given a row with the same schema as . - /// - /// The transform model as an . If not all transforms - /// in the pipeline are then it returns . - [BestFriend] - internal abstract IRowToRowMapper AsRowToRowMapper(IExceptionContext ectx); - } + /// + /// Returns the transform model as an that can output a row + /// given a row with the same schema as . + /// + /// The transform model as an . If not all transforms + /// in the pipeline are then it returns . + [BestFriend] + internal abstract IRowToRowMapper AsRowToRowMapper(IExceptionContext ectx); }