Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potentially mishandled continuous action space shape #172

Closed
ekiefl opened this issue Dec 22, 2023 · 4 comments
Closed

Potentially mishandled continuous action space shape #172

ekiefl opened this issue Dec 22, 2023 · 4 comments
Labels
bug Something isn't working

Comments

@ekiefl
Copy link
Contributor

ekiefl commented Dec 22, 2023

I'm running a sampled EfficientZero experiment with a discrete, image-based observation space and a two-parameter continuous action space (continuous gaussian policy).

After collecting enough episodes to sample a mini-batch, I encounter the following error during iteration 0 of the training procedure:

Traceback (most recent call last):
  File "zoo/pooltool/sum_to_three/config/sum_to_three_image_config.py", line 119, in <module>
    train_muzero(
  File "/Users/evan/Software/pooltool_ml/LightZero/lzero/entry/train_muzero.py", line 185, in train_muzero
    log_vars = learner.train(train_data, collector.envstep)
  File "/Users/evan/anaconda3/envs/pooltool_ml/lib/python3.8/site-packages/ding/worker/learner/base_learner.py", line 165, in wrapper
    ret = fn(*args, **kwargs)
  File "/Users/evan/anaconda3/envs/pooltool_ml/lib/python3.8/site-packages/ding/worker/learner/base_learner.py", line 205, in train
    log_vars = self._policy.forward(data, **policy_kwargs)
  File "/Users/evan/Software/pooltool_ml/LightZero/lzero/policy/sampled_efficientzero.py", line 416, in _forward_learn
    network_output = self._learn_model.recurrent_inference(
  File "/Users/evan/Software/pooltool_ml/LightZero/lzero/model/sampled_efficientzero_model.py", line 306, in recurrent_inference
    next_latent_state, reward_hidden_state, value_prefix = self._dynamics(latent_state, reward_hidden_state, action)
  File "/Users/evan/Software/pooltool_ml/LightZero/lzero/model/sampled_efficientzero_model.py", line 418, in _dynamics
    action_encoding = action_encoding_tmp.expand(
RuntimeError: The expanded size of the tensor (20) must match the existing size (2) at non-singleton dimension 2.  Target sizes: [128, 2, 20, 10].  Tensor sizes: [128, 2, 1]

Here is the relevant code (in SampledEfficientZeroModel._dynamics):

else:
# continuous action space
if len(action.shape) == 2:
# (batch_size, action_dim) -> (batch_size, action_dim, 1, 1)
# e.g., torch.Size([8, 2]) -> torch.Size([8, 2, 1, 1])
action = action.unsqueeze(-1).unsqueeze(-1)
elif len(action.shape) == 1:
# (batch_size,) -> (batch_size, action_dim=1, 1, 1)
# e.g., -> torch.Size([8, 2, 1, 1])
action = action.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
action_encoding_tmp = action
action_encoding = action_encoding_tmp.expand(
latent_state.shape[0], self.action_space_size, latent_state.shape[2], latent_state.shape[3]
)

This method successfully runs during the collection step of train_muzero (

new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
)

During successful passes of the method, the action shape is torch.Size([6, 2]). This makes sense because I have 6 collector environments and the action space is 2D.

But the method fails with the above error during the train step of train_muzero (

# The core train steps for MCTS+RL algorithms.
log_vars = learner.train(train_data, collector.envstep)
)

During failure, the action shape is torch.Size([128, 2, 1]) (I'm running 128 batch sizes). Here are some printouts from my debugger:

ipdb> action.shape
torch.Size([128, 2, 1])
ipdb> action[0]
tensor([[ 0.8633],
        [-0.6739]])
ipdb> action[0][0]
tensor([0.8633])
ipdb> action[0][0][0]
tensor(0.8633)

I think either (1) the action passed to _dynamics has been erroneously shaped or (2) the _dynamics method needs to be extended to support shape lengths of 3:

            # continuous action space
            if len(action.shape) == 3:
                # <Added condition>
                action = action.unsqueeze(-1)
            elif len(action.shape) == 2:
                # (batch_size, action_dim) -> (batch_size, action_dim, 1, 1)
                # e.g.,  torch.Size([8, 2]) ->  torch.Size([8, 2, 1, 1])
                action = action.unsqueeze(-1).unsqueeze(-1)
            elif len(action.shape) == 1:
                # (batch_size,) -> (batch_size, action_dim=1, 1, 1)
                # e.g.,  -> torch.Size([8, 2, 1, 1])
                action = action.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)

            action_encoding_tmp = action
            action_encoding = action_encoding_tmp.expand(
                latent_state.shape[0], self.action_space_size, latent_state.shape[2], latent_state.shape[3]
            )

After making this change, the program runs.

Please let me know if I can provide further information that's helpful.

ekiefl added a commit to ekiefl/LightZero that referenced this issue Dec 22, 2023
@puyuan1996
Copy link
Collaborator

Hello, thank you for your feedback. We have identified this bug and it has now been fixed in the latest commit 27188cf. BTW, as you can see, there might be a more efficient implementation for this data processing code. We would greatly appreciate it if you could provide an optimized version. Best wishes!

@puyuan1996 puyuan1996 added the bug Something isn't working label Dec 22, 2023
@puyuan1996
Copy link
Collaborator

Of course, the current modifications are made for the sake of program compatibility. The better approach is to identify the location where the shape of the action tensor changes from torch.Size([128, 2]) to torch.Size([128, 2, 1]). We would greatly appreciate it if you could debug and locate the corresponding position.

@ekiefl
Copy link
Contributor Author

ekiefl commented Dec 23, 2023

Identifying the problem

I found the problem. I'm using a batch size of 32 in these examples, and a continuous action space size of 2.

At this point in the code, the action space for the whole batch seems properly shaped:

obs_batch_ori, action_batch, child_sampled_actions_batch, mask_batch, indices, weights, make_time = current_batch

shape of action_batch: (32, 5, 2)

Then, there's this line:

# shape: (batch_size, num_unroll_steps, action_dim)
# NOTE: .float(), in continuous action space.
action_batch = torch.from_numpy(action_batch).to(self._cfg.device).float().unsqueeze(-1)

shape of action_batch: torch.Size([32, 5, 2, 1])

This is a problem the next time action_batch is used:

network_output = self._learn_model.recurrent_inference(
latent_state, reward_hidden_state, action_batch[:, step_k]
)

shape of action_batch[:, step_k]: torch.Size([32, 2, 1])

Testing variables

Here is the initial action_batch:

from numpy import array
array([[[-0.4490683972835541, -0.9950420260429382],
        [-0.4014930129051208, -0.7829872369766235],
        [-0.4213773608207703, -0.7039519548416138],
        [-0.4044414460659027, -0.834709882736206 ],
        [-0.1590566784143448, -0.2195666283369064]],

       [[-0.4014930129051208, -0.7829872369766235],
        [-0.4213773608207703, -0.7039519548416138],
        [-0.4044414460659027, -0.834709882736206 ],
        [-0.1590566784143448, -0.2195666283369064],
        [-0.7565191984176636, -0.2842720746994019]],

       [[-0.4213773608207703, -0.7039519548416138],
        [-0.4044414460659027, -0.834709882736206 ],
        [-0.1590566784143448, -0.2195666283369064],
        [-0.7565191984176636, -0.2842720746994019],
        [ 0.6578846573829651,  0.8156458139419556]],

       [[-0.4044414460659027, -0.834709882736206 ],
        [-0.1590566784143448, -0.2195666283369064],
        [-0.7565191984176636, -0.2842720746994019],
        [ 0.6578846573829651,  0.8156458139419556],
        [ 0.7340825796127319,  0.8159331679344177]],

       [[-0.1590566784143448, -0.2195666283369064],
        [-0.7565191984176636, -0.2842720746994019],
        [ 0.6578846573829651,  0.8156458139419556],
        [ 0.7340825796127319,  0.8159331679344177],
        [-0.4984613955020905,  0.2341146618127823]],

       [[ 0.6578846573829651,  0.8156458139419556],
        [ 0.7340825796127319,  0.8159331679344177],
        [-0.4984613955020905,  0.2341146618127823],
        [ 0.2147858291864395,  0.8802192211151123],
        [ 0.3157244141177024,  0.122345928192084 ]],

       [[ 0.7340825796127319,  0.8159331679344177],
        [-0.4984613955020905,  0.2341146618127823],
        [ 0.2147858291864395,  0.8802192211151123],
        [ 0.0447234260510479, -0.0061242674195186],
        [ 2.076320333718342 ,  0.0612052095300819]],

       [[-0.4984613955020905,  0.2341146618127823],
        [ 0.2147858291864395,  0.8802192211151123],
        [-0.0753794324983024,  0.018401418199892 ],
        [ 0.230092215421578 , -0.154497093666865 ],
        [ 0.8510225575898217,  0.0945327318063604]],

       [[ 0.0692359060049057, -0.5581157207489014],
        [-0.0967390909790993,  0.6544007062911987],
        [-0.5564347505569458, -0.5217161774635315],
        [ 0.227927178144455 , -0.5079849362373352],
        [-0.7363128662109375,  0.7372166514396667]],

       [[-0.5564347505569458, -0.5217161774635315],
        [ 0.227927178144455 , -0.5079849362373352],
        [-0.7363128662109375,  0.7372166514396667],
        [ 0.8861203789710999,  0.7655220031738281],
        [-0.8919497132301331,  0.7664076685905457]],

       [[ 0.8861203789710999,  0.7655220031738281],
        [-0.8919497132301331,  0.7664076685905457],
        [-0.8111757040023804,  0.3714333176612854],
        [ 0.9299460443093313,  0.62140516600195  ],
        [-0.7294452606963147, -0.5507582288642449]],

       [[-0.4489460289478302, -0.4065003991127014],
        [-0.0877294093370438,  0.6240969896316528],
        [-0.145333394408226 ,  0.0135438507422805],
        [ 0.7303228378295898,  0.324310839176178 ],
        [ 0.1930731981992722, -0.6004103422164917]],

       [[-0.0877294093370438,  0.6240969896316528],
        [-0.145333394408226 ,  0.0135438507422805],
        [ 0.7303228378295898,  0.324310839176178 ],
        [ 0.1930731981992722, -0.6004103422164917],
        [-0.2952068150043488,  0.7421513199806213]],

       [[-0.145333394408226 ,  0.0135438507422805],
        [ 0.7303228378295898,  0.324310839176178 ],
        [ 0.1930731981992722, -0.6004103422164917],
        [-0.2952068150043488,  0.7421513199806213],
        [ 0.3888700902462006, -0.6838327646255493]],

       [[ 0.7991305589675903, -0.6135419011116028],
        [-1.4756834812262265, -1.5111042458027768],
        [-1.0753735258729593, -0.0687822111641535],
        [-0.6506161214320586, -0.2692845067856774],
        [-1.509456993847584 ,  0.4623811175903334]],

       [[ 0.9131344556808472, -0.9124650955200195],
        [-0.6182714104652405,  0.8433078527450562],
        [ 0.9737944006919861,  0.3280912041664124],
        [-0.503142774105072 , -0.4597557783126831],
        [ 0.1592527478933334, -0.3711529672145844]],

       [[ 0.97501140832901  , -0.9102342128753662],
        [-0.6536456346511841, -0.6314908266067505],
        [-0.3102632761001587,  0.5417718887329102],
        [ 0.5856770137135533, -0.1399733906201436],
        [ 0.9160984176427144,  0.0171331457675763]],

       [[-0.6536456346511841, -0.6314908266067505],
        [-0.3102632761001587,  0.5417718887329102],
        [ 1.1057067309019377,  0.2459597965178402],
        [-2.78280593866222  , -2.279670235885282 ],
        [ 0.3490433187008873,  0.2584693277416903]],

       [[-0.3102632761001587,  0.5417718887329102],
        [ 1.7954634122483244, -0.1587822891046758],
        [ 0.8435877018233144, -1.792088634720315 ],
        [-0.6243843906905057,  1.004054336454013 ],
        [ 1.0931276206988136, -0.8503439391331218]],

       [[-0.0523804016411304,  0.9104548692703247],
        [-0.982227087020874 , -0.4803890585899353],
        [-0.9744316339492798,  0.9138423204421997],
        [-0.8951768279075623, -0.8566049337387085],
        [ 0.6736979484558105, -0.8682780265808105]],

       [[-0.9744316339492798,  0.9138423204421997],
        [-0.8951768279075623, -0.8566049337387085],
        [ 0.6736979484558105, -0.8682780265808105],
        [-0.9753063321113586, -0.1303187161684036],
        [ 0.0419043824076653, -0.9880978465080261]],

       [[ 0.6736979484558105, -0.8682780265808105],
        [-0.9753063321113586, -0.1303187161684036],
        [ 0.0419043824076653, -0.9880978465080261],
        [-0.8325971961021423, -0.4706742167472839],
        [ 0.4065617024898529, -0.1309379935264587]],

       [[-0.9753063321113586, -0.1303187161684036],
        [ 0.0419043824076653, -0.9880978465080261],
        [-0.8325971961021423, -0.4706742167472839],
        [ 0.4065617024898529, -0.1309379935264587],
        [-0.9776009976312126,  0.0929417044544182]],

       [[ 0.0419043824076653, -0.9880978465080261],
        [-0.8325971961021423, -0.4706742167472839],
        [ 0.4065617024898529, -0.1309379935264587],
        [-1.2604611072245877,  0.3255282013572517],
        [-0.6375371714023212, -0.2479576952053556]],

       [[ 0.4065617024898529, -0.1309379935264587],
        [ 1.0853122572285252, -1.113450095645758 ],
        [ 0.6383299872593756,  0.4615825320021247],
        [-0.4526769121750209, -0.5026150726186746],
        [-0.4429909946263973, -0.6435901784670306]],

       [[ 0.8820233941078186,  0.4469195306301117],
        [ 0.1421182006597519, -0.3563036918640137],
        [ 0.4080510437488556, -0.0753544196486473],
        [-0.9183218479156494, -0.715552031993866 ],
        [ 0.3346443176269531, -0.7762950658798218]],

       [[-0.9183218479156494, -0.715552031993866 ],
        [ 0.3346443176269531, -0.7762950658798218],
        [ 0.3761063516139984,  0.7810404896736145],
        [-0.9150596857070923,  0.6707392930984497],
        [ 0.1624187082052231, -0.0102332159876823]],

       [[ 0.3346443176269531, -0.7762950658798218],
        [ 0.3761063516139984,  0.7810404896736145],
        [-0.9150596857070923,  0.6707392930984497],
        [ 0.1624187082052231, -0.0102332159876823],
        [ 0.9643478393554688,  0.5449203252792358]],

       [[ 0.3761063516139984,  0.7810404896736145],
        [-0.9150596857070923,  0.6707392930984497],
        [ 0.1624187082052231, -0.0102332159876823],
        [ 0.9643478393554688,  0.5449203252792358],
        [-0.6978347897529602,  0.3335686028003693]],

       [[ 0.1624187082052231, -0.0102332159876823],
        [ 0.9643478393554688,  0.5449203252792358],
        [-0.6978347897529602,  0.3335686028003693],
        [ 0.1176412274461428, -1.5554193438851447],
        [-0.639407819647646 ,  0.070470209032387 ]],

       [[ 0.9643478393554688,  0.5449203252792358],
        [-0.6978347897529602,  0.3335686028003693],
        [ 0.9371315195670832, -0.5187535949083031],
        [ 1.6102262403242997,  0.7461320383571783],
        [-1.0328825079779835,  0.5870174496427473]],

       [[-0.6978347897529602,  0.3335686028003693],
        [ 0.4909302677064525, -1.8273185492431   ],
        [-1.3642024846362164,  0.4962908880233692],
        [ 0.9805795313754992, -1.0712426831281652],
        [ 0.7456210692367394, -2.0334058101172436]]])

And here is action_batch after transformation:

from torch import tensor
tensor([[[[-0.4491],
          [-0.9950]],

         [[-0.4015],
          [-0.7830]],

         [[-0.4214],
          [-0.7040]],

         [[-0.4044],
          [-0.8347]],

         [[-0.1591],
          [-0.2196]]],


        [[[-0.4015],
          [-0.7830]],

         [[-0.4214],
          [-0.7040]],

         [[-0.4044],
          [-0.8347]],

         [[-0.1591],
          [-0.2196]],

         [[-0.7565],
          [-0.2843]]],


        [[[-0.4214],
          [-0.7040]],

         [[-0.4044],
          [-0.8347]],

         [[-0.1591],
          [-0.2196]],

         [[-0.7565],
          [-0.2843]],

         [[ 0.6579],
          [ 0.8156]]],


        [[[-0.4044],
          [-0.8347]],

         [[-0.1591],
          [-0.2196]],

         [[-0.7565],
          [-0.2843]],

         [[ 0.6579],
          [ 0.8156]],

         [[ 0.7341],
          [ 0.8159]]],


        [[[-0.1591],
          [-0.2196]],

         [[-0.7565],
          [-0.2843]],

         [[ 0.6579],
          [ 0.8156]],

         [[ 0.7341],
          [ 0.8159]],

         [[-0.4985],
          [ 0.2341]]],


        [[[ 0.6579],
          [ 0.8156]],

         [[ 0.7341],
          [ 0.8159]],

         [[-0.4985],
          [ 0.2341]],

         [[ 0.2148],
          [ 0.8802]],

         [[ 0.3157],
          [ 0.1223]]],


        [[[ 0.7341],
          [ 0.8159]],

         [[-0.4985],
          [ 0.2341]],

         [[ 0.2148],
          [ 0.8802]],

         [[ 0.0447],
          [-0.0061]],

         [[ 2.0763],
          [ 0.0612]]],


        [[[-0.4985],
          [ 0.2341]],

         [[ 0.2148],
          [ 0.8802]],

         [[-0.0754],
          [ 0.0184]],

         [[ 0.2301],
          [-0.1545]],

         [[ 0.8510],
          [ 0.0945]]],


        [[[ 0.0692],
          [-0.5581]],

         [[-0.0967],
          [ 0.6544]],

         [[-0.5564],
          [-0.5217]],

         [[ 0.2279],
          [-0.5080]],

         [[-0.7363],
          [ 0.7372]]],


        [[[-0.5564],
          [-0.5217]],

         [[ 0.2279],
          [-0.5080]],

         [[-0.7363],
          [ 0.7372]],

         [[ 0.8861],
          [ 0.7655]],

         [[-0.8919],
          [ 0.7664]]],


        [[[ 0.8861],
          [ 0.7655]],

         [[-0.8919],
          [ 0.7664]],

         [[-0.8112],
          [ 0.3714]],

         [[ 0.9299],
          [ 0.6214]],

         [[-0.7294],
          [-0.5508]]],


        [[[-0.4489],
          [-0.4065]],

         [[-0.0877],
          [ 0.6241]],

         [[-0.1453],
          [ 0.0135]],

         [[ 0.7303],
          [ 0.3243]],

         [[ 0.1931],
          [-0.6004]]],


        [[[-0.0877],
          [ 0.6241]],

         [[-0.1453],
          [ 0.0135]],

         [[ 0.7303],
          [ 0.3243]],

         [[ 0.1931],
          [-0.6004]],

         [[-0.2952],
          [ 0.7422]]],


        [[[-0.1453],
          [ 0.0135]],

         [[ 0.7303],
          [ 0.3243]],

         [[ 0.1931],
          [-0.6004]],

         [[-0.2952],
          [ 0.7422]],

         [[ 0.3889],
          [-0.6838]]],


        [[[ 0.7991],
          [-0.6135]],

         [[-1.4757],
          [-1.5111]],

         [[-1.0754],
          [-0.0688]],

         [[-0.6506],
          [-0.2693]],

         [[-1.5095],
          [ 0.4624]]],


        [[[ 0.9131],
          [-0.9125]],

         [[-0.6183],
          [ 0.8433]],

         [[ 0.9738],
          [ 0.3281]],

         [[-0.5031],
          [-0.4598]],

         [[ 0.1593],
          [-0.3712]]],


        [[[ 0.9750],
          [-0.9102]],

         [[-0.6536],
          [-0.6315]],

         [[-0.3103],
          [ 0.5418]],

         [[ 0.5857],
          [-0.1400]],

         [[ 0.9161],
          [ 0.0171]]],


        [[[-0.6536],
          [-0.6315]],

         [[-0.3103],
          [ 0.5418]],

         [[ 1.1057],
          [ 0.2460]],

         [[-2.7828],
          [-2.2797]],

         [[ 0.3490],
          [ 0.2585]]],


        [[[-0.3103],
          [ 0.5418]],

         [[ 1.7955],
          [-0.1588]],

         [[ 0.8436],
          [-1.7921]],

         [[-0.6244],
          [ 1.0041]],

         [[ 1.0931],
          [-0.8503]]],


        [[[-0.0524],
          [ 0.9105]],

         [[-0.9822],
          [-0.4804]],

         [[-0.9744],
          [ 0.9138]],

         [[-0.8952],
          [-0.8566]],

         [[ 0.6737],
          [-0.8683]]],


        [[[-0.9744],
          [ 0.9138]],

         [[-0.8952],
          [-0.8566]],

         [[ 0.6737],
          [-0.8683]],

         [[-0.9753],
          [-0.1303]],

         [[ 0.0419],
          [-0.9881]]],


        [[[ 0.6737],
          [-0.8683]],

         [[-0.9753],
          [-0.1303]],

         [[ 0.0419],
          [-0.9881]],

         [[-0.8326],
          [-0.4707]],

         [[ 0.4066],
          [-0.1309]]],


        [[[-0.9753],
          [-0.1303]],

         [[ 0.0419],
          [-0.9881]],

         [[-0.8326],
          [-0.4707]],

         [[ 0.4066],
          [-0.1309]],

         [[-0.9776],
          [ 0.0929]]],


        [[[ 0.0419],
          [-0.9881]],

         [[-0.8326],
          [-0.4707]],

         [[ 0.4066],
          [-0.1309]],

         [[-1.2605],
          [ 0.3255]],

         [[-0.6375],
          [-0.2480]]],


        [[[ 0.4066],
          [-0.1309]],

         [[ 1.0853],
          [-1.1135]],

         [[ 0.6383],
          [ 0.4616]],

         [[-0.4527],
          [-0.5026]],

         [[-0.4430],
          [-0.6436]]],


        [[[ 0.8820],
          [ 0.4469]],

         [[ 0.1421],
          [-0.3563]],

         [[ 0.4081],
          [-0.0754]],

         [[-0.9183],
          [-0.7156]],

         [[ 0.3346],
          [-0.7763]]],


        [[[-0.9183],
          [-0.7156]],

         [[ 0.3346],
          [-0.7763]],

         [[ 0.3761],
          [ 0.7810]],

         [[-0.9151],
          [ 0.6707]],

         [[ 0.1624],
          [-0.0102]]],


        [[[ 0.3346],
          [-0.7763]],

         [[ 0.3761],
          [ 0.7810]],

         [[-0.9151],
          [ 0.6707]],

         [[ 0.1624],
          [-0.0102]],

         [[ 0.9643],
          [ 0.5449]]],


        [[[ 0.3761],
          [ 0.7810]],

         [[-0.9151],
          [ 0.6707]],

         [[ 0.1624],
          [-0.0102]],

         [[ 0.9643],
          [ 0.5449]],

         [[-0.6978],
          [ 0.3336]]],


        [[[ 0.1624],
          [-0.0102]],

         [[ 0.9643],
          [ 0.5449]],

         [[-0.6978],
          [ 0.3336]],

         [[ 0.1176],
          [-1.5554]],

         [[-0.6394],
          [ 0.0705]]],


        [[[ 0.9643],
          [ 0.5449]],

         [[-0.6978],
          [ 0.3336]],

         [[ 0.9371],
          [-0.5188]],

         [[ 1.6102],
          [ 0.7461]],

         [[-1.0329],
          [ 0.5870]]],


        [[[-0.6978],
          [ 0.3336]],

         [[ 0.4909],
          [-1.8273]],

         [[-1.3642],
          [ 0.4963]],

         [[ 0.9806],
          [-1.0712]],

         [[ 0.7456],
          [-2.0334]]]])

@puyuan1996
Copy link
Collaborator

Hello, thank you for your detailed feedback. We have confirmed that this was a redundant operation and it has been fixed in the latest commit on the main branch. Thank you once again for your active contribution.

@ekiefl ekiefl closed this as completed Dec 26, 2023
puyuan1996 added a commit that referenced this issue Jul 4, 2024
* Add SumToThree pooltool env

* Woops

* Update datatypes and add single inference mode

* Move core into pooltool

* Add some speed and memory profiling for env debug

* Trying to get CNNs working

* Patch #172

* Setup first experiment

* Fix up sumtothreeimage

* Update obs space to be float

* Move image_representation into fork

- It was in pooltool ai-framework branch
- By moving it here, main branch of pooltool can be used

* Start a README

* Begin test suite for sum_to_three_env

* Add tests for datatypes

* Finish test suite for sum_to_three_env

* rename tests -> characterize

* Delete

* Increase to 300,000 replay buffer

* Finish README

* Fix image link

* Link the discussion page

* Update pooltool API calls to 0.3.0

* Switch to dataclasses

- attrs is not standard library, best not to impose my standards
- Also had some docs

* Progress on documentation and variable naming

* Finish docs for datatypes.py

* Data structure changes

- Additionally, move reward function into reward module and add options
  to select different rewards via cfg

* Parameterize action space bounds

- Remove clunky class methods

* Add a module docstring

* Finish docstrings for sum_to_three coordinate environment

* rm pooltool __init__.py

- LSP was getting confused with the `import pooltool` statement

* Add pytest

* Add pooltool-billiards

* Add docs for reward space

* Add tests for grayscale conversion, add docs

* Add module doc for reward.py

* Add docs for image_representation

* Fix image env

* Update info about px parameter

* Add serialie/deserialize methods for RenderConfig

* Three things:

- move px to RenderConfig
- serialize/deserialization methods for RenderConfig
- Mimic the refactor in cts env to the image env

* Use channels in renderconfig

* Buff image_representation visualization

- Add an animation

* Start consolidation

* More consolidation between observation types

* consolidate image and coordinate observation types

* Remove old file

* Add default config

* Single source state setting

* Add tests

* Unused

* Add default render config option

- Store as attribute

* Add speed test script

* Small changes

* Add sum to three to feature table

* Update pooltool README

* Move observation/ and reward.py into utils.py

* polish(pu): polish sum_to_three configs

* feature(pu): add sum_to_three_vector_obs_sac_config.py and polish related config names

* polish(pu): polish sum_to_three configs

* polish(pu): polish pooltool configs

---------

Co-authored-by: dyyoungg <yangdeyu@sensetime.com>
Co-authored-by: 蒲源 <2402552459@qq.com>
Co-authored-by: 蒲源 <48008469+puyuan1996@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants