Skip to content

Commit 715642c

Browse files
authored
feat: support generic network behaviors (MirageNet#574)
It is now possible to use generic NetworkBehaviors. The following example now works fine: ```cs public class Pepe<T> : NetworkBehavior { [SyncVar] public int someVariable; [ClientRpc] public void SomeRpc(string something) { } } class MyBehavior: Pepe<int> {} ``` Note that as of this PR, the synccvar or rpc cannot be generic.
1 parent 4cbf2b4 commit 715642c

20 files changed

+1511
-44
lines changed

Assets/Mirage/Weaver/Extensions.cs

+10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Linq;
33
using Mono.Cecil;
4+
using Mono.Collections.Generic;
45

56
namespace Mirage.Weaver
67
{
@@ -203,5 +204,14 @@ public static T GetField<T>(this CustomAttribute ca, string field, T defaultValu
203204
return defaultValue;
204205
}
205206

207+
public static FieldReference MakeHostGenericIfNeeded(this FieldReference fd)
208+
{
209+
if (fd.DeclaringType.HasGenericParameters)
210+
{
211+
return new FieldReference(fd.Name, fd.FieldType, fd.DeclaringType.Resolve().ConvertToGenericIfNeeded());
212+
}
213+
214+
return fd;
215+
}
206216
}
207217
}

Assets/Mirage/Weaver/MethodExtensions.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,4 @@ public static SequencePoint GetSequencePoint(this MethodDefinition method, Instr
5353
return sequencePoint;
5454
}
5555
}
56-
}
56+
}

Assets/Mirage/Weaver/Processors/ClientRpcProcessor.cs

+3-6
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,15 @@ MethodDefinition GenerateSkeleton(MethodDefinition md, MethodDefinition userCode
5151
{
5252
MethodDefinition rpc = md.DeclaringType.AddMethod(
5353
SkeletonPrefix + md.Name,
54-
MethodAttributes.Family | MethodAttributes.Static | MethodAttributes.HideBySig);
54+
MethodAttributes.Family | MethodAttributes.HideBySig);
5555

56-
_ = rpc.AddParam<NetworkBehaviour>("obj");
5756
_ = rpc.AddParam<NetworkReader>("reader");
5857
_ = rpc.AddParam<INetworkConnection>("senderConnection");
5958
_ = rpc.AddParam<int>("replyId");
6059

6160
ILProcessor worker = rpc.Body.GetILProcessor();
6261

63-
// setup for reader
6462
worker.Append(worker.Create(OpCodes.Ldarg_0));
65-
worker.Append(worker.Create(OpCodes.Castclass, md.DeclaringType));
6663

6764
// NetworkConnection parameter is only required for Client.Connection
6865
Client target = clientRpcAttr.GetField("target", Client.Observers);
@@ -168,7 +165,7 @@ MethodDefinition GenerateStub(MethodDefinition md, CustomAttribute clientRpcAttr
168165
else if (target == Client.Owner)
169166
worker.Append(worker.Create(OpCodes.Ldnull));
170167

171-
worker.Append(worker.Create(OpCodes.Ldtoken, md.DeclaringType));
168+
worker.Append(worker.Create(OpCodes.Ldtoken, md.DeclaringType.ConvertToGenericIfNeeded()));
172169
// invokerClass
173170
worker.Append(worker.Create(OpCodes.Call, () => Type.GetTypeFromHandle(default)));
174171
worker.Append(worker.Create(OpCodes.Ldstr, rpcName));
@@ -237,7 +234,7 @@ public void RegisterClientRpcs(ILProcessor cctorWorker)
237234
*/
238235
void GenerateRegisterRemoteDelegate(ILProcessor worker, MethodDefinition func, string cmdName)
239236
{
240-
TypeDefinition netBehaviourSubclass = func.DeclaringType;
237+
TypeReference netBehaviourSubclass = func.DeclaringType.ConvertToGenericIfNeeded();
241238
worker.Append(worker.Create(OpCodes.Ldtoken, netBehaviourSubclass));
242239
worker.Append(worker.Create(OpCodes.Call, () => Type.GetTypeFromHandle(default)));
243240
worker.Append(worker.Create(OpCodes.Ldstr, cmdName));
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using System.Collections.Generic;
1+
using System.Collections.Generic;
22
using Mono.Cecil;
33

44
namespace Mirage.Weaver
@@ -7,9 +7,9 @@ internal class FieldReferenceComparator : IEqualityComparer<FieldReference>
77
{
88
public bool Equals(FieldReference x, FieldReference y)
99
{
10-
return x.FullName == y.FullName;
10+
return x.DeclaringType.FullName == y.DeclaringType.FullName && x.Name == y.Name;
1111
}
1212

13-
public int GetHashCode(FieldReference obj) => obj.FullName.GetHashCode();
13+
public int GetHashCode(FieldReference obj) => (obj.DeclaringType.FullName + "." + obj.Name).GetHashCode();
1414
}
15-
}
15+
}

Assets/Mirage/Weaver/Processors/PropertySiteProcessor.cs

+47-9
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,19 @@ void ProcessInstructionSetterField(Instruction i, FieldReference opField)
3434
// does it set a field that we replaced?
3535
if (Setters.TryGetValue(opField, out MethodDefinition replacement))
3636
{
37-
//replace with property
38-
i.OpCode = OpCodes.Call;
39-
i.Operand = replacement;
37+
if (opField.DeclaringType.IsGenericInstance || opField.DeclaringType.HasGenericParameters) // We're calling to a generic class
38+
{
39+
FieldReference newField = i.Operand as FieldReference;
40+
GenericInstanceType genericType = (GenericInstanceType)newField.DeclaringType;
41+
i.OpCode = OpCodes.Callvirt;
42+
i.Operand = replacement.MakeHostInstanceGeneric(genericType);
43+
}
44+
else
45+
{
46+
//replace with property
47+
i.OpCode = OpCodes.Call;
48+
i.Operand = replacement;
49+
}
4050
}
4151
}
4252

