Skip to content

Commit 7e87320

Browse files
committed
Add Python object reference tracking
This feature adds the capability to track Python objects throughout their lifecycle in memray. The implementation introduces a reference tracking system that records PyObject creation and destruction events, storing them as TrackedObject records in the memray capture file with associated stack traces. The Tracker class now accepts a reference_tracking flag to enable this functionality and provides a get_surviving_objects() method to retrieve live objects at the end of tracking. Object tracking works across threads and supports native, Python, and hybrid stack traces, providing insights into Python object lifetimes. This enhancement extends memray beyond memory allocation tracking to give developers a powerful tool for understanding object lifecycles and identifying potential reference leaks. With this addition, memray can now correlate memory allocations with the Python objects that trigger them, enabling more precise debugging of memory usage patterns. The feature is particularly useful for identifying objects that persist longer than expected or detecting unexpected object creation patterns. By tracking both memory allocations and object references, memray provides a more complete picture of a program's memory behavior, helping developers optimize memory usage and resolve leaks more effectively. This implementation leverages the new Python reference tracing API that was introduced in CPython 3.13. Signed-off-by: Pablo Galindo <pablogsal@gmail.com>
1 parent 864fc2d commit 7e87320

20 files changed

+1124
-7
lines changed

news/752.feature.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add object tracking capabilities to memray

src/memray/_memray.pyi

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,35 @@ class AllocationRecord:
6565
def __lt__(self, other: Any) -> Any: ...
6666
def __ne__(self, other: Any) -> Any: ...
6767

