Skip to content

Commit

Permalink
feature(xrk): add new env named Flozen Lake and DQN algorithm. (#781)
Browse files Browse the repository at this point in the history
* first_commit

* environment test pass

* frame creative

* change init file to new function

* change code to fit the pr request

* To ensure the environment operates correctly, consider adding more assertions for robust validation

* now it can use serial_pipeline to make function

* we change the code to justify new turnel

* Compliance Check

* add gif

* format my code
  • Loading branch information
rongkunxue authored Mar 13, 2024
1 parent c999b07 commit aeb4c9c
Show file tree
Hide file tree
Showing 9 changed files with 502 additions and 196 deletions.
399 changes: 203 additions & 196 deletions README.md

Large diffs are not rendered by default.

45 changes: 45 additions & 0 deletions ding/example/dqn_frozen_lake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from ditk import logging
from ding.model import DQN
from ding.policy import DQNPolicy
from ding.envs import DingEnvWrapper, BaseEnvManagerV2
from ding.data import DequeBuffer
from ding.config import compile_config
from ding.framework import task
from ding.framework.context import OnlineRLContext
from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
eps_greedy_handler, CkptSaver, nstep_reward_enhancer, final_ctx_saver
from ding.utils import set_pkg_seed
from dizoo.frozen_lake.config.frozen_lake_dqn_config import main_config, create_config
from dizoo.frozen_lake.envs import FrozenLakeEnv


def main():
logging.getLogger().setLevel(logging.INFO)
main_config.policy.nstep = 5
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
with task.start(async_mode=False, ctx=OnlineRLContext()):
collector_env = BaseEnvManagerV2(
env_fn=[lambda: FrozenLakeEnv(cfg=cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
)
evaluator_env = BaseEnvManagerV2(
env_fn=[lambda: FrozenLakeEnv(cfg=cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

model = DQN(**cfg.policy.model)
buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
policy = DQNPolicy(cfg.policy, model=model)

task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(eps_greedy_handler(cfg))
task.use(StepCollector(cfg, policy.collect_mode, collector_env))
task.use(nstep_reward_enhancer(cfg))
task.use(data_pusher(cfg, buffer_))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
task.use(final_ctx_saver(cfg.exp_name))
task.run()


if __name__ == "__main__":
main()
Binary file added dizoo/frozen_lake/FrozenLake.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file added dizoo/frozen_lake/__init__.py
Empty file.
1 change: 1 addition & 0 deletions dizoo/frozen_lake/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .frozen_lake_dqn_config import main_config, create_config
64 changes: 64 additions & 0 deletions dizoo/frozen_lake/config/frozen_lake_dqn_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from easydict import EasyDict

frozen_lake_dqn_config = dict(
exp_name='frozen_lake_seed0',
env=dict(
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=10,
env_id='FrozenLake-v1',
desc=None,
map_name="4x4",
is_slippery=False,
save_replay_gif=False,
),
policy=dict(
cuda=True,
load_path='frozen_lake_seed0/ckpt/ckpt_best.pth.tar',
model=dict(
obs_shape=16,
action_shape=4,
encoder_hidden_size_list=[128, 128, 64],
dueling=True,
),
nstep=3,
discount_factor=0.97,
learn=dict(
update_per_collect=5,
batch_size=256,
learning_rate=0.001,
),
collect=dict(n_sample=10),
eval=dict(evaluator=dict(eval_freq=40, )),
other=dict(
eps=dict(
type='exp',
start=0.8,
end=0.1,
decay=10000,
),
replay_buffer=dict(replay_buffer_size=20000, ),
),
),
)

frozen_lake_dqn_config = EasyDict(frozen_lake_dqn_config)
main_config = frozen_lake_dqn_config

frozen_lake_dqn_create_config = dict(
env=dict(
type='frozen_lake',
import_names=['dizoo.frozen_lake.envs.frozen_lake_env'],
),
env_manager=dict(type='base'),
policy=dict(type='dqn'),
replay_buffer=dict(type='deque', import_names=['ding.data.buffer.deque_buffer_wrapper']),
)

frozen_lake_dqn_create_config = EasyDict(frozen_lake_dqn_create_config)
create_config = frozen_lake_dqn_create_config

if __name__ == "__main__":
# or you can enter `ding -m serial -c frozen_lake_dqn_config.py -s 0`
from ding.entry import serial_pipeline
serial_pipeline((main_config, create_config), max_env_step=5000, seed=0)
1 change: 1 addition & 0 deletions dizoo/frozen_lake/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .frozen_lake_env import FrozenLakeEnv
144 changes: 144 additions & 0 deletions dizoo/frozen_lake/envs/frozen_lake_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from typing import Any, Dict, List, Optional
import imageio
import os
import gymnasium as gymn
import numpy as np
from ding.envs import BaseEnv, BaseEnvTimestep
from ding.torch_utils import to_ndarray
from ding.utils import ENV_REGISTRY


@ENV_REGISTRY.register('frozen_lake')
class FrozenLakeEnv(BaseEnv):

def __init__(self, cfg) -> None:
self._cfg = cfg
assert self._cfg.env_id == "FrozenLake-v1", "yout name is not FrozernLake_v1"
self._init_flag = False
self._save_replay_bool = False
self._save_replay_count = 0
self._init_flag = False
self._frames = []
self._replay_path = False

def reset(self) -> np.ndarray:
if not self._init_flag:
if not self._cfg.desc: #specify maps non-preloaded maps
self._env = gymn.make(
self._cfg.env_id,
desc=self._cfg.desc,
map_name=self._cfg.map_name,
is_slippery=self._cfg.is_slippery,
render_mode="rgb_array"
)
self._observation_space = self._env.observation_space
self._action_space = self._env.action_space
self._reward_space = gymn.spaces.Box(
low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
)
self._init_flag = True
self._eval_episode_return = 0
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
np_seed = 100 * np.random.randint(1, 1000)
self._env_seed = self._seed + np_seed
elif hasattr(self, '_seed'):
self._env_seed = self._seed
if hasattr(self, '_seed'):
obs, info = self._env.reset(seed=self._env_seed)
else:
obs, info = self._env.reset()
obs = np.eye(16, dtype=np.float32)[obs - 1]
return obs

def close(self) -> None:
if self._init_flag:
self._env.close()
self._init_flag = False

def seed(self, seed: int, dynamic_seed: bool = True) -> None:
self._seed = seed
self._dynamic_seed = dynamic_seed
np.random.seed(self._seed)

def step(self, action: Dict) -> BaseEnvTimestep:
obs, rew, terminated, truncated, info = self._env.step(action[0])
self._eval_episode_return += rew
obs = np.eye(16, dtype=np.float32)[obs - 1]
rew = to_ndarray([rew])
if self._save_replay_bool:
picture = self._env.render()
self._frames.append(picture)
if terminated or truncated:
done = True
else:
done = False
if done:
info['eval_episode_return'] = self._eval_episode_return
if self._save_replay_bool:
assert self._replay_path is not None, "your should have a path"
path = os.path.join(
self._replay_path, '{}_episode_{}.gif'.format(self._cfg.env_id, self._save_replay_count)
)
self.frames_to_gif(self._frames, path)
self._frames = []
self._save_replay_count += 1
rew = rew.astype(np.float32)
return BaseEnvTimestep(obs, rew, done, info)

def random_action(self) -> Dict:
raw_action = self._env.action_space.sample()
my_type = type(self._env.action_space)
return [raw_action]

def __repr__(self) -> str:
return "DI-engine Frozen Lake Env"

@property
def observation_space(self) -> gymn.spaces.Space:
return self._observation_space

@property
def action_space(self) -> gymn.spaces.Space:
return self._action_space

@property
def reward_space(self) -> gymn.spaces.Space:
return self._reward_space

def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
if replay_path is None:
replay_path = './video'
self._replay_path = replay_path
self._save_replay_bool = True
self._save_replay_count = 0
self._frames = []

@staticmethod
def frames_to_gif(frames: List[imageio.core.util.Array], gif_path: str, duration: float = 0.1) -> None:
"""
Convert a list of frames into a GIF.
Args:
- frames (List[imageio.core.util.Array]): A list of frames, each frame is an image.
- gif_path (str): The path to save the GIF file.
- duration (float): Duration between each frame in the GIF (seconds).
Returns:
None, the GIF file is saved directly to the specified path.
"""
# Save all frames as temporary image files
temp_image_files = []
for i, frame in enumerate(frames):
temp_image_file = f"frame_{i}.png" # Temporary file name
imageio.imwrite(temp_image_file, frame) # Save the frame as a PNG file
temp_image_files.append(temp_image_file)

# Use imageio to convert temporary image files to GIF
with imageio.get_writer(gif_path, mode='I', duration=duration) as writer:
for temp_image_file in temp_image_files:
image = imageio.imread(temp_image_file)
writer.append_data(image)

# Clean up temporary image files
for temp_image_file in temp_image_files:
os.remove(temp_image_file)
print(f"GIF saved as {gif_path}")
44 changes: 44 additions & 0 deletions dizoo/frozen_lake/envs/test_frozen_lake_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np
import pytest
from dizoo.frozen_lake.envs import FrozenLakeEnv
from easydict import EasyDict


@pytest.mark.envtest
class TestGymHybridEnv:

def test_my_lake(self):
env = FrozenLakeEnv(
EasyDict({
'env_id': 'FrozenLake-v1',
'desc': None,
'map_name': "4x4",
'is_slippery': False,
})
)
for _ in range(5):
env.seed(314, dynamic_seed=False)
assert env._seed == 314
obs = env.reset()
assert obs.shape == (
16,
), "Considering the one-hot encoding format, your observation should have a dimensionality of 16."
for i in range(10):
env.enable_save_replay("./video")
# Both ``env.random_action()``, and utilizing ``np.random`` as well as action space,
# can generate legal random action.
if i < 5:
random_action = np.array([env.action_space.sample()])
else:
random_action = env.random_action()
timestep = env.step(random_action)
print(timestep)
assert isinstance(timestep.obs, np.ndarray)
assert isinstance(timestep.done, bool)
assert timestep.obs.shape == (16, )
assert timestep.reward.shape == (1, )
assert timestep.reward >= env.reward_space.low
assert timestep.reward <= env.reward_space.high

print(env.observation_space, env.action_space, env.reward_space)
env.close()

0 comments on commit aeb4c9c

Please sign in to comment.