@@ -46,31 +56,59 @@ void ProcessInstructionGetterField(Instruction i, FieldReference opField)
4656
// does it set a field that we replaced?
4757
if (Getters.TryGetValue(opField, out MethodDefinition replacement))
4858
{
49-
//replace with property
50-
i.OpCode = OpCodes.Call;
51-
i.Operand = replacement;
59+
if (opField.DeclaringType.IsGenericInstance || opField.DeclaringType.HasGenericParameters) // We're calling to a generic class
60+
{
61+
FieldReference newField = i.Operand as FieldReference;
62+
GenericInstanceType genericType = (GenericInstanceType)newField.DeclaringType;
63+
i.OpCode = OpCodes.Callvirt;
64+
i.Operand = replacement.MakeHostInstanceGeneric(genericType);
65+
}
66+
else
67+
{
68+
//replace with property
69+
i.OpCode = OpCodes.Call;
70+
i.Operand = replacement;
71+
}
5272
}
5373
}
5474

5575
Instruction ProcessInstruction(MethodDefinition md, Instruction instr, SequencePoint sequencePoint)
5676
{
5777
if (instr.OpCode == OpCodes.Stfld && instr.Operand is FieldReference opFieldst)
5878
{
79+
FieldReference resolved = opFieldst.Resolve();
80+
if (resolved == null)
81+
{
82+
resolved = opFieldst.DeclaringType.Resolve().GetField(opFieldst.Name);
83+
}
84+
5985
// this instruction sets the value of a field. cache the field reference.
60-
ProcessInstructionSetterField(instr, opFieldst);
86+
ProcessInstructionSetterField(instr, resolved);
6187
}
6288

6389
if (instr.OpCode == OpCodes.Ldfld && instr.Operand is FieldReference opFieldld)
6490
{
91+
FieldReference resolved = opFieldld.Resolve();
92+
if (resolved == null)
93+
{
94+
resolved = opFieldld.DeclaringType.Resolve().GetField(opFieldld.Name);
95+
}
96+
6597
// this instruction gets the value of a field. cache the field reference.
66-
ProcessInstructionGetterField(instr, opFieldld);
98+
ProcessInstructionGetterField(instr, resolved);
6799
}
68100

69101
if (instr.OpCode == OpCodes.Ldflda && instr.Operand is FieldReference opFieldlda)
70102
{
103+
FieldReference resolved = opFieldlda.Resolve();
104+
if (resolved == null)
105+
{
106+
resolved = opFieldlda.DeclaringType.Resolve().GetField(opFieldlda.Name);
107+
}
108+
71109
// loading a field by reference, watch out for initobj instruction
72110
// see https://github.com/vis2k/Mirror/issues/696
73-
return ProcessInstructionLoadAddress(md, instr, opFieldlda);
111+
return ProcessInstructionLoadAddress(md, instr, resolved);
74112
}
75113

76114
return instr;

Assets/Mirage/Weaver/Processors/RpcProcessor.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using System.Linq;
1+
using System.Linq;
22
using System.Reflection;
33
using Cysharp.Threading.Tasks;
44
using Mirage.RemoteCalls;
@@ -338,15 +338,15 @@ public void FixRemoteCallToBaseMethod(TypeDefinition type, MethodDefinition meth
338338
calledMethod.Name == baseRemoteCallName)
339339
{
340340
TypeDefinition baseType = type.BaseType.Resolve();
341-
MethodDefinition baseMethod = baseType.GetMethodInBaseType(callName);
341+
MethodReference baseMethod = baseType.GetMethodInBaseType(callName);
342342

343343
if (baseMethod == null)
344344
{
345345
logger.Error($"Could not find base method for {callName}", method);
346346
return;
347347
}
348348

349-
if (!baseMethod.IsVirtual)
349+
if (!baseMethod.Resolve().IsVirtual)
350350
{
351351
logger.Error($"Could not find base method that was virtual {callName}", method);
352352
return;
@@ -375,4 +375,4 @@ static bool IsCallToMethod(Instruction instruction, out MethodDefinition calledM
375375
}
376376

377377
}
378-
}
378+
}

