Skip to content

Commit 22ed628

Browse files
committed
polish ppof rewardclip and add atari config
1 parent 1e6f503 commit 22ed628

File tree

4 files changed

+94
-14
lines changed

4 files changed

+94
-14
lines changed

ding/bonus/config.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,45 @@ def get_instance_config(env: str) -> EasyDict:
7777
critic_head_hidden_size=256,
7878
actor_head_hidden_size=256,
7979
)
80+
elif env == 'qbert':
81+
cfg.n_sample = 1024
82+
cfg.batch_size = 128
83+
cfg.epoch_per_collect = 10
84+
cfg.learning_rate = 0.0001
85+
cfg.model = dict(
86+
obs_shape=[4, 84, 84],
87+
action_shape=6,
88+
encoder_hidden_size_list=[32, 64, 64, 128],
89+
actor_head_hidden_size=128,
90+
critic_head_hidden_size=128,
91+
critic_head_layer_num=2,
92+
)
93+
elif env == 'kangaroo':
94+
cfg.n_sample = 1024
95+
cfg.batch_size = 128
96+
cfg.epoch_per_collect = 10
97+
cfg.learning_rate = 0.0001
98+
cfg.model = dict(
99+
obs_shape=[4, 84, 84],
100+
action_shape=18,
101+
encoder_hidden_size_list=[32, 64, 64, 128],
102+
actor_head_hidden_size=128,
103+
critic_head_hidden_size=128,
104+
critic_head_layer_num=2,
105+
)
106+
elif env == 'bowling':
107+
cfg.n_sample = 1024
108+
cfg.batch_size = 128
109+
cfg.epoch_per_collect = 10
110+
cfg.learning_rate = 0.0001
111+
cfg.model = dict(
112+
obs_shape=[4, 84, 84],
113+
action_shape=6,
114+
encoder_hidden_size_list=[32, 64, 64, 128],
115+
actor_head_hidden_size=128,
116+
critic_head_hidden_size=128,
117+
critic_head_layer_num=2,
118+
)
80119
else:
81120
raise KeyError("not supported env type: {}".format(env))
82121
return cfg
@@ -152,6 +191,36 @@ def get_instance_env(env: str) -> BaseEnv:
152191
},
153192
seed_api=False,
154193
)
194+
elif env == 'qbert':
195+
from dizoo.atari.envs.atari_env import AtariEnv
196+
cfg = EasyDict({
197+
'env_id': 'QbertNoFrameskip-v4',
198+
'env_wrapper': 'atari_default',
199+
})
200+
ding_env_atari = DingEnvWrapper(gym.make('QbertNoFrameskip-v4'), cfg=cfg)
201+
#ding_env_atari.enable_save_replay('atari_log/')
202+
obs = ding_env_atari.reset()
203+
return ding_env_atari
204+
elif env == 'kangaroo':
205+
from dizoo.atari.envs.atari_env import AtariEnv
206+
cfg = EasyDict({
207+
'env_id': 'KangarooNoFrameskip-v4',
208+
'env_wrapper': 'atari_default',
209+
})
210+
ding_env_atari = DingEnvWrapper(gym.make('KangarooNoFrameskip-v4'), cfg=cfg)
211+
#ding_env_atari.enable_save_replay('atari_log/')
212+
obs = ding_env_atari.reset()
213+
return ding_env_atari
214+
elif env == 'bowling':
215+
from dizoo.atari.envs.atari_env import AtariEnv
216+
cfg = EasyDict({
217+
'env_id': 'BowlingNoFrameskip-v4',
218+
'env_wrapper': 'atari_default',
219+
})
220+
ding_env_atari = DingEnvWrapper(gym.make('BowlingNoFrameskip-v4'), cfg=cfg)
221+
#ding_env_atari.enable_save_replay('atari_log/')
222+
obs = ding_env_atari.reset()
223+
return ding_env_atari
155224
else:
156225
raise KeyError("not supported env type: {}".format(env))
157226

ding/bonus/ppof.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Optional, Union
22
from ditk import logging
33
from easydict import EasyDict
4+
from functools import partial
45
import os
56
import gym
67
import torch
@@ -30,6 +31,10 @@ class PPOF:
3031
'mario',
3132
'di_sheep',
3233
'procgen_bigfish',
34+
# atari
35+
'qbert',
36+
'kangaroo',
37+
'bowling'
3338
]
3439

3540
def __init__(
@@ -67,8 +72,11 @@ def __init__(
6772
action_shape = action_space.shape
6873
if model is None:
6974
model = PPOFModel(
70-
self.env.observation_space.shape, action_shape, action_space=self.cfg.action_space, **self.cfg.model
75+
action_space=self.cfg.action_space, **self.cfg.model
7176
)
77+
# model = PPOFModel(
78+
# self.env.observation_space.shape, action_shape, action_space=self.cfg.action_space, **self.cfg.model
79+
# )
7280
self.policy = PPOFPolicy(self.cfg, model=model)
7381

7482
def train(
@@ -86,7 +94,7 @@ def train(
8694
logging.debug(self.policy._model)
8795
# define env and policy
8896
collector_env = self._setup_env_manager(collector_env_num, context, debug)
89-
evaluator_env = self._setup_env_manager(evaluator_env_num, context, debug)
97+
evaluator_env = self._setup_env_manager(evaluator_env_num, context, debug, 'evaluator')
9098

9199
with task.start(ctx=OnlineRLContext()):
92100
task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env))
@@ -168,7 +176,7 @@ def batch_evaluate(
168176
if debug:
169177
logging.getLogger().setLevel(logging.DEBUG)
170178
# define env and policy
171-
env = self._setup_env_manager(env_num, context, debug)
179+
env = self._setup_env_manager(env_num, context, debug, 'evaluator')
172180
if ckpt_path is None:
173181
ckpt_path = os.path.join(self.exp_name, 'ckpt/eval.pth.tar')
174182
state_dict = torch.load(ckpt_path, map_location='cpu')
@@ -179,7 +187,7 @@ def batch_evaluate(
179187
task.use(interaction_evaluator_ttorch(self.seed, self.policy, env, n_evaluator_episode))
180188
task.run(max_step=1)
181189

182-
def _setup_env_manager(self, env_num: int, context: Optional[str] = None, debug: bool = False) -> BaseEnvManagerV2:
190+
def _setup_env_manager(self, env_num: int, context: Optional[str] = None, debug: bool = False, caller: str = 'collector') -> BaseEnvManagerV2:
183191
if debug:
184192
env_cls = BaseEnvManagerV2
185193
manager_cfg = env_cls.default_config()
@@ -188,4 +196,4 @@ def _setup_env_manager(self, env_num: int, context: Optional[str] = None, debug:
188196
manager_cfg = env_cls.default_config()
189197
if context is not None:
190198
manager_cfg.context = context
191-
return env_cls([self.env.clone for _ in range(env_num)], manager_cfg)
199+
return env_cls([partial(self.env.clone, caller) for _ in range(env_num)], manager_cfg)

ding/envs/env/default_wrapper.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
eval_episode_return_wrapper = EasyDict(type='eval_episode_return')
66

77

8-
def get_default_wrappers(env_wrapper_name: str, env_id: Optional[str] = None) -> List[dict]:
8+
def get_default_wrappers(env_wrapper_name: str, env_id: Optional[str] = None, caller: str = 'collector') -> List[dict]:
9+
assert caller == 'collector' or 'evaluator'
910
if env_wrapper_name == 'mujoco_default':
1011
return [
1112
EasyDict(type='delay_reward', kwargs=dict(delay_reward_step=3)),
@@ -21,7 +22,8 @@ def get_default_wrappers(env_wrapper_name: str, env_id: Optional[str] = None) ->
2122
wrapper_list.append(EasyDict(type='fire_reset'))
2223
wrapper_list.append(EasyDict(type='warp_frame'))
2324
wrapper_list.append(EasyDict(type='scaled_float_frame'))
24-
wrapper_list.append(EasyDict(type='clip_reward'))
25+
if caller == 'collector':
26+
wrapper_list.append(EasyDict(type='clip_reward'))
2527
wrapper_list.append(EasyDict(type='frame_stack', kwargs=dict(n_frames=4)))
2628
wrapper_list.append(copy.deepcopy(eval_episode_return_wrapper))
2729
return wrapper_list

ding/envs/env/ding_env_wrapper.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
class DingEnvWrapper(BaseEnv):
1717

18-
def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True) -> None:
18+
def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True, caller: str = 'collector') -> None:
1919
"""
2020
You can pass in either an env instance, or a config to create an env instance:
2121
- An env instance: Parameter `env` must not be `None`, but should be the instance.
@@ -25,6 +25,7 @@ def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True)
2525
self._raw_env = env
2626
self._cfg = cfg
2727
self._seed_api = seed_api # some env may disable `env.seed` api
28+
self._caller = caller
2829
if self._cfg is None:
2930
self._cfg = dict()
3031
self._cfg = EasyDict(self._cfg)
@@ -37,7 +38,7 @@ def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True)
3738
if env is not None:
3839
self._init_flag = True
3940
self._env = env
40-
self._wrap_env()
41+
self._wrap_env(caller)
4142
self._observation_space = self._env.observation_space
4243
self._action_space = self._env.action_space
4344
self._action_space.seed(0) # default seed
@@ -57,7 +58,7 @@ def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True)
5758
def reset(self) -> None:
5859
if not self._init_flag:
5960
self._env = gym.make(self._cfg.env_id)
60-
self._wrap_env()
61+
self._wrap_env(self._caller)
6162
self._observation_space = self._env.observation_space
6263
self._action_space = self._env.action_space
6364
self._reward_space = gym.spaces.Box(
@@ -149,11 +150,11 @@ def random_action(self) -> np.ndarray:
149150
)
150151
return random_action
151152

152-
def _wrap_env(self) -> None:
153+
def _wrap_env(self, caller: str = 'collector') -> None:
153154
# wrapper_cfgs: Union[str, List]
154155
wrapper_cfgs = self._cfg.env_wrapper
155156
if isinstance(wrapper_cfgs, str):
156-
wrapper_cfgs = get_default_wrappers(wrapper_cfgs, self._cfg.env_id)
157+
wrapper_cfgs = get_default_wrappers(wrapper_cfgs, self._cfg.env_id, caller)
157158
# self._wrapper_cfgs: List[Union[Callable, Dict]]
158159
self._wrapper_cfgs = wrapper_cfgs
159160
for wrapper in self._wrapper_cfgs:
@@ -197,12 +198,12 @@ def action_space(self) -> gym.spaces.Space:
197198
def reward_space(self) -> gym.spaces.Space:
198199
return self._reward_space
199200

200-
def clone(self) -> BaseEnv:
201+
def clone(self, caller: str = 'collector') -> BaseEnv:
201202
try:
202203
spec = copy.deepcopy(self._raw_env.spec)
203204
raw_env = CloudPickleWrapper(self._raw_env)
204205
raw_env = copy.deepcopy(raw_env).data
205206
raw_env.__setattr__('spec', spec)
206207
except Exception:
207208
raw_env = self._raw_env
208-
return DingEnvWrapper(raw_env, self._cfg, self._seed_api)
209+
return DingEnvWrapper(raw_env, self._cfg, self._seed_api, caller)

0 commit comments

Comments
 (0)