@@ -32,6 +32,7 @@ from _memray.records cimport Allocation as _Allocation
32
32
from _memray.records cimport FileFormat as _FileFormat
33
33
from _memray.records cimport MemoryRecord
34
34
from _memray.records cimport MemorySnapshot as _MemorySnapshot
35
+ from _memray.records cimport TrackedObject
35
36
from _memray.sink cimport FileSink
36
37
from _memray.sink cimport NullSink
37
38
from _memray.sink cimport Sink
@@ -56,7 +57,9 @@ from _memray.source cimport SocketSource
56
57
from _memray.tracking_api cimport Tracker as NativeTracker
57
58
from _memray.tracking_api cimport install_trace_function
58
59
from _memray.tracking_api cimport set_up_pthread_fork_handlers
60
+ from cpython cimport Py_DECREF
59
61
from cpython cimport PyErr_CheckSignals
62
+ from cpython cimport PyObject
60
63
from libc.math cimport ceil
61
64
from libc.stdint cimport uint64_t
62
65
from libcpp cimport bool
@@ -67,6 +70,7 @@ from libcpp.memory cimport shared_ptr
67
70
from libcpp.memory cimport unique_ptr
68
71
from libcpp.string cimport string as cppstring
69
72
from libcpp.unordered_map cimport unordered_map
73
+ from libcpp.unordered_set cimport unordered_set
70
74
from libcpp.utility cimport move
71
75
from libcpp.vector cimport vector
72
76
@@ -362,6 +366,100 @@ cdef class AllocationRecord:
362
366
f" allocations={self.n_allocations}>" )
363
367
364
368
369
+ @ cython.freelist (1024 )
370
+ cdef class TrackedObjectRecord:
371
+ cdef object _tuple
372
+ cdef dict _stack_trace_cache
373
+ cdef shared_ptr[RecordReader] _reader
374
+
375
+ def __init__ (self , record ):
376
+ self ._tuple = record
377
+ self ._stack_trace_cache = {}
378
+
379
+ def __eq__ (self , other ):
380
+ cdef TrackedObjectRecord _other
381
+ if isinstance (other, TrackedObjectRecord):
382
+ _other = other
383
+ return self ._tuple == _other._tuple
384
+ return NotImplemented
385
+
386
+ def __hash__ (self ):
387
+ return hash (self ._tuple)
388
+
389
+ @property
390
+ def tid (self ):
391
+ return self ._tuple[0 ]
392
+
393
+ @property
394
+ def address (self ):
395
+ return self ._tuple[1 ]
396
+
397
+ @property
398
+ def is_created (self ):
399
+ return self ._tuple[2 ]
400
+
401
+ @property
402
+ def stack_id (self ):
403
+ return self ._tuple[3 ]
404
+
405
+ @property
406
+ def native_stack_id (self ):
407
+ return self ._tuple[4 ]
408
+
409
+ @property
410
+ def native_segment_generation (self ):
411
+ return self ._tuple[5 ]
412
+
413
+ @property
414
+ def thread_name (self ):
415
+ if self .tid == - 1 :
416
+ return " merged thread"
417
+ assert self ._reader.get() != NULL , " Cannot get thread name without reader."
418
+ return self ._reader.get().getThreadName(self .tid)
419
+
420
+ def stack_trace (self , max_stacks = None ):
421
+ cache_key = (" python" , max_stacks)
422
+ if cache_key not in self ._stack_trace_cache:
423
+ self ._stack_trace_cache[cache_key] = stack_trace(
424
+ self ._reader.get(),
425
+ self .tid,
426
+ None , # No allocator for objects
427
+ self .stack_id,
428
+ max_stacks,
429
+ )
430
+ return self ._stack_trace_cache[cache_key]
431
+
432
+ def native_stack_trace (self , max_stacks = None ):
433
+ cache_key = (" native" , max_stacks)
434
+ if cache_key not in self ._stack_trace_cache:
435
+ self ._stack_trace_cache[cache_key] = native_stack_trace(
436
+ self ._reader.get(),
437
+ None , # No allocator for objects
438
+ self .native_stack_id,
439
+ self .native_segment_generation,
440
+ max_stacks,
441
+ )
442
+ return self ._stack_trace_cache[cache_key]
443
+
444
+ def hybrid_stack_trace (self , max_stacks = None ):
445
+ cache_key = (" hybrid" , max_stacks)
446
+ if cache_key not in self ._stack_trace_cache:
447
+ self ._stack_trace_cache[cache_key] = hybrid_stack_trace(
448
+ self ._reader.get(),
449
+ self .tid,
450
+ None , # No allocator for objects
451
+ self .stack_id,
452
+ self .native_stack_id,
453
+ self .native_segment_generation,
454
+ max_stacks,
455
+ )
456
+ return self ._stack_trace_cache[cache_key]
457
+
458
+ def __repr__ (self ):
459
+ return (f" TrackedObjectRecord<tid={hex(self.tid)}, address={hex(self.address)}, "
460
+ f" is_created={self.is_created}, stack_id={self.stack_id}>" )
461
+
462
+
365
463
@ cython.freelist (1024 )
366
464
cdef class Interval:
367
465
cdef public size_t allocated_before_snapshot
@@ -628,12 +726,14 @@ cdef class Tracker:
628
726
of supported file formats and their limitations.
629
727
"""
630
728
cdef bool _native_traces
729
+ cdef bool _reference_tracking
631
730
cdef unsigned int _memory_interval_ms
632
731
cdef bool _follow_fork
633
732
cdef bool _trace_python_allocators
634
733
cdef object _previous_profile_func
635
734
cdef object _previous_thread_profile_func
636
735
cdef unique_ptr[RecordWriter] _writer
736
+ cdef object _surviving_objects
637
737
638
738
cdef unique_ptr[Sink] _make_writer(self , destination) except * :
639
739
# Creating a Sink can raise Python exceptions (if is interrupted by signal
@@ -658,12 +758,20 @@ cdef class Tracker:
658
758
def __cinit__ (self , object file_name = None , *, object destination = None ,
659
759
bool native_traces = False , unsigned int memory_interval_ms = 10 ,
660
760
bool follow_fork = False , bool trace_python_allocators = False ,
661
- FileFormat file_format = FileFormat.ALL_ALLOCATIONS):
761
+ reference_tracking = False , FileFormat file_format = FileFormat.ALL_ALLOCATIONS):
662
762
if (file_name, destination).count(None ) != 1 :
663
763
raise TypeError (" Exactly one of 'file_name' or 'destination' argument must be specified" )
664
764
765
+ # Check Python version if reference tracking is enabled
766
+ if reference_tracking and sys.version_info < (3 , 13 ):
767
+ raise RuntimeError (
768
+ " Python object reference tracking requires Python 3.13 or later. "
769
+ f" Current version: {sys.version_info.major}.{sys.version_info.minor}"
770
+ )
771
+
665
772
cdef cppstring command_line = " " .join(sys.argv)
666
773
self ._native_traces = native_traces
774
+ self ._reference_tracking = reference_tracking
667
775
self ._memory_interval_ms = memory_interval_ms
668
776
self ._follow_fork = follow_fork
669
777
self ._trace_python_allocators = trace_python_allocators
@@ -714,17 +822,21 @@ cdef class Tracker:
714
822
if " greenlet" in sys.modules:
715
823
NativeTracker.beginTrackingGreenlets()
716
824
825
+ self ._surviving_objects = []
826
+
717
827
NativeTracker.createTracker(
718
828
move(writer),
719
829
self ._native_traces,
720
830
self ._memory_interval_ms,
721
831
self ._follow_fork,
722
832
self ._trace_python_allocators,
833
+ self ._reference_tracking,
723
834
)
724
835
return self
725
836
726
837
@ cython.profile (False )
727
838
def __exit__ (self , exc_type , exc_value , exc_traceback ):
839
+ self ._populate_suriving_objects()
728
840
with tracker_creation_lock:
729
841
NativeTracker.destroyTracker()
730
842
sys.setprofile(self ._previous_profile_func)
@@ -733,6 +845,27 @@ cdef class Tracker:
733
845
for attr in (" _name" , " _ident" ):
734
846
delattr (threading.Thread, attr)
735
847
848
+ cdef void _populate_suriving_objects(self ):
849
+ assert NativeTracker.getTracker() != NULL
850
+ cdef unordered_set[PyObject* ] objects = NativeTracker.getTracker().getSurvivingObjects()
851
+ for obj in objects:
852
+ pass
853
+ self ._surviving_objects.append(< object > obj)
854
+ Py_DECREF(< object > obj)
855
+
856
+ def get_surviving_objects (self ):
857
+ """ Get a list of objects that were alive at the end of the tracking period.
858
+
859
+ Returns:
860
+ list: A list of objects that were alive at the end of the tracking period.
861
+ """
862
+ if sys.version_info < (3 , 13 ):
863
+ raise RuntimeError (
864
+ " Python object reference tracking requires Python 3.13 or later. "
865
+ f" Current version: {sys.version_info.major}.{sys.version_info.minor}"
866
+ )
867
+ return self ._surviving_objects
868
+
736
869
737
870
def start_thread_trace (frame , event , arg ):
738
871
if event in {" call" , " c_call" }:
@@ -946,6 +1079,8 @@ cdef class FileReader:
946
1079
)
947
1080
elif ret == RecordResult.RecordResultMemorySnapshot:
948
1081
self ._memory_snapshots.push_back(reader.getLatestMemorySnapshot())
1082
+ elif ret == RecordResult.RecordResultObjectRecord:
1083
+ pass
949
1084
else :
950
1085
break
951
1086
@@ -1239,6 +1374,58 @@ cdef class FileReader:
1239
1374
for record in self ._memory_snapshots:
1240
1375
yield MemorySnapshot(record.ms_since_epoch, record.rss, record.heap)
1241
1376
1377
+ def get_tracked_objects (self , filter_objs = None ):
1378
+ """ Return all tracked objects from the file.
1379
+
1380
+ This method yields TrackedObjectRecord instances representing Python objects
1381
+ that were tracked during the recording session.
1382
+
1383
+ Args:
1384
+ filter_objs (list, optional): A list of objects to filter the results by.
1385
+
1386
+ Returns:
1387
+ An iterator of TrackedObjectRecord instances.
1388
+ """
1389
+ self ._ensure_not_closed()
1390
+
1391
+ # Check Python version
1392
+ if sys.version_info < (3 , 13 ):
1393
+ raise RuntimeError (
1394
+ " Python object reference tracking requires Python 3.13 or later. "
1395
+ f" Current version: {sys.version_info.major}.{sys.version_info.minor}"
1396
+ )
1397
+
1398
+ object_ids = set ()
1399
+ if filter_objs is not None :
1400
+ object_ids = {id (obj) for obj in filter_objs}
1401
+
1402
+ cdef shared_ptr[RecordReader] reader_sp = make_shared[RecordReader](
1403
+ unique_ptr[FileSource](new FileSource(self ._path))
1404
+ )
1405
+ cdef RecordReader* reader = reader_sp.get()
1406
+
1407
+ cdef TrackedObject tracked_object
1408
+ while True :
1409
+ PyErr_CheckSignals()
1410
+ ret = reader.nextRecord()
1411
+ if ret == RecordResult.RecordResultObjectRecord:
1412
+ tracked_object = reader.getLatestObject()
1413
+ if filter_objs is not None and tracked_object.address not in object_ids:
1414
+ continue
1415
+ object_record = TrackedObjectRecord(tracked_object.toPythonObject())
1416
+ object_record._reader = reader_sp
1417
+ yield object_record
1418
+ elif ret == RecordResult.RecordResultAllocationRecord:
1419
+ pass
1420
+ elif ret == RecordResult.RecordResultMemoryRecord:
1421
+ pass
1422
+ elif ret == RecordResult.RecordResultMemorySnapshot:
1423
+ pass
1424
+ else :
1425
+ break
1426
+
1427
+ reader.close()
1428
+
1242
1429
@property
1243
1430
def metadata (self ):
1244
1431
return _create_metadata(self ._header, self ._high_watermark.peak_memory)
@@ -1283,6 +1470,8 @@ def compute_statistics(
1283
1470
pass
1284
1471
elif ret == RecordResult.RecordResultMemorySnapshot:
1285
1472
pass
1473
+ elif ret == RecordResult.RecordResultObjectRecord:
1474
+ pass
1286
1475
else :
1287
1476
assert ret != RecordResult.RecordResultMemorySnapshot
1288
1477
assert ret != RecordResult.RecordResultAggregatedAllocationRecord
0 commit comments