Assets/Mirage/Weaver/Processors/ServerRpcProcessor.cs

+3-5
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ MethodDefinition GenerateStub(MethodDefinition md, CustomAttribute serverRpcAttr
7979
// invoke internal send and return
8080
// load 'base.' to call the SendServerRpc function with
8181
worker.Append(worker.Create(OpCodes.Ldarg_0));
82-
worker.Append(worker.Create(OpCodes.Ldtoken, md.DeclaringType));
82+
worker.Append(worker.Create(OpCodes.Ldtoken, md.DeclaringType.ConvertToGenericIfNeeded()));
8383
// invokerClass
8484
worker.Append(worker.Create(OpCodes.Call, () => Type.GetTypeFromHandle(default)));
8585
worker.Append(worker.Create(OpCodes.Ldstr, cmdName));
@@ -146,10 +146,9 @@ private void CallSendServerRpc(MethodDefinition md, ILProcessor worker)
146146
MethodDefinition GenerateSkeleton(MethodDefinition method, MethodDefinition userCodeFunc)
147147
{
148148
MethodDefinition cmd = method.DeclaringType.AddMethod(SkeletonPrefix + method.Name,
149-
MethodAttributes.Family | MethodAttributes.Static | MethodAttributes.HideBySig,
149+
MethodAttributes.Family | MethodAttributes.HideBySig,
150150
userCodeFunc.ReturnType);
151151

152-
_ = cmd.AddParam<NetworkBehaviour>("obj");
153152
_ = cmd.AddParam<NetworkReader>("reader");
154153
_ = cmd.AddParam<INetworkConnection>("senderConnection");
155154
_ = cmd.AddParam<int>("replyId");
@@ -159,7 +158,6 @@ MethodDefinition GenerateSkeleton(MethodDefinition method, MethodDefinition user
159158

160159
// setup for reader
161160
worker.Append(worker.Create(OpCodes.Ldarg_0));
162-
worker.Append(worker.Create(OpCodes.Castclass, method.DeclaringType));
163161

164162
if (!ReadArguments(method, worker, false))
165163
return cmd;
@@ -214,7 +212,7 @@ void GenerateRegisterServerRpcDelegate(ILProcessor worker, ServerRpcMethod cmdRe
214212
bool requireAuthority = cmdResult.requireAuthority;
215213

216214
TypeDefinition netBehaviourSubclass = skeleton.DeclaringType;
217-
worker.Append(worker.Create(OpCodes.Ldtoken, netBehaviourSubclass));
215+
worker.Append(worker.Create(OpCodes.Ldtoken, netBehaviourSubclass.ConvertToGenericIfNeeded()));
218216
worker.Append(worker.Create(OpCodes.Call, () => Type.GetTypeFromHandle(default)));
219217
worker.Append(worker.Create(OpCodes.Ldstr, cmdName));
220218
worker.Append(worker.Create(OpCodes.Ldnull));

Assets/Mirage/Weaver/Processors/SyncObjectProcessor.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public void ProcessSyncObjects(TypeDefinition td)
3030
{
3131
foreach (FieldDefinition fd in td.Fields)
3232
{
33-
if (fd.FieldType.IsGenericParameter) // Just ignore all generic objects.
33+
if (fd.FieldType.IsGenericParameter || fd.ContainsGenericParameter) // Just ignore all generic objects.
3434
{
3535
continue;
3636
}

Assets/Mirage/Weaver/Processors/SyncVarProcessor.cs

+18-9
Original file line numberDiff line numberDiff line change
@@ -203,15 +203,15 @@ private void StoreField(FieldDefinition fd, ParameterDefinition valueParam, ILPr
203203
MethodReference setter = module.ImportReference(fd.FieldType.Resolve().GetMethod("set_Value"));
204204

205205
worker.Append(worker.Create(OpCodes.Ldarg_0));
206-
worker.Append(worker.Create(OpCodes.Ldflda, fd));
206+
worker.Append(worker.Create(OpCodes.Ldflda, fd.MakeHostGenericIfNeeded()));
207207
worker.Append(worker.Create(OpCodes.Ldarg, valueParam));
208208
worker.Append(worker.Create(OpCodes.Call, setter));
209209
}
210210
else
211211
{
212212
worker.Append(worker.Create(OpCodes.Ldarg_0));
213213
worker.Append(worker.Create(OpCodes.Ldarg, valueParam));
214-
worker.Append(worker.Create(OpCodes.Stfld, fd));
214+
worker.Append(worker.Create(OpCodes.Stfld, fd.MakeHostGenericIfNeeded()));
215215
}
216216
}
217217

@@ -221,7 +221,7 @@ private void LoadField(FieldDefinition fd, TypeReference originalType, ILProces
221221

222222
if (IsWrapped(fd.FieldType))
223223
{
224-
worker.Append(worker.Create(OpCodes.Ldflda, fd));
224+
worker.Append(worker.Create(OpCodes.Ldflda, fd.MakeHostGenericIfNeeded()));
225225
MethodReference getter = module.ImportReference(fd.FieldType.Resolve().GetMethod("get_Value"));
226226
worker.Append(worker.Create(OpCodes.Call, getter));
227227

@@ -236,7 +236,7 @@ private void LoadField(FieldDefinition fd, TypeReference originalType, ILProces
236236
}
237237
else
238238
{
239-
worker.Append(worker.Create(OpCodes.Ldfld, fd));
239+
worker.Append(worker.Create(OpCodes.Ldfld, fd.MakeHostGenericIfNeeded()));
240240
}
241241
}
242242

@@ -422,7 +422,16 @@ void WriteEndFunctionCall()
422422
{
423423
// only use Callvirt when not static
424424
OpCode opcode = hookMethod.IsStatic ? OpCodes.Call : OpCodes.Callvirt;
425-
worker.Append(worker.Create(opcode, hookMethod));
425+
MethodReference hookMethodReference = hookMethod;
426+
427+
if (hookMethodReference.DeclaringType.HasGenericParameters)
428+
{
429+
// we need to get the Type<T>.HookMethod so convert it to a generic<T>.
430+
var genericType = (GenericInstanceType)hookMethod.DeclaringType.ConvertToGenericIfNeeded();
431+
hookMethodReference = hookMethod.MakeHostInstanceGeneric(genericType);
432+
}
433+
434+
worker.Append(worker.Create(opcode, module.ImportReference(hookMethodReference)));
426435
}
427436
}
428437

@@ -453,7 +462,7 @@ void GenerateSerialization(TypeDefinition netBehaviourSubclass)
453462
// loc_0, this local variable is to determine if any variable was dirty
454463
VariableDefinition dirtyLocal = serialize.AddLocal<bool>();
455464

456-
MethodDefinition baseSerialize = netBehaviourSubclass.BaseType.Resolve().GetMethodInBaseType(SerializeMethodName);
465+
MethodReference baseSerialize = netBehaviourSubclass.BaseType.GetMethodInBaseType(SerializeMethodName);
457466
if (baseSerialize != null)
458467
{
459468
// base
@@ -538,7 +547,7 @@ private void WriteVariable(ILProcessor worker, ParameterDefinition writerParamet
538547
worker.Append(worker.Create(OpCodes.Ldarg, writerParameter));
539548
// this
540549
worker.Append(worker.Create(OpCodes.Ldarg_0));
541-
worker.Append(worker.Create(OpCodes.Ldfld, syncVar));
550+
worker.Append(worker.Create(OpCodes.Ldfld, syncVar.MakeHostGenericIfNeeded()));
542551
MethodReference writeFunc = writers.GetWriteFunc(syncVar.FieldType, null);
543552
if (writeFunc != null)
544553
{
@@ -574,7 +583,7 @@ void GenerateDeSerialization(TypeDefinition netBehaviourSubclass)
574583
serialize.Body.InitLocals = true;
575584
VariableDefinition dirtyBitsLocal = serialize.AddLocal<long>();
576585

577-
MethodDefinition baseDeserialize = netBehaviourSubclass.BaseType.Resolve().GetMethodInBaseType(DeserializeMethodName);
586+
MethodReference baseDeserialize = netBehaviourSubclass.BaseType.GetMethodInBaseType(DeserializeMethodName);
578587
if (baseDeserialize != null)
579588
{
580589
// base
@@ -681,7 +690,7 @@ void DeserializeField(FieldDefinition syncVar, ILProcessor serWorker, MethodDefi
681690
// reader.Read()
682691
serWorker.Append(serWorker.Create(OpCodes.Call, readFunc));
683692
// syncvar
684-
serWorker.Append(serWorker.Create(OpCodes.Stfld, syncVar));
693+
serWorker.Append(serWorker.Create(OpCodes.Stfld, syncVar.MakeHostGenericIfNeeded()));
685694

686695
if (hookMethod != null)
687696
{

0 commit comments

Comments
 (0)