@@ -86,6 +86,7 @@ typedef struct {
86
86
char * end ;
87
87
char * buf ;
88
88
_Py_hashtable_t * hashtable ;
89
+ int last_index ;
89
90
int version ;
90
91
} WFILE ;
91
92
@@ -276,37 +277,35 @@ w_ref(PyObject *v, char *flag, WFILE *p)
276
277
return 0 ; /* not writing object references */
277
278
278
279
/* if it has only one reference, it definitely isn't shared */
279
- if (Py_REFCNT (v ) == 1 )
280
+ if (Py_REFCNT (v ) == 1 ) {
280
281
return 0 ;
282
+ }
281
283
282
284
entry = _Py_HASHTABLE_GET_ENTRY (p -> hashtable , v );
283
- if (entry != NULL ) {
284
- /* write the reference index to the stream */
285
- _Py_HASHTABLE_ENTRY_READ_DATA (p -> hashtable , entry , w );
285
+ if (entry == NULL ) {
286
+ return 0 ;
287
+ }
288
+
289
+ _Py_HASHTABLE_ENTRY_READ_DATA (p -> hashtable , entry , w );
290
+ // w >= 0: index written by previous w_ref()
291
+ // w < 0 : refcnt counted by w_count_refs()
292
+ if (w == -1 ) {
293
+ // This object is used only once.
294
+ return 0 ;
295
+ }
296
+
297
+ if (w >= 0 ) {
286
298
/* we don't store "long" indices in the dict */
287
299
assert (0 <= w && w <= 0x7fffffff );
288
300
w_byte (TYPE_REF , p );
289
301
w_long (w , p );
290
302
return 1 ;
291
303
} else {
292
- size_t s = p -> hashtable -> entries ;
293
- /* we don't support long indices */
294
- if (s >= 0x7fffffff ) {
295
- PyErr_SetString (PyExc_ValueError , "too many objects" );
296
- goto err ;
297
- }
298
- w = (int )s ;
299
- Py_INCREF (v );
300
- if (_Py_HASHTABLE_SET (p -> hashtable , v , w ) < 0 ) {
301
- Py_DECREF (v );
302
- goto err ;
303
- }
304
+ w = p -> last_index ++ ;
305
+ _Py_HASHTABLE_ENTRY_WRITE_DATA (p -> hashtable , entry , w );
304
306
* flag |= FLAG_REF ;
305
307
return 0 ;
306
308
}
307
- err :
308
- p -> error = WFERR_UNMARSHALLABLE ;
309
- return 1 ;
310
309
}
311
310
312
311
static void
@@ -584,18 +583,135 @@ w_complex_object(PyObject *v, char flag, WFILE *p)
584
583
}
585
584
586
585
static int
587
- w_init_refs ( WFILE * wf , int version )
586
+ w_count_refs ( PyObject * v , WFILE * p )
588
587
{
589
- if (version >= 3 ) {
590
- wf -> hashtable = _Py_hashtable_new (sizeof (PyObject * ), sizeof (int ),
591
- _Py_hashtable_hash_ptr ,
592
- _Py_hashtable_compare_direct );
593
- if (wf -> hashtable == NULL ) {
594
- PyErr_NoMemory ();
595
- return -1 ;
588
+ if (p -> depth > MAX_MARSHAL_STACK_DEPTH ) {
589
+ PyErr_SetString (PyExc_ValueError ,
590
+ "object too deeply nested to marshal" );
591
+ goto err ;
592
+ }
593
+
594
+ if (v == NULL ||
595
+ v == Py_None ||
596
+ v == PyExc_StopIteration ||
597
+ v == Py_Ellipsis ||
598
+ v == Py_False ||
599
+ v == Py_True ) {
600
+ return 0 ;
601
+ }
602
+
603
+ /* if it has only one reference, it definitely isn't shared */
604
+ if (Py_REFCNT (v ) > 1 ) {
605
+ // Use negative number to count refs
606
+ _Py_hashtable_entry_t * entry = _Py_HASHTABLE_GET_ENTRY (p -> hashtable , v );
607
+ if (entry != NULL ) {
608
+ int w ;
609
+ _Py_HASHTABLE_ENTRY_READ_DATA (p -> hashtable , entry , w );
610
+ assert (w < 0 );
611
+ w -- ;
612
+ _Py_HASHTABLE_ENTRY_WRITE_DATA (p -> hashtable , entry , w );
613
+ return 0 ;
614
+ }
615
+ else {
616
+ size_t s = p -> hashtable -> entries ;
617
+ /* we don't support long indices */
618
+ if (s >= 0x7fffffff ) {
619
+ PyErr_SetString (PyExc_ValueError , "too many objects" );
620
+ goto err ;
621
+ }
622
+ int w = -1 ;
623
+ Py_INCREF (v );
624
+ if (_Py_HASHTABLE_SET (p -> hashtable , v , w ) < 0 ) {
625
+ Py_DECREF (v );
626
+ goto err ;
627
+ }
628
+ }
629
+ }
630
+
631
+ // These logic should be same to w_object()
632
+ p -> depth ++ ;
633
+
634
+ Py_ssize_t i , n ;
635
+ if (PyTuple_CheckExact (v )) {
636
+ n = PyTuple_Size (v );
637
+ for (i = 0 ; i < n ; i ++ ) {
638
+ w_count_refs (PyTuple_GET_ITEM (v , i ), p );
639
+ }
640
+ }
641
+ else if (PyList_CheckExact (v )) {
642
+ n = PyList_GET_SIZE (v );
643
+ for (i = 0 ; i < n ; i ++ ) {
644
+ w_count_refs (PyList_GET_ITEM (v , i ), p );
596
645
}
597
646
}
647
+ else if (PyDict_CheckExact (v )) {
648
+ PyObject * key , * value ;
649
+ i = 0 ;
650
+ while (PyDict_Next (v , & i , & key , & value )) {
651
+ w_count_refs (key , p );
652
+ w_count_refs (value , p );
653
+ }
654
+ }
655
+ else if (PyAnySet_CheckExact (v )) {
656
+ PyObject * value , * it ;
657
+
658
+ it = PyObject_GetIter (v );
659
+ if (it == NULL ) {
660
+ p -> depth -- ;
661
+ goto err ;
662
+ }
663
+ while ((value = PyIter_Next (it )) != NULL ) {
664
+ w_count_refs (value , p );
665
+ Py_DECREF (value );
666
+ }
667
+ Py_DECREF (it );
668
+ if (PyErr_Occurred ()) {
669
+ p -> depth -- ;
670
+ goto err ;
671
+ }
672
+ }
673
+ else if (PyCode_Check (v )) {
674
+ PyCodeObject * co = (PyCodeObject * )v ;
675
+ w_count_refs (co -> co_code , p );
676
+ w_count_refs (co -> co_consts , p );
677
+ w_count_refs (co -> co_names , p );
678
+ w_count_refs (co -> co_varnames , p );
679
+ w_count_refs (co -> co_freevars , p );
680
+ w_count_refs (co -> co_cellvars , p );
681
+ w_count_refs (co -> co_filename , p );
682
+ w_count_refs (co -> co_name , p );
683
+ w_count_refs (co -> co_lnotab , p );
684
+ }
685
+
686
+ p -> depth -- ;
687
+
688
+ if (p -> error == WFERR_UNMARSHALLABLE ) {
689
+ return 1 ;
690
+ }
598
691
return 0 ;
692
+
693
+ err :
694
+ p -> error = WFERR_UNMARSHALLABLE ;
695
+ return 1 ;
696
+ }
697
+
698
+ static int
699
+ w_init_refs (WFILE * wf , int version , PyObject * x )
700
+ {
701
+ if (version < 3 ) {
702
+ return 0 ;
703
+ }
704
+
705
+ wf -> hashtable = _Py_hashtable_new (sizeof (PyObject * ), sizeof (int ),
706
+ _Py_hashtable_hash_ptr ,
707
+ _Py_hashtable_compare_direct );
708
+ if (wf -> hashtable == NULL ) {
709
+ PyErr_NoMemory ();
710
+ return -1 ;
711
+ }
712
+ wf -> last_index = 0 ;
713
+
714
+ return w_count_refs (x , wf );
599
715
}
600
716
601
717
static int
@@ -645,8 +761,9 @@ PyMarshal_WriteObjectToFile(PyObject *x, FILE *fp, int version)
645
761
wf .end = wf .ptr + sizeof (buf );
646
762
wf .error = WFERR_OK ;
647
763
wf .version = version ;
648
- if (w_init_refs (& wf , version ))
764
+ if (w_init_refs (& wf , version , x )) {
649
765
return ; /* caller mush check PyErr_Occurred() */
766
+ }
650
767
w_object (x , & wf );
651
768
w_clear_refs (& wf );
652
769
w_flush (& wf );
@@ -1621,7 +1738,7 @@ PyMarshal_WriteObjectToString(PyObject *x, int version)
1621
1738
wf .end = wf .ptr + PyBytes_Size (wf .str );
1622
1739
wf .error = WFERR_OK ;
1623
1740
wf .version = version ;
1624
- if (w_init_refs (& wf , version )) {
1741
+ if (w_init_refs (& wf , version , x )) {
1625
1742
Py_DECREF (wf .str );
1626
1743
return NULL ;
1627
1744
}
0 commit comments