Skip to content

Commit

Permalink
Fix to job indexing/discovery. We need to also look for job functions…
Browse files Browse the repository at this point in the history
… in assemblies that reference binding extension assemblies.
  • Loading branch information
mathewc committed Jul 7, 2015
1 parent 4a894c5 commit 41d1ec0
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using Microsoft.Azure.WebJobs.Host.Bindings;
using Microsoft.Azure.WebJobs.Host.Config;
using Microsoft.Azure.WebJobs.Host.Triggers;

namespace Microsoft.Azure.WebJobs.Host
{
Expand All @@ -12,6 +16,14 @@ namespace Microsoft.Azure.WebJobs.Host
/// </summary>
public static class IExtensionRegistryExtensions
{
private static readonly Type[] ExtensionTypes = new Type[]
{
typeof(ITriggerBindingProvider),
typeof(IBindingProvider),
typeof(IExtensionConfigProvider),
typeof(IArgumentBindingProvider<>)
};

/// <summary>
/// Registers the specified instance.
/// </summary>
Expand Down Expand Up @@ -43,5 +55,22 @@ public static IEnumerable<TExtension> GetExtensions<TExtension>(this IExtensionR

return registry.GetExtensions(typeof(TExtension)).Cast<TExtension>();
}

/// <summary>
/// Returns the set of assemblies that have registered extensions.
/// </summary>
/// <param name="registry">The registry instance.</param>
/// <returns>The unique set of assemblies.</returns>
internal static IEnumerable<Assembly> GetExtensionAssemblies(this IExtensionRegistry registry)
{
HashSet<Assembly> assemblies = new HashSet<Assembly>();
foreach (Type extensionType in ExtensionTypes)
{
var currAssemblies = registry.GetExtensions(extensionType).Select(p => p.GetType().Assembly);
assemblies.UnionWith(currAssemblies);
}

return assemblies;
}
}
}
45 changes: 31 additions & 14 deletions src/Microsoft.Azure.WebJobs.Host/Indexers/DefaultTypeLocator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,47 @@ internal class DefaultTypeLocator : ITypeLocator
private static readonly string WebJobsAssemblyName = typeof(TableAttribute).Assembly.GetName().Name;

private readonly TextWriter _log;
private readonly IExtensionRegistry _extensions;

public DefaultTypeLocator(TextWriter log)
public DefaultTypeLocator(TextWriter log, IExtensionRegistry extensions)
{
if (log == null)
{
throw new ArgumentNullException("log");
}
if (extensions == null)
{
throw new ArgumentNullException("extensions");
}

_log = log;
_extensions = extensions;
}

