Skip to content

Commit

Permalink
Aux loss is now optional. Confif fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Sep 10, 2024
1 parent 80ee4df commit 39e0ff8
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 21 deletions.
4 changes: 2 additions & 2 deletions rl_games/configs/maniskill/maniskill_pickcube_vision.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ params:
concat_output: True

config:
name: PickCube_RGB_resnet18_LSTM_norm_embedding_64envs_auxloss
name: PickCube_RGB_resnet18_LSTM_norm_embedding_128envs_2e-4_linear_lr_first_layer_retrain
env_name: maniskill
reward_shaper:
scale_value: 1.0
Expand All @@ -72,7 +72,7 @@ params:
scale_value: 1.0
gamma: 0.99
tau : 0.95
learning_rate: 1e-4
learning_rate: 2e-4
lr_schedule: linear
kl_threshold: 0.008
max_epochs: 20000
Expand Down
7 changes: 4 additions & 3 deletions rl_games/envs/maniskill.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ def observation(self, observation: Dict):
# print("Observation:", observation.keys())
# for key, value in observation.items():
# print(key, value.keys())
aux_target = observation['extra']['aux_target']
del observation['extra']['aux_target']
if self.aux_loss:
aux_target = observation['extra']['aux_target']
del observation['extra']['aux_target']
# print("Input Obs:", observation.keys())
# print("Input Obs Agent:", observation['agent'].keys())
# print("Input Obs Extra:", observation['extra'].keys())
Expand Down Expand Up @@ -109,7 +110,7 @@ def __init__(self, config_name, num_envs, **kwargs):

# an observation type and space, see https://maniskill.readthedocs.io/en/latest/user_guide/concepts/observation.html for details
self.obs_mode = kwargs.pop('obs_mode', 'state') # can be one of ['pointcloud', 'rgbd', 'state_dict', 'state']
self.aux_loss = kwargs.pop('aux_loss', True)
self.aux_loss = kwargs.pop('aux_loss', False)

# a controller type / action space, see https://maniskill.readthedocs.io/en/latest/user_guide/concepts/controllers.html for a full list
# can be one of ['pd_ee_delta_pose', 'pd_ee_delta_pos', 'pd_joint_delta_pos', 'arm_pd_joint_pos_vel']
Expand Down
41 changes: 25 additions & 16 deletions rl_games/networks/vision_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,14 @@ def __init__(self, params, **kwargs):
}

self.mlp = self._build_mlp(**mlp_args)

self.aux_loss_linear = nn.Linear(out_size, self.target_shape)

self.aux_loss_map = {
'aux_dist_loss': None
}
# TODO: implement for Impala
self.aux_loss_map = None
if self.use_aux_loss:
self.aux_loss_linear = nn.Linear(out_size, self.target_shape)
self.aux_loss_map = {
'aux_dist_loss': None
}

self.value = self._build_value_layer(out_size, self.value_size)
self.value_act = self.activations_factory.create(self.value_activation)
Expand Down Expand Up @@ -283,9 +285,13 @@ def __init__(self, params, **kwargs):

print('full_input_shape: ', full_input_shape)

self.target_key = 'aux_target'
self.target_shape = full_input_shape[self.target_key]
print("Target shape: ", self.target_shape)
self.use_aux_loss = kwargs.pop('use_aux_loss', False)

if self.use_aux_loss:
self.target_key = 'aux_target'
if 'aux_target' in full_input_shape:
self.target_shape = full_input_shape[self.target_key]
print("Target shape: ", self.target_shape)

print("Observations shape: ", full_input_shape)

Expand Down Expand Up @@ -341,11 +347,12 @@ def __init__(self, params, **kwargs):

self.mlp = self._build_mlp(**mlp_args)

self.aux_loss_linear = nn.Linear(out_size, self.target_shape[0])

self.aux_loss_map = {
'aux_dist_loss': None
}
self.aux_loss_map = None
if self.use_aux_loss:
self.aux_loss_linear = nn.Linear(out_size, self.target_shape)
self.aux_loss_map = {
'aux_dist_loss': None
}

self.value = self._build_value_layer(out_size, self.value_size)
self.value_act = self.activations_factory.create(self.value_activation)
Expand Down Expand Up @@ -392,7 +399,8 @@ def forward(self, obs_dict):
else:
obs = obs_dict['obs']

target_obs = obs_dict['obs'][self.target_key]
if self.use_aux_loss:
target_obs = obs_dict['obs'][self.target_key]

# print('obs.min(): ', obs.min())
# print('obs.max(): ', obs.max())
Expand Down Expand Up @@ -452,8 +460,9 @@ def forward(self, obs_dict):

value = self.value_act(self.value(out))

y = self.aux_loss_linear(out)
self.aux_loss_map['aux_dist_loss'] = torch.nn.functional.mse_loss(y, target_obs)
if self.use_aux_loss:
y = self.aux_loss_linear(out)
self.aux_loss_map['aux_dist_loss'] = torch.nn.functional.mse_loss(y, target_obs)

if self.is_discrete:
logits = self.logits(out)
Expand Down

0 comments on commit 39e0ff8

Please sign in to comment.