Skip to content

Commit

Permalink
perf: use field and method map to find original member
Browse files Browse the repository at this point in the history
  • Loading branch information
Clazex committed Mar 8, 2023
1 parent c6f5b99 commit 1ebfc82
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 58 deletions.
2 changes: 2 additions & 0 deletions HKReflect.Fody/Extensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ internal static void ParallelForEach<T>(this IEnumerable<T> self, Action<T> acti
internal static bool IsInNamespace(this TypeReference self, string ns) =>
self.DeclaringType?.IsInNamespace(ns) ?? (self.Namespace == ns || self.Namespace.StartsWith(ns + '.'));

internal static bool IsHKReflectType(this TypeReference self) => self.IsInNamespace(nameof(HKReflect));


internal static GenericInstanceMethod MakeGenericMethod(this MethodReference self, params TypeReference[] arguments) {
if (arguments.Length != self.GenericParameters.Count) {
Expand Down
2 changes: 1 addition & 1 deletion HKReflect.Fody/FieldProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace HKReflect.Fody;

public sealed partial class ModuleWeaver {
private void ProcessField(FieldDefinition fieldDef) {
if (fieldDef.FieldType.IsInNamespace(nameof(HKReflect))) {
if (fieldDef.FieldType.IsHKReflectType()) {
throw new WeavingException(fieldDef.FullName + " contains reflected type, please use original type instead");
}
}
Expand Down
61 changes: 22 additions & 39 deletions HKReflect.Fody/InstructionProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,53 +9,42 @@
namespace HKReflect.Fody;

public sealed partial class ModuleWeaver {
private void ProcessInstruction(Instruction inst, MethodDefinition methodDef, MethodBody body, Instruction[] branchInsts) {
if (inst.Operand is TypeReference typeRef && typeRef.IsInNamespace(nameof(HKReflect))) {
private void ProcessInstruction(
Instruction inst,
MethodDefinition methodDef,
TypeDefinition typeDef,
Instruction[] branchInsts
) {
if (inst.Operand is TypeReference typeRef && typeRef.IsHKReflectType()) {
throw new WeavingException(methodDef.FullName + " contains typeof on HKReflect types");
} else if (inst.Operand is FieldReference fieldRef && fieldRef.DeclaringType.IsInNamespace(nameof(HKReflect))) {
if (fieldRef.DeclaringType.IsInNamespace(nameof(HKReflect))) {
if (fieldRef.DeclaringType.IsInNamespace("HKReflect.Static")) {
ProcessFieldAccessStatic(inst, fieldRef);
} else {
ProcessFieldAccess(inst, fieldRef, body, branchInsts);
}
}
} else if (inst.Operand is MethodReference methodRef) {
} else if (inst.Operand is FieldReference fieldRef && fieldRef.DeclaringType.IsHKReflectType()) {
ProcessFieldAccess(inst, fieldRef, methodDef, typeDef, branchInsts);
} else if (inst.Operand is MethodReference methodRef && methodRef.DeclaringType.IsHKReflectType()) {
if (methodRef.DeclaringType.FullName == "HKReflect.Reflector") {
Reroute(branchInsts, inst, inst.Next);
body.GetILProcessor().Remove(inst);
methodDef.Body.GetILProcessor().Remove(inst);
} else if (methodRef.DeclaringType.FullName == "HKReflect.Singletons") {
Instruction accessInst = CreateSingletonInstanceGetInstruction(methodRef.ReturnType);
inst.OpCode = accessInst.OpCode;
inst.Operand = accessInst.Operand;
} else if (methodRef.DeclaringType.IsInNamespace(nameof(HKReflect))) {
if (methodRef.DeclaringType.IsInNamespace("HKReflect.Static")) {
ProcessMethodCallStatic(inst, methodRef);
} else {
ProcessMethodCall(inst, methodRef);
}
} else {
ProcessMethodCall(inst, methodRef);
}
}
}

private void ProcessFieldAccessInternal(Instruction inst, FieldReference fieldRef) =>
inst.Operand = ModuleDefinition.ImportReference(fieldRef);

private void ProcessFieldAccess(Instruction inst, FieldReference fieldRef, MethodBody body, Instruction[] branchInsts) {
TypeDefinition origType = FindOrigType(fieldRef.DeclaringType);

if (origType.FullName == "PlayerData" && inst.OpCode.Code is Code.Ldfld or Code.Stfld) {
ProcessFieldAccessPlayerData(inst, origType, fieldRef, body, branchInsts);
private void ProcessFieldAccess(Instruction inst, FieldReference fieldRef, MethodDefinition methodDef, TypeDefinition typeDef, Instruction[] branchInsts) {
if (fieldRef.DeclaringType.FullName == "HKReflect.PlayerData" && inst.OpCode.Code is Code.Ldfld or Code.Stfld) {
ProcessFieldAccessPlayerData(inst, fieldRef, methodDef, typeDef, branchInsts);
return;
}

ProcessFieldAccessInternal(inst, FindOrigField(origType, fieldRef));
inst.Operand = ModuleDefinition.ImportReference(FindOrigField(fieldRef.DeclaringType, fieldRef));
}

private void ProcessFieldAccessStatic(Instruction inst, FieldReference fieldRef) =>
ProcessFieldAccessInternal(inst, FindOrigField(FindOrigTypeStatic(fieldRef.DeclaringType), fieldRef));
private void ProcessFieldAccessPlayerData(Instruction inst, FieldReference fieldRef, MethodDefinition methodDef, TypeDefinition typeDef, Instruction[] branchInsts) {
TypeDefinition pdType = FindTypeDefinition("PlayerData");

private void ProcessFieldAccessPlayerData(Instruction inst, TypeDefinition pdType, FieldReference fieldRef, MethodBody body, Instruction[] branchInsts) {
MethodReference methodRef = inst.OpCode.Code switch {
Code.Ldfld => fieldRef.FieldType.FullName switch {
"System.Boolean" => pdType.Methods.First(method => method.Name == "GetBool"),
Expand All @@ -73,25 +62,19 @@ private void ProcessFieldAccessPlayerData(Instruction inst, TypeDefinition pdTyp
_ => pdType.Methods.First(method => method.Name == "SetVariableSwappedArgs")
.MakeGenericMethod(ModuleDefinition.ImportReference(fieldRef.FieldType)),
},
Code code => throw new WeavingException($"{body.Method.FullName} contains invalid opcode {code} for accessing field {fieldRef.FullName}")
Code code => throw new WeavingException($"{methodDef.FullName} contains invalid opcode {code} for accessing field {fieldRef.FullName}")
};

var ldFldNameInst = Instruction.Create(OpCodes.Ldstr, fieldRef.Name);
body.GetILProcessor().InsertBefore(inst, ldFldNameInst);
methodDef.Body.GetILProcessor().InsertBefore(inst, ldFldNameInst);
Reroute(branchInsts, inst, ldFldNameInst);
inst.OpCode = OpCodes.Callvirt;
inst.Operand = ModuleDefinition.ImportReference(methodRef);
}


private void ProcessMethodCallInternal(Instruction inst, MethodReference methodRef) =>
inst.Operand = ModuleDefinition.ImportReference(methodRef);

private void ProcessMethodCall(Instruction inst, MethodReference methodRef) =>
ProcessMethodCallInternal(inst, FindOrigMethod(FindOrigType(methodRef.DeclaringType), methodRef));

private void ProcessMethodCallStatic(Instruction inst, MethodReference methodRef) =>
ProcessMethodCallInternal(inst, FindOrigMethod(FindOrigTypeStatic(methodRef.DeclaringType), methodRef));
inst.Operand = ModuleDefinition.ImportReference(FindOrigMethod(methodRef.DeclaringType, methodRef));


private Instruction CreateSingletonInstanceGetInstruction(TypeReference typeRef) => typeRef.FullName switch {
Expand Down
11 changes: 6 additions & 5 deletions HKReflect.Fody/MethodProcessor.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Collections.Generic;
using System.Linq;

using Fody;
Expand All @@ -9,17 +10,17 @@
namespace HKReflect.Fody;

public sealed partial class ModuleWeaver {
private void ProcessMethod(MethodDefinition methodDef) {
private void ProcessMethod(MethodDefinition methodDef, TypeDefinition typeDef) {
foreach (GenericParameter gp in methodDef.GenericParameters) {
foreach (GenericParameterConstraint constraint in gp.Constraints) {
if (constraint.ConstraintType.IsInNamespace(nameof(HKReflect))) {
if (constraint.ConstraintType.IsHKReflectType()) {
throw new WeavingException(methodDef.FullName + " contains generic constraint(s) of reflected type, please use original type instead");
}
}
}

foreach (ParameterDefinition paramDef in methodDef.Parameters) {
if (paramDef.ParameterType.IsInNamespace(nameof(HKReflect))) {
if (paramDef.ParameterType.IsHKReflectType()) {
throw new WeavingException(methodDef.FullName + " contains parameter(s) of reflected type, please use original type instead");
}
}
Expand All @@ -35,15 +36,15 @@ private void ProcessMethod(MethodDefinition methodDef) {
.ToArray();

foreach (VariableDefinition varDef in body.Variables) {
if (varDef.VariableType.IsInNamespace(nameof(HKReflect))) {
if (varDef.VariableType.IsHKReflectType()) {
varDef.VariableType = ModuleDefinition.ImportReference(FindOrigType(varDef.VariableType));
}
}

body.SimplifyMacros();

foreach (Instruction inst in body.Instructions.ToArray()) {
ProcessInstruction(inst, methodDef, body, branchInsts);
ProcessInstruction(inst, methodDef, typeDef, branchInsts);
}

body.Optimize();
Expand Down
71 changes: 60 additions & 11 deletions HKReflect.Fody/OrigFinder.cs
Original file line number Diff line number Diff line change
@@ -1,24 +1,73 @@
using System.Collections.Generic;
using System.Linq;

using Mono.Cecil;
using Mono.Cecil.Cil;

namespace HKReflect.Fody;

public sealed partial class ModuleWeaver {
private TypeDefinition FindOrigType(TypeReference typeRef) =>
FindTypeDefinition(typeRef.FullName.StripStart("HKReflect").StripStart("."));
private static readonly Dictionary<string, (Dictionary<string, FieldReference> fieldMap, Dictionary<string, MethodReference> methodMap)> origMap = new();

private TypeDefinition FindOrigTypeStatic(TypeReference typeRef) => FindTypeDefinition(
typeRef.FullName.StripStart("HKReflect.Static").StripStart(".").StripEnd("R")
);
private static void BuildOrigMap(
TypeDefinition type,
out Dictionary<string, FieldReference> fieldMap,
out Dictionary<string, MethodReference> methodMap
) {
fieldMap = new();
methodMap = new();

Instruction[] rawFieldMap = type.Methods.Single(method => method.Name == "<OrigFields>")
.Body.Instructions.ToArray();

for (int i = 0; i < rawFieldMap.Length; i += 2) {
if (rawFieldMap[i].OpCode.Code == Code.Ret) {
break;
}

fieldMap.Add((string) rawFieldMap[i].Operand, (FieldReference) rawFieldMap[i + 1].Operand);
}

Instruction[] rawMethodMap = type.Methods.Single(method => method.Name == "<OrigMethods>")
.Body.Instructions.ToArray();

for (int i = 0; i < rawMethodMap.Length; i += 2) {
if (rawMethodMap[i].OpCode.Code == Code.Ret) {
break;
}

methodMap.Add((string) rawMethodMap[i].Operand, (MethodReference) rawMethodMap[i + 1].Operand);
}
}

private void GetOrigMap(
TypeReference type,
out Dictionary<string, FieldReference> fieldMap,
out Dictionary<string, MethodReference> methodMap
) {
string fullName = type.FullName;

lock (origMap) {
if (origMap.ContainsKey(fullName)) {
(fieldMap, methodMap) = origMap[fullName];
} else {
BuildOrigMap(type as TypeDefinition ?? type.Resolve(), out fieldMap, out methodMap);
origMap[fullName] = (fieldMap, methodMap);
}
}
}

private TypeDefinition FindOrigTypeSingleton(TypeReference typeRef) => FindTypeDefinition(
typeRef.FullName.StripStart("HKReflect.Singleton").StripStart(".").StripEnd("R")
private TypeDefinition FindOrigType(TypeReference typeRef) => FindTypeDefinition(
typeRef.FullName.StripStart("HKReflect").StripStart(".")
);

private FieldDefinition FindOrigField(TypeDefinition origType, FieldReference fieldRef) =>
origType.Fields.First(fieldDef => fieldDef.Name == fieldRef.Name);
private FieldReference FindOrigField(TypeReference type, FieldReference fieldRef) {
GetOrigMap(type, out Dictionary<string, FieldReference> fieldMap, out _);
return fieldMap[fieldRef.Name];
}

private MethodDefinition FindOrigMethod(TypeDefinition origType, MethodReference methodRef) =>
origType.Methods.First(methodDef => methodDef.Name == methodRef.Name);
private MethodReference FindOrigMethod(TypeReference type, MethodReference methodRef) {
GetOrigMap(type, out _, out Dictionary<string, MethodReference> methodMap);
return methodMap[methodRef.GetElementMethod().FullName];
}
}
4 changes: 3 additions & 1 deletion HKReflect.Fody/TypeProcessor.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using System.Collections.Generic;

using Mono.Cecil;

namespace HKReflect.Fody;
Expand All @@ -10,6 +12,6 @@ private void ProcessType(TypeDefinition typeDef) {

typeDef.Fields.ForEach(ProcessField);

typeDef.Methods.ParallelForEach(ProcessMethod);
typeDef.Methods.ParallelForEach(method => ProcessMethod(method, typeDef));
}
}
41 changes: 41 additions & 0 deletions ReflectGen/InstanceClassGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.Runtime.CompilerServices;

using Mono.Cecil;
using Mono.Cecil.Cil;

namespace ReflectGen;

Expand Down Expand Up @@ -35,6 +36,30 @@ Dictionary<TypeDefinition, string> typesToAddInheritance
| (typeDef.IsNested ? TypeAttributes.NestedPublic : TypeAttributes.Public)
));

Lazy<ILProcessor> fieldMap = new(() => {
MethodDefinition fieldMapMethod = new(
"<OrigFields>",
MethodAttributes.Assembly | MethodAttributes.Static | MethodAttributes.SpecialName,
module.TypeSystem.Void
);
resTypeDef.Value.Methods.Add(fieldMapMethod);
return fieldMapMethod.Body.GetILProcessor();
});

Lazy<ILProcessor> methodMap = new(() => {
MethodDefinition methodMapMethod = new(
"<OrigMethods>",
MethodAttributes.Assembly | MethodAttributes.Static | MethodAttributes.SpecialName,
module.TypeSystem.Void
);
resTypeDef.Value.Methods.Add(methodMapMethod);
return methodMapMethod.Body.GetILProcessor();
});

foreach (FieldDefinition fieldDef in typeDef.Fields) {
if (IsCompilerGenerated(fieldDef) || !IsPubliclyAvailable(fieldDef)) {
continue;
Expand All @@ -45,6 +70,9 @@ Dictionary<TypeDefinition, string> typesToAddInheritance
FieldAttributes.Public | (fieldDef.IsStatic ? FieldAttributes.Static : FieldAttributes.CompilerControlled),
module.ImportReference(fieldDef.FieldType)
));

fieldMap.Value.Emit(OpCodes.Ldstr, fieldDef.Name);
fieldMap.Value.Emit(OpCodes.Stfld, module.ImportReference(fieldDef));
}

foreach (PropertyDefinition propDef in typeDef.Properties) {
Expand All @@ -69,6 +97,9 @@ Dictionary<TypeDefinition, string> typesToAddInheritance

resTypeDef.Value.Methods.Add(get);
resPropDef.GetMethod = get;

methodMap.Value.Emit(OpCodes.Ldstr, get.FullName);
methodMap.Value.Emit(OpCodes.Call, module.ImportReference(propDef.GetMethod));
}

if (propDef.SetMethod != null) {
Expand All @@ -84,6 +115,9 @@ Dictionary<TypeDefinition, string> typesToAddInheritance

resTypeDef.Value.Methods.Add(set);
resPropDef.SetMethod = set;

methodMap.Value.Emit(OpCodes.Ldstr, set.FullName);
methodMap.Value.Emit(OpCodes.Call, module.ImportReference(propDef.SetMethod));
}

resTypeDef.Value.Properties.Add(resPropDef);
Expand Down Expand Up @@ -111,6 +145,7 @@ Dictionary<TypeDefinition, string> typesToAddInheritance
(methodDef.Attributes & (~MethodAttributes.MemberAccessMask)) | MethodAttributes.Public,
module.ImportReference(methodDef.ReturnType)
);

foreach (ParameterDefinition origPd in methodDef.Parameters) {
ParameterDefinition pd = new(origPd.Name, origPd.Attributes, module.ImportReference(origPd.ParameterType));
foreach (CustomAttribute attr in origPd.CustomAttributes) {
Expand All @@ -121,6 +156,9 @@ Dictionary<TypeDefinition, string> typesToAddInheritance
}

resTypeDef.Value.Methods.Add(resMethodDef);

methodMap.Value.Emit(OpCodes.Ldstr, resMethodDef.FullName);
methodMap.Value.Emit(OpCodes.Call, module.ImportReference(methodDef));
}

foreach (TypeDefinition nestedType in typeDef.NestedTypes) {
Expand All @@ -132,6 +170,9 @@ Dictionary<TypeDefinition, string> typesToAddInheritance
if (resTypeDef.IsValueCreated) {
TypeDefinition resTypeDefVal = resTypeDef.Value;

fieldMap.Value.Emit(OpCodes.Ret);
methodMap.Value.Emit(OpCodes.Ret);

MethodDefinition reflectMethodDef = new("Reflect", MethodAttributes.Public | MethodAttributes.Static, resTypeDefVal);
reflectMethodDef.Parameters.Add(new("self", ParameterAttributes.None, module.ImportReference(typeDef)));
reflectMethodDef.CustomAttributes.Add(new(module.ImportReference(typeof(ExtensionAttribute).GetConstructor(Type.EmptyTypes))));
Expand Down
Loading

0 comments on commit 1ebfc82

Please sign in to comment.