Skip to content

Commit

Permalink
Fixed observation dict.
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Aug 21, 2024
1 parent 61d1bc4 commit c1eeeba
Showing 1 changed file with 4 additions and 9 deletions.
13 changes: 4 additions & 9 deletions rl_games/algos_torch/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,29 +1181,24 @@ def __init__(self, params, **kwargs):

def forward(self, obs_dict):
if self.proprio_size > 0:
obs = obs_dict['camera']
proprio = obs_dict['proprio']
obs = obs_dict['obs']['camera']
proprio = obs_dict['obs']['proprio']
else:
obs = obs_dict['obs']

# print('obs.shape: ', obs.shape)
if self.permute_input:
obs = obs.permute((0, 3, 1, 2))

if self.preprocess_image:
obs = preprocess_image(obs)

# Assuming your input image is a tensor or PIL image, resize it to 224x224
#obs = self.resize_transform(obs)

dones = obs_dict.get('dones', None)
bptt_len = obs_dict.get('bptt_len', 0)
states = obs_dict.get('rnn_states', None)

out = obs
out = self.cnn(out)
out = out.flatten(1)

out = self.flatten_act(out)

if self.proprio_size > 0:
Expand Down Expand Up @@ -1298,9 +1293,9 @@ def _build_backbone(self, input_shape, backbone_params):

# TODO: add low-res parameter
backbone.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=3, stride=1, padding=1, bias=False)
#backbone.maxpool = nn.Identity()
# backbone.maxpool = nn.Identity()
# if input_shape[0] != 3:
# model.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=7, stride=2, padding=3, bias=False)
# backbone.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=7, stride=2, padding=3, bias=False)
# Remove the fully connected layer
backbone_output_size = backbone.fc.in_features
print('backbone_output_size: ', backbone_output_size)
Expand Down

0 comments on commit c1eeeba

Please sign in to comment.