68+
class TrackedObjectRecord:
69+
@property
70+
def tid(self) -> int: ...
71+
@property
72+
def address(self) -> int: ...
73+
@property
74+
def is_created(self) -> bool: ...
75+
@property
76+
def frame_index(self) -> int: ...
77+
@property
78+
def native_frame_id(self) -> int: ...
79+
@property
80+
def native_segment_generation(self) -> int: ...
81+
def toPythonObject(self) -> Any: ...
82+
def hybrid_stack_trace(
83+
self,
84+
max_stacks: Optional[int] = None,
85+
) -> List[Union[PythonStackElement, NativeStackElement]]: ...
86+
def native_stack_trace(
87+
self, max_stacks: Optional[int] = None
88+
) -> List[NativeStackElement]: ...
89+
def stack_trace(
90+
self, max_stacks: Optional[int] = None
91+
) -> List[PythonStackElement]: ...
92+
def __eq__(self, other: Any) -> Any: ...
93+
def __hash__(self) -> Any: ...
94+
def __repr__(self) -> str: ...
95+
def __str__(self) -> str: ...
96+
6897
class Interval:
6998
def __init__(
7099
self,
@@ -169,6 +198,9 @@ class FileReader:
169198
@property
170199
def closed(self) -> bool: ...
171200
def close(self) -> None: ...
201+
def get_tracked_objects(
202+
self, filter_objs: Optional[Iterable[Any]] = None
203+
) -> Iterable[TrackedObjectRecord]: ...
172204

173205
def compute_statistics(
174206
file_name: Union[str, Path],
@@ -212,6 +244,7 @@ class Tracker:
212244
follow_fork: bool = ...,
213245
trace_python_allocators: bool = ...,
214246
file_format: FileFormat = ...,
247+
reference_tracking: bool = ...,
215248
) -> None: ...
216249
@overload
217250
def __init__(
@@ -223,6 +256,7 @@ class Tracker:
223256
follow_fork: bool = ...,
224257
trace_python_allocators: bool = ...,
225258
file_format: FileFormat = ...,
259+
reference_tracking: bool = ...,
226260
) -> None: ...
227261
def __enter__(self) -> Any: ...
228262
def __exit__(

src/memray/_memray.pyx

Lines changed: 190 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ from _memray.records cimport Allocation as _Allocation
3232
from _memray.records cimport FileFormat as _FileFormat
3333
from _memray.records cimport MemoryRecord
3434
from _memray.records cimport MemorySnapshot as _MemorySnapshot
35+
from _memray.records cimport TrackedObject
3536
from _memray.sink cimport FileSink
3637
from _memray.sink cimport NullSink
3738
from _memray.sink cimport Sink
@@ -56,7 +57,9 @@ from _memray.source cimport SocketSource
5657
from _memray.tracking_api cimport Tracker as NativeTracker
5758
from _memray.tracking_api cimport install_trace_function
5859
from _memray.tracking_api cimport set_up_pthread_fork_handlers
60+
from cpython cimport Py_DECREF
5961
from cpython cimport PyErr_CheckSignals
62+
from cpython cimport PyObject
6063
from libc.math cimport ceil
6164
from libc.stdint cimport uint64_t
6265
from libcpp cimport bool
@@ -67,6 +70,7 @@ from libcpp.memory cimport shared_ptr
6770
from libcpp.memory cimport unique_ptr
6871
from libcpp.string cimport string as cppstring
6972
from libcpp.unordered_map cimport unordered_map
73+
from libcpp.unordered_set cimport unordered_set
7074
from libcpp.utility cimport move
7175
from libcpp.vector cimport vector
7276

@@ -362,6 +366,100 @@ cdef class AllocationRecord:
362366
f"allocations={self.n_allocations}>")
363367

364368

369+
@cython.freelist(1024)
370+
cdef class TrackedObjectRecord:
371+
cdef object _tuple
372+
cdef dict _stack_trace_cache
373+
cdef shared_ptr[RecordReader] _reader
374+
375+
def __init__(self, record):
376+
self._tuple = record
377+
self._stack_trace_cache = {}
378+
379+
def __eq__(self, other):
380+
cdef TrackedObjectRecord _other
381+
if isinstance(other, TrackedObjectRecord):
382+
_other = other
383+
return self._tuple == _other._tuple
384+
return NotImplemented
385+
386+
def __hash__(self):
387+
return hash(self._tuple)
388+
389+
@property
390+
def tid(self):
391+
return self._tuple[0]
392+
393+
@property
394+
def address(self):
395+
return self._tuple[1]
396+
397+
@property
398+
def is_created(self):
399+
return self._tuple[2]
400+
401+
@property
402+
def stack_id(self):
403+
return self._tuple[3]
404+
405+
@property
406+
def native_stack_id(self):
407+
return self._tuple[4]
408+
409+
@property
410+
def native_segment_generation(self):
411+
return self._tuple[5]
412+
413+
@property
414+
def thread_name(self):
415+
if self.tid == -1:
416+
return "merged thread"
417+
assert self._reader.get() != NULL, "Cannot get thread name without reader."
418+
return self._reader.get().getThreadName(self.tid)
419+
420+
def stack_trace(self, max_stacks=None):
421+
cache_key = ("python", max_stacks)
422+
if cache_key not in self._stack_trace_cache:
423+
self._stack_trace_cache[cache_key] = stack_trace(
424+
self._reader.get(),
425+
self.tid,
426+
None, # No allocator for objects
427+
self.stack_id,
428+
max_stacks,
429+
)
430+
return self._stack_trace_cache[cache_key]
431+
432+
def native_stack_trace(self, max_stacks=None):
433+
cache_key = ("native", max_stacks)
434+
if cache_key not in self._stack_trace_cache:
435+
self._stack_trace_cache[cache_key] = native_stack_trace(
436+
self._reader.get(),
437+
None, # No allocator for objects
438+
self.native_stack_id,
439+
self.native_segment_generation,
440+
max_stacks,
441+
)
442+
return self._stack_trace_cache[cache_key]
443+
444+
def hybrid_stack_trace(self, max_stacks=None):
445+
cache_key = ("hybrid", max_stacks)
446+
if cache_key not in self._stack_trace_cache:
447+
self._stack_trace_cache[cache_key] = hybrid_stack_trace(
448+
self._reader.get(),
449+
self.tid,
450+
None, # No allocator for objects
451+
self.stack_id,
452+
self.native_stack_id,
453+
self.native_segment_generation,
454+
max_stacks,
455+
)
456+
return self._stack_trace_cache[cache_key]
457+
458+
def __repr__(self):
459+
return (f"TrackedObjectRecord<tid={hex(self.tid)}, address={hex(self.address)}, "
460+
f"is_created={self.is_created}, stack_id={self.stack_id}>")
461+
462+
365463
@cython.freelist(1024)
366464
cdef class Interval:
367465
cdef public size_t allocated_before_snapshot
@@ -628,12 +726,14 @@ cdef class Tracker:
628726
of supported file formats and their limitations.
629727
"""
630728
cdef bool _native_traces
729+
cdef bool _reference_tracking
631730
cdef unsigned int _memory_interval_ms
632731
cdef bool _follow_fork
633732
cdef bool _trace_python_allocators
634733
cdef object _previous_profile_func
635734
cdef object _previous_thread_profile_func
636735
cdef unique_ptr[RecordWriter] _writer
736+
cdef object _surviving_objects
637737

638738
cdef unique_ptr[Sink] _make_writer(self, destination) except*:
639739
# Creating a Sink can raise Python exceptions (if is interrupted by signal
@@ -658,12 +758,20 @@ cdef class Tracker:
658758
def __cinit__(self, object file_name=None, *, object destination=None,
659759
bool native_traces=False, unsigned int memory_interval_ms = 10,
660760
bool follow_fork=False, bool trace_python_allocators=False,
661-
FileFormat file_format=FileFormat.ALL_ALLOCATIONS):
761+
reference_tracking=False, FileFormat file_format=FileFormat.ALL_ALLOCATIONS):
662762
if (file_name, destination).count(None) != 1:
663763
raise TypeError("Exactly one of 'file_name' or 'destination' argument must be specified")
664764

765+
# Check Python version if reference tracking is enabled
766+
if reference_tracking and sys.version_info < (3, 13):
767+
raise RuntimeError(
768+
"Python object reference tracking requires Python 3.13 or later. "
769+
f"Current version: {sys.version_info.major}.{sys.version_info.minor}"
770+
)
771+
665772
cdef cppstring command_line = " ".join(sys.argv)
666773
self._native_traces = native_traces
774+
self._reference_tracking = reference_tracking
667775
self._memory_interval_ms = memory_interval_ms
668776
self._follow_fork = follow_fork
669777
self._trace_python_allocators = trace_python_allocators
@@ -714,17 +822,21 @@ cdef class Tracker:
714822
if "greenlet" in sys.modules:
715823
NativeTracker.beginTrackingGreenlets()
716824

825+
self._surviving_objects = []
826+
717827
NativeTracker.createTracker(
718828
move(writer),
719829
self._native_traces,
720830
self._memory_interval_ms,
721831
self._follow_fork,
722832
self._trace_python_allocators,
833+
self._reference_tracking,
723834
)
724835
return self
725836

726837
@cython.profile(False)
727838
def __exit__(self, exc_type, exc_value, exc_traceback):
839+
self._populate_suriving_objects()
728840
with tracker_creation_lock:
729841
NativeTracker.destroyTracker()
730842
sys.setprofile(self._previous_profile_func)
@@ -733,6 +845,27 @@ cdef class Tracker:
733845
for attr in ("_name", "_ident"):
734846
delattr(threading.Thread, attr)
735847

848+
cdef void _populate_suriving_objects(self):
849+
assert NativeTracker.getTracker() != NULL
850+
cdef unordered_set[PyObject*] objects = NativeTracker.getTracker().getSurvivingObjects()
851+
for obj in objects:
852+
pass
853+
self._surviving_objects.append(<object>obj)
854+
Py_DECREF(<object>obj)
855+
856+
def get_surviving_objects(self):
857+
"""Get a list of objects that were alive at the end of the tracking period.
858+
859+
Returns:
860+
list: A list of objects that were alive at the end of the tracking period.
861+
"""
862+
if sys.version_info < (3, 13):
863+
raise RuntimeError(
864+
"Python object reference tracking requires Python 3.13 or later. "
865+
f"Current version: {sys.version_info.major}.{sys.version_info.minor}"
866+
)
867+
return self._surviving_objects
868+
736869

737870
def start_thread_trace(frame, event, arg):
738871
if event in {"call", "c_call"}:
@@ -946,6 +1079,8 @@ cdef class FileReader:
9461079
)
9471080
elif ret == RecordResult.RecordResultMemorySnapshot:
9481081
self._memory_snapshots.push_back(reader.getLatestMemorySnapshot())
1082+
elif ret == RecordResult.RecordResultObjectRecord:
1083+
pass
9491084
else:
9501085
break
9511086

@@ -1239,6 +1374,58 @@ cdef class FileReader:
12391374
for record in self._memory_snapshots:
12401375
yield MemorySnapshot(record.ms_since_epoch, record.rss, record.heap)
12411376

1377+
def get_tracked_objects(self, filter_objs=None):
1378+
"""Return all tracked objects from the file.
1379+
1380+
This method yields TrackedObjectRecord instances representing Python objects
1381+
that were tracked during the recording session.
1382+
1383+
Args:
1384+
filter_objs (list, optional): A list of objects to filter the results by.
1385+
1386+
Returns:
1387+
An iterator of TrackedObjectRecord instances.
1388+
"""
1389+
self._ensure_not_closed()
1390+
1391+
# Check Python version
1392+
if sys.version_info < (3, 13):
1393+
raise RuntimeError(
1394+
"Python object reference tracking requires Python 3.13 or later. "
1395+
f"Current version: {sys.version_info.major}.{sys.version_info.minor}"
1396+
)
1397+
1398+
object_ids = set()
1399+
if filter_objs is not None:
1400+
object_ids = {id(obj) for obj in filter_objs}
1401+
1402+
cdef shared_ptr[RecordReader] reader_sp = make_shared[RecordReader](
1403+
unique_ptr[FileSource](new FileSource(self._path))
1404+
)
1405+
cdef RecordReader* reader = reader_sp.get()
1406+
1407+
cdef TrackedObject tracked_object
1408+
while True:
1409+
PyErr_CheckSignals()
1410+
ret = reader.nextRecord()
1411+
if ret == RecordResult.RecordResultObjectRecord:
1412+
tracked_object = reader.getLatestObject()
1413+
if filter_objs is not None and tracked_object.address not in object_ids:
1414+
continue
1415+
object_record = TrackedObjectRecord(tracked_object.toPythonObject())
1416+
object_record._reader = reader_sp
1417+
yield object_record
1418+
elif ret == RecordResult.RecordResultAllocationRecord:
1419+
pass
1420+
elif ret == RecordResult.RecordResultMemoryRecord:
1421+
pass
1422+
elif ret == RecordResult.RecordResultMemorySnapshot:
1423+
pass
1424+
else:
1425+
break
1426+
1427+
reader.close()
1428+
12421429
@property
12431430
def metadata(self):
12441431
return _create_metadata(self._header, self._high_watermark.peak_memory)
@@ -1283,6 +1470,8 @@ def compute_statistics(
12831470
pass
12841471
elif ret == RecordResult.RecordResultMemorySnapshot:
12851472
pass
1473+
elif ret == RecordResult.RecordResultObjectRecord:
1474+
pass
12861475
else:
12871476
assert ret != RecordResult.RecordResultMemorySnapshot
12881477
assert ret != RecordResult.RecordResultAggregatedAllocationRecord

src/memray/_memray/compat.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,4 +90,16 @@ threadStateGetInterpreter(PyThreadState* tstate)
9090
void
9191
setprofileAllThreads(Py_tracefunc func, PyObject* arg);
9292

93+
typedef int (*refTracer)(PyObject*, int event, void* data);
94+
95+
inline int
96+
refTracerSetTracer(refTracer tracer, void* data)
97+
{
98+
#if PY_VERSION_HEX >= 0x030D0000
99+
return PyRefTracer_SetTracer(reinterpret_cast<PyRefTracer>(tracer), data);
100+
#else
101+
return 0;
102+
#endif
103+
}
104+
93105
} // namespace memray::compat

src/memray/_memray/hooks.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,4 +486,11 @@ PyGILState_Ensure() noexcept
486486
return ret;
487487
}
488488

489+
int
490+
pyreftracer(PyObject* obj, int event, void* data) noexcept
491+
{
492+
tracking_api::Tracker::trackObject(obj, event);
493+
return 0;
494+
}
495+
489496
} // namespace memray::intercept

0 commit comments

Comments
 (0)