Skip to content

Commit

Permalink
Moved opcode-related logic to a respective helper
Browse files Browse the repository at this point in the history
  • Loading branch information
Kir-Antipov committed Oct 6, 2023
1 parent fb29062 commit a8d9478
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 80 deletions.
11 changes: 5 additions & 6 deletions src/HotAvalonia/AvaloniaRuntimeXamlScanner.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using System.Diagnostics.CodeAnalysis;
using System.Reflection;
using System.Reflection.Emit;
using HotAvalonia.Helpers;
using HotAvalonia.Reflection;

Expand Down Expand Up @@ -196,9 +195,9 @@ private static bool TryExtractControlUri(MethodInfo populateMethod, [NotNullWhen
if (methodBody is null)
return false;

int ldstrLocation = methodBody.Length > commonLdstrLocation && methodBody[commonLdstrLocation] == OpCodes.Ldstr.Value
int ldstrLocation = methodBody.Length > commonLdstrLocation && methodBody[commonLdstrLocation] == OpCodeHelper.LdstrValue
? commonLdstrLocation
: MethodBodyReader.IndexOf(methodBody, OpCodes.Ldstr.Value);
: MethodBodyReader.IndexOf(methodBody, OpCodeHelper.LdstrValue);

int uriTokenLocation = ldstrLocation + 1;

Expand Down Expand Up @@ -251,19 +250,19 @@ private static IEnumerable<AvaloniaControlInfo> ExtractAvaloniaControls(ReadOnly
while (reader.Next())
{
short opCode = reader.OpCode.Value;
if (opCode == OpCodes.Ret.Value)
if (opCode is OpCodeHelper.RetValue)
{
(str, uri) = (null, null);
continue;
}

if (opCode == OpCodes.Ldstr.Value)
if (opCode is OpCodeHelper.LdstrValue)
{
str = reader.ResolveString(module);
continue;
}

if (opCode != OpCodes.Call.Value && opCode != OpCodes.Newobj.Value)
if (opCode is not (OpCodeHelper.CallValue or OpCodeHelper.NewobjValue))
continue;

MethodBase method = reader.ResolveMethod(module);
Expand Down
165 changes: 165 additions & 0 deletions src/HotAvalonia/Helpers/OpCodeHelper.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
using System.Reflection;
using System.Reflection.Emit;

namespace HotAvalonia.Helpers;

/// <summary>
/// Provides helper methods for working with CIL opcodes.
/// </summary>
internal static class OpCodeHelper
{
/// <summary>
/// The <see cref="OpCode.Value"/> for the 'call' instruction.
/// </summary>
internal const short CallValue = 0x28;

/// <summary>
/// The <see cref="OpCode.Value"/> for the 'ret' instruction.
/// </summary>
internal const short RetValue = 0x2A;

/// <summary>
/// The <see cref="OpCode.Value"/> for the 'switch' instruction.
/// </summary>
internal const short SwitchValue = 0x45;

/// <summary>
/// The <see cref="OpCode.Value"/> for the 'ldstr' instruction.
/// </summary>
internal const short LdstrValue = 0x72;

/// <summary>
/// The <see cref="OpCode.Value"/> for the 'newobj' instruction.
/// </summary>
internal const short NewobjValue = 0x73;

/// <summary>
/// The flag indicating that the <see cref="OpCode"/> is represented by 2 bytes.
/// </summary>
private const int Int16OpCodeFlag = 0xFE00;

/// <summary>
/// The marker value for two-byte opcodes.
/// </summary>
private const int Int16OpCodeMarker = 0xFE;

/// <summary>
/// The instruction set containing all <see cref="OpCode"/> instances.
/// </summary>
private static readonly Lazy<OpCode[]> s_opCodes = new(CreateInstructionSet, isThreadSafe: true);

/// <summary>
/// Attempts to read an <see cref="OpCode"/> from the given IL span.
/// </summary>
/// <param name="il">The span containing the IL bytecode.</param>
/// <param name="opCode">When this method returns, contains the <see cref="OpCode"/> if the method is successful.</param>
/// <returns><c>true</c> if an <see cref="OpCode"/> was successfully read; otherwise, <c>false</c>.</returns>
public static bool TryReadOpCode(ReadOnlySpan<byte> il, out OpCode opCode)
{
opCode = OpCodes.Nop;

if (il.IsEmpty)
return false;

int opCodeValue = il[0];
if (opCodeValue >= Int16OpCodeMarker)
{
if (il.Length < 1)
return false;

opCodeValue = (opCodeValue << 8) | il[1];
}

return TryGetOpCode(opCodeValue, out opCode);
}

/// <summary>
/// Tries to get the <see cref="OpCode"/> associated with the given value.
/// </summary>
/// <param name="value">The opcode value.</param>
/// <param name="opCode">When this method returns, contains the <see cref="OpCode"/> if the method is successful.</param>
/// <returns><c>true</c> if the <see cref="OpCode"/> was found for the given value; otherwise, <c>false</c>.</returns>
public static bool TryGetOpCode(int value, out OpCode opCode)
{
int index = GetOpCodeIndex(value);
OpCode[] opCodes = s_opCodes.Value;
if ((uint)index < (uint)opCodes.Length)
{
opCode = opCodes[index];
return true;
}

opCode = OpCodes.Nop;
return false;
}

/// <summary>
/// Creates an instruction set containing all <see cref="OpCode"/> instances.
/// </summary>
/// <returns>An array of <see cref="OpCode"/> instances.</returns>
private static OpCode[] CreateInstructionSet()
{
List<OpCode> opCodes = new(256);
opCodes.AddRange(ExtractAllOpCodes());

int maxIndex = opCodes.Max(static x => GetOpCodeIndex(x.Value));
int instructionSetSize = maxIndex + 1;
OpCode[] instructionSet = new OpCode[instructionSetSize];

foreach (OpCode opCode in opCodes)
{
instructionSet[GetOpCodeIndex(opCode.Value)] = opCode;
}

return instructionSet;
}

/// <summary>
/// Gets the index of the given opcode in the instruction set.
/// </summary>
/// <param name="value">The opcode value.</param>
/// <returns>The index of the opcode.</returns>
private static int GetOpCodeIndex(int value)
=> (value & Int16OpCodeFlag) == Int16OpCodeFlag ? (256 + (value & 0xFF)) : (value & 0xFF);

/// <summary>
/// Extracts all opcodes from the <see cref="OpCodes"/> class.
/// </summary>
/// <returns>
/// All opcodes defined in the <see cref="OpCodes"/> class.
/// </returns>
private static IEnumerable<OpCode> ExtractAllOpCodes()
{
FieldInfo[] fields = typeof(OpCodes).GetFields(BindingFlags.Public | BindingFlags.Static);
return fields.Where(static x => x.FieldType == typeof(OpCode)).Select(static x => (OpCode)x.GetValue(null));
}

/// <summary>
/// Calculates the size of the operand associated with the given opcode.
/// </summary>
/// <param name="opCode">The operation code in question.</param>
/// <returns>
/// The size of the operand in bytes; or 0 if the opcode has no operand.
/// </returns>
public static int GetOperandSize(this OpCode opCode)
=> opCode.Size is 0 ? 0 : GetOperandSize(opCode.OperandType);

/// <summary>
/// Determines the size of an operand based on its type.
/// </summary>
/// <param name="operandType">The type of the operand.</param>
/// <returns>
/// The size of the operand in bytes.
/// </returns>
public static int GetOperandSize(this OperandType operandType) => operandType switch
{
OperandType.InlineBrTarget or OperandType.InlineField or OperandType.InlineI
or OperandType.InlineMethod or OperandType.InlineSig or OperandType.InlineString
or OperandType.InlineSwitch or OperandType.InlineTok or OperandType.InlineType
or OperandType.ShortInlineR => sizeof(int),
OperandType.InlineI8 or OperandType.InlineR => sizeof(long),
OperandType.InlineVar => sizeof(short),
OperandType.ShortInlineBrTarget or OperandType.ShortInlineI or OperandType.ShortInlineVar => sizeof(byte),
_ => 0,
};
}
84 changes: 10 additions & 74 deletions src/HotAvalonia/Reflection/MethodBodyReader.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System.Reflection;
using System.Reflection.Emit;
using System.Runtime.InteropServices;
using HotAvalonia.Helpers;

#if NETSTANDARD2_0
using BitConverter = HotAvalonia.Helpers.BitHelper;
Expand All @@ -14,17 +15,6 @@ namespace HotAvalonia.Reflection;
/// </summary>
internal struct MethodBodyReader
{
/// <summary>
/// A flag representing a 16-bit opcode value.
/// </summary>
private const short Int16OpCodeFlag = 0xFE;

/// <summary>
/// A dictionary containing <see cref="OpCodes"/> mapped to their respective opcode values.
/// </summary>
private static readonly Lazy<Dictionary<short, OpCode>> s_opCodes = new(
static () => GetAllOpCodes().ToDictionary(static x => x.Value));

/// <summary>
/// The byte sequence that constitutes the method body.
/// </summary>
Expand Down Expand Up @@ -80,7 +70,7 @@ public readonly ReadOnlySpan<byte> Operand
if (size is 0)
return Array.Empty<byte>();

return _methodBody.Slice(_position + size, GetOperandSize(_opCode.OperandType)).Span;
return _methodBody.Slice(_position + size, _opCode.OperandType.GetOperandSize()).Span;
}
}

Expand All @@ -91,7 +81,7 @@ public readonly ReadOnlySpan<int> JumpTable
{
get
{
if (_opCode.Value != OpCodes.Switch.Value)
if (_opCode.Value is not OpCodeHelper.SwitchValue)
return Array.Empty<int>();

int start = _position + sizeof(byte) + sizeof(int);
Expand All @@ -112,28 +102,15 @@ public bool Next()
{
ReadOnlySpan<byte> methodBody = _methodBody.Span;
int nextPosition = _bytesConsumed;
if (nextPosition >= methodBody.Length)
return false;

short newOpCodeValue = methodBody[nextPosition];
int operandStart = nextPosition + 1;
if (newOpCodeValue >= Int16OpCodeFlag)
{
if (operandStart >= methodBody.Length)
return false;

newOpCodeValue = (short)((newOpCodeValue << 8) | methodBody[operandStart]);
++operandStart;
}

if (!s_opCodes.Value.TryGetValue(newOpCodeValue, out OpCode newOpCode))
if (!OpCodeHelper.TryReadOpCode(methodBody.Slice(nextPosition), out OpCode newOpCode))
return false;

int nextBytesConsumed = operandStart + GetOperandSize(newOpCode.OperandType);
int operandStart = nextPosition + newOpCode.Size;
int nextBytesConsumed = operandStart + newOpCode.OperandType.GetOperandSize();
if (nextBytesConsumed > methodBody.Length)
return false;

if (newOpCode.Value == OpCodes.Switch.Value)
if (newOpCode.Value is OpCodeHelper.SwitchValue)
{
int n = BitConverter.ToInt32(methodBody.Slice(operandStart));
nextBytesConsumed += n * sizeof(int);
Expand All @@ -157,7 +134,7 @@ public bool Next()
/// The zero-based index of the first occurrence of the specified op code in the method body;
/// or -1 if the opcode is not found.
/// </returns>
public static int IndexOf(ReadOnlyMemory<byte> methodBody, short opCode)
public static int IndexOf(ReadOnlyMemory<byte> methodBody, int opCode)
{
MethodBodyReader reader = new(methodBody);
while (reader.Next())
Expand Down Expand Up @@ -373,53 +350,12 @@ public readonly Type ResolveType(Module module, Type[] genericTypeArguments, Typ
/// <exception cref="InvalidOperationException">Thrown when the operand size does not match the expected size.</exception>
private readonly void EnsureOperandSize(int size)
{
if (GetOperandSize(_opCode) == size)
if (_opCode.GetOperandSize() == size)
return;

ThrowInvalidOperationException_SizeDoesNotMatch(size, GetOperandSize(_opCode));
ThrowInvalidOperationException_SizeDoesNotMatch(size, _opCode.GetOperandSize());

static void ThrowInvalidOperationException_SizeDoesNotMatch(int expectedSize, int actualSize)
=> throw new InvalidOperationException($"The operand size ({actualSize} bytes) does not match the expected size ({expectedSize} bytes).");
}

/// <summary>
/// Calculates the size of the operand associated with the given opcode.
/// </summary>
/// <param name="opCode">The operation code in question.</param>
/// <returns>
/// The size of the operand in bytes; or 0 if the opcode has no operand.
/// </returns>
private static int GetOperandSize(OpCode opCode)
=> opCode.Size is 0 ? 0 : GetOperandSize(opCode.OperandType);

/// <summary>
/// Determines the size of an operand based on its type.
/// </summary>
/// <param name="operandType">The type of the operand.</param>
/// <returns>
/// The size of the operand in bytes.
/// </returns>
public static int GetOperandSize(OperandType operandType) => operandType switch
{
OperandType.InlineBrTarget or OperandType.InlineField or OperandType.InlineI
or OperandType.InlineMethod or OperandType.InlineSig or OperandType.InlineString
or OperandType.InlineSwitch or OperandType.InlineTok or OperandType.InlineType
or OperandType.ShortInlineR => sizeof(int),
OperandType.InlineI8 or OperandType.InlineR => sizeof(long),
OperandType.InlineVar => sizeof(short),
OperandType.ShortInlineBrTarget or OperandType.ShortInlineI or OperandType.ShortInlineVar => sizeof(byte),
_ => 0,
};

/// <summary>
/// Retrieves all opcodes defined in the <see cref="OpCodes"/> class.
/// </summary>
/// <returns>
/// All opcodes defined in the <see cref="OpCodes"/> class.
/// </returns>
private static IEnumerable<OpCode> GetAllOpCodes()
{
FieldInfo[] fields = typeof(OpCodes).GetFields(BindingFlags.Public | BindingFlags.Static);
return fields.Where(static x => x.FieldType == typeof(OpCode)).Select(static x => (OpCode)x.GetValue(null));
}
}

0 comments on commit a8d9478

Please sign in to comment.