Skip to content

Commit

Permalink
File-scoped namespaces in files under EntryPoints (`Microsoft.ML.Co…
Browse files Browse the repository at this point in the history
…re`) (#6790)

Co-authored-by: Lehonti Ramos <john@doe>
  • Loading branch information
Lehonti and Lehonti Ramos authored Aug 25, 2023
1 parent 92eccad commit 43a6a81
Show file tree
Hide file tree
Showing 5 changed files with 819 additions and 824 deletions.
35 changes: 17 additions & 18 deletions src/Microsoft.ML.Core/EntryPoints/EntryPointModuleAttribute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,22 @@
// See the LICENSE file in the project root for more information.

using System;
namespace Microsoft.ML.EntryPoints
{
/// <summary>
/// This is a signature for classes that are 'holders' of entry points and components.
/// </summary>
[BestFriend]
internal delegate void SignatureEntryPointModule();
namespace Microsoft.ML.EntryPoints;

/// <summary>
/// This is a signature for classes that are 'holders' of entry points and components.
/// </summary>
[BestFriend]
internal delegate void SignatureEntryPointModule();

/// <summary>
/// A simplified assembly attribute for marking EntryPoint modules.
/// </summary>
[AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)]
[BestFriend]
internal sealed class EntryPointModuleAttribute : LoadableClassAttributeBase
{
public EntryPointModuleAttribute(Type loaderType)
: base(null, typeof(void), loaderType, null, new[] { typeof(SignatureEntryPointModule) }, loaderType.FullName)
{ }
}
/// <summary>
/// A simplified assembly attribute for marking EntryPoint modules.
/// </summary>
[AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)]
[BestFriend]
internal sealed class EntryPointModuleAttribute : LoadableClassAttributeBase
{
public EntryPointModuleAttribute(Type loaderType)
: base(null, typeof(void), loaderType, null, new[] { typeof(SignatureEntryPointModule) }, loaderType.FullName)
{ }
}
213 changes: 106 additions & 107 deletions src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,127 +9,126 @@
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;

namespace Microsoft.ML.EntryPoints
namespace Microsoft.ML.EntryPoints;

[BestFriend]
internal static class EntryPointUtils
{
[BestFriend]
internal static class EntryPointUtils
private static readonly FuncStaticMethodInfo1<TlcModule.RangeAttribute, object, bool> _isValueWithinRangeMethodInfo
= new FuncStaticMethodInfo1<TlcModule.RangeAttribute, object, bool>(IsValueWithinRange<int>);

private static bool IsValueWithinRange<T>(TlcModule.RangeAttribute range, object obj)
{
private static readonly FuncStaticMethodInfo1<TlcModule.RangeAttribute, object, bool> _isValueWithinRangeMethodInfo
= new FuncStaticMethodInfo1<TlcModule.RangeAttribute, object, bool>(IsValueWithinRange<int>);
T val;
if (obj is Optional<T> asOptional)
val = asOptional.Value;
else
val = (T)obj;

private static bool IsValueWithinRange<T>(TlcModule.RangeAttribute range, object obj)
{
T val;
if (obj is Optional<T> asOptional)
val = asOptional.Value;
else
val = (T)obj;

return
(range.Min == null || ((IComparable)range.Min).CompareTo(val) <= 0) &&
(range.Inf == null || ((IComparable)range.Inf).CompareTo(val) < 0) &&
(range.Max == null || ((IComparable)range.Max).CompareTo(val) >= 0) &&
(range.Sup == null || ((IComparable)range.Sup).CompareTo(val) > 0);
}
return
(range.Min == null || ((IComparable)range.Min).CompareTo(val) <= 0) &&
(range.Inf == null || ((IComparable)range.Inf).CompareTo(val) < 0) &&
(range.Max == null || ((IComparable)range.Max).CompareTo(val) >= 0) &&
(range.Sup == null || ((IComparable)range.Sup).CompareTo(val) > 0);
}

public static bool IsValueWithinRange(this TlcModule.RangeAttribute range, object val)
{
Contracts.AssertValue(range);
Contracts.AssertValue(val);
// Avoid trying to cast double as float. If range
// was specified using floats, but value being checked
// is double, change range to be of type double
if (range.Type == typeof(float) && val is double)
range.CastToDouble();
return Utils.MarshalInvoke(_isValueWithinRangeMethodInfo, range.Type, range, val);
}
public static bool IsValueWithinRange(this TlcModule.RangeAttribute range, object val)
{
Contracts.AssertValue(range);
Contracts.AssertValue(val);
// Avoid trying to cast double as float. If range
// was specified using floats, but value being checked
// is double, change range to be of type double
if (range.Type == typeof(float) && val is double)
range.CastToDouble();
return Utils.MarshalInvoke(_isValueWithinRangeMethodInfo, range.Type, range, val);
}

/// <summary>
/// Performs checks on an EntryPoint input class equivalent to the checks that are done
/// when parsing a JSON EntryPoint graph.
///
/// Call this method from EntryPoint methods to ensure that range and required checks are performed
/// in a consistent manner when EntryPoints are created directly from code.
/// </summary>
public static void CheckInputArgs(IExceptionContext ectx, object args)
/// <summary>
/// Performs checks on an EntryPoint input class equivalent to the checks that are done
/// when parsing a JSON EntryPoint graph.
///
/// Call this method from EntryPoint methods to ensure that range and required checks are performed
/// in a consistent manner when EntryPoints are created directly from code.
/// </summary>
public static void CheckInputArgs(IExceptionContext ectx, object args)
{
foreach (var fieldInfo in args.GetType().GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance))
{
foreach (var fieldInfo in args.GetType().GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance))
{
var attr = fieldInfo.GetCustomAttributes(typeof(ArgumentAttribute), false).FirstOrDefault()
as ArgumentAttribute;
if (attr == null || attr.Visibility == ArgumentAttribute.VisibilityType.CmdLineOnly)
continue;

var fieldVal = fieldInfo.GetValue(args);
var fieldType = fieldInfo.FieldType;

// Optionals are either left in their Implicit constructed state or
// a new Explicit optional is constructed. They should never be set
// to null.
if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Optional<>) && fieldVal == null)
throw ectx.Except("Field '{0}' is Optional<> and set to null instead of an explicit value.", fieldInfo.Name);