// Helper to filter out assemblies that don't even reference this SDK.
private static bool DoesAssemblyReferenceSdk(Assembly a)
// Helper to filter out assemblies that don't reference the SDK or
// binding extension assemblies (i.e. possible sources of binding attributes, etc.)
private static bool AssemblyReferencesSdkOrExtension(Assembly assembly, IEnumerable<Assembly> extensionAssemblies)
{
// Don't index methods in our assemblies.
if (typeof(DefaultTypeLocator).Assembly == a)
if (typeof(DefaultTypeLocator).Assembly == assembly)
{
return false;
}

AssemblyName[] referencedAssemblyNames = a.GetReferencedAssemblies();
AssemblyName[] referencedAssemblyNames = assembly.GetReferencedAssemblies();
foreach (var referencedAssemblyName in referencedAssemblyNames)
{
if (String.Equals(referencedAssemblyName.Name, WebJobsAssemblyName, StringComparison.OrdinalIgnoreCase))
{
// the assembly references our core SDK assembly
// containing our built in attribute types
return true;
}

if (extensionAssemblies.Any(p => string.Equals(referencedAssemblyName.Name, p.GetName().Name, StringComparison.OrdinalIgnoreCase)))
{
// the assembly references an extension assembly that may
// contain extension attributes
return true;
}
}
Expand All @@ -52,10 +68,10 @@ public IReadOnlyList<Type> GetTypes()
List<Type> allTypes = new List<Type>();

var assemblies = GetUserAssemblies();
IEnumerable<Assembly> extensionAssemblies = _extensions.GetExtensionAssemblies();
foreach (var assembly in assemblies)
{
var assemblyTypes = FindTypes(assembly);

var assemblyTypes = FindTypes(assembly, extensionAssemblies);
if (assemblyTypes != null)
{
allTypes.AddRange(assemblyTypes.Where(IsJobClass));
Expand Down Expand Up @@ -87,11 +103,12 @@ private static IEnumerable<Assembly> GetUserAssemblies()
return AppDomain.CurrentDomain.GetAssemblies();
}

public Type[] FindTypes(Assembly a)
private Type[] FindTypes(Assembly assembly, IEnumerable<Assembly> extensionAssemblies)
{
// Only try to index assemblies that reference this SDK.
// This avoids trying to index through a bunch of FX assemblies that reflection may not be able to load anyways.
if (!DoesAssemblyReferenceSdk(a))
// Only try to index assemblies that reference the core SDK assembly containing
// binding attributes (or any registered extension assemblies). This ensures we
// don't do more inspection work that is necessary during function indexing.
if (!AssemblyReferencesSdkOrExtension(assembly, extensionAssemblies))
{
return null;
}
Expand All @@ -100,12 +117,12 @@ public Type[] FindTypes(Assembly a)

try
{
types = a.GetTypes();
types = assembly.GetTypes();
}
catch (ReflectionTypeLoadException ex)
{
// TODO: Log this somewhere?
_log.WriteLine("Warning: Only got partial types from assembly: {0}", a.FullName);
_log.WriteLine("Warning: Only got partial types from assembly: {0}", assembly.FullName);
_log.WriteLine("Exception message: {0}", ex.ToString());

// In case of a type load exception, at least get the types that did succeed in loading
Expand All @@ -114,7 +131,7 @@ public Type[] FindTypes(Assembly a)
catch (Exception ex)
{
// TODO: Log this somewhere?
_log.WriteLine("Warning: Failed to get types from assembly: {0}", a.FullName);
_log.WriteLine("Warning: Failed to get types from assembly: {0}", assembly.FullName);
_log.WriteLine("Exception message: {0}", ex.ToString());
}

Expand Down
14 changes: 5 additions & 9 deletions src/Microsoft.Azure.WebJobs.Host/Indexers/FunctionIndexer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ internal class FunctionIndexer
private readonly IBindingProvider _bindingProvider;
private readonly IJobActivator _activator;
private readonly IFunctionExecutor _executor;
private readonly HashSet<Assembly> _jobTypeAssemblies;
private readonly HashSet<Assembly> _jobAttributeAssemblies;

public FunctionIndexer(ITriggerBindingProvider triggerBindingProvider, IBindingProvider bindingProvider, IJobActivator activator, IFunctionExecutor executor, IExtensionRegistry extensions)
{
Expand Down Expand Up @@ -59,7 +59,7 @@ public FunctionIndexer(ITriggerBindingProvider triggerBindingProvider, IBindingP
_bindingProvider = bindingProvider;
_activator = activator;
_executor = executor;
_jobTypeAssemblies = new HashSet<Assembly>(GetJobTypeAssemblies(extensions, typeof(ITriggerBindingProvider), typeof(IBindingProvider)));
_jobAttributeAssemblies = GetJobAttributeAssemblies(extensions);
}

public async Task IndexTypeAsync(Type type, IFunctionIndexCollector index, CancellationToken cancellationToken)
Expand Down Expand Up @@ -95,25 +95,21 @@ public bool IsJobMethod(MethodInfo method)
return false;
}

private static HashSet<Assembly> GetJobTypeAssemblies(IExtensionRegistry extensions, params Type[] extensionTypes)
private static HashSet<Assembly> GetJobAttributeAssemblies(IExtensionRegistry extensions)
{
// create a set containing our own core assemblies
HashSet<Assembly> assemblies = new HashSet<Assembly>();
assemblies.Add(typeof(BlobAttribute).Assembly);

// add any extension assemblies
foreach (Type extensionType in extensionTypes)
{
var currAssemblies = extensions.GetExtensions(extensionType).Select(p => p.GetType().Assembly);
assemblies.UnionWith(currAssemblies);
}
assemblies.UnionWith(extensions.GetExtensionAssemblies());

return assemblies;
}

private bool HasJobAttribute(CustomAttributeData attributeData)
{
return _jobTypeAssemblies.Contains(attributeData.AttributeType.Assembly);
return _jobAttributeAssemblies.Contains(attributeData.AttributeType.Assembly);
}

public async Task IndexMethodAsync(MethodInfo method, IFunctionIndexCollector index, CancellationToken cancellationToken)
Expand Down
7 changes: 5 additions & 2 deletions src/Microsoft.Azure.WebJobs.Host/JobHostConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public sealed class JobHostConfiguration : IServiceProvider
private string _serviceBusConnectionString;

private string _hostId;
private ITypeLocator _typeLocator = new DefaultTypeLocator(ConsoleProvider.Out);
private ITypeLocator _typeLocator;
private INameResolver _nameResolver = new DefaultNameResolver();
private IJobActivator _activator = DefaultJobActivator.Instance;

Expand Down Expand Up @@ -54,8 +54,11 @@ private JobHostConfiguration(DefaultStorageAccountProvider storageAccountProvide
{
_storageAccountProvider = storageAccountProvider;

IExtensionRegistry extensions = new DefaultExtensionRegistry();
_typeLocator = new DefaultTypeLocator(ConsoleProvider.Out, extensions);

// add our built in services here
AddService<IExtensionRegistry>(new DefaultExtensionRegistry());
AddService<IExtensionRegistry>(extensions);
}

/// <summary>Gets or sets the host ID.</summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ internal class TestJobHostContextFactory : IJobHostContextFactory

public Task<JobHostContext> CreateAndLogHostStartedAsync(CancellationToken shutdownToken, CancellationToken cancellationToken)
{
ITypeLocator typeLocator = new DefaultTypeLocator(new StringWriter());
ITypeLocator typeLocator = new DefaultTypeLocator(new StringWriter(), new DefaultExtensionRegistry());
INameResolver nameResolver = new RandomNameResolver();
JobHostConfiguration config = new JobHostConfiguration
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Reflection;
using System.Threading;
using Microsoft.Azure.WebJobs.Host.Bindings;
using Microsoft.Azure.WebJobs.Host.Config;
using Microsoft.Azure.WebJobs.Host.Executors;
using Microsoft.Azure.WebJobs.Host.Indexers;
using Microsoft.Azure.WebJobs.Host.Triggers;
Expand All @@ -24,8 +25,10 @@ public void TestFails()
foreach (var method in this.GetType().GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static))
{
Mock<IExtensionRegistry> extensionsMock = new Mock<IExtensionRegistry>(MockBehavior.Strict);
extensionsMock.Setup(p => p.GetExtensions(typeof(IExtensionConfigProvider))).Returns(Enumerable.Empty<object>());
extensionsMock.Setup(p => p.GetExtensions(typeof(ITriggerBindingProvider))).Returns(Enumerable.Empty<object>());
extensionsMock.Setup(p => p.GetExtensions(typeof(IBindingProvider))).Returns(Enumerable.Empty<object>());
extensionsMock.Setup(p => p.GetExtensions(typeof(IArgumentBindingProvider<>))).Returns(Enumerable.Empty<object>());
Mock<IFunctionExecutor> executorMock = new Mock<IFunctionExecutor>(MockBehavior.Strict);

IFunctionIndexCollector stubIndex = new Mock<IFunctionIndexCollector>().Object;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ internal class TestJobHostContextFactory : IJobHostContextFactory

public Task<JobHostContext> CreateAndLogHostStartedAsync(CancellationToken shutdownToken, CancellationToken cancellationToken)
{
ITypeLocator typeLocator = new DefaultTypeLocator(new StringWriter());
ITypeLocator typeLocator = new DefaultTypeLocator(new StringWriter(), new DefaultExtensionRegistry());
INameResolver nameResolver = new RandomNameResolver();
JobHostConfiguration config = new JobHostConfiguration
{
Expand Down

0 comments on commit 41d1ec0

Please sign in to comment.