Skip to content

Commit

Permalink
Fix invalid handle bug happening when TypeBuilder type used in exce…
Browse files Browse the repository at this point in the history
…ption catch clause (#106665)

* Handle TypeBuilder exception type in catch clause

* Remove PersistedAssemblyBuilder.IsDynamic from ref instead of overriding it
  • Loading branch information
buyaa-n authored Aug 20, 2024
1 parent c47fc5f commit 57502fc
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,6 @@ public sealed class PersistedAssemblyBuilder : System.Reflection.Emit.AssemblyBu
{
public PersistedAssemblyBuilder(System.Reflection.AssemblyName name, System.Reflection.Assembly coreAssembly, System.Collections.Generic.IEnumerable<System.Reflection.Emit.CustomAttributeBuilder>? assemblyAttributes = null) { }
public override string? FullName { get { throw null; } }
public override bool IsDynamic { get { throw null; } }
public override System.Reflection.Module ManifestModule { get { throw null; } }
[System.Diagnostics.CodeAnalysis.RequiresDynamicCode("Defining a dynamic assembly requires dynamic code.")]
protected override System.Reflection.Emit.ModuleBuilder DefineDynamicModuleCore(string name) { throw null; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ internal sealed class ILGeneratorImpl : ILGenerator
private int _localCount;
private Dictionary<Label, LabelInfo> _labelTable = new(2);
private List<KeyValuePair<object, BlobWriter>> _memberReferences = new();
private List<ExceptionBlock> _exceptionStack = new();
private List<ExceptionBlock> _exceptionStack = new(); // tracks the exception nesting
private List<ExceptionHandlerInfo> _exceptionBlocks = new(); // keeps all ExceptionHandler blocks
private Dictionary<SymbolDocumentWriter, List<SequencePoint>> _documentToSequencePoints = new();

internal ILGeneratorImpl(MethodBuilderImpl methodBuilder, int size)
Expand All @@ -54,6 +55,32 @@ internal ILGeneratorImpl(MethodBuilderImpl methodBuilder, int size)
internal Scope Scope => _scope;
internal Dictionary<SymbolDocumentWriter, List<SequencePoint>> DocumentToSequencePoints => _documentToSequencePoints;

internal void AddExceptionBlocks()
{
foreach(ExceptionHandlerInfo eb in _exceptionBlocks)
{
switch (eb.Kind)
{
case ExceptionRegionKind.Catch:
_cfBuilder.AddCatchRegion(GetMetaLabel(eb.TryStart), GetMetaLabel(eb.TryEnd),
GetMetaLabel(eb.HandlerStart), GetMetaLabel(eb.HandlerEnd), _moduleBuilder.GetTypeHandle(eb.ExceptionType!));
break;
case ExceptionRegionKind.Filter:
_cfBuilder.AddFilterRegion(GetMetaLabel(eb.TryStart), GetMetaLabel(eb.TryEnd),
GetMetaLabel(eb.HandlerStart), GetMetaLabel(eb.HandlerEnd), GetMetaLabel(eb.FilterStart));
break;
case ExceptionRegionKind.Fault:
_cfBuilder.AddFaultRegion(GetMetaLabel(eb.TryStart), GetMetaLabel(eb.TryEnd),
GetMetaLabel(eb.HandlerStart), GetMetaLabel(eb.HandlerEnd));
break;
case ExceptionRegionKind.Finally:
_cfBuilder.AddFinallyRegion(GetMetaLabel(eb.TryStart), GetMetaLabel(eb.TryEnd),
GetMetaLabel(eb.HandlerStart), GetMetaLabel(eb.HandlerEnd));
break;
}
}
}

public override int ILOffset => _il.Offset;

public override void BeginCatchBlock(Type? exceptionType)
Expand Down Expand Up @@ -91,8 +118,8 @@ public override void BeginCatchBlock(Type? exceptionType)

currentExBlock.HandleStart = DefineLabel();
currentExBlock.HandleEnd = DefineLabel();
_cfBuilder.AddCatchRegion(GetMetaLabel(currentExBlock.TryStart), GetMetaLabel(currentExBlock.TryEnd),
GetMetaLabel(currentExBlock.HandleStart), GetMetaLabel(currentExBlock.HandleEnd), _moduleBuilder.GetTypeHandle(exceptionType));
_exceptionBlocks.Add(new ExceptionHandlerInfo(ExceptionRegionKind.Catch, currentExBlock.TryStart,
currentExBlock.TryEnd, currentExBlock.HandleStart, currentExBlock.HandleEnd, default, exceptionType));
MarkLabel(currentExBlock.HandleStart);
}

Expand Down Expand Up @@ -124,9 +151,9 @@ public override void BeginExceptFilterBlock()
currentExBlock.FilterStart = DefineLabel();
currentExBlock.HandleStart = DefineLabel();
currentExBlock.HandleEnd = DefineLabel();
_cfBuilder.AddFilterRegion(GetMetaLabel(currentExBlock.TryStart), GetMetaLabel(currentExBlock.TryEnd),
GetMetaLabel(currentExBlock.HandleStart), GetMetaLabel(currentExBlock.HandleEnd), GetMetaLabel(currentExBlock.FilterStart));
currentExBlock.State = ExceptionState.Filter;
_exceptionBlocks.Add(new ExceptionHandlerInfo(ExceptionRegionKind.Filter, currentExBlock.TryStart,
currentExBlock.TryEnd, currentExBlock.HandleStart, currentExBlock.HandleEnd, currentExBlock.FilterStart));
MarkLabel(currentExBlock.FilterStart);
// Stack depth for "filter" starts at one.
_currentStackDepth = 1;
Expand Down Expand Up @@ -166,8 +193,8 @@ public override void BeginFaultBlock()

currentExBlock.HandleStart = DefineLabel();
currentExBlock.HandleEnd = DefineLabel();
_cfBuilder.AddFaultRegion(GetMetaLabel(currentExBlock.TryStart), GetMetaLabel(currentExBlock.TryEnd),
GetMetaLabel(currentExBlock.HandleStart), GetMetaLabel(currentExBlock.HandleEnd));
_exceptionBlocks.Add(new ExceptionHandlerInfo(ExceptionRegionKind.Fault, currentExBlock.TryStart,
currentExBlock.TryEnd, currentExBlock.HandleStart, currentExBlock.HandleEnd));
currentExBlock.State = ExceptionState.Fault;
MarkLabel(currentExBlock.HandleStart);
// Stack depth for "fault" starts at zero.
Expand Down Expand Up @@ -197,8 +224,8 @@ public override void BeginFinallyBlock()
MarkLabel(currentExBlock.TryEnd);
currentExBlock.HandleStart = DefineLabel();
currentExBlock.HandleEnd = finallyEndLabel;
_cfBuilder.AddFinallyRegion(GetMetaLabel(currentExBlock.TryStart), GetMetaLabel(currentExBlock.TryEnd),
GetMetaLabel(currentExBlock.HandleStart), GetMetaLabel(currentExBlock.HandleEnd));
_exceptionBlocks.Add(new ExceptionHandlerInfo(ExceptionRegionKind.Finally, currentExBlock.TryStart,
currentExBlock.TryEnd, currentExBlock.HandleStart, currentExBlock.HandleEnd));
currentExBlock.State = ExceptionState.Finally;
MarkLabel(currentExBlock.HandleStart);
// Stack depth for "finally" starts at zero.
Expand Down Expand Up @@ -835,6 +862,31 @@ internal sealed class ExceptionBlock
public ExceptionState State;
}

internal struct ExceptionHandlerInfo
{
public readonly ExceptionRegionKind Kind;
public readonly Label TryStart, TryEnd, HandlerStart, HandlerEnd, FilterStart;
public Type? ExceptionType;

public ExceptionHandlerInfo(
ExceptionRegionKind kind,
Label tryStart,
Label tryEnd,
Label handlerStart,
Label handlerEnd,
Label filterStart = default,
Type? catchType = null)
{
Kind = kind;
TryStart = tryStart;
TryEnd = tryEnd;
HandlerStart = handlerStart;
HandlerEnd = handlerEnd;
FilterStart = filterStart;
ExceptionType = catchType;
}
}

internal enum ExceptionState
{
Undefined,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ private void WriteMethods(List<MethodBuilderImpl> methods, List<GenericTypeParam
if (il != null)
{
FillMemberReferences(il);
il.AddExceptionBlocks();
StandaloneSignatureHandle signature = il.LocalCount == 0 ? default :
_metadataBuilder.AddStandaloneSignature(_metadataBuilder.GetOrAddBlob(MetadataSignatureHelper.GetLocalSignature(il.Locals, this)));
offset = AddMethodBody(method, il, signature, methodBodyEncoder);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,157 @@ public void SimpleTryCatchBlock()
}
}

[Fact]
public void TryCatchWithTypeBuilderException()
{
using (TempFile file = TempFile.Create())
{
PersistedAssemblyBuilder ab = new PersistedAssemblyBuilder(new AssemblyName("MyAssembly"), typeof(object).Assembly);
ModuleBuilder mb = ab.DefineDynamicModule("MyModule");
TypeBuilder tb = mb.DefineType("MyType", TypeAttributes.Public);
TypeBuilder exceptionType = mb.DefineType("MyException", TypeAttributes.Public, typeof(Exception));
MethodBuilder method = tb.DefineMethod("Method", MethodAttributes.Public | MethodAttributes.Static, typeof(int), [typeof(int), typeof(int)]);
ILGenerator ilGenerator = method.GetILGenerator();
ilGenerator.BeginExceptionBlock();
ilGenerator.Emit(OpCodes.Ldarg_0);
ilGenerator.Emit(OpCodes.Ldarg_1);
ilGenerator.Emit(OpCodes.Add);
ilGenerator.BeginCatchBlock(exceptionType);
ilGenerator.Emit(OpCodes.Ldc_I4_0);
ilGenerator.EndExceptionBlock();
ilGenerator.Emit(OpCodes.Ret);
tb.CreateType();
exceptionType.CreateType();
ab.Save(file.Path);

using (MetadataLoadContext mlc = new MetadataLoadContext(new CoreMetadataAssemblyResolver()))
{
Assembly assemblyFromDisk = mlc.LoadFromAssemblyPath(file.Path);
Type typeFromDisk = assemblyFromDisk.Modules.First().GetType("MyType");
MethodBody body = typeFromDisk.GetMethod("Method").GetMethodBody();
Assert.Equal(1, body.ExceptionHandlingClauses.Count);
Assert.Equal("MyException", body.ExceptionHandlingClauses[0].CatchType.FullName);
Assert.Equal(ExceptionHandlingClauseOptions.Clause, body.ExceptionHandlingClauses[0].Flags);
}
}
}

[Fact]
public void TryMultipleCatchFinallyBlocks()
{
using (TempFile file = TempFile.Create())
{
PersistedAssemblyBuilder ab = AssemblySaveTools.PopulateAssemblyBuilderAndTypeBuilder(out TypeBuilder tb);
MethodBuilder method = tb.DefineMethod("Method", MethodAttributes.Public | MethodAttributes.Static, typeof(int), [typeof(int), typeof(int)]);
FieldBuilder fb = tb.DefineField("Field", typeof(int), FieldAttributes.Public | FieldAttributes.Static);
Type dBZException = typeof(DivideByZeroException);
TypeBuilder myExceptionType = ab.GetDynamicModule("MyModule").DefineType("MyException", TypeAttributes.Public, typeof(Exception));
myExceptionType.CreateType();
Type exception = typeof(Exception);
Type overflowException = typeof(OverflowException);
ILGenerator ilGenerator = method.GetILGenerator();
LocalBuilder local = ilGenerator.DeclareLocal(typeof(int));
Label exBlock = ilGenerator.BeginExceptionBlock();
Label check100 = ilGenerator.DefineLabel();
Label leave = ilGenerator.DefineLabel();
ilGenerator.Emit(OpCodes.Ldarg_0);
ilGenerator.Emit(OpCodes.Ldarg_1);
ilGenerator.Emit(OpCodes.Div);
ilGenerator.Emit(OpCodes.Stloc_0);
ilGenerator.Emit(OpCodes.Ldloc_0);
ilGenerator.Emit(OpCodes.Brtrue, check100);
ilGenerator.ThrowException(myExceptionType);
ilGenerator.MarkLabel(check100);
ilGenerator.Emit(OpCodes.Ldarg_1);
ilGenerator.Emit(OpCodes.Ldc_I4, 100);
ilGenerator.Emit(OpCodes.Bne_Un, leave);
ilGenerator.ThrowException(overflowException);
ilGenerator.MarkLabel(leave);
ilGenerator.BeginCatchBlock(dBZException);
ilGenerator.EmitWriteLine("Error: division by zero");
ilGenerator.Emit(OpCodes.Ldc_I4_M1);
ilGenerator.Emit(OpCodes.Stloc_0);
ilGenerator.BeginCatchBlock(myExceptionType);
ilGenerator.EmitWriteLine("Error: MyException");
ilGenerator.Emit(OpCodes.Ldc_I4_S, 2);
ilGenerator.Emit(OpCodes.Stloc_0);
ilGenerator.BeginCatchBlock(exception);
ilGenerator.EmitWriteLine("Error: generic Exception");
ilGenerator.Emit(OpCodes.Ldc_I4_S, 3);
ilGenerator.Emit(OpCodes.Stloc_0);
ilGenerator.BeginFinallyBlock();
ilGenerator.EmitWriteLine("Finally block");
ilGenerator.Emit(OpCodes.Ldc_I4_S, 30);
ilGenerator.Emit(OpCodes.Stsfld, fb);
ilGenerator.EndExceptionBlock();
ilGenerator.Emit(OpCodes.Ldloc_0);
ilGenerator.Emit(OpCodes.Ret);
tb.CreateType();
ab.Save(file.Path);

TestAssemblyLoadContext tlc = new TestAssemblyLoadContext();
Assembly assemblyFromDisk = tlc.LoadFromAssemblyPath(file.Path);
Type typeFromDisk = assemblyFromDisk.GetType("MyType");
MethodInfo methodFromDisk = typeFromDisk.GetMethod("Method");
MethodBody body = methodFromDisk.GetMethodBody();
Assert.Equal(4, body.ExceptionHandlingClauses.Count);
Assert.Equal(ExceptionHandlingClauseOptions.Clause, body.ExceptionHandlingClauses[0].Flags);
Assert.Equal(ExceptionHandlingClauseOptions.Clause, body.ExceptionHandlingClauses[1].Flags);
Assert.Equal(ExceptionHandlingClauseOptions.Clause, body.ExceptionHandlingClauses[2].Flags);
Assert.Equal(ExceptionHandlingClauseOptions.Finally, body.ExceptionHandlingClauses[3].Flags);
Assert.Equal(dBZException.FullName, body.ExceptionHandlingClauses[0].CatchType.FullName);
Assert.Equal("MyException", body.ExceptionHandlingClauses[1].CatchType.FullName);
Assert.Equal(exception.FullName, body.ExceptionHandlingClauses[2].CatchType.FullName);
/*
public class MyException : Exception { }
public class MyType
{
public static int Field;
public static int Method(int a, int b)
{
int res;
try{
res = a/b;
if (res == 0)
throw new MyException();
if (b == 100)
throw new OverflowException();
}
catch(DivideByZeroException)
{
Console.WriteLine("Divide by zero caught");
res = -1;
}
catch(MyException)
{
Console.WriteLine("MyException caught");
res = 2;
}
catch(Exception)
{
Console.WriteLine("Divide by zero!");
res = 3;
}
finally
{
Console.WriteLine("Finally block");
Field = 30;
}
return res;
}
}*/
FieldInfo field = typeFromDisk.GetField("Field");
Assert.Equal(0, field.GetValue(null));
Assert.Equal(5, methodFromDisk.Invoke(null, new object[] { 50, 10 }));
Assert.Equal(30, field.GetValue(null));
Assert.Equal(-1, methodFromDisk.Invoke(null, new object[] { 1, 0 }));
Assert.Equal(2, methodFromDisk.Invoke(null, new object[] { 0, 1 }));
Assert.Equal(3, methodFromDisk.Invoke(null, new object[] { 1000, 100 }));
tlc.Unload();
}
}

[Fact]
public void TryMultipleCatchBlocks()
{
Expand Down

0 comments on commit 57502fc

Please sign in to comment.