if (attr.IsRequired)
{
bool equalToDefault;
if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Optional<>))
equalToDefault = !((Optional)fieldVal).IsExplicit;
else
equalToDefault = fieldType.IsValueType ? Activator.CreateInstance(fieldType).Equals(fieldVal) : fieldVal == null;

if (equalToDefault)
throw ectx.Except("Field '{0}' is required but is not set.", fieldInfo.Name);
}

var rangeAttr = fieldInfo.GetCustomAttributes(typeof(TlcModule.RangeAttribute), false).FirstOrDefault()
as TlcModule.RangeAttribute;
if (rangeAttr != null && fieldVal != null && !rangeAttr.IsValueWithinRange(fieldVal))
throw ectx.Except("Field '{0}' is set to a value that falls outside the range bounds.", fieldInfo.Name);
}
}
var attr = fieldInfo.GetCustomAttributes(typeof(ArgumentAttribute), false).FirstOrDefault()
as ArgumentAttribute;
if (attr == null || attr.Visibility == ArgumentAttribute.VisibilityType.CmdLineOnly)
continue;

public static IHost CheckArgsAndCreateHost(IHostEnvironment env, string hostName, object input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register(hostName);
host.CheckValue(input, nameof(input));
CheckInputArgs(host, input);
return host;
}
var fieldVal = fieldInfo.GetValue(args);
var fieldType = fieldInfo.FieldType;

/// <summary>
/// Searches for the given column name in the schema. This method applies a
/// common policy that throws an exception if the column is not found
/// and the column name was explicitly specified. If the column is not found
/// and the column name was not explicitly specified, it returns null.
/// </summary>
public static string FindColumnOrNull(IExceptionContext ectx, DataViewSchema schema, Optional<string> value)
{
Contracts.CheckValueOrNull(ectx);
ectx.CheckValue(schema, nameof(schema));
ectx.CheckValue(value, nameof(value));
// Optionals are either left in their Implicit constructed state or
// a new Explicit optional is constructed. They should never be set
// to null.
if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Optional<>) && fieldVal == null)
throw ectx.Except("Field '{0}' is Optional<> and set to null instead of an explicit value.", fieldInfo.Name);

if (value == "")
return null;
if (schema.GetColumnOrNull(value) == null)
if (attr.IsRequired)
{
if (value.IsExplicit)
throw ectx.Except("Column '{0}' not found", value);
return null;
bool equalToDefault;
if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Optional<>))
equalToDefault = !((Optional)fieldVal).IsExplicit;
else
equalToDefault = fieldType.IsValueType ? Activator.CreateInstance(fieldType).Equals(fieldVal) : fieldVal == null;

if (equalToDefault)
throw ectx.Except("Field '{0}' is required but is not set.", fieldInfo.Name);
}
return value;

var rangeAttr = fieldInfo.GetCustomAttributes(typeof(TlcModule.RangeAttribute), false).FirstOrDefault()
as TlcModule.RangeAttribute;
if (rangeAttr != null && fieldVal != null && !rangeAttr.IsValueWithinRange(fieldVal))
throw ectx.Except("Field '{0}' is set to a value that falls outside the range bounds.", fieldInfo.Name);
}
}

/// <summary>
/// Converts EntryPoint Optional{T} types into nullable types, with the
/// implicit value being converted to the null value.
/// </summary>
public static T? AsNullable<T>(this Optional<T> opt) where T : struct
public static IHost CheckArgsAndCreateHost(IHostEnvironment env, string hostName, object input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register(hostName);
host.CheckValue(input, nameof(input));
CheckInputArgs(host, input);
return host;
}

/// <summary>
/// Searches for the given column name in the schema. This method applies a
/// common policy that throws an exception if the column is not found
/// and the column name was explicitly specified. If the column is not found
/// and the column name was not explicitly specified, it returns null.
/// </summary>
public static string FindColumnOrNull(IExceptionContext ectx, DataViewSchema schema, Optional<string> value)
{
Contracts.CheckValueOrNull(ectx);
ectx.CheckValue(schema, nameof(schema));
ectx.CheckValue(value, nameof(value));

if (value == "")
return null;
if (schema.GetColumnOrNull(value) == null)
{
if (opt.IsExplicit)
return opt.Value;
else
return null;
if (value.IsExplicit)
throw ectx.Except("Column '{0}' not found", value);
return null;
}
return value;
}

/// <summary>
/// Converts EntryPoint Optional{T} types into nullable types, with the
/// implicit value being converted to the null value.
/// </summary>
public static T? AsNullable<T>(this Optional<T> opt) where T : struct
{
if (opt.IsExplicit)
return opt.Value;
else
return null;
}
}
Loading

0 comments on commit 43a6a81

Please sign in to comment.