Skip to content

Commit

Permalink
Fixed impala network.
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Sep 11, 2024
1 parent 39e0ff8 commit ee19a1c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 22 deletions.
6 changes: 3 additions & 3 deletions rl_games/envs/maniskill.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
import gymnasium.spaces.utils
from gymnasium.vector.utils import batch_space

from mani_skill.envs.sapien_env import BaseEnv
from mani_skill.utils import common


VecEnvObs = Dict[str, torch.Tensor | Dict[str, torch.Tensor]]

Expand Down Expand Up @@ -53,6 +50,9 @@ class RlgFlattenRGBDObservationWrapper(gym2.ObservationWrapper):
"""

def __init__(self, env, rgb=True, depth=False, state=True, aux_loss=False) -> None:
from mani_skill.envs.sapien_env import BaseEnv
from mani_skill.utils import common

self.base_env: BaseEnv = env.unwrapped
self.aux_loss = aux_loss
super().__init__(env)
Expand Down
39 changes: 20 additions & 19 deletions rl_games/networks/vision_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@ class Network(NetworkBuilder.BaseNetwork):
def __init__(self, params, **kwargs):
self.actions_num = actions_num = kwargs.pop('actions_num')
full_input_shape = kwargs.pop('input_shape')
proprio_size = 0 # Number of proprioceptive features
self.use_aux_loss = kwargs.pop('use_aux_loss', False)

self.proprio_size = 0 # Number of proprioceptive features
if type(full_input_shape) is dict:
input_shape = full_input_shape['camera']
proprio_shape = full_input_shape['proprio']

proprio_size = proprio_shape[0]
self.proprio_size = proprio_shape[0]
else:
input_shape = full_input_shape

self.normalize_emb = kwargs.pop('normalize_emb', False)

self.num_seqs = kwargs.pop('num_seqs', 1)
self.value_size = kwargs.pop('value_size', 1)

Expand All @@ -40,7 +40,7 @@ def __init__(self, params, **kwargs):
self.cnn = self._build_impala(input_shape, self.conv_depths)
cnn_output_size = self._calc_input_size(input_shape, self.cnn)

mlp_input_size = cnn_output_size + proprio_size
mlp_input_size = cnn_output_size + self.proprio_size
if len(self.units) == 0:
out_size = cnn_output_size
else:
Expand Down Expand Up @@ -71,7 +71,6 @@ def __init__(self, params, **kwargs):

self.mlp = self._build_mlp(**mlp_args)

# TODO: implement for Impala
self.aux_loss_map = None
if self.use_aux_loss:
self.aux_loss_linear = nn.Linear(out_size, self.target_shape)
Expand Down Expand Up @@ -129,9 +128,15 @@ def get_aux_loss(self):
return self.aux_loss_map

def forward(self, obs_dict):
obs = obs_dict['obs']['camera']
proprio = obs_dict['obs']['proprio']
target_obs = obs[self.target_key]
if self.proprio_size > 0:
obs = obs_dict['obs']['camera']
proprio = obs_dict['obs']['proprio']
else:
obs = obs_dict['obs']

if self.use_aux_loss:
target_obs = obs_dict['obs'][self.target_key]

if self.permute_input:
obs = obs.permute((0, 3, 1, 2))

Expand All @@ -144,7 +149,8 @@ def forward(self, obs_dict):
out = out.flatten(1)
out = self.flatten_act(out)

out = torch.cat([out, proprio], dim=1)
if self.proprio_size > 0:
out = torch.cat([out, proprio], dim=1)
out = self.layer_norm_emb(out)

if self.has_rnn:
Expand Down Expand Up @@ -181,8 +187,9 @@ def forward(self, obs_dict):

value = self.value_act(self.value(out))

y = self.aux_loss_linear(out)
self.aux_loss_map['aux_dist_loss'] = torch.nn.functional.mse_loss(y, target_obs)
if self.use_aux_loss:
y = self.aux_loss_linear(out)
self.aux_loss_map['aux_dist_loss'] = torch.nn.functional.mse_loss(y, target_obs)

if self.is_discrete:
logits = self.logits(out)
Expand Down Expand Up @@ -283,17 +290,15 @@ def __init__(self, params, **kwargs):
self.actions_num = kwargs.pop('actions_num')
full_input_shape = kwargs.pop('input_shape')

print('full_input_shape: ', full_input_shape)

self.use_aux_loss = kwargs.pop('use_aux_loss', False)

if self.use_aux_loss:
self.target_key = 'aux_target'
if 'aux_target' in full_input_shape:
self.target_shape = full_input_shape[self.target_key]
print("Target shape: ", self.target_shape)

print("Observations shape: ", full_input_shape)
print("Use aux loss: ", self.use_aux_loss)

self.proprio_size = 0 # Number of proprioceptive features
if isinstance(full_input_shape, dict):
Expand Down Expand Up @@ -402,10 +407,6 @@ def forward(self, obs_dict):
if self.use_aux_loss:
target_obs = obs_dict['obs'][self.target_key]

# print('obs.min(): ', obs.min())
# print('obs.max(): ', obs.max())
# print('obs.shape: ', obs.shape)

if self.permute_input:
obs = obs.permute((0, 3, 1, 2))

Expand Down

0 comments on commit ee19a1c

Please sign in to comment.