-
Notifications
You must be signed in to change notification settings - Fork 113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Potentially mishandled continuous action space shape #172
Comments
Hello, thank you for your feedback. We have identified this bug and it has now been fixed in the latest commit 27188cf. BTW, as you can see, there might be a more efficient implementation for this data processing code. We would greatly appreciate it if you could provide an optimized version. Best wishes! |
Of course, the current modifications are made for the sake of program compatibility. The better approach is to identify the location where the shape of the action tensor changes from |
Identifying the problemI found the problem. I'm using a batch size of 32 in these examples, and a continuous action space size of 2. At this point in the code, the action space for the whole batch seems properly shaped:
shape of Then, there's this line: LightZero/lzero/policy/sampled_efficientzero.py Lines 333 to 335 in 95e94b9
shape of This is a problem the next time LightZero/lzero/policy/sampled_efficientzero.py Lines 416 to 418 in 95e94b9
shape of Testing variablesHere is the initial from numpy import array
array([[[-0.4490683972835541, -0.9950420260429382],
[-0.4014930129051208, -0.7829872369766235],
[-0.4213773608207703, -0.7039519548416138],
[-0.4044414460659027, -0.834709882736206 ],
[-0.1590566784143448, -0.2195666283369064]],
[[-0.4014930129051208, -0.7829872369766235],
[-0.4213773608207703, -0.7039519548416138],
[-0.4044414460659027, -0.834709882736206 ],
[-0.1590566784143448, -0.2195666283369064],
[-0.7565191984176636, -0.2842720746994019]],
[[-0.4213773608207703, -0.7039519548416138],
[-0.4044414460659027, -0.834709882736206 ],
[-0.1590566784143448, -0.2195666283369064],
[-0.7565191984176636, -0.2842720746994019],
[ 0.6578846573829651, 0.8156458139419556]],
[[-0.4044414460659027, -0.834709882736206 ],
[-0.1590566784143448, -0.2195666283369064],
[-0.7565191984176636, -0.2842720746994019],
[ 0.6578846573829651, 0.8156458139419556],
[ 0.7340825796127319, 0.8159331679344177]],
[[-0.1590566784143448, -0.2195666283369064],
[-0.7565191984176636, -0.2842720746994019],
[ 0.6578846573829651, 0.8156458139419556],
[ 0.7340825796127319, 0.8159331679344177],
[-0.4984613955020905, 0.2341146618127823]],
[[ 0.6578846573829651, 0.8156458139419556],
[ 0.7340825796127319, 0.8159331679344177],
[-0.4984613955020905, 0.2341146618127823],
[ 0.2147858291864395, 0.8802192211151123],
[ 0.3157244141177024, 0.122345928192084 ]],
[[ 0.7340825796127319, 0.8159331679344177],
[-0.4984613955020905, 0.2341146618127823],
[ 0.2147858291864395, 0.8802192211151123],
[ 0.0447234260510479, -0.0061242674195186],
[ 2.076320333718342 , 0.0612052095300819]],
[[-0.4984613955020905, 0.2341146618127823],
[ 0.2147858291864395, 0.8802192211151123],
[-0.0753794324983024, 0.018401418199892 ],
[ 0.230092215421578 , -0.154497093666865 ],
[ 0.8510225575898217, 0.0945327318063604]],
[[ 0.0692359060049057, -0.5581157207489014],
[-0.0967390909790993, 0.6544007062911987],
[-0.5564347505569458, -0.5217161774635315],
[ 0.227927178144455 , -0.5079849362373352],
[-0.7363128662109375, 0.7372166514396667]],
[[-0.5564347505569458, -0.5217161774635315],
[ 0.227927178144455 , -0.5079849362373352],
[-0.7363128662109375, 0.7372166514396667],
[ 0.8861203789710999, 0.7655220031738281],
[-0.8919497132301331, 0.7664076685905457]],
[[ 0.8861203789710999, 0.7655220031738281],
[-0.8919497132301331, 0.7664076685905457],
[-0.8111757040023804, 0.3714333176612854],
[ 0.9299460443093313, 0.62140516600195 ],
[-0.7294452606963147, -0.5507582288642449]],
[[-0.4489460289478302, -0.4065003991127014],
[-0.0877294093370438, 0.6240969896316528],
[-0.145333394408226 , 0.0135438507422805],
[ 0.7303228378295898, 0.324310839176178 ],
[ 0.1930731981992722, -0.6004103422164917]],
[[-0.0877294093370438, 0.6240969896316528],
[-0.145333394408226 , 0.0135438507422805],
[ 0.7303228378295898, 0.324310839176178 ],
[ 0.1930731981992722, -0.6004103422164917],
[-0.2952068150043488, 0.7421513199806213]],
[[-0.145333394408226 , 0.0135438507422805],
[ 0.7303228378295898, 0.324310839176178 ],
[ 0.1930731981992722, -0.6004103422164917],
[-0.2952068150043488, 0.7421513199806213],
[ 0.3888700902462006, -0.6838327646255493]],
[[ 0.7991305589675903, -0.6135419011116028],
[-1.4756834812262265, -1.5111042458027768],
[-1.0753735258729593, -0.0687822111641535],
[-0.6506161214320586, -0.2692845067856774],
[-1.509456993847584 , 0.4623811175903334]],
[[ 0.9131344556808472, -0.9124650955200195],
[-0.6182714104652405, 0.8433078527450562],
[ 0.9737944006919861, 0.3280912041664124],
[-0.503142774105072 , -0.4597557783126831],
[ 0.1592527478933334, -0.3711529672145844]],
[[ 0.97501140832901 , -0.9102342128753662],
[-0.6536456346511841, -0.6314908266067505],
[-0.3102632761001587, 0.5417718887329102],
[ 0.5856770137135533, -0.1399733906201436],
[ 0.9160984176427144, 0.0171331457675763]],
[[-0.6536456346511841, -0.6314908266067505],
[-0.3102632761001587, 0.5417718887329102],
[ 1.1057067309019377, 0.2459597965178402],
[-2.78280593866222 , -2.279670235885282 ],
[ 0.3490433187008873, 0.2584693277416903]],
[[-0.3102632761001587, 0.5417718887329102],
[ 1.7954634122483244, -0.1587822891046758],
[ 0.8435877018233144, -1.792088634720315 ],
[-0.6243843906905057, 1.004054336454013 ],
[ 1.0931276206988136, -0.8503439391331218]],
[[-0.0523804016411304, 0.9104548692703247],
[-0.982227087020874 , -0.4803890585899353],
[-0.9744316339492798, 0.9138423204421997],
[-0.8951768279075623, -0.8566049337387085],
[ 0.6736979484558105, -0.8682780265808105]],
[[-0.9744316339492798, 0.9138423204421997],
[-0.8951768279075623, -0.8566049337387085],
[ 0.6736979484558105, -0.8682780265808105],
[-0.9753063321113586, -0.1303187161684036],
[ 0.0419043824076653, -0.9880978465080261]],
[[ 0.6736979484558105, -0.8682780265808105],
[-0.9753063321113586, -0.1303187161684036],
[ 0.0419043824076653, -0.9880978465080261],
[-0.8325971961021423, -0.4706742167472839],
[ 0.4065617024898529, -0.1309379935264587]],
[[-0.9753063321113586, -0.1303187161684036],
[ 0.0419043824076653, -0.9880978465080261],
[-0.8325971961021423, -0.4706742167472839],
[ 0.4065617024898529, -0.1309379935264587],
[-0.9776009976312126, 0.0929417044544182]],
[[ 0.0419043824076653, -0.9880978465080261],
[-0.8325971961021423, -0.4706742167472839],
[ 0.4065617024898529, -0.1309379935264587],
[-1.2604611072245877, 0.3255282013572517],
[-0.6375371714023212, -0.2479576952053556]],
[[ 0.4065617024898529, -0.1309379935264587],
[ 1.0853122572285252, -1.113450095645758 ],
[ 0.6383299872593756, 0.4615825320021247],
[-0.4526769121750209, -0.5026150726186746],
[-0.4429909946263973, -0.6435901784670306]],
[[ 0.8820233941078186, 0.4469195306301117],
[ 0.1421182006597519, -0.3563036918640137],
[ 0.4080510437488556, -0.0753544196486473],
[-0.9183218479156494, -0.715552031993866 ],
[ 0.3346443176269531, -0.7762950658798218]],
[[-0.9183218479156494, -0.715552031993866 ],
[ 0.3346443176269531, -0.7762950658798218],
[ 0.3761063516139984, 0.7810404896736145],
[-0.9150596857070923, 0.6707392930984497],
[ 0.1624187082052231, -0.0102332159876823]],
[[ 0.3346443176269531, -0.7762950658798218],
[ 0.3761063516139984, 0.7810404896736145],
[-0.9150596857070923, 0.6707392930984497],
[ 0.1624187082052231, -0.0102332159876823],
[ 0.9643478393554688, 0.5449203252792358]],
[[ 0.3761063516139984, 0.7810404896736145],
[-0.9150596857070923, 0.6707392930984497],
[ 0.1624187082052231, -0.0102332159876823],
[ 0.9643478393554688, 0.5449203252792358],
[-0.6978347897529602, 0.3335686028003693]],
[[ 0.1624187082052231, -0.0102332159876823],
[ 0.9643478393554688, 0.5449203252792358],
[-0.6978347897529602, 0.3335686028003693],
[ 0.1176412274461428, -1.5554193438851447],
[-0.639407819647646 , 0.070470209032387 ]],
[[ 0.9643478393554688, 0.5449203252792358],
[-0.6978347897529602, 0.3335686028003693],
[ 0.9371315195670832, -0.5187535949083031],
[ 1.6102262403242997, 0.7461320383571783],
[-1.0328825079779835, 0.5870174496427473]],
[[-0.6978347897529602, 0.3335686028003693],
[ 0.4909302677064525, -1.8273185492431 ],
[-1.3642024846362164, 0.4962908880233692],
[ 0.9805795313754992, -1.0712426831281652],
[ 0.7456210692367394, -2.0334058101172436]]]) And here is from torch import tensor
tensor([[[[-0.4491],
[-0.9950]],
[[-0.4015],
[-0.7830]],
[[-0.4214],
[-0.7040]],
[[-0.4044],
[-0.8347]],
[[-0.1591],
[-0.2196]]],
[[[-0.4015],
[-0.7830]],
[[-0.4214],
[-0.7040]],
[[-0.4044],
[-0.8347]],
[[-0.1591],
[-0.2196]],
[[-0.7565],
[-0.2843]]],
[[[-0.4214],
[-0.7040]],
[[-0.4044],
[-0.8347]],
[[-0.1591],
[-0.2196]],
[[-0.7565],
[-0.2843]],
[[ 0.6579],
[ 0.8156]]],
[[[-0.4044],
[-0.8347]],
[[-0.1591],
[-0.2196]],
[[-0.7565],
[-0.2843]],
[[ 0.6579],
[ 0.8156]],
[[ 0.7341],
[ 0.8159]]],
[[[-0.1591],
[-0.2196]],
[[-0.7565],
[-0.2843]],
[[ 0.6579],
[ 0.8156]],
[[ 0.7341],
[ 0.8159]],
[[-0.4985],
[ 0.2341]]],
[[[ 0.6579],
[ 0.8156]],
[[ 0.7341],
[ 0.8159]],
[[-0.4985],
[ 0.2341]],
[[ 0.2148],
[ 0.8802]],
[[ 0.3157],
[ 0.1223]]],
[[[ 0.7341],
[ 0.8159]],
[[-0.4985],
[ 0.2341]],
[[ 0.2148],
[ 0.8802]],
[[ 0.0447],
[-0.0061]],
[[ 2.0763],
[ 0.0612]]],
[[[-0.4985],
[ 0.2341]],
[[ 0.2148],
[ 0.8802]],
[[-0.0754],
[ 0.0184]],
[[ 0.2301],
[-0.1545]],
[[ 0.8510],
[ 0.0945]]],
[[[ 0.0692],
[-0.5581]],
[[-0.0967],
[ 0.6544]],
[[-0.5564],
[-0.5217]],
[[ 0.2279],
[-0.5080]],
[[-0.7363],
[ 0.7372]]],
[[[-0.5564],
[-0.5217]],
[[ 0.2279],
[-0.5080]],
[[-0.7363],
[ 0.7372]],
[[ 0.8861],
[ 0.7655]],
[[-0.8919],
[ 0.7664]]],
[[[ 0.8861],
[ 0.7655]],
[[-0.8919],
[ 0.7664]],
[[-0.8112],
[ 0.3714]],
[[ 0.9299],
[ 0.6214]],
[[-0.7294],
[-0.5508]]],
[[[-0.4489],
[-0.4065]],
[[-0.0877],
[ 0.6241]],
[[-0.1453],
[ 0.0135]],
[[ 0.7303],
[ 0.3243]],
[[ 0.1931],
[-0.6004]]],
[[[-0.0877],
[ 0.6241]],
[[-0.1453],
[ 0.0135]],
[[ 0.7303],
[ 0.3243]],
[[ 0.1931],
[-0.6004]],
[[-0.2952],
[ 0.7422]]],
[[[-0.1453],
[ 0.0135]],
[[ 0.7303],
[ 0.3243]],
[[ 0.1931],
[-0.6004]],
[[-0.2952],
[ 0.7422]],
[[ 0.3889],
[-0.6838]]],
[[[ 0.7991],
[-0.6135]],
[[-1.4757],
[-1.5111]],
[[-1.0754],
[-0.0688]],
[[-0.6506],
[-0.2693]],
[[-1.5095],
[ 0.4624]]],
[[[ 0.9131],
[-0.9125]],
[[-0.6183],
[ 0.8433]],
[[ 0.9738],
[ 0.3281]],
[[-0.5031],
[-0.4598]],
[[ 0.1593],
[-0.3712]]],
[[[ 0.9750],
[-0.9102]],
[[-0.6536],
[-0.6315]],
[[-0.3103],
[ 0.5418]],
[[ 0.5857],
[-0.1400]],
[[ 0.9161],
[ 0.0171]]],
[[[-0.6536],
[-0.6315]],
[[-0.3103],
[ 0.5418]],
[[ 1.1057],
[ 0.2460]],
[[-2.7828],
[-2.2797]],
[[ 0.3490],
[ 0.2585]]],
[[[-0.3103],
[ 0.5418]],
[[ 1.7955],
[-0.1588]],
[[ 0.8436],
[-1.7921]],
[[-0.6244],
[ 1.0041]],
[[ 1.0931],
[-0.8503]]],
[[[-0.0524],
[ 0.9105]],
[[-0.9822],
[-0.4804]],
[[-0.9744],
[ 0.9138]],
[[-0.8952],
[-0.8566]],
[[ 0.6737],
[-0.8683]]],
[[[-0.9744],
[ 0.9138]],
[[-0.8952],
[-0.8566]],
[[ 0.6737],
[-0.8683]],
[[-0.9753],
[-0.1303]],
[[ 0.0419],
[-0.9881]]],
[[[ 0.6737],
[-0.8683]],
[[-0.9753],
[-0.1303]],
[[ 0.0419],
[-0.9881]],
[[-0.8326],
[-0.4707]],
[[ 0.4066],
[-0.1309]]],
[[[-0.9753],
[-0.1303]],
[[ 0.0419],
[-0.9881]],
[[-0.8326],
[-0.4707]],
[[ 0.4066],
[-0.1309]],
[[-0.9776],
[ 0.0929]]],
[[[ 0.0419],
[-0.9881]],
[[-0.8326],
[-0.4707]],
[[ 0.4066],
[-0.1309]],
[[-1.2605],
[ 0.3255]],
[[-0.6375],
[-0.2480]]],
[[[ 0.4066],
[-0.1309]],
[[ 1.0853],
[-1.1135]],
[[ 0.6383],
[ 0.4616]],
[[-0.4527],
[-0.5026]],
[[-0.4430],
[-0.6436]]],
[[[ 0.8820],
[ 0.4469]],
[[ 0.1421],
[-0.3563]],
[[ 0.4081],
[-0.0754]],
[[-0.9183],
[-0.7156]],
[[ 0.3346],
[-0.7763]]],
[[[-0.9183],
[-0.7156]],
[[ 0.3346],
[-0.7763]],
[[ 0.3761],
[ 0.7810]],
[[-0.9151],
[ 0.6707]],
[[ 0.1624],
[-0.0102]]],
[[[ 0.3346],
[-0.7763]],
[[ 0.3761],
[ 0.7810]],
[[-0.9151],
[ 0.6707]],
[[ 0.1624],
[-0.0102]],
[[ 0.9643],
[ 0.5449]]],
[[[ 0.3761],
[ 0.7810]],
[[-0.9151],
[ 0.6707]],
[[ 0.1624],
[-0.0102]],
[[ 0.9643],
[ 0.5449]],
[[-0.6978],
[ 0.3336]]],
[[[ 0.1624],
[-0.0102]],
[[ 0.9643],
[ 0.5449]],
[[-0.6978],
[ 0.3336]],
[[ 0.1176],
[-1.5554]],
[[-0.6394],
[ 0.0705]]],
[[[ 0.9643],
[ 0.5449]],
[[-0.6978],
[ 0.3336]],
[[ 0.9371],
[-0.5188]],
[[ 1.6102],
[ 0.7461]],
[[-1.0329],
[ 0.5870]]],
[[[-0.6978],
[ 0.3336]],
[[ 0.4909],
[-1.8273]],
[[-1.3642],
[ 0.4963]],
[[ 0.9806],
[-1.0712]],
[[ 0.7456],
[-2.0334]]]]) |
Hello, thank you for your detailed feedback. We have confirmed that this was a redundant operation and it has been fixed in the latest commit on the main branch. Thank you once again for your active contribution. |
* Add SumToThree pooltool env * Woops * Update datatypes and add single inference mode * Move core into pooltool * Add some speed and memory profiling for env debug * Trying to get CNNs working * Patch #172 * Setup first experiment * Fix up sumtothreeimage * Update obs space to be float * Move image_representation into fork - It was in pooltool ai-framework branch - By moving it here, main branch of pooltool can be used * Start a README * Begin test suite for sum_to_three_env * Add tests for datatypes * Finish test suite for sum_to_three_env * rename tests -> characterize * Delete * Increase to 300,000 replay buffer * Finish README * Fix image link * Link the discussion page * Update pooltool API calls to 0.3.0 * Switch to dataclasses - attrs is not standard library, best not to impose my standards - Also had some docs * Progress on documentation and variable naming * Finish docs for datatypes.py * Data structure changes - Additionally, move reward function into reward module and add options to select different rewards via cfg * Parameterize action space bounds - Remove clunky class methods * Add a module docstring * Finish docstrings for sum_to_three coordinate environment * rm pooltool __init__.py - LSP was getting confused with the `import pooltool` statement * Add pytest * Add pooltool-billiards * Add docs for reward space * Add tests for grayscale conversion, add docs * Add module doc for reward.py * Add docs for image_representation * Fix image env * Update info about px parameter * Add serialie/deserialize methods for RenderConfig * Three things: - move px to RenderConfig - serialize/deserialization methods for RenderConfig - Mimic the refactor in cts env to the image env * Use channels in renderconfig * Buff image_representation visualization - Add an animation * Start consolidation * More consolidation between observation types * consolidate image and coordinate observation types * Remove old file * Add default config * Single source state setting * Add tests * Unused * Add default render config option - Store as attribute * Add speed test script * Small changes * Add sum to three to feature table * Update pooltool README * Move observation/ and reward.py into utils.py * polish(pu): polish sum_to_three configs * feature(pu): add sum_to_three_vector_obs_sac_config.py and polish related config names * polish(pu): polish sum_to_three configs * polish(pu): polish pooltool configs --------- Co-authored-by: dyyoungg <yangdeyu@sensetime.com> Co-authored-by: 蒲源 <2402552459@qq.com> Co-authored-by: 蒲源 <48008469+puyuan1996@users.noreply.github.com>
I'm running a sampled EfficientZero experiment with a discrete, image-based observation space and a two-parameter continuous action space (continuous gaussian policy).
After collecting enough episodes to sample a mini-batch, I encounter the following error during iteration 0 of the training procedure:
Here is the relevant code (in
SampledEfficientZeroModel._dynamics
):LightZero/lzero/model/sampled_efficientzero_model.py
Lines 406 to 420 in 3823560
This method successfully runs during the collection step of
train_muzero
(LightZero/lzero/entry/train_muzero.py
Line 160 in 3823560
During successful passes of the method, the
action
shape istorch.Size([6, 2])
. This makes sense because I have 6 collector environments and the action space is 2D.But the method fails with the above error during the train step of
train_muzero
(LightZero/lzero/entry/train_muzero.py
Lines 184 to 185 in 3823560
During failure, the action shape is
torch.Size([128, 2, 1])
(I'm running 128 batch sizes). Here are some printouts from my debugger:I think either (1) the action passed to
_dynamics
has been erroneously shaped or (2) the_dynamics
method needs to be extended to support shape lengths of 3:After making this change, the program runs.
Please let me know if I can provide further information that's helpful.
The text was updated successfully, but these errors were encountered: