Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(ek): add pooltool env and related configs #227

Merged
merged 70 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from 66 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
98fa06b
Add SumToThree pooltool env
ekiefl Dec 5, 2023
8a31208
Woops
ekiefl Dec 5, 2023
0035407
Update datatypes and add single inference mode
ekiefl Dec 9, 2023
5e26050
Move core into pooltool
ekiefl Dec 11, 2023
a75a0c6
Merge remote-tracking branch 'upstream/main'
ekiefl Dec 12, 2023
f45b9bb
Merge remote-tracking branch 'upstream/main'
ekiefl Dec 12, 2023
1a5b969
Merge remote-tracking branch 'upstream/main'
ekiefl Dec 12, 2023
d50e02a
Add some speed and memory profiling for env debug
ekiefl Dec 13, 2023
ce3c6cf
Trying to get CNNs working
ekiefl Dec 20, 2023
39aadc4
Merge remote-tracking branch 'upstream/main'
ekiefl Dec 22, 2023
054e37d
Patch https://github.com/opendilab/LightZero/issues/172
ekiefl Dec 22, 2023
b250458
Setup first experiment
ekiefl Dec 22, 2023
40aa25a
Fix up sumtothreeimage
ekiefl Dec 26, 2023
fde6a2c
Update obs space to be float
ekiefl Dec 29, 2023
a977cc9
Move image_representation into fork
ekiefl Jan 12, 2024
73fa10e
Merge remote-tracking branch 'upstream/main'
ekiefl Jan 12, 2024
a9cb4e6
Start a README
ekiefl Jan 13, 2024
4f2abb8
Begin test suite for sum_to_three_env
ekiefl Jan 14, 2024
40d0779
Add tests for datatypes
ekiefl Jan 15, 2024
6a15222
Finish test suite for sum_to_three_env
ekiefl Jan 15, 2024
a93e5eb
rename tests -> characterize
ekiefl Jan 17, 2024
e5ebc4c
Delete
ekiefl Jan 17, 2024
bc96d48
Increase to 300,000 replay buffer
ekiefl Jan 17, 2024
2e2c69e
Finish README
ekiefl Jan 21, 2024
76dbe2d
Fix image link
ekiefl Jan 21, 2024
ff6d6f7
Link the discussion page
ekiefl Jan 21, 2024
1259d53
Update pooltool API calls to 0.3.0
ekiefl Mar 19, 2024
e43f585
Switch to dataclasses
ekiefl Apr 7, 2024
0107be8
Progress on documentation and variable naming
ekiefl Apr 7, 2024
0b827b6
Merge remote-tracking branch 'upstream/main'
ekiefl Apr 11, 2024
4118e1f
Finish docs for datatypes.py
ekiefl Apr 11, 2024
68e91e4
Data structure changes
ekiefl Apr 11, 2024
b688e0d
Parameterize action space bounds
ekiefl Apr 11, 2024
f35bd77
Add a module docstring
ekiefl Apr 12, 2024
672df11
Finish docstrings for sum_to_three coordinate environment
ekiefl Apr 12, 2024
1fb508f
rm pooltool __init__.py
ekiefl Apr 22, 2024
43dbbd7
Add pytest
ekiefl Apr 22, 2024
bcffd22
Add pooltool-billiards
ekiefl Apr 22, 2024
d3b6053
Add docs for reward space
ekiefl Apr 22, 2024
ec0cf41
Add tests for grayscale conversion, add docs
ekiefl Apr 22, 2024
903f676
Add module doc for reward.py
ekiefl Apr 22, 2024
66058bb
Add docs for image_representation
ekiefl May 2, 2024
50bacab
Fix image env
ekiefl May 3, 2024
0159fc4
Update info about px parameter
ekiefl May 3, 2024
1e83417
Add serialie/deserialize methods for RenderConfig
ekiefl May 4, 2024
e488417
Three things:
ekiefl May 4, 2024
1e20b8a
Use channels in renderconfig
ekiefl May 4, 2024
af8f71c
Buff image_representation visualization
ekiefl May 4, 2024
083db9e
Start consolidation
ekiefl May 4, 2024
52a952f
More consolidation between observation types
ekiefl May 4, 2024
0f05b9e
consolidate image and coordinate observation types
ekiefl May 5, 2024
90db7cf
Remove old file
ekiefl May 5, 2024
326cd9d
Add default config
ekiefl May 6, 2024
ac26d39
Single source state setting
ekiefl May 6, 2024
21425b0
Add tests
ekiefl May 6, 2024
619b32c
Unused
ekiefl May 6, 2024
b06d085
Add default render config option
ekiefl May 6, 2024
9238f8a
Add speed test script
ekiefl May 6, 2024
b141b25
Merge branch 'opendilab:main' into main
ekiefl May 27, 2024
6221d6e
Small changes
ekiefl May 27, 2024
fbda12a
Add sum to three to feature table
ekiefl Jun 23, 2024
e6138a7
Update pooltool README
ekiefl Jun 23, 2024
de799f3
Move observation/ and reward.py into utils.py
ekiefl Jun 23, 2024
17f806e
polish(pu): polish sum_to_three configs
dyyoungg Jun 24, 2024
2cefe42
Merge branch 'main' of https://github.com/ekiefl/LightZero into dev-p…
dyyoungg Jun 24, 2024
d46f18f
feature(pu): add sum_to_three_vector_obs_sac_config.py and polish rel…
puyuan1996 Jun 24, 2024
74afa1f
Merge branch 'main' into main
puyuan1996 Jun 24, 2024
ec4d6df
polish(pu): polish sum_to_three configs
dyyoungg Jul 1, 2024
5bdcf48
Merge remote-tracking branch 'origin/main' into dev-pooltool
dyyoungg Jul 4, 2024
54f7928
polish(pu): polish pooltool configs
dyyoungg Jul 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -1444,4 +1444,7 @@ events.*
!/lzero/mcts/**/lib/*.h
**/tb/*
**/mcts/ctree/tests_cpp/*
**/*tmp*
**/*tmp*

# pooltool-specific stuff
puyuan1996 marked this conversation as resolved.
Show resolved Hide resolved
!/assets/pooltool/**
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ The environments and algorithms currently supported by LightZero are shown in th
| MiniGrid | --- | ✔ | ✔ | ✔ | 🔒 | 🔒 |
| Bsuite | --- | ✔ | ✔ | ✔ | 🔒 | 🔒 |
| Memory | --- | ✔ | ✔ | ✔ | 🔒 | 🔒 |
| SumToThree (billiards) | --- | 🔒 | 🔒 | ✔ | 🔒 | 🔒 |

<sup>(1): "✔" means that the corresponding item is finished and well-tested.</sup>

Expand Down
1 change: 1 addition & 0 deletions README.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ LightZero 目前支持的环境及算法如下表所示:
| MiniGrid | --- | ✔ | ✔ | ✔ | 🔒 | 🔒 |
| Bsuite | --- | ✔ | ✔ | ✔ | 🔒 | 🔒 |
| Memory | --- | ✔ | ✔ | ✔ | 🔒 | 🔒 |
| SumToThree (billiards) | --- | 🔒 | 🔒 | ✔ | 🔒 | 🔒 |

<sup>(1): "✔" 表示对应的项目已经完成并经过良好的测试。</sup>

Expand Down
Binary file added assets/pooltool/3hits.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/pooltool/4hits.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/pooltool/cts.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/pooltool/cts_zoom.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/pooltool/discrete.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/pooltool/feature_planes.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/pooltool/largecut.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/pooltool/nocut.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
81 changes: 58 additions & 23 deletions lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
import torch.nn as nn
from ding.torch_utils import MLP, ResBlock
from ding.utils import SequenceType

import torch.nn.init as init
import torch.nn.functional as F

# use dataclass to make the output of network more convenient to use
@dataclass
Expand All @@ -35,6 +36,31 @@ class MZNetworkOutput:
latent_state: torch.Tensor



class SimNorm(nn.Module):
"""
Simplicial normalization.
Adapted from https://arxiv.org/abs/2204.00616.
"""

def __init__(self, simnorm_dim):
super().__init__()
self.dim = simnorm_dim

def forward(self, x):
shp = x.shape
# Ensure that there is at least one simplex to normalize across.
if shp[1] != 0:
x = x.view(*shp[:-1], -1, self.dim)
x = F.softmax(x, dim=-1)
return x.view(*shp)
else:
return x

def __repr__(self):
return f"SimNorm(dim={self.dim})"


class DownSample(nn.Module):

def __init__(self, observation_shape: SequenceType, out_channels: int, activation: nn.Module = nn.ReLU(inplace=True),
Expand Down Expand Up @@ -140,6 +166,9 @@ def __init__(
downsample: bool = True,
activation: nn.Module = nn.ReLU(inplace=True),
norm_type: str = 'BN',
embedding_dim: int = 256,
group_size: int = 8,
use_sim_norm: bool = False,
) -> None:
"""
Overview:
Expand Down Expand Up @@ -174,19 +203,30 @@ def __init__(
self.norm = nn.BatchNorm2d(num_channels)
elif norm_type == 'LN':
if downsample:
self.norm = nn.LayerNorm([num_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)])
self.norm = nn.LayerNorm(
[num_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)],
eps=1e-5)
else:
self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]])
self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]], eps=1e-5)

self.resblocks = nn.ModuleList(
[
ResBlock(
in_channels=num_channels, activation=activation, norm_type='BN', res_type='basic', bias=False
in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False
) for _ in range(num_res_blocks)
]
)
self.activation = activation

self.use_sim_norm = use_sim_norm

if self.use_sim_norm:
self.embedding_dim = embedding_dim
self.last_linear = nn.Linear(64 * 8 * 8, self.embedding_dim, bias=False)
# Initialize weights using He initialization
init.kaiming_normal_(self.last_linear.weight, mode='fan_out', nonlinearity='relu')
self.sim_norm = SimNorm(simnorm_dim=group_size)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Shapes:
Expand All @@ -204,20 +244,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

for block in self.resblocks:
x = block(x)
return x

def get_param_mean(self) -> float:
"""
Overview:
Get the mean of parameters in the network for debug and visualization.
Returns:
- mean (:obj:`float`): The mean of parameters in the network.
"""
mean = []
for name, param in self.named_parameters():
mean += np.abs(param.detach().cpu().numpy().reshape(-1)).tolist()
mean = sum(mean) / len(mean)
return mean
if self.use_sim_norm:
# NOTE: very important.
# for atari 64,8,8 = 4096 -> 768
x = self.sim_norm(x)

return x


class RepresentationNetworkMLP(nn.Module):
Expand All @@ -227,9 +260,9 @@ def __init__(
observation_shape: int,
hidden_channels: int = 64,
layer_num: int = 2,
activation: Optional[nn.Module] = nn.ReLU(inplace=True),
last_linear_layer_init_zero: bool = True,
activation: nn.Module = nn.GELU(),
norm_type: Optional[str] = 'BN',
group_size: int = 8,
) -> torch.Tensor:
"""
Overview:
Expand All @@ -244,8 +277,6 @@ def __init__(
we don't need this module.
- activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(). \
Use the inplace operation to speed up.
- last_linear_layer_init_zero (:obj:`bool`): Whether to initialize the last linear layer with zeros, \
which can provide stable zero outputs in the beginning, defaults to True.
- norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'.
"""
super().__init__()
Expand All @@ -262,14 +293,18 @@ def __init__(
# last_linear_layer_init_zero=True is beneficial for convergence speed.
last_linear_layer_init_zero=True,
)
self.sim_norm = SimNorm(simnorm_dim=group_size)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Shapes:
- x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation.
- output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size.
"""
return self.fc_representation(x)
x = self.fc_representation(x)
x = self.sim_norm(x)
return x



class PredictionNetwork(nn.Module):
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ pympler
bsuite
minigrid
moviepy
pycolab
pycolab
pytest
pooltool-billiards>=0.3.1
35 changes: 27 additions & 8 deletions zoo/atari/config/atari_muzero_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from easydict import EasyDict
import torch
device = 1
torch.cuda.set_device(device)

# options={'PongNoFrameskip-v4', 'QbertNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 'BreakoutNoFrameskip-v4', ...}
env_id = 'PongNoFrameskip-v4'
env_id = 'MsPacmanNoFrameskip-v4'

if env_id == 'PongNoFrameskip-v4':
action_space_size = 6
Expand All @@ -17,21 +20,27 @@
# ==============================================================
# begin of the most frequently changed config specified by the user
# ==============================================================
collector_env_num = 8
n_episode = 8
# collector_env_num = 8
# n_episode = 8
collector_env_num = 1
n_episode = 1
evaluator_env_num = 3
num_simulations = 50
update_per_collect = 1000
# update_per_collect = 1000
update_per_collect = None
model_update_ratio = 0.25
batch_size = 256
max_env_step = int(1e6)
reanalyze_ratio = 0.
# max_env_step = int(1e6)
max_env_step = int(1e8)
# reanalyze_ratio = 0.
reanalyze_ratio = 1
eps_greedy_exploration_in_collect = False
# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================

atari_muzero_config = dict(
exp_name=f'data_mz_ctree/{env_id[:-14]}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
exp_name=f'data_muzero_tune/{env_id[:-14]}_muzero_collect{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_no-priority_seed0',
env=dict(
stop_value=int(1e6),
env_id=env_id,
Expand All @@ -42,6 +51,13 @@
manager=dict(shared_memory=False, ),
),
policy=dict(
learn=dict(
learner=dict(
hook=dict(
save_ckpt_after_iter=1000000, # default is 10000
),
),
),
model=dict(
observation_shape=(4, 96, 96),
frame_stack_num=4,
Expand All @@ -65,8 +81,10 @@
end=0.05,
decay=int(1e5),
),
use_priority=False, # TODO
use_augmentation=True,
update_per_collect=update_per_collect,
model_update_ratio=model_update_ratio,
batch_size=batch_size,
optim_type='SGD',
lr_piecewise_constant_decay=True,
Expand All @@ -75,7 +93,8 @@
reanalyze_ratio=reanalyze_ratio,
ssl_loss_weight=2, # default is 0
n_episode=n_episode,
eval_freq=int(2e3),
# eval_freq=int(2e3),
eval_freq=int(1e4),
replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions.
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
Expand Down
123 changes: 123 additions & 0 deletions zoo/pooltool/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Billiards RL

Welcome to the documentation for billiards simulation within the LightZero framework. Billiards offers an intriguing learning environment for reinforcement learning due to its continuous action space, turn-based play, and the need for long-term planning and strategy formulation.

## Pooltool

Pooltool is a general purpose billiards simulator crafted specifically for science and engineering applications (learn more [here](https://github.com/ekiefl/pooltool)). It has been incorporated into LightZero to create diverse learning environments for billiards games.

## Testing your installation

Pooltool comes pre-installed with LightZero. If you are using a custom setup, follow the _pip_ install instructions [here](https://pooltool.readthedocs.io/en/latest/getting_started/install.html#install-option-1-pip).

Verify pooltool is found in your python path:

```bash
python -c "import pooltool; print(pooltool.__version__)"
```

Further test your installation by opening the interactive interface:

```bash
# Unix
run_pooltool

# Windows
run_pooltool.bat
```

(For instructions on how to play, check out the [Getting Started tutorial](https://pooltool.readthedocs.io/en/latest/getting_started/interface.html))

## Supported Games

Currently supports the following games:

1. **Sum to Three**: A simplified billiards game designed to make learning easier for agents.
2. **Standard Billiards Games** (planned for future updates): Including 8-ball, 9-ball, and snooker.

The rest of the document provides details for each supported game.

## Game 1: Sum to Three

Standard billiards games like 8-ball, 9-ball, and snooker have complex rulesets which make learning more difficult.

In contrast, _sum to three_ is a fictitious billiards game with a simple ruleset.

### Rules

1. The game is played on a table with no pockets
1. There are 2 balls: a cue ball and an object ball
1. The player must hit the object ball with the cue ball
1. The player scores a point if the number of times a ball hits a cushion is 3
1. The player takes 10 shots, and their final score is the number of points they achieve

For example, this is a successful shot because there are three ball-cushion collisions:

<img src="../../assets/pooltool/3hits.gif" width="600" />

This is an unsuccessful shot because there are four ball-cushion collisions:

<img src="../../assets/pooltool/4hits.gif" width="600" />

### Observation / Action Spaces

Continuous and discrete observatwon spaces are supported. The continuous observation space uses the coordinates of the two balls as the observation. The discrete observation space is based on configurable image-based feature planes.

In general, when an agent strikes a cue ball, the cue stick is described by 5 continuous parameters:

```
V0 : positive float
What initial velocity does the cue strike the ball?
phi : float (degrees)
The direction you strike the ball
theta : float (degrees)
How elevated is the cue from the playing surface, in degrees?
a : float
How much side english should be put on? -1 being rightmost side of ball, +1 being
leftmost side of ball
b : float
How much vertical english should be put on? -1 being bottom-most side of ball, +1 being
topmost side of ball
```

Since sum to three is a simple game, only a reduced action space with 2 parameters is supported:

1. V0: The speed of the cue stick. Increasing this means the cue ball travels further
1. cut angle: The angle that the cue ball hits the object ball with

For example, in this shot, the cut angle is -70 (hitting the left side of the object ball):

<img src="../../assets/pooltool/largecut.gif" width="600" />

For example, in this shot, the cut angle is 0 (head-on collision):

<img src="../../assets/pooltool/nocut.gif" width="600" />

Based on the game dimensions, a suitable bound for the action parameters is used: [0.3, 3] for speed and [-70, 70] for cut angle.

### Experiments

You can conduct experiments using different observation spaces:

1. **Continuous Observation Space Experiment**:
- Run the experiment with:
```bash
python ./zoo/pooltool/sum_to_three/config/sum_to_three_config.py
```
- Results will be saved in `./data_pooltool_sampled_efficientzero/image-obs`.

2. **Discrete Observation Space Experiment**:
- Run the experiment with:
```bash
python ./zoo/pooltool/sum_to_three/config/sum_to_three_image_config.py
```
- Modify the feature plane information by editing `./zoo/pooltool/sum_to_three/config/feature_plane_config.json`. View the usage example in `./zoo/pooltool/image_representation.py` for details about the feature plane content.
- Results will be saved in `./data_pooltool_sampled_efficientzero/vector-obs`.

### Results

TODO(puyuan1996)

## Game 2: 8-ball / 9-ball / 3-cushion / snooker

What billiards game would you like to see next?
Loading
Loading