@@ -2653,6 +2653,153 @@ def test_prb_update_max_priority(self, max_priority_within_buffer):
2653
2653
assert rb ._sampler ._max_priority [0 ] == 21
2654
2654
assert rb ._sampler ._max_priority [1 ] == 0
2655
2655
2656
+ def test_prb_ndim (self ):
2657
+ """This test lists all the possible ways of updating the priority of a PRB with RB, TRB and TPRB.
2658
+
2659
+ All tests are done for 1d and 2d TDs.
2660
+
2661
+ """
2662
+ torch .manual_seed (0 )
2663
+ np .random .seed (0 )
2664
+
2665
+ # first case: 1d, RB
2666
+ rb = ReplayBuffer (
2667
+ sampler = PrioritizedSampler (max_capacity = 100 , alpha = 1.0 , beta = 1.0 ),
2668
+ storage = LazyTensorStorage (100 ),
2669
+ batch_size = 4 ,
2670
+ )
2671
+ data = TensorDict ({"a" : torch .arange (10 ), "p" : torch .ones (10 ) / 2 }, [10 ])
2672
+ idx = rb .extend (data )
2673
+ assert (torch .tensor ([rb ._sampler ._sum_tree [i ] for i in range (10 )]) == 1 ).all ()
2674
+ rb .update_priority (idx , 2 )
2675
+ assert (torch .tensor ([rb ._sampler ._sum_tree [i ] for i in range (10 )]) == 2 ).all ()
2676
+ s , info = rb .sample (return_info = True )
2677
+ rb .update_priority (info ["index" ], 3 )
2678
+ assert (
2679
+ torch .tensor ([rb ._sampler ._sum_tree [i ] for i in range (10 )])[info ["index" ]]
2680
+ == 3
2681
+ ).all ()
2682
+
2683
+ # second case: 1d, TRB
2684
+ rb = TensorDictReplayBuffer (
2685
+ sampler = PrioritizedSampler (max_capacity = 100 , alpha = 1.0 , beta = 1.0 ),
2686
+ storage = LazyTensorStorage (100 ),
2687
+ batch_size = 4 ,
2688
+ )
2689
+ data = TensorDict ({"a" : torch .arange (10 ), "p" : torch .ones (10 ) / 2 }, [10 ])
2690
+ idx = rb .extend (data )
2691
+ assert (torch .tensor ([rb ._sampler ._sum_tree [i ] for i in range (10 )]) == 1 ).all ()
2692
+ rb .update_priority (idx , 2 )
2693
+ assert (torch .tensor ([rb ._sampler ._sum_tree [i ] for i in range (10 )]) == 2 ).all ()
2694
+ s = rb .sample ()
2695
+ rb .update_priority (s ["index" ], 3 )
2696
+ assert (
2697
+ torch .tensor ([rb ._sampler ._sum_tree [i ] for i in range (10 )])[s ["index" ]] == 3
2698
+ ).all ()
2699
+
2700
+ # third case: 1d TPRB
2701
+ rb = TensorDictPrioritizedReplayBuffer (
2702
+ alpha = 1.0 ,
2703
+ beta = 1.0 ,
2704
+ storage = LazyTensorStorage (100 ),
2705
+ batch_size = 4 ,
2706
+ priority_key = "p" ,
2707
+ )
2708
+ data = TensorDict ({"a" : torch .arange (10 ), "p" : torch .ones (10 ) / 2 }, [10 ])
2709
+ idx = rb .extend (data )
2710
+ assert (torch .tensor ([rb ._sampler ._sum_tree [i ] for i in range (10 )]) == 1 ).all ()
2711
+ rb .update_priority (idx , 2 )
2712
+ assert (torch .tensor ([rb ._sampler ._sum_tree [i ] for i in range (10 )]) == 2 ).all ()
2713
+ s = rb .sample ()
2714
+
2715
+ s ["p" ] = torch .ones (4 ) * 10_000
2716
+ rb .update_tensordict_priority (s )
2717
+ assert (
2718
+ torch .tensor ([rb ._sampler ._sum_tree [i ] for i in range (10 )])[s ["index" ]]
2719
+ == 10_000
2720
+ ).all ()
2721
+
2722
+ s2 = rb .sample ()
2723
+ # All indices in s2 must be from s since we set a very high priority to these items
2724
+ assert (s2 ["index" ].unsqueeze (0 ) == s ["index" ].unsqueeze (1 )).any (0 ).all ()
2725
+
2726
+ # fourth case: 2d RB
2727
+ rb = ReplayBuffer (
2728
+ sampler = PrioritizedSampler (max_capacity = 100 , alpha = 1.0 , beta = 1.0 ),
2729
+ storage = LazyTensorStorage (100 , ndim = 2 ),
2730
+ batch_size = 4 ,
2731
+ )
2732
+ data = TensorDict (
2733
+ {"a" : torch .arange (5 ).expand (2 , 5 ), "p" : torch .ones (2 , 5 ) / 2 }, [2 , 5 ]
2734
+ )
2735
+ idx = rb .extend (data )
2736
+ assert (torch .tensor ([rb ._sampler ._sum_tree [i ] for i in range (10 )]) == 1 ).all ()
2737
+ rb .update_priority (idx , 2 )
2738
+ assert (torch .tensor ([rb ._sampler ._sum_tree [i ] for i in range (10 )]) == 2 ).all ()
2739
+
2740
+ s , info = rb .sample (return_info = True )
2741
+ rb .update_priority (info ["index" ], 3 )
2742
+ priorities = torch .tensor (
2743
+ [rb ._sampler ._sum_tree [i ] for i in range (10 )]
2744
+ ).reshape ((5 , 2 ))
2745
+ assert (priorities [info ["index" ]] == 3 ).all ()
2746
+
2747
+ # fifth case: 2d TRB
2748
+ # 2d
2749
+ rb = TensorDictReplayBuffer (
2750
+ sampler = PrioritizedSampler (max_capacity = 100 , alpha = 1.0 , beta = 1.0 ),
2751
+ storage = LazyTensorStorage (100 , ndim = 2 ),
2752
+ batch_size = 4 ,
2753
+ )
2754
+ data = TensorDict (
2755
+ {"a" : torch .arange (5 ).expand (2 , 5 ), "p" : torch .ones (2 , 5 ) / 2 }, [2 , 5 ]
2756
+ )
2757
+ idx = rb .extend (data )
2758
+ assert (torch .tensor ([rb ._sampler ._sum_tree [i ] for i in range (10 )]) == 1 ).all ()
2759
+ rb .update_priority (idx , 2 )
2760
+ assert (torch .tensor ([rb ._sampler ._sum_tree [i ] for i in range (10 )]) == 2 ).all ()
2761
+
2762
+ s = rb .sample ()
2763
+ rb .update_priority (s ["index" ], 10_000 )
2764
+ priorities = torch .tensor (
2765
+ [rb ._sampler ._sum_tree [i ] for i in range (10 )]
2766
+ ).reshape ((5 , 2 ))
2767
+ assert (priorities [s ["index" ].unbind (- 1 )] == 10_000 ).all ()
2768
+
2769
+ s2 = rb .sample ()
2770
+ assert (
2771
+ (s2 ["index" ].unsqueeze (0 ) == s ["index" ].unsqueeze (1 )).all (- 1 ).any (0 ).all ()
2772
+ )
2773
+
2774
+ # Sixth case: 2d TDPRB
2775
+ rb = TensorDictPrioritizedReplayBuffer (
2776
+ alpha = 1.0 ,
2777
+ beta = 1.0 ,
2778
+ storage = LazyTensorStorage (100 , ndim = 2 ),
2779
+ batch_size = 4 ,
2780
+ priority_key = "p" ,
2781
+ )
2782
+ data = TensorDict (
2783
+ {"a" : torch .arange (5 ).expand (2 , 5 ), "p" : torch .ones (2 , 5 ) / 2 }, [2 , 5 ]
2784
+ )
2785
+ idx = rb .extend (data )
2786
+ assert (torch .tensor ([rb ._sampler ._sum_tree [i ] for i in range (10 )]) == 1 ).all ()
2787
+ rb .update_priority (idx , torch .ones (()) * 2 )
2788
+ assert (torch .tensor ([rb ._sampler ._sum_tree [i ] for i in range (10 )]) == 2 ).all ()
2789
+ s = rb .sample ()
2790
+ # setting the priorities to a value that is so big that the buffer will resample them
2791
+ s ["p" ] = torch .ones (4 ) * 10_000
2792
+ rb .update_tensordict_priority (s )
2793
+ priorities = torch .tensor (
2794
+ [rb ._sampler ._sum_tree [i ] for i in range (10 )]
2795
+ ).reshape ((5 , 2 ))
2796
+ assert (priorities [s ["index" ].unbind (- 1 )] == 10_000 ).all ()
2797
+
2798
+ s2 = rb .sample ()
2799
+ assert (
2800
+ (s2 ["index" ].unsqueeze (0 ) == s ["index" ].unsqueeze (1 )).all (- 1 ).any (0 ).all ()
2801
+ )
2802
+
2656
2803
2657
2804
def test_prioritized_slice_sampler_doc_example ():
2658
2805
sampler = PrioritizedSliceSampler (max_capacity = 9 , num_slices = 3 , alpha = 0.7 , beta = 0.9 )
0 commit comments