Skip to content

[Feature] TensorDictMap #2283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed

[Feature] TensorDictMap #2283

wants to merge 7 commits into from

Conversation

vmoens
Copy link
Collaborator

@vmoens vmoens commented Jul 9, 2024

This code is copy-pasted from pytorch/tensordict#826

Since we have the opportunity of building composite storages or storages of any type, I would be keen to remove the key_to_storage dict and just have one storage per map (a storage can have anything in it)

Credits to @mjlaali

Missing:

  • Tests
  • Docs
  • contains is broken currently
  • Dynamic storage with a default value was pretty cool, we should have something similar

Copy link

pytorch-bot bot commented Jul 9, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2283

Note: Links to docs will display an error until the docs builds have been completed.

⏳ 26 Pending, 2 Unrelated Failures

As of commit 126ca4a with merge base f764c02 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 9, 2024
@vmoens vmoens added the enhancement New feature or request label Jul 9, 2024
Copy link

github-actions bot commented Jul 9, 2024

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 91. Improved: $\large\color{#35bf28}7$. Worsened: $\large\color{#d91a1a}2$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_single 58.5684ms 56.5628ms 17.6795 Ops/s 17.0442 Ops/s $\color{#35bf28}+3.73\%$
test_sync 38.8439ms 32.8310ms 30.4590 Ops/s 31.7465 Ops/s $\color{#d91a1a}-4.06\%$
test_async 68.5401ms 29.4945ms 33.9046 Ops/s 33.0659 Ops/s $\color{#35bf28}+2.54\%$
test_simple 0.4689s 0.4017s 2.4893 Ops/s 2.4756 Ops/s $\color{#35bf28}+0.55\%$
test_transformed 0.6200s 0.5600s 1.7857 Ops/s 1.7356 Ops/s $\color{#35bf28}+2.89\%$
test_serial 1.3115s 1.2511s 0.7993 Ops/s 0.8030 Ops/s $\color{#d91a1a}-0.46\%$
test_parallel 1.1758s 1.0984s 0.9104 Ops/s 0.9214 Ops/s $\color{#d91a1a}-1.19\%$
test_step_mdp_speed[True-True-True-True-True] 0.2269ms 25.0360μs 39.9424 KOps/s 38.6759 KOps/s $\color{#35bf28}+3.27\%$
test_step_mdp_speed[True-True-True-True-False] 43.2910μs 14.6658μs 68.1857 KOps/s 66.6951 KOps/s $\color{#35bf28}+2.23\%$
test_step_mdp_speed[True-True-True-False-True] 38.1410μs 14.6204μs 68.3975 KOps/s 66.6336 KOps/s $\color{#35bf28}+2.65\%$
test_step_mdp_speed[True-True-True-False-False] 44.0050μs 8.5751μs 116.6173 KOps/s 115.1114 KOps/s $\color{#35bf28}+1.31\%$
test_step_mdp_speed[True-True-False-True-True] 74.4780μs 26.7858μs 37.3333 KOps/s 36.4264 KOps/s $\color{#35bf28}+2.49\%$
test_step_mdp_speed[True-True-False-True-False] 69.3460μs 16.0259μs 62.3990 KOps/s 59.8946 KOps/s $\color{#35bf28}+4.18\%$
test_step_mdp_speed[True-True-False-False-True] 0.1405ms 16.6038μs 60.2273 KOps/s 60.6092 KOps/s $\color{#d91a1a}-0.63\%$
test_step_mdp_speed[True-True-False-False-False] 45.4880μs 9.9879μs 100.1207 KOps/s 95.6729 KOps/s $\color{#35bf28}+4.65\%$
test_step_mdp_speed[True-False-True-True-True] 82.5830μs 28.7845μs 34.7409 KOps/s 34.2237 KOps/s $\color{#35bf28}+1.51\%$
test_step_mdp_speed[True-False-True-True-False] 69.1790μs 17.6777μs 56.5686 KOps/s 54.7673 KOps/s $\color{#35bf28}+3.29\%$
test_step_mdp_speed[True-False-True-False-True] 51.7360μs 16.0271μs 62.3943 KOps/s 60.2143 KOps/s $\color{#35bf28}+3.62\%$
test_step_mdp_speed[True-False-True-False-False] 31.6490μs 10.0484μs 99.5186 KOps/s 96.4662 KOps/s $\color{#35bf28}+3.16\%$
test_step_mdp_speed[True-False-False-True-True] 0.1465ms 30.3349μs 32.9653 KOps/s 32.5394 KOps/s $\color{#35bf28}+1.31\%$
test_step_mdp_speed[True-False-False-True-False] 65.2110μs 19.1458μs 52.2308 KOps/s 50.5217 KOps/s $\color{#35bf28}+3.38\%$
test_step_mdp_speed[True-False-False-False-True] 58.9490μs 17.8968μs 55.8760 KOps/s 54.4112 KOps/s $\color{#35bf28}+2.69\%$
test_step_mdp_speed[True-False-False-False-False] 38.0010μs 11.6025μs 86.1886 KOps/s 84.4970 KOps/s $\color{#35bf28}+2.00\%$
test_step_mdp_speed[False-True-True-True-True] 76.6430μs 28.3662μs 35.2533 KOps/s 34.3400 KOps/s $\color{#35bf28}+2.66\%$
test_step_mdp_speed[False-True-True-True-False] 64.6000μs 17.7361μs 56.3823 KOps/s 53.9827 KOps/s $\color{#35bf28}+4.45\%$
test_step_mdp_speed[False-True-True-False-True] 66.5840μs 18.8834μs 52.9564 KOps/s 52.0773 KOps/s $\color{#35bf28}+1.69\%$
test_step_mdp_speed[False-True-True-False-False] 42.8200μs 11.3623μs 88.0102 KOps/s 85.8465 KOps/s $\color{#35bf28}+2.52\%$
test_step_mdp_speed[False-True-False-True-True] 80.9610μs 29.8008μs 33.5561 KOps/s 32.6166 KOps/s $\color{#35bf28}+2.88\%$
test_step_mdp_speed[False-True-False-True-False] 67.2850μs 19.1855μs 52.1227 KOps/s 50.5730 KOps/s $\color{#35bf28}+3.06\%$
test_step_mdp_speed[False-True-False-False-True] 67.9960μs 20.1945μs 49.5184 KOps/s 49.1367 KOps/s $\color{#35bf28}+0.78\%$
test_step_mdp_speed[False-True-False-False-False] 50.9050μs 12.8499μs 77.8215 KOps/s 76.3133 KOps/s $\color{#35bf28}+1.98\%$
test_step_mdp_speed[False-False-True-True-True] 3.5318ms 31.6938μs 31.5519 KOps/s 30.4054 KOps/s $\color{#35bf28}+3.77\%$
test_step_mdp_speed[False-False-True-True-False] 57.5070μs 20.8953μs 47.8576 KOps/s 46.7734 KOps/s $\color{#35bf28}+2.32\%$
test_step_mdp_speed[False-False-True-False-True] 69.3600μs 20.0013μs 49.9967 KOps/s 48.5172 KOps/s $\color{#35bf28}+3.05\%$
test_step_mdp_speed[False-False-True-False-False] 73.6860μs 12.8919μs 77.5683 KOps/s 75.1527 KOps/s $\color{#35bf28}+3.21\%$
test_step_mdp_speed[False-False-False-True-True] 64.6000μs 32.6057μs 30.6694 KOps/s 30.1260 KOps/s $\color{#35bf28}+1.80\%$
test_step_mdp_speed[False-False-False-True-False] 60.9640μs 22.0699μs 45.3106 KOps/s 43.7790 KOps/s $\color{#35bf28}+3.50\%$
test_step_mdp_speed[False-False-False-False-True] 46.3460μs 21.2709μs 47.0125 KOps/s 44.7079 KOps/s $\textbf{\color{#35bf28}+5.15\%}$
test_step_mdp_speed[False-False-False-False-False] 36.1770μs 14.0040μs 71.4083 KOps/s 69.0110 KOps/s $\color{#35bf28}+3.47\%$
test_values[generalized_advantage_estimate-True-True] 11.7396ms 9.4386ms 105.9482 Ops/s 99.0019 Ops/s $\textbf{\color{#35bf28}+7.02\%}$
test_values[vec_generalized_advantage_estimate-True-True] 35.2560ms 33.2188ms 30.1034 Ops/s 29.6275 Ops/s $\color{#35bf28}+1.61\%$
test_values[td0_return_estimate-False-False] 0.2235ms 0.1674ms 5.9727 KOps/s 5.9688 KOps/s $\color{#35bf28}+0.07\%$
test_values[td1_return_estimate-False-False] 27.0160ms 23.2004ms 43.1028 Ops/s 42.6380 Ops/s $\color{#35bf28}+1.09\%$
test_values[vec_td1_return_estimate-False-False] 35.4179ms 33.1156ms 30.1973 Ops/s 29.4693 Ops/s $\color{#35bf28}+2.47\%$
test_values[td_lambda_return_estimate-True-False] 35.0263ms 33.2346ms 30.0891 Ops/s 28.3963 Ops/s $\textbf{\color{#35bf28}+5.96\%}$
test_values[vec_td_lambda_return_estimate-True-False] 35.3144ms 33.1683ms 30.1493 Ops/s 27.7661 Ops/s $\textbf{\color{#35bf28}+8.58\%}$
test_gae_speed[generalized_advantage_estimate-False-1-512] 8.3682ms 8.2163ms 121.7094 Ops/s 121.0127 Ops/s $\color{#35bf28}+0.58\%$
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 2.5193ms 2.0237ms 494.1489 Ops/s 520.4212 Ops/s $\textbf{\color{#d91a1a}-5.05\%}$
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] 0.4531ms 0.3505ms 2.8531 KOps/s 2.8584 KOps/s $\color{#d91a1a}-0.18\%$
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] 45.7344ms 43.3348ms 23.0761 Ops/s 22.8792 Ops/s $\color{#35bf28}+0.86\%$
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 3.9298ms 3.0265ms 330.4171 Ops/s 329.5451 Ops/s $\color{#35bf28}+0.26\%$
test_dqn_speed 6.5579ms 1.3751ms 727.2123 Ops/s 720.7747 Ops/s $\color{#35bf28}+0.89\%$
test_ddpg_speed 3.9052ms 2.8905ms 345.9585 Ops/s 345.2737 Ops/s $\color{#35bf28}+0.20\%$
test_sac_speed 9.8415ms 8.4837ms 117.8732 Ops/s 117.1821 Ops/s $\color{#35bf28}+0.59\%$
test_redq_speed 14.9508ms 13.0616ms 76.5604 Ops/s 76.6121 Ops/s $\color{#d91a1a}-0.07\%$
test_redq_deprec_speed 13.9800ms 13.1871ms 75.8316 Ops/s 75.9078 Ops/s $\color{#d91a1a}-0.10\%$
test_td3_speed 8.6227ms 8.3725ms 119.4385 Ops/s 118.5144 Ops/s $\color{#35bf28}+0.78\%$
test_cql_speed 38.8063ms 36.3204ms 27.5327 Ops/s 25.1259 Ops/s $\textbf{\color{#35bf28}+9.58\%}$
test_a2c_speed 8.3736ms 7.4161ms 134.8411 Ops/s 132.8253 Ops/s $\color{#35bf28}+1.52\%$
test_ppo_speed 9.5290ms 7.6960ms 129.9375 Ops/s 128.6905 Ops/s $\color{#35bf28}+0.97\%$
test_reinforce_speed 7.5238ms 6.6880ms 149.5214 Ops/s 150.3489 Ops/s $\color{#d91a1a}-0.55\%$
test_iql_speed 33.4011ms 32.3856ms 30.8779 Ops/s 31.1996 Ops/s $\color{#d91a1a}-1.03\%$
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 5.7986ms 3.8079ms 262.6103 Ops/s 265.8122 Ops/s $\color{#d91a1a}-1.20\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 0.6421ms 0.4723ms 2.1172 KOps/s 2.1210 KOps/s $\color{#d91a1a}-0.18\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 0.6604ms 0.4494ms 2.2252 KOps/s 2.2297 KOps/s $\color{#d91a1a}-0.20\%$
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 4.5141ms 3.7924ms 263.6833 Ops/s 269.5668 Ops/s $\color{#d91a1a}-2.18\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 0.5659ms 0.4656ms 2.1477 KOps/s 2.1343 KOps/s $\color{#35bf28}+0.63\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 0.6223ms 0.4403ms 2.2712 KOps/s 2.2448 KOps/s $\color{#35bf28}+1.17\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] 2.4530ms 1.6800ms 595.2205 Ops/s 597.6025 Ops/s $\color{#d91a1a}-0.40\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] 1.9294ms 1.5794ms 633.1534 Ops/s 630.2020 Ops/s $\color{#35bf28}+0.47\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 6.0591ms 3.8507ms 259.6935 Ops/s 258.0475 Ops/s $\color{#35bf28}+0.64\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 0.7938ms 0.6078ms 1.6453 KOps/s 1.6063 KOps/s $\color{#35bf28}+2.43\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 0.6454ms 0.5728ms 1.7458 KOps/s 1.6891 KOps/s $\color{#35bf28}+3.36\%$
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 3.8310ms 3.7114ms 269.4427 Ops/s 261.5577 Ops/s $\color{#35bf28}+3.01\%$
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 0.6425ms 0.4742ms 2.1089 KOps/s 2.0991 KOps/s $\color{#35bf28}+0.46\%$
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 0.5155ms 0.4449ms 2.2475 KOps/s 2.1806 KOps/s $\color{#35bf28}+3.07\%$
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 5.9293ms 3.6952ms 270.6241 Ops/s 260.2560 Ops/s $\color{#35bf28}+3.98\%$
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 0.6484ms 0.4682ms 2.1359 KOps/s 2.1218 KOps/s $\color{#35bf28}+0.67\%$
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 0.6739ms 0.4416ms 2.2646 KOps/s 2.2387 KOps/s $\color{#35bf28}+1.16\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 4.1173ms 3.8713ms 258.3081 Ops/s 252.6158 Ops/s $\color{#35bf28}+2.25\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 0.7858ms 0.6032ms 1.6577 KOps/s 1.6245 KOps/s $\color{#35bf28}+2.04\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 0.7153ms 0.5758ms 1.7368 KOps/s 1.7085 KOps/s $\color{#35bf28}+1.66\%$
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 0.1016s 5.8217ms 171.7711 Ops/s 122.4393 Ops/s $\textbf{\color{#35bf28}+40.29\%}$
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] 17.0136ms 12.8652ms 77.7291 Ops/s 79.2396 Ops/s $\color{#d91a1a}-1.91\%$
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] 1.1471ms 1.0183ms 981.9908 Ops/s 976.0906 Ops/s $\color{#35bf28}+0.60\%$
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 0.1025s 5.7141ms 175.0070 Ops/s 170.5379 Ops/s $\color{#35bf28}+2.62\%$
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] 0.1100s 14.8153ms 67.4978 Ops/s 76.8115 Ops/s $\textbf{\color{#d91a1a}-12.13\%}$
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] 1.5668ms 1.0349ms 966.2831 Ops/s 936.9393 Ops/s $\color{#35bf28}+3.13\%$
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 99.2384ms 5.8381ms 171.2895 Ops/s 127.5935 Ops/s $\textbf{\color{#35bf28}+34.25\%}$
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] 16.7797ms 13.0705ms 76.5083 Ops/s 77.8708 Ops/s $\color{#d91a1a}-1.75\%$
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 1.7860ms 1.1883ms 841.5461 Ops/s 813.9312 Ops/s $\color{#35bf28}+3.39\%$

Copy link

github-actions bot commented Jul 9, 2024

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of GPU Benchmark Tests

Total Benchmarks: 94. Improved: $\large\color{#35bf28}11$. Worsened: $\large\color{#d91a1a}2$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_single 0.1068s 0.1061s 9.4238 Ops/s 9.1429 Ops/s $\color{#35bf28}+3.07\%$
test_sync 95.5394ms 94.7249ms 10.5569 Ops/s 10.4646 Ops/s $\color{#35bf28}+0.88\%$
test_async 0.1784s 88.9088ms 11.2475 Ops/s 11.1027 Ops/s $\color{#35bf28}+1.30\%$
test_single_pixels 0.1150s 0.1148s 8.7070 Ops/s 8.5466 Ops/s $\color{#35bf28}+1.88\%$
test_sync_pixels 75.9209ms 75.0763ms 13.3198 Ops/s 13.6364 Ops/s $\color{#d91a1a}-2.32\%$
test_async_pixels 0.1428s 70.9977ms 14.0850 Ops/s 14.3503 Ops/s $\color{#d91a1a}-1.85\%$
test_simple 0.8501s 0.7694s 1.2998 Ops/s 1.3037 Ops/s $\color{#d91a1a}-0.30\%$
test_transformed 1.0658s 0.9902s 1.0099 Ops/s 0.9973 Ops/s $\color{#35bf28}+1.26\%$
test_serial 2.2601s 2.1811s 0.4585 Ops/s 0.4533 Ops/s $\color{#35bf28}+1.13\%$
test_parallel 2.0114s 1.9295s 0.5183 Ops/s 0.5147 Ops/s $\color{#35bf28}+0.69\%$
test_step_mdp_speed[True-True-True-True-True] 0.1527ms 36.6180μs 27.3090 KOps/s 26.2674 KOps/s $\color{#35bf28}+3.97\%$
test_step_mdp_speed[True-True-True-True-False] 0.1186ms 20.8753μs 47.9036 KOps/s 46.4895 KOps/s $\color{#35bf28}+3.04\%$
test_step_mdp_speed[True-True-True-False-True] 0.1263ms 20.2192μs 49.4579 KOps/s 46.2614 KOps/s $\textbf{\color{#35bf28}+6.91\%}$
test_step_mdp_speed[True-True-True-False-False] 31.3000μs 11.7507μs 85.1016 KOps/s 82.0583 KOps/s $\color{#35bf28}+3.71\%$
test_step_mdp_speed[True-True-False-True-True] 61.2010μs 39.1892μs 25.5173 KOps/s 24.6214 KOps/s $\color{#35bf28}+3.64\%$
test_step_mdp_speed[True-True-False-True-False] 0.2054ms 23.3748μs 42.7812 KOps/s 41.5844 KOps/s $\color{#35bf28}+2.88\%$
test_step_mdp_speed[True-True-False-False-True] 0.2240ms 23.2921μs 42.9330 KOps/s 41.7624 KOps/s $\color{#35bf28}+2.80\%$
test_step_mdp_speed[True-True-False-False-False] 57.9500μs 13.9867μs 71.4966 KOps/s 68.9300 KOps/s $\color{#35bf28}+3.72\%$
test_step_mdp_speed[True-False-True-True-True] 62.7300μs 41.0731μs 24.3468 KOps/s 23.2965 KOps/s $\color{#35bf28}+4.51\%$
test_step_mdp_speed[True-False-True-True-False] 0.2015ms 25.3878μs 39.3891 KOps/s 37.9309 KOps/s $\color{#35bf28}+3.84\%$
test_step_mdp_speed[True-False-True-False-True] 43.6910μs 22.6807μs 44.0903 KOps/s 41.6499 KOps/s $\textbf{\color{#35bf28}+5.86\%}$
test_step_mdp_speed[True-False-True-False-False] 37.6500μs 14.0218μs 71.3177 KOps/s 68.4102 KOps/s $\color{#35bf28}+4.25\%$
test_step_mdp_speed[True-False-False-True-True] 82.6410μs 43.3104μs 23.0891 KOps/s 22.1018 KOps/s $\color{#35bf28}+4.47\%$
test_step_mdp_speed[True-False-False-True-False] 93.7120μs 27.7406μs 36.0483 KOps/s 35.0453 KOps/s $\color{#35bf28}+2.86\%$
test_step_mdp_speed[True-False-False-False-True] 0.1390ms 24.7718μs 40.3684 KOps/s 38.0311 KOps/s $\textbf{\color{#35bf28}+6.15\%}$
test_step_mdp_speed[True-False-False-False-False] 35.4900μs 16.2261μs 61.6291 KOps/s 59.1861 KOps/s $\color{#35bf28}+4.13\%$
test_step_mdp_speed[False-True-True-True-True] 62.3410μs 40.7377μs 24.5473 KOps/s 23.3656 KOps/s $\textbf{\color{#35bf28}+5.06\%}$
test_step_mdp_speed[False-True-True-True-False] 54.7810μs 25.2926μs 39.5373 KOps/s 37.8451 KOps/s $\color{#35bf28}+4.47\%$
test_step_mdp_speed[False-True-True-False-True] 50.1010μs 27.4374μs 36.4466 KOps/s 35.3639 KOps/s $\color{#35bf28}+3.06\%$
test_step_mdp_speed[False-True-True-False-False] 42.5490μs 16.3276μs 61.2458 KOps/s 60.1769 KOps/s $\color{#35bf28}+1.78\%$
test_step_mdp_speed[False-True-False-True-True] 65.3310μs 43.4797μs 22.9993 KOps/s 22.2175 KOps/s $\color{#35bf28}+3.52\%$
test_step_mdp_speed[False-True-False-True-False] 59.4710μs 27.6704μs 36.1398 KOps/s 34.8182 KOps/s $\color{#35bf28}+3.80\%$
test_step_mdp_speed[False-True-False-False-True] 61.0010μs 29.8555μs 33.4947 KOps/s 32.6351 KOps/s $\color{#35bf28}+2.63\%$
test_step_mdp_speed[False-True-False-False-False] 50.3210μs 18.1419μs 55.1210 KOps/s 52.5978 KOps/s $\color{#35bf28}+4.80\%$
test_step_mdp_speed[False-False-True-True-True] 4.8120ms 46.3627μs 21.5691 KOps/s 20.8679 KOps/s $\color{#35bf28}+3.36\%$
test_step_mdp_speed[False-False-True-True-False] 51.0510μs 30.1297μs 33.1898 KOps/s 31.9551 KOps/s $\color{#35bf28}+3.86\%$
test_step_mdp_speed[False-False-True-False-True] 49.9910μs 29.5039μs 33.8938 KOps/s 32.7063 KOps/s $\color{#35bf28}+3.63\%$
test_step_mdp_speed[False-False-True-False-False] 66.7210μs 18.3344μs 54.5424 KOps/s 52.5436 KOps/s $\color{#35bf28}+3.80\%$
test_step_mdp_speed[False-False-False-True-True] 83.8810μs 47.6592μs 20.9823 KOps/s 20.0702 KOps/s $\color{#35bf28}+4.54\%$
test_step_mdp_speed[False-False-False-True-False] 53.3510μs 32.2893μs 30.9700 KOps/s 29.9435 KOps/s $\color{#35bf28}+3.43\%$
test_step_mdp_speed[False-False-False-False-True] 54.8510μs 30.9677μs 32.2918 KOps/s 30.7028 KOps/s $\textbf{\color{#35bf28}+5.18\%}$
test_step_mdp_speed[False-False-False-False-False] 41.4410μs 20.3202μs 49.2121 KOps/s 46.8325 KOps/s $\textbf{\color{#35bf28}+5.08\%}$
test_values[generalized_advantage_estimate-True-True] 26.1524ms 25.7266ms 38.8703 Ops/s 40.0025 Ops/s $\color{#d91a1a}-2.83\%$
test_values[vec_generalized_advantage_estimate-True-True] 96.8067ms 2.8402ms 352.0855 Ops/s 365.8197 Ops/s $\color{#d91a1a}-3.75\%$
test_values[td0_return_estimate-False-False] 89.6200μs 64.9646μs 15.3930 KOps/s 14.9553 KOps/s $\color{#35bf28}+2.93\%$
test_values[td1_return_estimate-False-False] 57.7869ms 56.9707ms 17.5529 Ops/s 17.9762 Ops/s $\color{#d91a1a}-2.36\%$
test_values[vec_td1_return_estimate-False-False] 1.3190ms 1.0801ms 925.8298 Ops/s 918.1787 Ops/s $\color{#35bf28}+0.83\%$
test_values[td_lambda_return_estimate-True-False] 90.8806ms 89.7140ms 11.1465 Ops/s 11.3339 Ops/s $\color{#d91a1a}-1.65\%$
test_values[vec_td_lambda_return_estimate-True-False] 1.3407ms 1.0803ms 925.6351 Ops/s 922.6118 Ops/s $\color{#35bf28}+0.33\%$
test_gae_speed[generalized_advantage_estimate-False-1-512] 25.6887ms 24.7812ms 40.3532 Ops/s 40.5384 Ops/s $\color{#d91a1a}-0.46\%$
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 0.9665ms 0.7136ms 1.4013 KOps/s 1.3843 KOps/s $\color{#35bf28}+1.23\%$
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] 0.8062ms 0.6665ms 1.5005 KOps/s 1.4886 KOps/s $\color{#35bf28}+0.80\%$
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] 1.6330ms 1.4735ms 678.6552 Ops/s 678.7814 Ops/s $\color{#d91a1a}-0.02\%$
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 0.8566ms 0.6814ms 1.4676 KOps/s 1.4631 KOps/s $\color{#35bf28}+0.31\%$
test_dqn_speed 1.6772ms 1.4426ms 693.1698 Ops/s 676.6337 Ops/s $\color{#35bf28}+2.44\%$
test_ddpg_speed 3.2460ms 2.9448ms 339.5793 Ops/s 329.6548 Ops/s $\color{#35bf28}+3.01\%$
test_sac_speed 9.1301ms 8.5704ms 116.6809 Ops/s 114.8209 Ops/s $\color{#35bf28}+1.62\%$
test_redq_speed 11.8329ms 10.9327ms 91.4684 Ops/s 89.9808 Ops/s $\color{#35bf28}+1.65\%$
test_redq_deprec_speed 0.1110s 12.7982ms 78.1359 Ops/s 84.4333 Ops/s $\textbf{\color{#d91a1a}-7.46\%}$
test_td3_speed 8.7080ms 8.5394ms 117.1041 Ops/s 114.8851 Ops/s $\color{#35bf28}+1.93\%$
test_cql_speed 28.2763ms 27.0789ms 36.9292 Ops/s 36.7113 Ops/s $\color{#35bf28}+0.59\%$
test_a2c_speed 6.6296ms 5.9199ms 168.9231 Ops/s 167.2917 Ops/s $\color{#35bf28}+0.98\%$
test_ppo_speed 6.5943ms 6.2266ms 160.6004 Ops/s 158.9749 Ops/s $\color{#35bf28}+1.02\%$
test_reinforce_speed 5.1973ms 4.8460ms 206.3549 Ops/s 204.1456 Ops/s $\color{#35bf28}+1.08\%$
test_iql_speed 20.9824ms 20.5401ms 48.6852 Ops/s 47.2614 Ops/s $\color{#35bf28}+3.01\%$
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 5.4289ms 5.1077ms 195.7818 Ops/s 186.2444 Ops/s $\textbf{\color{#35bf28}+5.12\%}$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 0.6942ms 0.5196ms 1.9246 KOps/s 1.9155 KOps/s $\color{#35bf28}+0.48\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 0.7164ms 0.4962ms 2.0154 KOps/s 1.9924 KOps/s $\color{#35bf28}+1.15\%$
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 5.5166ms 5.0892ms 196.4946 Ops/s 188.4105 Ops/s $\color{#35bf28}+4.29\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 0.6959ms 0.5084ms 1.9671 KOps/s 1.9312 KOps/s $\color{#35bf28}+1.86\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 0.6769ms 0.4904ms 2.0390 KOps/s 2.0213 KOps/s $\color{#35bf28}+0.88\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] 2.1560ms 1.9788ms 505.3486 Ops/s 495.0452 Ops/s $\color{#35bf28}+2.08\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] 2.1354ms 1.8909ms 528.8454 Ops/s 517.4089 Ops/s $\color{#35bf28}+2.21\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 5.5833ms 5.2664ms 189.8824 Ops/s 182.4724 Ops/s $\color{#35bf28}+4.06\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 0.8586ms 0.6691ms 1.4945 KOps/s 1.4879 KOps/s $\color{#35bf28}+0.44\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 0.8496ms 0.6457ms 1.5487 KOps/s 1.5317 KOps/s $\color{#35bf28}+1.11\%$
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 5.4611ms 5.1069ms 195.8126 Ops/s 185.6920 Ops/s $\textbf{\color{#35bf28}+5.45\%}$
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 0.6854ms 0.5161ms 1.9376 KOps/s 1.9181 KOps/s $\color{#35bf28}+1.02\%$
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 0.7568ms 0.5013ms 1.9950 KOps/s 1.9973 KOps/s $\color{#d91a1a}-0.12\%$
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 5.6158ms 5.0920ms 196.3861 Ops/s 188.6939 Ops/s $\color{#35bf28}+4.08\%$
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 0.6822ms 0.5126ms 1.9508 KOps/s 1.9307 KOps/s $\color{#35bf28}+1.04\%$
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 0.6692ms 0.4971ms 2.0115 KOps/s 2.0138 KOps/s $\color{#d91a1a}-0.11\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 5.4205ms 5.2318ms 191.1374 Ops/s 181.9527 Ops/s $\textbf{\color{#35bf28}+5.05\%}$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 0.8486ms 0.6638ms 1.5064 KOps/s 1.4855 KOps/s $\color{#35bf28}+1.41\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 0.8060ms 0.6451ms 1.5501 KOps/s 1.5343 KOps/s $\color{#35bf28}+1.03\%$
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 0.1507s 8.1370ms 122.8955 Ops/s 120.4396 Ops/s $\color{#35bf28}+2.04\%$
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] 18.5897ms 16.3244ms 61.2579 Ops/s 60.9398 Ops/s $\color{#35bf28}+0.52\%$
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] 2.1583ms 1.1322ms 883.2548 Ops/s 888.3748 Ops/s $\color{#d91a1a}-0.58\%$
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 0.1293s 7.7327ms 129.3207 Ops/s 95.6292 Ops/s $\textbf{\color{#35bf28}+35.23\%}$
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] 18.6442ms 16.2801ms 61.4249 Ops/s 60.6611 Ops/s $\color{#35bf28}+1.26\%$
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] 2.1761ms 1.1457ms 872.8615 Ops/s 864.9996 Ops/s $\color{#35bf28}+0.91\%$
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 0.1306s 7.9640ms 125.5646 Ops/s 123.8533 Ops/s $\color{#35bf28}+1.38\%$
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] 0.1458s 18.9672ms 52.7227 Ops/s 59.9978 Ops/s $\textbf{\color{#d91a1a}-12.13\%}$
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 2.3098ms 1.2724ms 785.9111 Ops/s 726.9930 Ops/s $\textbf{\color{#35bf28}+8.10\%}$

digits = torch.vmap(torch.dot, (None, 0))(
self.bases, feature_parts.to(self.bases.dtype)
)
digits = digits.reshape(shape=(-1, features.shape[-1] // self.num_bits))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: features.shape[-1] -> num_features

@vmoens
Copy link
Collaborator Author

vmoens commented Jul 16, 2024

Here is an example of how to build a tree with TensorDictMap and MCTSForest:

# The following example requires torchrl and gymnasium to be installed
from torchrl.envs import GymEnv
import torch
from tensordict import TensorDict, LazyStackedTensorDict
from torchrl.data import TensorDictMap, ListStorage
from torchrl.data.map.tree import MCTSForest

from torchrl.envs import BraxEnv

forest = MCTSForest()

env = BraxEnv("ant")
env.set_seed(0)
reset_state = env.reset()
rollout0 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone())

starttd = rollout0[3]
rollout0 = rollout0.exclude("state", ("next", "state"))
forest.extend(rollout0)
print(len(forest))


rollout1 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone()).exclude("state", ("next", "state"))
forest.extend(rollout1)
print(len(forest))

rollout0b = env.rollout(6, auto_reset=False, tensordict=starttd).exclude("state", ("next", "state"))
forest.extend(rollout0b)
print(len(forest))

r = rollout0[0]
r.names = None
tree = forest.get_tree(r)

forest.plot(tree)

image

@mjlaali

@mjlaali
Copy link

mjlaali commented Jul 17, 2024

Pretty neat. I was wondering how to update the tree. In particular, can we store a count per child and update it after each rollout?

binary_to_decimal = BinaryToDecimal(
num_bits=4, device="cpu", dtype=torch.int32, convert_to_binary=True
)
binary = torch.Tensor([[0, 0, 1, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 10, 0]])
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the 10 should be a 1 in this case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's on purpose, to test that we convert non binary to binary (see the class constructor)

@vmoens
Copy link
Collaborator Author

vmoens commented Jul 22, 2024

Closing in favour of #2304

@vmoens vmoens closed this Jul 22, 2024
@vmoens vmoens mentioned this pull request Aug 4, 2024
23 tasks
@vmoens vmoens deleted the td2td branch August 7, 2024 01:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants