Skip to content

Commit

Permalink
fix complex obs demo for ppo pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
zjowowen committed Mar 28, 2024
1 parent aeb4c9c commit 88fd50f
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 1 deletion.
7 changes: 6 additions & 1 deletion ding/example/ppo_with_complex_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@
cuda=True,
action_space='discrete',
model=dict(
obs_shape=None,
obs_shape=dict(
key_0=dict(k1=(), k2=()),
key_1=(5, 10),
key_2=(10, 10, 3),
key_3=(2, ),
),
action_shape=2,
action_space='discrete',
critic_head_hidden_size=138,
Expand Down
3 changes: 3 additions & 0 deletions ding/framework/middleware/functional/advantage_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ding.rl_utils import gae, gae_data, get_train_sample
from ding.framework import task
from ding.utils.data import ttorch_collate
from ding.utils.dict_helper import convert_easy_dict_to_dict
from ding.torch_utils import to_device

if TYPE_CHECKING:
Expand All @@ -33,9 +34,11 @@ def gae_estimator(cfg: EasyDict, policy: Policy, buffer_: Optional[Buffer] = Non
# Unify the shape of obs and action
obs_shape = cfg['policy']['model']['obs_shape']
obs_shape = torch.Size(torch.tensor(obs_shape)) if isinstance(obs_shape, list) \
else ttorch.size.Size(convert_easy_dict_to_dict(obs_shape)) if isinstance(obs_shape, dict) \
else torch.Size(torch.tensor(obs_shape).unsqueeze(0))
action_shape = cfg['policy']['model']['action_shape']
action_shape = torch.Size(torch.tensor(action_shape)) if isinstance(action_shape, list) \
else ttorch.size.Size(action_shape) if isinstance(action_shape, dict) \
else torch.Size(torch.tensor(action_shape).unsqueeze(0))

def _gae(ctx: "OnlineRLContext"):
Expand Down
1 change: 1 addition & 0 deletions ding/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
LimitedSpaceContainer, deep_merge_dicts, set_pkg_seed, flatten_dict, one_time_warning, split_data_generator, \
RunningMeanStd, make_key_as_identifier, remove_illegal_item
from .design_helper import SingletonMetaclass
from .dict_helper import convert_easy_dict_to_dict
from .file_helper import read_file, save_file, remove_file
from .import_helper import try_import_ceph, try_import_mc, try_import_link, import_module, try_import_redis, \
try_import_rediscluster
Expand Down

0 comments on commit 88fd50f

Please sign in to comment.