-
Notifications
You must be signed in to change notification settings - Fork 400
[Feature] TensorDictMap #2283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] TensorDictMap #2283
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2283
Note: Links to docs will display an error until the docs builds have been completed. ⏳ 26 Pending, 2 Unrelated FailuresAs of commit 126ca4a with merge base f764c02 ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_single | 58.5684ms | 56.5628ms | 17.6795 Ops/s | 17.0442 Ops/s | |
test_sync | 38.8439ms | 32.8310ms | 30.4590 Ops/s | 31.7465 Ops/s | |
test_async | 68.5401ms | 29.4945ms | 33.9046 Ops/s | 33.0659 Ops/s | |
test_simple | 0.4689s | 0.4017s | 2.4893 Ops/s | 2.4756 Ops/s | |
test_transformed | 0.6200s | 0.5600s | 1.7857 Ops/s | 1.7356 Ops/s | |
test_serial | 1.3115s | 1.2511s | 0.7993 Ops/s | 0.8030 Ops/s | |
test_parallel | 1.1758s | 1.0984s | 0.9104 Ops/s | 0.9214 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.2269ms | 25.0360μs | 39.9424 KOps/s | 38.6759 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 43.2910μs | 14.6658μs | 68.1857 KOps/s | 66.6951 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 38.1410μs | 14.6204μs | 68.3975 KOps/s | 66.6336 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 44.0050μs | 8.5751μs | 116.6173 KOps/s | 115.1114 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 74.4780μs | 26.7858μs | 37.3333 KOps/s | 36.4264 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 69.3460μs | 16.0259μs | 62.3990 KOps/s | 59.8946 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 0.1405ms | 16.6038μs | 60.2273 KOps/s | 60.6092 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 45.4880μs | 9.9879μs | 100.1207 KOps/s | 95.6729 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 82.5830μs | 28.7845μs | 34.7409 KOps/s | 34.2237 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 69.1790μs | 17.6777μs | 56.5686 KOps/s | 54.7673 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 51.7360μs | 16.0271μs | 62.3943 KOps/s | 60.2143 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 31.6490μs | 10.0484μs | 99.5186 KOps/s | 96.4662 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 0.1465ms | 30.3349μs | 32.9653 KOps/s | 32.5394 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 65.2110μs | 19.1458μs | 52.2308 KOps/s | 50.5217 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 58.9490μs | 17.8968μs | 55.8760 KOps/s | 54.4112 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 38.0010μs | 11.6025μs | 86.1886 KOps/s | 84.4970 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 76.6430μs | 28.3662μs | 35.2533 KOps/s | 34.3400 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 64.6000μs | 17.7361μs | 56.3823 KOps/s | 53.9827 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 66.5840μs | 18.8834μs | 52.9564 KOps/s | 52.0773 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 42.8200μs | 11.3623μs | 88.0102 KOps/s | 85.8465 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 80.9610μs | 29.8008μs | 33.5561 KOps/s | 32.6166 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 67.2850μs | 19.1855μs | 52.1227 KOps/s | 50.5730 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 67.9960μs | 20.1945μs | 49.5184 KOps/s | 49.1367 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 50.9050μs | 12.8499μs | 77.8215 KOps/s | 76.3133 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 3.5318ms | 31.6938μs | 31.5519 KOps/s | 30.4054 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 57.5070μs | 20.8953μs | 47.8576 KOps/s | 46.7734 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 69.3600μs | 20.0013μs | 49.9967 KOps/s | 48.5172 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 73.6860μs | 12.8919μs | 77.5683 KOps/s | 75.1527 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 64.6000μs | 32.6057μs | 30.6694 KOps/s | 30.1260 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 60.9640μs | 22.0699μs | 45.3106 KOps/s | 43.7790 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 46.3460μs | 21.2709μs | 47.0125 KOps/s | 44.7079 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 36.1770μs | 14.0040μs | 71.4083 KOps/s | 69.0110 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 11.7396ms | 9.4386ms | 105.9482 Ops/s | 99.0019 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 35.2560ms | 33.2188ms | 30.1034 Ops/s | 29.6275 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.2235ms | 0.1674ms | 5.9727 KOps/s | 5.9688 KOps/s | |
test_values[td1_return_estimate-False-False] | 27.0160ms | 23.2004ms | 43.1028 Ops/s | 42.6380 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 35.4179ms | 33.1156ms | 30.1973 Ops/s | 29.4693 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 35.0263ms | 33.2346ms | 30.0891 Ops/s | 28.3963 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 35.3144ms | 33.1683ms | 30.1493 Ops/s | 27.7661 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 8.3682ms | 8.2163ms | 121.7094 Ops/s | 121.0127 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 2.5193ms | 2.0237ms | 494.1489 Ops/s | 520.4212 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.4531ms | 0.3505ms | 2.8531 KOps/s | 2.8584 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 45.7344ms | 43.3348ms | 23.0761 Ops/s | 22.8792 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 3.9298ms | 3.0265ms | 330.4171 Ops/s | 329.5451 Ops/s | |
test_dqn_speed | 6.5579ms | 1.3751ms | 727.2123 Ops/s | 720.7747 Ops/s | |
test_ddpg_speed | 3.9052ms | 2.8905ms | 345.9585 Ops/s | 345.2737 Ops/s | |
test_sac_speed | 9.8415ms | 8.4837ms | 117.8732 Ops/s | 117.1821 Ops/s | |
test_redq_speed | 14.9508ms | 13.0616ms | 76.5604 Ops/s | 76.6121 Ops/s | |
test_redq_deprec_speed | 13.9800ms | 13.1871ms | 75.8316 Ops/s | 75.9078 Ops/s | |
test_td3_speed | 8.6227ms | 8.3725ms | 119.4385 Ops/s | 118.5144 Ops/s | |
test_cql_speed | 38.8063ms | 36.3204ms | 27.5327 Ops/s | 25.1259 Ops/s | |
test_a2c_speed | 8.3736ms | 7.4161ms | 134.8411 Ops/s | 132.8253 Ops/s | |
test_ppo_speed | 9.5290ms | 7.6960ms | 129.9375 Ops/s | 128.6905 Ops/s | |
test_reinforce_speed | 7.5238ms | 6.6880ms | 149.5214 Ops/s | 150.3489 Ops/s | |
test_iql_speed | 33.4011ms | 32.3856ms | 30.8779 Ops/s | 31.1996 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 5.7986ms | 3.8079ms | 262.6103 Ops/s | 265.8122 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.6421ms | 0.4723ms | 2.1172 KOps/s | 2.1210 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.6604ms | 0.4494ms | 2.2252 KOps/s | 2.2297 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 4.5141ms | 3.7924ms | 263.6833 Ops/s | 269.5668 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.5659ms | 0.4656ms | 2.1477 KOps/s | 2.1343 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.6223ms | 0.4403ms | 2.2712 KOps/s | 2.2448 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 2.4530ms | 1.6800ms | 595.2205 Ops/s | 597.6025 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 1.9294ms | 1.5794ms | 633.1534 Ops/s | 630.2020 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.0591ms | 3.8507ms | 259.6935 Ops/s | 258.0475 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 0.7938ms | 0.6078ms | 1.6453 KOps/s | 1.6063 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.6454ms | 0.5728ms | 1.7458 KOps/s | 1.6891 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 3.8310ms | 3.7114ms | 269.4427 Ops/s | 261.5577 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.6425ms | 0.4742ms | 2.1089 KOps/s | 2.0991 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.5155ms | 0.4449ms | 2.2475 KOps/s | 2.1806 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 5.9293ms | 3.6952ms | 270.6241 Ops/s | 260.2560 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.6484ms | 0.4682ms | 2.1359 KOps/s | 2.1218 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.6739ms | 0.4416ms | 2.2646 KOps/s | 2.2387 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 4.1173ms | 3.8713ms | 258.3081 Ops/s | 252.6158 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 0.7858ms | 0.6032ms | 1.6577 KOps/s | 1.6245 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.7153ms | 0.5758ms | 1.7368 KOps/s | 1.7085 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 0.1016s | 5.8217ms | 171.7711 Ops/s | 122.4393 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 17.0136ms | 12.8652ms | 77.7291 Ops/s | 79.2396 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 1.1471ms | 1.0183ms | 981.9908 Ops/s | 976.0906 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.1025s | 5.7141ms | 175.0070 Ops/s | 170.5379 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 0.1100s | 14.8153ms | 67.4978 Ops/s | 76.8115 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 1.5668ms | 1.0349ms | 966.2831 Ops/s | 936.9393 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 99.2384ms | 5.8381ms | 171.2895 Ops/s | 127.5935 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 16.7797ms | 13.0705ms | 76.5083 Ops/s | 77.8708 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 1.7860ms | 1.1883ms | 841.5461 Ops/s | 813.9312 Ops/s |
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_single | 0.1068s | 0.1061s | 9.4238 Ops/s | 9.1429 Ops/s | |
test_sync | 95.5394ms | 94.7249ms | 10.5569 Ops/s | 10.4646 Ops/s | |
test_async | 0.1784s | 88.9088ms | 11.2475 Ops/s | 11.1027 Ops/s | |
test_single_pixels | 0.1150s | 0.1148s | 8.7070 Ops/s | 8.5466 Ops/s | |
test_sync_pixels | 75.9209ms | 75.0763ms | 13.3198 Ops/s | 13.6364 Ops/s | |
test_async_pixels | 0.1428s | 70.9977ms | 14.0850 Ops/s | 14.3503 Ops/s | |
test_simple | 0.8501s | 0.7694s | 1.2998 Ops/s | 1.3037 Ops/s | |
test_transformed | 1.0658s | 0.9902s | 1.0099 Ops/s | 0.9973 Ops/s | |
test_serial | 2.2601s | 2.1811s | 0.4585 Ops/s | 0.4533 Ops/s | |
test_parallel | 2.0114s | 1.9295s | 0.5183 Ops/s | 0.5147 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.1527ms | 36.6180μs | 27.3090 KOps/s | 26.2674 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 0.1186ms | 20.8753μs | 47.9036 KOps/s | 46.4895 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 0.1263ms | 20.2192μs | 49.4579 KOps/s | 46.2614 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 31.3000μs | 11.7507μs | 85.1016 KOps/s | 82.0583 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 61.2010μs | 39.1892μs | 25.5173 KOps/s | 24.6214 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 0.2054ms | 23.3748μs | 42.7812 KOps/s | 41.5844 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 0.2240ms | 23.2921μs | 42.9330 KOps/s | 41.7624 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 57.9500μs | 13.9867μs | 71.4966 KOps/s | 68.9300 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 62.7300μs | 41.0731μs | 24.3468 KOps/s | 23.2965 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 0.2015ms | 25.3878μs | 39.3891 KOps/s | 37.9309 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 43.6910μs | 22.6807μs | 44.0903 KOps/s | 41.6499 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 37.6500μs | 14.0218μs | 71.3177 KOps/s | 68.4102 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 82.6410μs | 43.3104μs | 23.0891 KOps/s | 22.1018 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 93.7120μs | 27.7406μs | 36.0483 KOps/s | 35.0453 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 0.1390ms | 24.7718μs | 40.3684 KOps/s | 38.0311 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 35.4900μs | 16.2261μs | 61.6291 KOps/s | 59.1861 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 62.3410μs | 40.7377μs | 24.5473 KOps/s | 23.3656 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 54.7810μs | 25.2926μs | 39.5373 KOps/s | 37.8451 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 50.1010μs | 27.4374μs | 36.4466 KOps/s | 35.3639 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 42.5490μs | 16.3276μs | 61.2458 KOps/s | 60.1769 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 65.3310μs | 43.4797μs | 22.9993 KOps/s | 22.2175 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 59.4710μs | 27.6704μs | 36.1398 KOps/s | 34.8182 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 61.0010μs | 29.8555μs | 33.4947 KOps/s | 32.6351 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 50.3210μs | 18.1419μs | 55.1210 KOps/s | 52.5978 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 4.8120ms | 46.3627μs | 21.5691 KOps/s | 20.8679 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 51.0510μs | 30.1297μs | 33.1898 KOps/s | 31.9551 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 49.9910μs | 29.5039μs | 33.8938 KOps/s | 32.7063 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 66.7210μs | 18.3344μs | 54.5424 KOps/s | 52.5436 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 83.8810μs | 47.6592μs | 20.9823 KOps/s | 20.0702 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 53.3510μs | 32.2893μs | 30.9700 KOps/s | 29.9435 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 54.8510μs | 30.9677μs | 32.2918 KOps/s | 30.7028 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 41.4410μs | 20.3202μs | 49.2121 KOps/s | 46.8325 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 26.1524ms | 25.7266ms | 38.8703 Ops/s | 40.0025 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 96.8067ms | 2.8402ms | 352.0855 Ops/s | 365.8197 Ops/s | |
test_values[td0_return_estimate-False-False] | 89.6200μs | 64.9646μs | 15.3930 KOps/s | 14.9553 KOps/s | |
test_values[td1_return_estimate-False-False] | 57.7869ms | 56.9707ms | 17.5529 Ops/s | 17.9762 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 1.3190ms | 1.0801ms | 925.8298 Ops/s | 918.1787 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 90.8806ms | 89.7140ms | 11.1465 Ops/s | 11.3339 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 1.3407ms | 1.0803ms | 925.6351 Ops/s | 922.6118 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 25.6887ms | 24.7812ms | 40.3532 Ops/s | 40.5384 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 0.9665ms | 0.7136ms | 1.4013 KOps/s | 1.3843 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.8062ms | 0.6665ms | 1.5005 KOps/s | 1.4886 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 1.6330ms | 1.4735ms | 678.6552 Ops/s | 678.7814 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 0.8566ms | 0.6814ms | 1.4676 KOps/s | 1.4631 KOps/s | |
test_dqn_speed | 1.6772ms | 1.4426ms | 693.1698 Ops/s | 676.6337 Ops/s | |
test_ddpg_speed | 3.2460ms | 2.9448ms | 339.5793 Ops/s | 329.6548 Ops/s | |
test_sac_speed | 9.1301ms | 8.5704ms | 116.6809 Ops/s | 114.8209 Ops/s | |
test_redq_speed | 11.8329ms | 10.9327ms | 91.4684 Ops/s | 89.9808 Ops/s | |
test_redq_deprec_speed | 0.1110s | 12.7982ms | 78.1359 Ops/s | 84.4333 Ops/s | |
test_td3_speed | 8.7080ms | 8.5394ms | 117.1041 Ops/s | 114.8851 Ops/s | |
test_cql_speed | 28.2763ms | 27.0789ms | 36.9292 Ops/s | 36.7113 Ops/s | |
test_a2c_speed | 6.6296ms | 5.9199ms | 168.9231 Ops/s | 167.2917 Ops/s | |
test_ppo_speed | 6.5943ms | 6.2266ms | 160.6004 Ops/s | 158.9749 Ops/s | |
test_reinforce_speed | 5.1973ms | 4.8460ms | 206.3549 Ops/s | 204.1456 Ops/s | |
test_iql_speed | 20.9824ms | 20.5401ms | 48.6852 Ops/s | 47.2614 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 5.4289ms | 5.1077ms | 195.7818 Ops/s | 186.2444 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.6942ms | 0.5196ms | 1.9246 KOps/s | 1.9155 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.7164ms | 0.4962ms | 2.0154 KOps/s | 1.9924 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 5.5166ms | 5.0892ms | 196.4946 Ops/s | 188.4105 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.6959ms | 0.5084ms | 1.9671 KOps/s | 1.9312 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.6769ms | 0.4904ms | 2.0390 KOps/s | 2.0213 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 2.1560ms | 1.9788ms | 505.3486 Ops/s | 495.0452 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 2.1354ms | 1.8909ms | 528.8454 Ops/s | 517.4089 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 5.5833ms | 5.2664ms | 189.8824 Ops/s | 182.4724 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 0.8586ms | 0.6691ms | 1.4945 KOps/s | 1.4879 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.8496ms | 0.6457ms | 1.5487 KOps/s | 1.5317 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 5.4611ms | 5.1069ms | 195.8126 Ops/s | 185.6920 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.6854ms | 0.5161ms | 1.9376 KOps/s | 1.9181 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.7568ms | 0.5013ms | 1.9950 KOps/s | 1.9973 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 5.6158ms | 5.0920ms | 196.3861 Ops/s | 188.6939 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.6822ms | 0.5126ms | 1.9508 KOps/s | 1.9307 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.6692ms | 0.4971ms | 2.0115 KOps/s | 2.0138 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 5.4205ms | 5.2318ms | 191.1374 Ops/s | 181.9527 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 0.8486ms | 0.6638ms | 1.5064 KOps/s | 1.4855 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.8060ms | 0.6451ms | 1.5501 KOps/s | 1.5343 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 0.1507s | 8.1370ms | 122.8955 Ops/s | 120.4396 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 18.5897ms | 16.3244ms | 61.2579 Ops/s | 60.9398 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 2.1583ms | 1.1322ms | 883.2548 Ops/s | 888.3748 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.1293s | 7.7327ms | 129.3207 Ops/s | 95.6292 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 18.6442ms | 16.2801ms | 61.4249 Ops/s | 60.6611 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 2.1761ms | 1.1457ms | 872.8615 Ops/s | 864.9996 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 0.1306s | 7.9640ms | 125.5646 Ops/s | 123.8533 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 0.1458s | 18.9672ms | 52.7227 Ops/s | 59.9978 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 2.3098ms | 1.2724ms | 785.9111 Ops/s | 726.9930 Ops/s |
digits = torch.vmap(torch.dot, (None, 0))( | ||
self.bases, feature_parts.to(self.bases.dtype) | ||
) | ||
digits = digits.reshape(shape=(-1, features.shape[-1] // self.num_bits)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: features.shape[-1] -> num_features
Here is an example of how to build a tree with # The following example requires torchrl and gymnasium to be installed
from torchrl.envs import GymEnv
import torch
from tensordict import TensorDict, LazyStackedTensorDict
from torchrl.data import TensorDictMap, ListStorage
from torchrl.data.map.tree import MCTSForest
from torchrl.envs import BraxEnv
forest = MCTSForest()
env = BraxEnv("ant")
env.set_seed(0)
reset_state = env.reset()
rollout0 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone())
starttd = rollout0[3]
rollout0 = rollout0.exclude("state", ("next", "state"))
forest.extend(rollout0)
print(len(forest))
rollout1 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone()).exclude("state", ("next", "state"))
forest.extend(rollout1)
print(len(forest))
rollout0b = env.rollout(6, auto_reset=False, tensordict=starttd).exclude("state", ("next", "state"))
forest.extend(rollout0b)
print(len(forest))
r = rollout0[0]
r.names = None
tree = forest.get_tree(r)
forest.plot(tree)
|
Pretty neat. I was wondering how to update the tree. In particular, can we store a count per child and update it after each rollout? |
binary_to_decimal = BinaryToDecimal( | ||
num_bits=4, device="cpu", dtype=torch.int32, convert_to_binary=True | ||
) | ||
binary = torch.Tensor([[0, 0, 1, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 10, 0]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the 10 should be a 1 in this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's on purpose, to test that we convert non binary to binary (see the class constructor)
Closing in favour of #2304 |
This code is copy-pasted from pytorch/tensordict#826
Since we have the opportunity of building composite storages or storages of any type, I would be keen to remove the key_to_storage dict and just have one storage per map (a storage can have anything in it)
Credits to @mjlaali
Missing: