Skip to content

Commit

Permalink
Fixed input dictinary observation shapes for resnet network. Fixed amp.
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Aug 17, 2024
1 parent ec4e4f1 commit 04d653a
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 27 deletions.
2 changes: 1 addition & 1 deletion rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def calc_gradients(self, input_dict):
if self.zero_rnn_on_done:
batch_dict['dones'] = input_dict['dones']

with torch.cuda.amp.autocast(enabled=self.mixed_precision):
with torch.amp.autocast("cuda", enabled=self.mixed_precision):
res_dict = self.model(batch_dict)
action_log_probs = res_dict['prev_neglogp']
values = res_dict['values']
Expand Down
2 changes: 1 addition & 1 deletion rl_games/algos_torch/a2c_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def calc_gradients(self, input_dict):
if self.zero_rnn_on_done:
batch_dict['dones'] = input_dict['dones']

with torch.cuda.amp.autocast(enabled=self.mixed_precision):
with torch.amp.autocast("cuda", enabled=self.mixed_precision):
res_dict = self.model(batch_dict)
action_log_probs = res_dict['prev_neglogp']
values = res_dict['values']
Expand Down
72 changes: 47 additions & 25 deletions rl_games/networks/vision_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@ def load(self, params):

class Network(NetworkBuilder.BaseNetwork):
def __init__(self, params, **kwargs):
self.actions_num = actions_num = kwargs.pop('actions_num')
input_shape = kwargs.pop('input_shape')
print('input_shape:', input_shape)
if type(input_shape) is dict:
input_shape = input_shape['camera']
proprio_shape = input_shape['proprio']
self.actions_num = kwargs.pop('actions_num')
full_input_shape = kwargs.pop('input_shape')
proprio_size = 0 # Number of proprioceptive features
if type(full_input_shape) is dict:
input_shape = full_input_shape['camera']
proprio_shape = full_input_shape['proprio']
proprio_size = proprio_shape[0]
else:
input_shape = full_input_shape

self.num_seqs = kwargs.pop('num_seqs', 1)
self.value_size = kwargs.pop('value_size', 1)
Expand All @@ -32,7 +35,6 @@ def __init__(self, params, **kwargs):

self.cnn = self._build_impala(input_shape, self.conv_depths)
cnn_output_size = self._calc_input_size(input_shape, self.cnn)
proprio_size = proprio_shape[0] # Number of proprioceptive features

mlp_input_size = cnn_output_size + proprio_size
if len(self.units) == 0:
Expand Down Expand Up @@ -66,18 +68,18 @@ def __init__(self, params, **kwargs):
self.flatten_act = self.activations_factory.create(self.activation)

if self.is_discrete:
self.logits = torch.nn.Linear(out_size, actions_num)
self.logits = torch.nn.Linear(out_size, self.actions_num)
if self.is_continuous:
self.mu = torch.nn.Linear(out_size, actions_num)
self.mu = torch.nn.Linear(out_size, self.actions_num)
self.mu_act = self.activations_factory.create(self.space_config['mu_activation'])
mu_init = self.init_factory.create(**self.space_config['mu_init'])
self.sigma_act = self.activations_factory.create(self.space_config['sigma_activation'])
sigma_init = self.init_factory.create(**self.space_config['sigma_init'])

if self.fixed_sigma:
self.sigma = nn.Parameter(torch.zeros(actions_num, requires_grad=True, dtype=torch.float32), requires_grad=True)
self.sigma = nn.Parameter(torch.zeros(self.actions_num, requires_grad=True, dtype=torch.float32), requires_grad=True)
else:
self.sigma = torch.nn.Linear(out_size, actions_num)
self.sigma = torch.nn.Linear(out_size, self.actions_num)

mlp_init = self.init_factory.create(**self.initializer)

Expand All @@ -97,13 +99,14 @@ def __init__(self, params, **kwargs):
else:
sigma_init(self.sigma.weight)

mlp_init(self.value.weight)
mlp_init(self.value.weight)

def forward(self, obs_dict):
# for key in obs_dict:
# print(key)
obs = obs_dict['camera']
proprio = obs_dict['proprio']
# print(obs_dict.keys())
# print(obs_dict['obs'].keys())
# currently works only dictinary of camera and proprio observations
obs = obs_dict['obs']['camera']
proprio = obs_dict['obs']['proprio']
if self.permute_input:
obs = obs.permute((0, 3, 1, 2))

Expand All @@ -119,7 +122,9 @@ def forward(self, obs_dict):
out = torch.cat([out, proprio], dim=1)

if self.has_rnn:
seq_length = obs_dict['seq_length']
# TODO: Double check, it's not lways present!!!
#seq_length = obs_dict['seq_length']
seq_length = obs_dict.get('seq_length', 1)

out_in = out
if not self.is_rnn_before_mlp:
Expand Down Expand Up @@ -179,7 +184,21 @@ def load(self, params):
self.space_config = params['space']['continuous']
self.fixed_sigma = self.space_config['fixed_sigma']
elif self.is_discrete:
self.space_config = params['sA2CVisionBuildernv_depths']
self.space_config = params['space']['discrete']
elif self.is_multi_discrete:
self.space_config = params['space']['multi_discrete']

self.has_rnn = 'rnn' in params
if self.has_rnn:
self.rnn_units = params['rnn']['units']
self.rnn_layers = params['rnn']['layers']
self.rnn_name = params['rnn']['name']
self.is_rnn_before_mlp = params['rnn'].get('before_mlp', False)
self.rnn_ln = params['rnn'].get('layer_norm', False)

self.has_cnn = True
self.permute_input = params['cnn'].get('permute_input', True)
self.conv_depths = params['cnn']['conv_depths']
self.require_rewards = params.get('require_rewards')
self.require_last_actions = params.get('require_last_actions')

Expand All @@ -200,10 +219,10 @@ def is_rnn(self):
def get_default_rnn_state(self):
num_layers = self.rnn_layers
if self.rnn_name == 'lstm':
return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)),
return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)),
torch.zeros((num_layers, self.num_seqs, self.rnn_units)))
else:
return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)))
return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)))

def build(self, name, **kwargs):
net = A2CVisionBuilder.Network(self.params, **kwargs)
Expand All @@ -220,11 +239,14 @@ def load(self, params):
class Network(NetworkBuilder.BaseNetwork):
def __init__(self, params, **kwargs):
self.actions_num = kwargs.pop('actions_num')
input_shape = kwargs.pop('input_shape')
print('input_shape:', input_shape)
if isinstance(input_shape, dict):
input_shape = input_shape['camera']
proprio_shape = input_shape['proprio']
full_input_shape = kwargs.pop('input_shape')
proprio_size = 0 # Number of proprioceptive features
if type(full_input_shape) is dict:
input_shape = full_input_shape['camera']
proprio_shape = full_input_shape['proprio']
proprio_size = proprio_shape[0]
else:
input_shape = full_input_shape

self.num_seqs = kwargs.pop('num_seqs', 1)
self.value_size = kwargs.pop('value_size', 1)
Expand Down

0 comments on commit 04d653a

Please sign in to comment.