Skip to content

Commit d6c9716

Browse files
committed
Support weakref for CLR types
1 parent 12c0206 commit d6c9716

File tree

12 files changed

+384
-234
lines changed

12 files changed

+384
-234
lines changed

src/embed_tests/Python.EmbeddingTest.csproj

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
<?xml version="1.0" encoding="utf-8"?>
1+
<?xml version="1.0" encoding="utf-8"?>
22
<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003" ToolsVersion="4.0">
33
<PropertyGroup>
44
<Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration>
@@ -90,6 +90,7 @@
9090
<Compile Include="pyinitialize.cs" />
9191
<Compile Include="pyrunstring.cs" />
9292
<Compile Include="References.cs" />
93+
<Compile Include="TestClass.cs" />
9394
<Compile Include="TestConverter.cs" />
9495
<Compile Include="TestCustomMarshal.cs" />
9596
<Compile Include="TestDomainReload.cs" />

src/embed_tests/TestClass.cs

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
using System;
2+
using System.Runtime.InteropServices;
3+
4+
using NUnit.Framework;
5+
6+
using Python.Runtime;
7+
8+
using PyRuntime = Python.Runtime.Runtime;
9+
10+
namespace Python.EmbeddingTest
11+
{
12+
public class TestClass
13+
{
14+
public class MyClass
15+
{
16+
}
17+
18+
[OneTimeSetUp]
19+
public void SetUp()
20+
{
21+
PythonEngine.Initialize();
22+
}
23+
24+
[OneTimeTearDown]
25+
public void Dispose()
26+
{
27+
PythonEngine.Shutdown();
28+
}
29+
30+
[Test]
31+
public void WeakRefForClrObject()
32+
{
33+
var obj = new MyClass();
34+
using (var scope = Py.CreateScope())
35+
{
36+
scope.Set("clr_obj", obj);
37+
scope.Exec(@"
38+
import weakref
39+
ref = weakref.ref(clr_obj)
40+
");
41+
using (PyObject pyobj = scope.Get("clr_obj"))
42+
{
43+
ValidateAttachedGCHandle(obj, pyobj.Handle);
44+
}
45+
}
46+
}
47+
48+
[Test]
49+
public void WeakRefForSubClass()
50+
{
51+
using (var scope = Py.CreateScope())
52+
{
53+
scope.Exec(@"
54+
from Python.EmbeddingTest import TestClass
55+
import weakref
56+
57+
class Sub(TestClass.MyClass):
58+
pass
59+
60+
obj = Sub()
61+
ref = weakref.ref(obj)
62+
");
63+
using (PyObject pyobj = scope.Get("obj"))
64+
{
65+
IntPtr op = pyobj.Handle;
66+
IntPtr type = PyRuntime.PyObject_TYPE(op);
67+
IntPtr clrHandle = Marshal.ReadIntPtr(op, ObjectOffset.magic(type));
68+
var clobj = (CLRObject)GCHandle.FromIntPtr(clrHandle).Target;
69+
Assert.IsTrue(clobj.inst is MyClass);
70+
}
71+
}
72+
}
73+
74+
private static void ValidateAttachedGCHandle(object obj, IntPtr op)
75+
{
76+
IntPtr type = PyRuntime.PyObject_TYPE(op);
77+
IntPtr clrHandle = Marshal.ReadIntPtr(op, ObjectOffset.magic(type));
78+
var clobj = (CLRObject)GCHandle.FromIntPtr(clrHandle).Target;
79+
Assert.True(ReferenceEquals(clobj.inst, obj));
80+
}
81+
}
82+
}

src/runtime/classbase.cs

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
using System;
22
using System.Collections;
3-
using System.Diagnostics;
43
using System.Runtime.InteropServices;
5-
using System.Runtime.Serialization;
64

75
namespace Python.Runtime
86
{
@@ -288,44 +286,40 @@ public static IntPtr tp_repr(IntPtr ob)
288286
public static void tp_dealloc(IntPtr ob)
289287
{
290288
ManagedType self = GetManagedObject(ob);
289+
if (Runtime.PyType_SUPPORTS_WEAKREFS(Runtime.PyObject_TYPE(ob)))
290+
{
291+
Runtime.PyObject_ClearWeakRefs(ob);
292+
}
291293
tp_clear(ob);
292-
Runtime.PyObject_GC_UnTrack(self.pyHandle);
293-
Runtime.PyObject_GC_Del(self.pyHandle);
294+
Runtime.PyObject_GC_UnTrack(ob);
295+
Runtime.PyObject_GC_Del(ob);
294296
self.FreeGCHandle();
295297
}
296298

297299
public static int tp_clear(IntPtr ob)
298300
{
299301
ManagedType self = GetManagedObject(ob);
300-
if (!self.IsTypeObject())
301-
{
302-
ClearObjectDict(ob);
303-
}
304-
self.tpHandle = IntPtr.Zero;
302+
ClearObjectDict(ob);
303+
Runtime.Py_CLEAR(ref self.tpHandle);
305304
return 0;
306305
}
307306

308307
protected override void OnSave(InterDomainContext context)
309308
{
310309
base.OnSave(context);
311-
if (pyHandle != tpHandle)
312-
{
313-
IntPtr dict = GetObjectDict(pyHandle);
314-
Runtime.XIncref(dict);
315-
context.Storage.AddValue("dict", dict);
316-
}
310+
IntPtr dict = GetObjectDict(pyHandle);
311+
Runtime.XIncref(dict);
312+
Runtime.XIncref(tpHandle);
313+
context.Storage.AddValue("dict", dict);
317314
}
318315

319316
protected override void OnLoad(InterDomainContext context)
320317
{
321318
base.OnLoad(context);
322-
if (pyHandle != tpHandle)
323-
{
324-
IntPtr dict = context.Storage.GetValue<IntPtr>("dict");
325-
SetObjectDict(pyHandle, dict);
326-
}
319+
IntPtr dict = context.Storage.GetValue<IntPtr>("dict");
320+
SetObjectDict(pyHandle, dict);
327321
gcHandle = AllocGCHandle();
328-
Marshal.WriteIntPtr(pyHandle, TypeOffset.magic(), (IntPtr)gcHandle);
322+
Marshal.WriteIntPtr(pyHandle, ObjectOffset.magic(tpHandle), (IntPtr)gcHandle);
329323
}
330324
}
331325
}

src/runtime/classmanager.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ private static ClassInfo GetClassInfo(Type type)
443443
}
444444
// Note the given instance might be uninitialized
445445
ob = GetClass(tp);
446+
ob.IncrRefCount();
446447
ci.members[mi.Name] = ob;
447448
continue;
448449
}

src/runtime/clrobject.cs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,6 @@ internal CLRObject(object ob, IntPtr tp)
2929
tpHandle = tp;
3030
pyHandle = py;
3131
inst = ob;
32-
33-
// Fix the BaseException args (and __cause__ in case of Python 3)
34-
// slot if wrapping a CLR exception
35-
Exceptions.SetArgsAndCause(py);
3632
}
3733

3834
protected CLRObject()
@@ -48,7 +44,7 @@ static CLRObject GetInstance(object ob, IntPtr pyType)
4844
static CLRObject GetInstance(object ob)
4945
{
5046
ClassBase cc = ClassManager.GetClass(ob.GetType());
51-
return GetInstance(ob, cc.tpHandle);
47+
return GetInstance(ob, cc.pyHandle);
5248
}
5349

5450

@@ -62,7 +58,7 @@ internal static IntPtr GetInstHandle(object ob, IntPtr pyType)
6258
internal static IntPtr GetInstHandle(object ob, Type type)
6359
{
6460
ClassBase cc = ClassManager.GetClass(type);
65-
CLRObject co = GetInstance(ob, cc.tpHandle);
61+
CLRObject co = GetInstance(ob, cc.pyHandle);
6662
return co.pyHandle;
6763
}
6864

src/runtime/exceptions.cs

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,13 @@ internal static Exception ToException(IntPtr ob)
8282
}
8383
return Runtime.PyUnicode_FromString(message);
8484
}
85+
86+
public static int tp_init(IntPtr ob, IntPtr args, IntPtr kwds)
87+
{
88+
Exceptions.SetArgsAndCause(ob);
89+
return 0;
90+
}
91+
8592
}
8693

8794
/// <summary>
@@ -177,15 +184,23 @@ internal static void SetArgsAndCause(IntPtr ob)
177184
args = Runtime.PyTuple_New(0);
178185
}
179186

180-
Marshal.WriteIntPtr(ob, ExceptionOffset.args, args);
181-
187+
int baseOffset = OriginalObjectOffsets.Size;
188+
Runtime.Py_SETREF(ob, baseOffset + ExceptionOffset.args, args);
189+
182190
if (e.InnerException != null)
183191
{
184-
IntPtr cause = CLRObject.GetInstHandle(e.InnerException);
185-
Marshal.WriteIntPtr(ob, ExceptionOffset.cause, cause);
192+
IntPtr cause = GetExceptHandle(e.InnerException);
193+
Runtime.Py_SETREF(ob, baseOffset + ExceptionOffset.cause, cause);
186194
}
187195
}
188196

197+
internal static IntPtr GetExceptHandle(Exception e)
198+
{
199+
IntPtr op = CLRObject.GetInstHandle(e);
200+
SetArgsAndCause(op);
201+
return op;
202+
}
203+
189204
/// <summary>
190205
/// Shortcut for (pointer == NULL) -&gt; throw PythonException
191206
/// </summary>
@@ -283,7 +298,7 @@ public static void SetError(Exception e)
283298
return;
284299
}
285300

286-
IntPtr op = CLRObject.GetInstHandle(e);
301+
IntPtr op = GetExceptHandle(e);
287302
IntPtr etype = Runtime.PyObject_GetAttrString(op, "__class__");
288303
Runtime.PyErr_SetObject(new BorrowedReference(etype), new BorrowedReference(op));
289304
Runtime.XDecref(etype);

0 commit comments

Comments
 (0)